import json
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from numpy import asarray, exp, expand_dims, float32
from PIL import Image

tag_escape_pattern = re.compile(r"([\\()])")

from . import dbimutils


class Interrogator:
    @staticmethod
    def postprocess_tags(
        tags: Dict[str, float],
        threshold=0.35,
        additional_tags: List[str] = [],
        exclude_tags: List[str] = [],
        sort_by_alphabetical_order=False,
        add_confident_as_weight=False,
        replace_underscore=False,
        replace_underscore_excludes: List[str] = [],
        escape_tag=False,
    ) -> Dict[str, float]:
        for t in additional_tags:
            tags[t] = 1.0

        # those lines are totally not "pythonic" but looks better to me
        tags = {
            t: c
            # sort by tag name or confident
            for t, c in sorted(
                tags.items(),
                key=lambda i: i[0 if sort_by_alphabetical_order else 1],
                reverse=not sort_by_alphabetical_order,
            )
            # filter tags
            if (c >= threshold and t not in exclude_tags)
        }

        new_tags = []
        for tag in list(tags):
            new_tag = tag

            if replace_underscore and tag not in replace_underscore_excludes:
                new_tag = new_tag.replace("_", " ")

            if escape_tag:
                new_tag = tag_escape_pattern.sub(r"\\\1", new_tag)

            if add_confident_as_weight:
                new_tag = f"({new_tag}:{tags[tag]})"

            new_tags.append((new_tag, tags[tag]))
        tags = dict(new_tags)

        return tags

    def __init__(self, name: str) -> None:
        from onnxruntime import InferenceSession

        self.model: InferenceSession | None = None
        self.tags = None
        self.name = name
        self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]

    def load(self):
        raise NotImplementedError()

    def unload(self) -> bool:
        unloaded = False

        if hasattr(self, "model") and self.model is not None:
            del self.model
            unloaded = True
            print(f"Unloaded {self.name}")

        if hasattr(self, "tags"):
            del self.tags

        return unloaded

    def use_cpu(self) -> None:
        self.providers = ["CPUExecutionProvider"]

    def interrogate(
        self, image: Image.Image
    ) -> Tuple[
        Dict[str, float], Dict[str, float]  # rating confidents  # tag confidents
    ]:
        raise NotImplementedError()


class WaifuDiffusionInterrogator(Interrogator):
    def __init__(
        self,
        name: str,
        model_path="model.onnx",
        tags_path="selected_tags.csv",
        **kwargs,
    ) -> None:
        super().__init__(name)
        self.model_path = model_path
        self.tags_path = tags_path
        self.kwargs = kwargs

    def download(self) -> Tuple[os.PathLike, os.PathLike]:
        print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")

        model_path = Path(hf_hub_download(**self.kwargs, filename=self.model_path))
        tags_path = Path(hf_hub_download(**self.kwargs, filename=self.tags_path))
        return model_path, tags_path

    def load(self) -> None:
        model_path, tags_path = self.download()

        from onnxruntime import InferenceSession

        self.model = InferenceSession(str(model_path), providers=self.providers)

        print(f"Loaded {self.name} model from {model_path}")

        self.tags = pd.read_csv(tags_path)

    def interrogate(
        self, image: Image.Image
    ) -> Tuple[
        Dict[str, float], Dict[str, float]  # rating confidents  # tag confidents
    ]:
        # init model
        if not hasattr(self, "model") or self.model is None:
            self.load()

        # code for converting the image and running the model is taken from the link below
        # thanks, SmilingWolf!
        # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py

        # convert an image to fit the model
        assert self.model is not None
        _, height, _, _ = self.model.get_inputs()[0].shape

        # alpha to white
        image = image.convert("RGBA")
        new_image = Image.new("RGBA", image.size, "WHITE")
        new_image.paste(image, mask=image)
        image = new_image.convert("RGB")
        image_array = np.asarray(image)

        # PIL RGB to OpenCV BGR
        image_array = image_array[:, :, ::-1]

        image_array = dbimutils.make_square(image_array, height)
        image_array = dbimutils.smart_resize(image_array, height)
        image_array = image_array.astype(np.float32)
        image_array = np.expand_dims(image_array, 0)

        # evaluate model
        input_name = self.model.get_inputs()[0].name
        label_name = self.model.get_outputs()[0].name
        confidents = self.model.run([label_name], {input_name: image_array})[0]

        tags = self.tags[:][["name"]]
        tags["confidents"] = confidents[0]

        # first 4 items are for rating (general, sensitive, questionable, explicit)
        ratings = dict(tags[:4].values)

        # rest are regular tags
        tags = dict(tags[4:].values)

        return ratings, tags


class MLDanbooruInterrogator(Interrogator):
    """Interrogator for the MLDanbooru model."""

    def __init__(
        self,
        name: str,
        repo_id: str,
        model_path: str,
        tags_path="classes.json",
    ) -> None:
        super().__init__(name)
        self.model_path = model_path
        self.tags_path = tags_path
        self.repo_id = repo_id
        self.tags = None
        self.model = None

    def download(self) -> Tuple[str, str]:
        print(f"Loading {self.name} model file from {self.repo_id}")

        model_path = hf_hub_download(repo_id=self.repo_id, filename=self.model_path)
        tags_path = hf_hub_download(
            repo_id=self.repo_id,
            filename=self.tags_path,
        )
        return model_path, tags_path

    def load(self) -> None:
        model_path, tags_path = self.download()

        from onnxruntime import InferenceSession

        self.model = InferenceSession(model_path, providers=self.providers)
        print(f"Loaded {self.name} model from {model_path}")

        with open(tags_path, "r", encoding="utf-8") as filen:
            self.tags = json.load(filen)

    def interrogate(
        self, image: Image.Image
    ) -> Tuple[
        Dict[str, float], Dict[str, float]  # rating confidents  # tag confidents
    ]:
        # init model
        if self.model is None:
            self.load()

        image = dbimutils.fill_transparent(image)
        image = dbimutils.resize(image, 448)  # TODO CUSTOMIZE

        x = asarray(image, dtype=float32) / 255
        # HWC -> 1CHW
        x = x.transpose((2, 0, 1))
        x = expand_dims(x, 0)

        assert self.model is not None
        input_ = self.model.get_inputs()[0]
        output = self.model.get_outputs()[0]
        # evaluate model
        (y,) = self.model.run([output.name], {input_.name: x})

        # Softmax
        y = 1 / (1 + exp(-y))

        assert self.tags is not None
        tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())}
        return {}, tags

    def large_batch_interrogate(self, images: List, dry_run=False) -> str:
        raise NotImplementedError()
