#!/usr/bin/env python
# Copyright (C) 2023 OVH SAS
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import os
import socket

from functools import partial

from oio.common.amqp import AmqpConnector, AMQPError, ExchangeType
from oio.common.logger import get_logger

from oio.event.beanstalk import (
    BeanstalkdListener,
)


class AmqpSender(AmqpConnector):
    def publish(self, jobid, data, exchange, routing_key, attempts=3):
        last_err = None
        for _ in range(attempts):
            try:
                self.maybe_reconnect()
                self._channel.basic_publish(
                    exchange=exchange,
                    routing_key=routing_key,
                    body=data,
                )
                break
            except AMQPError as err:
                last_err = err
                self._close_conn(True)
                self.logger.warning("Failed to publish (jobid=%s): %s", jobid, err)
        else:
            raise last_err
        # Please do not ask
        return [1]


def make_arg_parser():
    descr = """
    Read messages from Beanstalkd, forward them to RabbitMQ.
    """
    parser = argparse.ArgumentParser(description=descr)
    parser.add_argument(
        "--src-tube",
        default="oio",
        help="Name of the Beanstalkd tube to read messages from (default is 'oio')",
    )
    parser.add_argument(
        "--src-endpoint",
        default="beanstalk://127.0.0.1:6005",
        help="Endpoint of the Beanstalkd server",
    )
    parser.add_argument(
        "--declare-exchange",
        action="store_true",
        help="Declare the exchange before binding it",
    )
    parser.add_argument(
        "--declare-queue",
        action="store_true",
        help="Declare the queue before binding it",
    )
    parser.add_argument(
        "--dst-endpoint",
        default=os.environ.get(
            "OIO_RABBITMQ_ENDPOINT", "amqp://guest:guest@127.0.0.1:5672/%2F"
        ),
        help=(
            "Endpoint of the RabbitMQ server. Can be a list separated by ';' "
            "(env: OIO_RABBITMQ_ENDPOINT, "
            "default: amqp://guest:guest@127.0.0.1:5672/%%2F)"
        ),
    )
    parser.add_argument(
        "--dst-exchange",
        default="oio",
        help="Name of the RabbitMQ exchange to send messages to (default is 'oio')",
    )
    parser.add_argument(
        "--dst-queue",
        default="oio",
        help="Name of the RabbitMQ queue to send messages to (default is 'oio')",
    )
    parser.add_argument(
        "--namespace",
        "--ns",
        default=os.environ.get("OIO_NS", "OPENIO"),
        help="Namespace name",
    )
    parser.add_argument(
        "--no-unbind",
        action="store_false",
        dest="unbind",
        help="Do not unbind the queue from the exchange when quitting",
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true", help="More verbose output"
    )
    return parser


def main():
    args = make_arg_parser().parse_args()
    logger = get_logger(None, verbose=args.verbose)
    hostname = socket.gethostname()
    routing_key = f"oio-beanstalkd-to-rabbitmq-{hostname}-{os.getpid()}"
    bsl = BeanstalkdListener(args.src_endpoint, args.src_tube, logger=logger)
    rab = AmqpSender(endpoints=args.dst_endpoint, logger=logger)

    if args.declare_exchange:
        rab.maybe_reconnect()
        try:
            rab._channel.exchange_declare(
                exchange=args.dst_exchange,
                exchange_type=ExchangeType.topic,
                durable=True,
            )
        except AMQPError as err:
            # We do not fail here, we suppose it already exists with other parameters
            logger.warning("Failed to declare exchange: %s", err)

    if args.declare_queue:
        rab.maybe_reconnect()
        try:
            rab._channel.queue_declare(
                queue=args.dst_queue,
                arguments={"x-queue-type": "quorum"},
                durable=True,
            )
        except AMQPError as err:
            # We do not fail here, we suppose it already exists with other parameters
            logger.warning("Failed to declare queue: %s", err)

    rab.maybe_reconnect()
    rab._channel.queue_bind(
        exchange=args.dst_exchange,
        queue=args.dst_queue,
        routing_key=routing_key,
    )

    send = partial(rab.publish, exchange=args.dst_exchange, routing_key=routing_key)
    logger.info(
        f"Starting to copy messages from {args.src_endpoint} tube={args.src_tube} "
        f"to {args.dst_endpoint} routing_key={routing_key}"
    )
    jobs = 0
    try:
        for _job in bsl.fetch_jobs(send):
            jobs += 1
    except KeyboardInterrupt:
        logger.info("%d messages copied", jobs)
    finally:
        if args.unbind:
            rab.maybe_reconnect()
            rab._channel.queue_unbind(
                queue=args.dst_queue,
                exchange=args.dst_exchange,
                routing_key=routing_key,
            )


if __name__ == "__main__":
    main()
