#!/usr/bin/env python
# Copyright (C) 2024 OVH SAS
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from oio.common.green import eventlet_monkey_patch

eventlet_monkey_patch()

import argparse
import json
import signal
import time
from os.path import exists, join
from collections import OrderedDict
from multiprocessing.queues import Empty

from oio.api.object_storage import ObjectStorageApi
from oio.common.configuration import read_conf
from oio.common.constants import MULTIUPLOAD_SUFFIX
from oio.common.exceptions import NoSuchContainer
from oio.common.easy_value import int_value
from oio.common.green import GreenPool, LightQueue, sleep
from oio.common.kafka import KafkaSender
from oio.common.logger import get_logger
from oio.common.statsd import get_statsd
from oio.common.timestamp import Timestamp
from oio.common.utils import cid_from_name
from oio.container.sharding import ContainerSharding
from oio.event.evob import EventTypes
from oio.lifecycle.metrics import (
    LifecycleMetricTracker,
    LifecycleAction,
    LifecycleStep,
    statsd_key,
)


CHECKPOINT_TOPIC_DEFAULT = "oio-lifecycle-checkpoint"


class Aborted(Exception):
    pass


class CheckpointCollector:
    def __init__(self, conf, logger, feature, run_id):
        self._conf = conf
        self._logger = logger
        self._running = False
        self._feature_name = feature
        self._marker = None

        # Configuration
        self._concurrency = int_value(self._conf.get("concurrency"), 100)
        self._topic = self._conf.get("topic", CHECKPOINT_TOPIC_DEFAULT)

        # Threading
        self._pool = GreenPool(self._concurrency)
        self._result_queue = LightQueue()

        # Event producer
        self._kafka_producer = KafkaSender(
            self._conf.get("endpoint"),
            self._logger,
            app_conf=self._conf,
        )

        # Oio clients
        namespace = conf["namespace"]
        self._api = ObjectStorageApi(namespace, logger=logger)
        self._sharding_client = ContainerSharding(
            self._conf, logger=logger, pool_manager=self._api.container.pool_manager
        )

        # Metrics helper
        self._metrics = LifecycleMetricTracker(self._conf)

        # Statsd helpers
        self._statsd = get_statsd(self._conf)

        # Progress tracking
        self._progress_dir = self._conf.get("progress_dir", ".")
        self._run_id = run_id
        self._progress = OrderedDict()
        self._last_commited_marker = None

        self._reload_progress()

    def _reload_progress(self):
        # load marker
        self._logger.info("Retrieve progress from file: '%s'", self.progress_file)
        if exists(self.progress_file):
            with open(self.progress_file, "r", encoding="utf-8") as progress_file:
                line = progress_file.readline()
                if not line:
                    return
                parts = line.split(";")
                if len(parts) != 2:
                    return
                account, bucket = parts
                self._logger.info(
                    "Reload marker account=%s, bucket=%s", account, bucket
                )
                self._last_commited_marker = self._marker = (account, bucket)

    @property
    def progress_file(self):
        return join(self._progress_dir, f"checkpoint-collector.progress.{self._run_id}")

    @property
    def error_file(self):
        return join(self._progress_dir, f"checkpoint-collector.error.{self._run_id}")

    def _make_payload(self, account, bucket, cid, shard_info):
        shard_info = shard_info or {}
        return json.dumps(
            {
                "event": EventTypes.LIFECYCLE_CHECKPOINT,
                "when": time.time(),
                "data": {
                    "run_id": self._run_id,
                    "account": account,
                    "bucket": bucket,
                    "cid": shard_info.get("cid") or cid,
                    "root_cid": cid,
                    "bounds": {
                        "lower": shard_info.get("lower", ""),
                        "upper": shard_info.get("upper", ""),
                    },
                },
            }
        )

    def _increment_snapshot_counter(self, ctx, cid):
        if not ctx:
            return
        account, bucket = ctx
        self._statsd.incr(
            statsd_key(
                self._run_id, LifecycleStep.SUBMITTED, LifecycleAction.CHECKPOINT
            )
        )
        self._metrics.increment_counter(
            self._run_id,
            account,
            bucket,
            cid,
            LifecycleStep.SUBMITTED,
            LifecycleAction.CHECKPOINT,
        )

    def _process_container(self, account, container, ctx=None):
        cid = cid_from_name(account, container)
        # Produce event for root container
        self._produce_event(account, container, cid)
        self._increment_snapshot_counter(ctx, cid)

    def _produce_event(self, account, container, cid, shard_info=None):
        payload = self._make_payload(account, container, cid, shard_info)
        self._logger.info("Produce event %s", container)
        self._kafka_producer.send(self._topic, payload, flush=True)

    def _process_bucket(self, account, bucket):
        error = None
        self._progress[(account, bucket)] = None
        self._logger.info("Processing %s %s", account, bucket)
        ctx = (account, bucket)
        try:
            for acct_suffix in ("", MULTIUPLOAD_SUFFIX):
                _account = f"{account}{acct_suffix}"
                try:
                    self._logger.info(
                        "Get info on container: acct=%s, ref=%s", _account, bucket
                    )
                    _ = self._api.container_show(_account, bucket)
                    self._process_container(account, bucket, ctx=ctx)
                except NoSuchContainer:
                    continue
        except Exception as exc:
            self._logger.error("Failed to process bucket %s, reason: %s", bucket, exc)
            error = exc
        self._result_queue.put((account, bucket, error))

    def _fetch_buckets(self, marker=None):
        if isinstance(marker, tuple):
            marker = "|".join(marker)

        while True:
            resp = self._api.bucket.buckets_list_by_feature(
                self._feature_name, marker=marker, limit=100
            )
            for entry in resp.get("buckets", []):
                yield entry["account"], entry["bucket"]
            if not resp.get("truncated", False):
                break
            marker = resp.get("next_marker")
            if not marker:
                break

    def _compute_progress(self, account, bucket, status):
        self._progress[(account, bucket)] = status
        marker = None

        while True:
            value = None
            if self._progress:
                # Retrieve first item
                key, value = next(iter(self._progress.items()))
            if value is None:
                break
            _, _ = self._progress.popitem(last=False)
            marker = key

        if marker is not None:
            self._last_commited_marker = marker

    def _fetch_progression(self):
        last_marker = None
        with open(self.progress_file, "w", encoding="utf-8") as progress_file:
            with open(self.error_file, "w", encoding="utf-8") as error_file:
                while True:
                    try:
                        result = self._result_queue.get(timeout=1)
                    except Empty:
                        if not self._progress:
                            break
                    account, bucket, error = result
                    if error:
                        error_file.write(f"{account};{bucket};{error}")
                    self._compute_progress(account, bucket, error or True)
                    if (
                        self._last_commited_marker
                        and self._last_commited_marker != last_marker
                    ):
                        self._logger.info(
                            "Update progress file %s", self._last_commited_marker
                        )
                        account, bucket = self._last_commited_marker
                        progress_file.truncate(0)
                        progress_file.seek(0)
                        progress_file.write(f"{account};{bucket}\n")
                        last_marker = self._last_commited_marker

    def __stop(self):
        self._logger.info("Stopping")
        self._running = False

    def run(self):
        """ """
        # Install signal handlers
        signal.signal(signal.SIGINT, lambda _sig, _stack: self.__stop())
        signal.signal(signal.SIGTERM, lambda _sig, _stack: self.__stop())
        self._running = True

        task_progression = self._pool.spawn(self._fetch_progression)

        tasks = [task_progression]

        def cancel_pending_tasks():
            self._logger.warning("Aborting pending tasks")
            for task in tasks:
                task.cancel()

        self._logger.info(
            "Listing bucket with feature '%s' activated", self._feature_name
        )
        has_process_bucket = bool(self._marker)
        for account, bucket in self._fetch_buckets(self._marker):
            if not self._running:
                cancel_pending_tasks()
                break
            task = self._pool.spawn(self._process_bucket, account, bucket)
            tasks.append(task)
            has_process_bucket = True
        if not has_process_bucket:
            self._logger.warning(
                "No buckets found with feature '%s'. Ensure the feature is tracked"
                "(features_whitelist)",
                self._feature_name,
            )

        while (self._pool.running() + self._pool.waiting()) > 0:
            if not self._running:
                cancel_pending_tasks()
                break
            sleep(1)
        # Let all threads end
        self._pool.waitall()

        self._kafka_producer.close()


def make_arg_parser():
    descr = """
    Generate events to create checkpoints for lifecycle enabled containers
    """
    parser = argparse.ArgumentParser(description=descr)

    parser.add_argument(
        "--verbose", "-v", action="store_true", help="More verbose output"
    )
    parser.add_argument(
        "--run-id", help="Run identifier", default=str(Timestamp().timestamp)
    )
    parser.add_argument("configuration", help="Path to the legacy configuration file")
    return parser


def main():
    args = make_arg_parser().parse_args()
    conf = read_conf(args.configuration, "checkpoint-collector")
    logger = get_logger(conf, verbose=args.verbose)

    collector = CheckpointCollector(conf, logger, "lifecycle", args.run_id)
    collector.run()


if __name__ == "__main__":
    main()
