"""Core database classes."""

from contextlib import asynccontextmanager
from typing import Any, Callable

from motor.motor_asyncio import (
    AsyncIOMotorClientSession,
    AsyncIOMotorDatabase,
)
from pymongo import ReturnDocument
from pymongo.errors import DuplicateKeyError

from virtool.mongo.identifier import AbstractIdProvider
from virtool.mongo.utils import id_exists
from virtool.types import Document, Projection


class Collection:
    """A wrapper for a Motor collection.

    Wraps collection methods that modify the database and automatically dispatches
    websocket messages to inform clients of the changes.

    """

    def __init__(self, mongo: "Mongo", name: str):
        self.mongo = mongo
        self.name = name
        self._collection = mongo.motor_database[name]

        self.aggregate = self._collection.aggregate
        self.bulk_write = self._collection.bulk_write
        self.count_documents = self._collection.count_documents
        self.create_index = self._collection.create_index
        self.delete_many = self._collection.delete_many
        self.delete_one = self._collection.delete_one
        self.distinct = self._collection.distinct
        self.drop_index = self._collection.drop_index
        self.drop_indexes = self._collection.drop_indexes
        self.find_one = self._collection.find_one
        self.find = self._collection.find
        self.rename = self._collection.rename
        self.replace_one = self._collection.replace_one
        self.update_many = self._collection.update_many
        self.update_one = self._collection.update_one

    async def find_one_and_update(
        self,
        query: dict,
        update: dict,
        projection: Projection | None = None,
        upsert: bool = False,
        session: AsyncIOMotorClientSession | None = None,
    ) -> Document | None:
        """Update a document and return the updated result.

        :param query: a MongoDB query used to select the documents to update
        :param update: a MongoDB update
        :param projection: a projection to apply to the document instead of the default
        :param upsert: insert a new document if no existing document is found
        :param session: an optional Motor session to use
        :return: the updated document

        """
        document = await self._collection.find_one_and_update(
            query,
            update,
            projection=projection,
            return_document=ReturnDocument.AFTER,
            upsert=upsert,
            session=session,
        )

        if document:
            return document

        return None

    async def insert_one(
        self,
        document: Document,
        session: AsyncIOMotorClientSession | None = None,
    ) -> Document:
        """Insert the `document`.

        If no `_id` is included in the `document`, one will be autogenerated. If a
        provided `_id` already exists, a :class:`DuplicateKeyError` will be raised.

        :param document: the document to insert
        :param session: an optional Motor session to use
        :return: the inserted document

        """
        if "_id" in document:
            await self._collection.insert_one(document, session=session)
            inserted = document
        else:
            document_id = self.mongo.id_provider.get()

            if await id_exists(self, document_id, session):
                inserted = await self.insert_one(document, session=session)
            else:
                inserted = {**document, "_id": document_id}
                await self._collection.insert_one(inserted, session=session)

        return inserted

    async def insert_many(
        self,
        documents: list[Document],
        session: AsyncIOMotorClientSession,
    ):
        inserted = await self._bulk_set_document_ids(documents, session=session)

        await self._collection.insert_many(inserted, session=session)

        return inserted

    async def _bulk_set_document_ids(
        self,
        documents: list[Document],
        session: AsyncIOMotorClientSession = None,
    ) -> list[Document]:
        """Set the `_id` field for each document in ``documents`` that does not already
        have one.

        If a document already has an `_id`, it will be left unchanged. If an `_id`
        provided in ``documents`` is already in use, an ``DuplicateKeyError`` will be
        raised.

        :param documents: the documents to set `_id` fields for
        :param session: an optional Motor session to use
        :return: the documents with `_id` fields set
        """
        ids_already_exists = any("_id" in document for document in documents)

        id_documents = [
            {**document, "_id": document["_id"] or self.mongo.id_provider.get()}
            for document in documents
        ]

        if await self.count_documents(
            {"_id": {"in": [document["_id"] for document in id_documents]}},
            session=session,
        ):
            if ids_already_exists:
                raise DuplicateKeyError

            return await self._bulk_set_document_ids(documents)

        return id_documents


class Mongo:
    def __init__(
        self,
        motor_database: AsyncIOMotorDatabase,
        id_provider: AbstractIdProvider,
    ):
        self.motor_database = motor_database
        self.start_session = motor_database.start_session
        self.id_provider = id_provider

        self.analyses = self.bind_collection("analyses")
        self.files = self.bind_collection("files")
        self.groups = self.bind_collection("groups")
        self.history = self.bind_collection("history")
        self.hmm = self.bind_collection("hmm")
        self.indexes = self.bind_collection("indexes")
        self.jobs = self.bind_collection("jobs")
        self.keys = self.bind_collection("keys")
        self.labels = self.bind_collection("labels")
        self.migrations = self.bind_collection("migrations")
        self.otus = self.bind_collection("otus")
        self.tasks = self.bind_collection("tasks")
        self.references = self.bind_collection("references")
        self.samples = self.bind_collection("samples")
        self.settings = self.bind_collection("settings")
        self.sequences = self.bind_collection("sequences")
        self.sessions = self.bind_collection("sessions")
        self.status = self.bind_collection("status")
        self.subtraction = self.bind_collection("subtraction")
        self.users = self.bind_collection("users")

    def bind_collection(self, name: str) -> Collection:
        return Collection(self, name)

    @asynccontextmanager
    async def create_session(self):
        async with await (
            self.motor_database.client.start_session()
        ) as s, s.start_transaction():
            yield s

    async def with_transaction(self, func: Callable) -> Any:
        """Run the passed async function in a MongoDB transaction."""
        async with await self.motor_database.client.start_session() as s:
            return await s.with_transaction(func)
