#!/usr/bin python3
""" Face Filterer for extraction in faceswap.py """

import logging

from lib.faces_detect import DetectedFace
from lib.logger import get_loglevel
from lib.vgg_face import VGGFace
from lib.utils import cv2_read_img
from plugins.extract.pipeline import Extractor

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


def avg(arr):
    """ Return an average """
    return sum(arr) * 1.0 / len(arr)


class FaceFilter():
    """ Face filter for extraction
        NB: we take only first face, so the reference file should only contain one face. """

    def __init__(self, reference_file_paths, nreference_file_paths, detector, aligner, loglevel,
                 multiprocess=False, threshold=0.4):
        logger.debug("Initializing %s: (reference_file_paths: %s, nreference_file_paths: %s, "
                     "detector: %s, aligner: %s. loglevel: %s, multiprocess: %s, threshold: %s)",
                     self.__class__.__name__, reference_file_paths, nreference_file_paths,
                     detector, aligner, loglevel, multiprocess, threshold)
        self.numeric_loglevel = get_loglevel(loglevel)
        self.vgg_face = VGGFace()
        self.filters = self.load_images(reference_file_paths, nreference_file_paths)
        self.align_faces(detector, aligner, loglevel, multiprocess)
        self.get_filter_encodings()
        self.threshold = threshold
        logger.debug("Initialized %s", self.__class__.__name__)

    @staticmethod
    def load_images(reference_file_paths, nreference_file_paths):
        """ Load the images """
        retval = dict()
        for fpath in reference_file_paths:
            retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
                             "type": "filter"}
        for fpath in nreference_file_paths:
            retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
                             "type": "nfilter"}
        logger.debug("Loaded filter images: %s", {k: v["type"] for k, v in retval.items()})
        return retval

    # Extraction pipeline
    def align_faces(self, detector_name, aligner_name, loglevel, multiprocess):
        """ Use the requested detectors to retrieve landmarks for filter images """
        extractor = Extractor(detector_name, aligner_name, loglevel, multiprocess=multiprocess)
        self.run_extractor(extractor)
        del extractor
        self.load_aligned_face()

    def run_extractor(self, extractor):
        """ Run extractor to get faces """
        exception = False
        for _ in range(extractor.passes):
            self.queue_images(extractor)
            if exception:
                break
            extractor.launch()
            for faces in extractor.detected_faces():
                exception = faces.get("exception", False)
                if exception:
                    break
                filename = faces["filename"]
                detected_faces = faces["detected_faces"]

                if len(detected_faces) > 1:
                    logger.warning("Multiple faces found in %s file: '%s'. Using first detected "
                                   "face.", self.filters[filename]["type"], filename)
                    detected_faces = [detected_faces[0]]
                self.filters[filename]["detected_faces"] = detected_faces

                # Aligner output
                if extractor.final_pass:
                    landmarks = faces["landmarks"]
                    self.filters[filename]["landmarks"] = landmarks

    def queue_images(self, extractor):
        """ queue images for detection and alignment """
        in_queue = extractor.input_queue
        for fname, img in self.filters.items():
            logger.debug("Adding to filter queue: '%s' (%s)", fname, img["type"])
            feed_dict = dict(filename=fname, image=img["image"])
            if img.get("detected_faces", None):
                feed_dict["detected_faces"] = img["detected_faces"]
            logger.debug("Queueing filename: '%s' items: %s",
                         fname, list(feed_dict.keys()))
            in_queue.put(feed_dict)
        logger.debug("Sending EOF to filter queue")
        in_queue.put("EOF")

    def load_aligned_face(self):
        """ Align the faces for vgg_face input """
        for filename, face in self.filters.items():
            logger.debug("Loading aligned face: '%s'", filename)
            bounding_box = face["detected_faces"][0]
            image = face["image"]
            landmarks = face["landmarks"][0]

            detected_face = DetectedFace()
            detected_face.from_bounding_box_dict(bounding_box, image)
            detected_face.landmarksXY = landmarks
            detected_face.load_aligned(image, size=224)
            face["face"] = detected_face.aligned_face
            del face["image"]
            logger.debug("Loaded aligned face: ('%s', shape: %s)",
                         filename, face["face"].shape)

    def get_filter_encodings(self):
        """ Return filter face encodings from Keras VGG Face """
        for filename, face in self.filters.items():
            logger.debug("Getting encodings for: '%s'", filename)
            encodings = self.vgg_face.predict(face["face"])
            logger.debug("Filter Filename: %s, encoding shape: %s", filename, encodings.shape)
            face["encoding"] = encodings
            del face["face"]

    def check(self, detected_face):
        """ Check the extracted Face """
        logger.trace("Checking face with FaceFilter")
        distances = {"filter": list(), "nfilter": list()}
        encodings = self.vgg_face.predict(detected_face.aligned_face)
        for filt in self.filters.values():
            similarity = self.vgg_face.find_cosine_similiarity(filt["encoding"], encodings)
            distances[filt["type"]].append(similarity)

        avgs = {key: avg(val) if val else None for key, val in distances.items()}
        mins = {key: min(val) if val else None for key, val in distances.items()}
        # Filter
        if distances["filter"] and avgs["filter"] > self.threshold:
            msg = "Rejecting filter face: {} > {}".format(round(avgs["filter"], 2), self.threshold)
            retval = False
        # nFilter no Filter
        elif not distances["filter"] and avgs["nfilter"] < self.threshold:
            msg = "Rejecting nFilter face: {} < {}".format(round(avgs["nfilter"], 2),
                                                           self.threshold)
            retval = False
        # Filter with nFilter
        elif distances["filter"] and distances["nfilter"] and mins["filter"] > mins["nfilter"]:
            msg = ("Rejecting face as distance from nfilter sample is smaller: (filter: {}, "
                   "nfilter: {})".format(round(mins["filter"], 2), round(mins["nfilter"], 2)))
            retval = False
        elif distances["filter"] and distances["nfilter"] and avgs["filter"] > avgs["nfilter"]:
            msg = ("Rejecting face as average distance from nfilter sample is smaller: (filter: "
                   "{}, nfilter: {})".format(round(mins["filter"], 2), round(mins["nfilter"], 2)))
            retval = False
        elif distances["filter"] and distances["nfilter"]:
            # k-nn classifier
            var_k = min(5, min(len(distances["filter"]), len(distances["nfilter"])) + 1)
            var_n = sum(list(map(lambda x: x[0],
                                 list(sorted([(1, d) for d in distances["filter"]] +
                                             [(0, d) for d in distances["nfilter"]],
                                             key=lambda x: x[1]))[:var_k])))
            ratio = var_n/var_k
            if ratio < 0.5:
                msg = ("Rejecting face as k-nearest neighbors classification is less than "
                       "0.5: {}".format(round(ratio, 2)))
                retval = False
            else:
                msg = None
                retval = True
        else:
            msg = None
            retval = True
        if msg:
            logger.verbose(msg)
        else:
            logger.trace("Accepted face: (similarity: %s, threshold: %s)",
                         distances, self.threshold)
        return retval
