import aiohttp
import aiofiles
import asyncio
import hashlib
import os
import re
import traceback
from .Log import Log



class ModelsTypes:
    Checkpoint = "Checkpoint"
    TextualInversion = "TextualInversion"
    Hypernetwork = "Hypernetwork"
    AestheticGradient = "AestheticGradient"
    LORA = "LORA"
    LoCon = "LoCon"
    Controlnet = "Controlnet"
    Poses = "Poses"


class ModelsSort:
    HighestRated = "Highest Rated"
    MostDownloaded = "Most Downloaded"
    Newest = "Newest"


class ModelsPeriod:
    AllTime = "AllTime"
    Year = "Year"
    Month = "Month"
    Week = "Week"
    Day = "Day"



class Civitai():
    token = None
    types = []
    favorites=False


    def __init__(self, token):
        self.token = token



    def set_filters(self, types=[], favorites=False):
        self.types = types
        self.favorites = favorites



    async def get_all(self, wait_time = 1):
        out_list = []
        page = 1
        cursor = ""
        while True:
            data = await self.get_page_and_next(page, cursor)
            await asyncio.sleep(wait_time)
            if not data["models"] or data["cursor"] is None:
                break
            page += 1
            cursor = data["cursor"]
            out_list.extend(data["models"])
        return out_list



    async def get_page_and_next(self, page, cursor, separate_model_request = False):
        Log.info("Civit AI", f"Load page {page}")

        out_list = []
        models = await self.__get_models_json(favorites=self.favorites, types=self.types, cursor=cursor)

        Log.ok("Civit AI", f"Page Loaded ({len(models['items'])} models)")
        
        for model in models["items"]:
            try:
                if separate_model_request:
                    out_list.append(await self.get_model_data(model["id"]))
                else:
                    out_list.append(self.__convert_json_model_data(model))
            except Exception as err:
                Log.exception("Civit AI", f"Load model info error. Model: https://civitai.com/models/{model['id']}", err, traceback.format_exc())

        return { 
            "models": out_list, 
            "cursor": models.get("metadata", {}).get("nextCursor")
        }



    async def download_model(self, id, save_dir, name, hash = None):
        model = await self.get_model_data(id)

        Log.info("Civit AI", f"Start download {id} model")
        response = await aiohttp.ClientSession().get(model["download"], params={'token': self.token})
        format =  self.__get_extension_from_response(response)

        if format is None:
            Log.error("Civit AI", f"Download model {id} ({name}). Unable to download - unknown file extension")
            raise Exception(f"Unknown file extension. Headers: {response.headers}")
        
        if name == "":
            Log.error("Civit AI", f"Download model id: {id}. Unable to download - empty file name")
            raise Exception(f"Unknown file extension. Headers: {response.headers}")
        
        file_path = os.path.join(save_dir, f'{name}.{format}')

        f = await aiofiles.open(file_path, mode = "wb")
        async for chunk in response.content.iter_chunked(1024 * 1024):
            await f.write(chunk)
        await f.close()

        # TODO: Check hash

        Log.ok("Civit AI", f"Model {id} ({name}) successfully downloaded")



    async def download_model_from_url(self, url, save_dir, name, hash = None):
        id = re.search("models\/([0-9]+)", url)[1]
        return await self.download_model(id, save_dir, name, hash)

    

    async def get_model_data(self, id):
        Log.info("Civit AI", f"Load model info. Model ID: {id}")
        model = await self.__get_model_json(id)
        Log.ok("Civit AI", f"Model info loaded - {id} ({model['name']})")
        
        return self.__convert_json_model_data(model)
    


    def __convert_json_model_data(self, model):
        first_model = model["modelVersions"][0]
        first_model_file = next((item for item in first_model["files"] if item["type"] != "Training Data"), None)
        return {
            "url": f"https://civitai.com/models/{model['id']}",
            "tags": model["tags"],
            "id": model["id"],
            "type": model["type"],
            "triggers": first_model.get("trainedWords", ""),
            "download": first_model_file["downloadUrl"],
            "format": first_model_file["metadata"]["format"],
            "name": model["name"],
            "base_model": first_model["baseModel"],
            "version": first_model["name"],
            "images": [o["url"] for o in first_model["images"]],
            "SHA256": first_model_file["hashes"]["SHA256"]
        }



    async def __get_json(self, page, params = []):
        headers = {"Content-type": "application/json"}

        async with aiohttp.ClientSession().get(f'https://civitai.com/api/v1/{page}', params=params, headers=headers) as response:
            json_data = await response.json()

            if "error" in json_data:
                Log.error("Civit AI", f"JSON get error ({response.real_url}). Error info: {json_data['error']}")
                raise Exception()
        
        return json_data



    async def __get_models_json(self, cursor = "", types = [ModelsTypes.Checkpoint], sort = ModelsSort.Newest, period = ModelsPeriod.AllTime, query = "", username = "", favorites = False, nsfw = True, limit = 100):
        return await self.__get_json("models", {
            "cursor": cursor, 
            "types": types, 
            "sort": sort, 
            "period": period, 
            "query": query, 
            "username": username, 
            "favorites": "true" if favorites else "false", 
            "nsfw": "true" if nsfw else "false",
            "limit": limit, 
            "token": self.token
        }) 



    async def __get_model_json(self, id):
        return await self.__get_json(f"models/{id}") 
    


    async def get_model_from_url_json(self, url):
        id = re.search("models\/([0-9]+)", url)[1]
        return await self.__get_model_json(id)
    

    
    def __get_extension_from_response(self, response):
        cd = response.headers.get("content-disposition")
        if not cd:
            return None
        
        filename = re.findall("filename=\"(.+)\"", cd)
        if len(filename) == 0:
            return None
        
        return filename[0].split(".")[-1]
    


    @staticmethod
    async def sha256(filename):
        f = await aiofiles.open(filename, mode = "rb")
        bytes = await f.read()  

        # TODO: make async hash
        sha256 = hashlib.sha256(bytes).hexdigest().upper()
        await f.close()
        return sha256