from typing import Dict
import os
import json
import threading
from qdrant_client import QdrantClient
from qdrant_client.models import (
    VectorParams, Distance, 
    PointStruct, Filter, FieldCondition, MatchText, MatchAny
)

from qwin_qdrant_config import qdrant_config

class QwinQdrantClient:
    QWIN_DOCS_COLLECTION = 'qwin_docs'
    
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls, *args, **kwargs):
        with cls._lock:
            if not cls._instance:
                cls._instance = super(QwinQdrantClient, cls).__new__(cls)
                
        return cls._instance    
    
    def __getstate__(self):
        state = self.__dict__.copy()
        state.pop('_lock', None)
        return state
    
    def __setstate__(self, state):
        self.__dict__.update(state)
        self._lock = threading.Lock()
    
    def __init__(self):
        self.config = qdrant_config()
        self.client = QdrantClient(url=self.config['qdrant']['url'])
        
    def create_collection(self, verbose: bool = False):
        try:
            self.client.delete_collection(collection_name=self.QWIN_DOCS_COLLECTION)
        except:
            pass
        if verbose:
            print(f'Qdrant collection {self.QWIN_DOCS_COLLECTION} deleted')
            
        self.client.create_collection(collection_name=self.QWIN_DOCS_COLLECTION,
                                      vectors_config=VectorParams(size=1536, distance=Distance.COSINE))
        if verbose:
            print(f'Qdrant collection {self.QWIN_DOCS_COLLECTION} created')
            
    def load_docs(self, doc_dir: str = '/docs'):
        curr_dir = os.path.dirname(os.path.realpath(__file__))
        doc_dir = curr_dir + doc_dir
        chains = os.listdir(doc_dir)
        
        doc_id = 1
        for chain in chains:
            folder, doc_loader = chain_doc_loaders[chain]
        
            chain_doc_dir = doc_dir + folder
            files = os.listdir(chain_doc_dir)
            
            for i, file in enumerate(files, 1):
                file_path = os.path.join(chain_doc_dir, file)
                if os.path.isfile(file_path):
                    game_doc = doc_loader(file_path)
                    
                    self.client.upsert(
                        collection_name=self.QWIN_DOCS_COLLECTION,
                        wait=True,
                        points=[
                            PointStruct(id=doc_id,
                                        vector = [.0] * 1536,
                                        payload=game_doc)
                        ]
                    )
                    doc_id += 1
                    
        # self.client.create_payload_index(collection_name=self.QWIN_DOCS_COLLECTION, 
        #                                  field_name='name',
        #                                  field_schema=TextIndexParams(
        #                                      type='text',
        #                                      tokenizer=TokenizerType.WORD,
        #                                      min_token_len=2,
        #                                      max_token_len=15,
        #                                      lowercase=True
        #                                  ))
        
        # self.client.create_payload_index(collection_name=self.QWIN_DOCS_COLLECTION, 
        #                                  field_name='text',
        #                                  field_schema=TextIndexParams(
        #                                      type='text',
        #                                      tokenizer=TokenizerType.WORD,
        #                                      min_token_len=2,
        #                                      max_token_len=15,
        #                                      lowercase=True
        #                                  ))
        
    def game_list(self):
        must = [FieldCondition(key='chain', match=MatchText(text='games')),
                FieldCondition(key='code', match=MatchText(text='game_list'))]

        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )    
        
        if result:
            return result[0][0].payload['text']
        
        return None
    
    def search_game_docs(self, game_codes: [str], key_words: [str] = []):
        # must = [FieldCondition(key='text', match=MatchText(text=kw)) for kw in key_words]
        must = [FieldCondition(key='game_code', match=MatchAny(any=game_codes))]

        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )
        
        return [item.payload['text'] for item in result[0]]

    def search_games_docs(self, key_words: [str]) -> [str]:
        must = [FieldCondition(key='chain', match=MatchText(text='games'))]

        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )

        return [item.payload['text'] for item in result[0]]
    
    def search_casino_docs(self):
        must = [FieldCondition(key='chain', match=MatchText(text='casino'))]
        
        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )
        
        return [item.payload['text'] for item in result[0]]
    
    def search_finance_docs(self):
        must = [FieldCondition(key='chain', match=MatchText(text='finance'))]
        
        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )
        
        return [item.payload['text'] for item in result[0]]
    
    def search_marketing_docs(self):
        must = [FieldCondition(key='chain', match=MatchText(text='marketing'))]
        
        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )
        
        return [item.payload['text'] for item in result[0]]
    
    def search_affiliate_docs(self):
        must = [FieldCondition(key='chain', match=MatchText(text='affiliate'))]
        
        result = self.client.scroll(
            collection_name=self.QWIN_DOCS_COLLECTION,
            scroll_filter=Filter(must=must)
        )
        
        return [item.payload['text'] for item in result[0]]
    
  
def doc_header(string: str) -> Dict[str, str] | None:
    try:
        return json.loads(string)
    except:
        print(f'Incorrect json string {string}')
        return None

def load_casino_doc(file_path: str) -> Dict:
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        headers = doc_header(lines[0])

        assert headers is not None, f'parse headers, incorrect file format {file_path}'
        assert headers['chain'] is not None and headers['chain'] == 'casino', f'Cannot find chain key in headers, incorrect file format {file_path}'

        return {**headers, 'text': '.'.join(lines[1:])}

def load_finance_doc(file_path: str) -> Dict:
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        headers = doc_header(lines[0])

        assert headers is not None, f'parse headers, incorrect file format {file_path}'
        assert headers['chain'] is not None and headers[
            'chain'] == 'finance', f'Cannot find chain key in headers, incorrect file format {file_path}'

        return {**headers, 'text': '.'.join(lines[1:])}

def load_marketing_doc(file_path: str) -> Dict:
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        headers = doc_header(lines[0])

        assert headers is not None, f'parse headers, incorrect file format {file_path}'
        assert headers['chain'] is not None and headers[
            'chain'] == 'marketing', f'Cannot find chain key in headers, incorrect file format {file_path}'

        return {**headers, 'text': '.'.join(lines[1:])}
    
def load_affiliate_doc(file_path: str) -> Dict:
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        headers = doc_header(lines[0])

        assert headers is not None, f'parse headers, incorrect file format {file_path}'
        assert headers['chain'] is not None and headers[
            'chain'] == 'affiliate', f'Cannot find chain key in headers, incorrect file format {file_path}'

        return {**headers, 'text': '.'.join(lines[1:])}
    
def load_game_doc(file_path: str) -> Dict:
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        headers = doc_header(lines[0])

        assert headers is not None, f'parse headers, incorrect file format {file_path}'
        assert headers['chain'] is not None and headers[
            'chain'] == 'games', f'Cannot find chain key in headers, incorrect file format {file_path}'

        return {**headers, 'text': '.'.join(lines[1:])}
                    
                    
chain_doc_loaders = {
    'casino': ('/casino', load_casino_doc),
    'finance': ('/finance', load_finance_doc),
    'games': ('/games', load_game_doc),
    'marketing': ('/marketing', load_marketing_doc),
    'affiliate': ('/affiliate', load_affiliate_doc)
}         