/*
 * Copyright 2016-present ScyllaDB
 */

/*
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#define BOOST_TEST_MODULE core

#include "utils/assert.hh"
#include <boost/test/unit_test.hpp>

#include <seastar/util/variant_utils.hh>

#include <fmt/ranges.h>

#include <vector>
#include <optional>
#include <fmt/ranges.h>

#include "test/lib/test_utils.hh"
#include "bytes.hh"
#include "bytes_ostream.hh"

struct simple_compound {
    // TODO: change this to test for #905
    uint32_t foo;
    uint32_t bar;

    bool operator==(const simple_compound& other) const {
        return foo == other.foo && bar == other.bar;
    }
};

class non_final_composite_test_object {
    simple_compound _x;
public:
    static thread_local int construction_count;
    non_final_composite_test_object(simple_compound x) : _x(x) {
        ++construction_count;
    }
    simple_compound x() const { return _x; }
};

class final_composite_test_object {
    simple_compound _x;
public:
    static thread_local int construction_count;
    final_composite_test_object(simple_compound x) : _x(x) {
        ++construction_count;
    }
    simple_compound x() const { return _x; }
};

thread_local int non_final_composite_test_object::construction_count = 0;
thread_local int final_composite_test_object::construction_count = 0;

template <> struct fmt::formatter<simple_compound> : fmt::formatter<string_view> {
    auto format(const simple_compound& sc, fmt::format_context& ctx) const {
        return fmt::format_to(ctx.out(), " {{ foo: {}, bar: {} }}", sc.foo, sc.bar);
    }
};

std::ostream& operator<<(std::ostream& os, const simple_compound& sc)
{
    fmt::print(os, "{}", sc);
    return os;
}

struct compound_with_optional {
    std::optional<simple_compound> first;
    simple_compound second;

    bool operator==(const compound_with_optional& other) const {
        return first == other.first && second == other.second;
    }
};

std::ostream& operator<<(std::ostream& os, const compound_with_optional& v)
{
    os << " { first: ";
    if (v.first) {
        fmt::print(os, "{}", *v.first);
    } else {
        fmt::print(os, "<disengaged>");
    }
    fmt::print(os, ", second: {}}}", v.second);
    return os;
}

struct wrapped_vector {
    std::vector<simple_compound> vector;

    bool operator==(const wrapped_vector& v) const { // = default;
        return vector == v.vector;
    }
};

std::ostream& operator<<(std::ostream& os, const wrapped_vector& v)
{
    fmt::print(os, "{}", v.vector);
    return os;
}

struct vectors_of_compounds {
    std::vector<simple_compound> first;
    wrapped_vector second;
};

struct empty_struct { };

struct empty_final_struct { };

class fragment_generator {
    std::vector<bytes> data;
public:
    using fragment_type = bytes_view;
    using iterator = std::vector<bytes>::iterator;
    using const_iterator = std::vector<bytes>::const_iterator;
    fragment_generator(size_t fragment_count, size_t fragment_size) : data(fragment_count, bytes(fragment_size, 'x')) {
    }
    iterator begin() {
        return data.begin();
    }
    iterator end() {
        return data.end();
    }
    const_iterator begin() const {
        return data.begin();
    }
    const_iterator end() const {
        return data.end();
    }
    size_t size_bytes() const {
        return data.empty() ? 0 : data.size() * data.front().size();
    }
    bool empty() const {
        return data.empty();
    }
    bytes to_bytes() const {
        return data.empty() ? bytes() : bytes(data.size() * data.front().size(), 'x');
    }
};

template <typename T>
struct const_template_arg_wrapper {
    T x;

    const_template_arg_wrapper(const T& t)
        : x(t)
    {}

    bool operator == (const const_template_arg_wrapper& rhs) const {
        return x == rhs.x;
    }
};

struct const_template_arg_test_object {
    std::vector<const_template_arg_wrapper<const simple_compound>> first;

    bool operator == (const const_template_arg_test_object& rhs) const {
        return first == rhs.first;
    }
};

#include "idl/idl_test.dist.hh"
#include "idl/idl_test.dist.impl.hh"

BOOST_AUTO_TEST_CASE(test_simple_compound)
{
    simple_compound sc = { 0xdeadbeef, 0xbadc0ffe };

    bytes_ostream buf1;
    ser::serialize(buf1, sc);
    BOOST_REQUIRE_EQUAL(buf1.size(), 12);

    bytes_ostream buf2;
    ser::writer_of_writable_simple_compound<bytes_ostream> wowsc(buf2);
    std::move(wowsc).write_foo(sc.foo).write_bar(sc.bar).end_writable_simple_compound();
    BOOST_REQUIRE_EQUAL(buf1.linearize(), buf2.linearize());

    auto bv1 = buf1.linearize();
    auto in1 = ser::as_input_stream(bv1);
    auto deser_sc = ser::deserialize(in1, std::type_identity<simple_compound>());
    BOOST_REQUIRE_EQUAL(sc, deser_sc);

    auto bv2 = buf2.linearize();
    auto in2 = ser::as_input_stream(bv2);
    auto sc_view = ser::deserialize(in2, std::type_identity<ser::writable_simple_compound_view>());
    BOOST_REQUIRE_EQUAL(sc.foo, sc_view.foo());
    BOOST_REQUIRE_EQUAL(sc.bar, sc_view.bar());
}

BOOST_AUTO_TEST_CASE(test_vector)
{
    std::vector<simple_compound> vec1 = {
        { 1, 2 },
        { 3, 4 },
        { 5, 6 },
        { 7, 8 },
        { 9, 10 },
    };
    std::vector<simple_compound> vec2 = {
        { 11, 12 },
        { 13, 14 },
        { 15, 16 },
        { 17, 18 },
        { 19, 20 },
    };
    vectors_of_compounds voc = { vec1, wrapped_vector { vec2 } };

    bytes_ostream buf1;
    ser::serialize(buf1, voc);
    BOOST_REQUIRE_EQUAL(buf1.size(), 136);

    bytes_ostream buf2;
    ser::writer_of_writable_vectors_of_compounds<bytes_ostream> wowvoc(buf2);
    auto first_writer = std::move(wowvoc).start_first();
    for (auto& c : vec1) {
        first_writer.add().write_foo(c.foo).write_bar(c.bar).end_writable_simple_compound();
    }
    auto second_writer = std::move(first_writer).end_first().start_second().start_vector();
    for (auto& c : vec2) {
        second_writer.add_vector(c);
    }
    std::move(second_writer).end_vector().end_second().end_writable_vectors_of_compounds();
    BOOST_REQUIRE_EQUAL(buf1.linearize(), buf2.linearize());

    auto bv1 = buf1.linearize();
    auto in1 = ser::as_input_stream(bv1);
    auto deser_voc = ser::deserialize(in1, std::type_identity<vectors_of_compounds>());
    BOOST_REQUIRE_EQUAL(voc.first, deser_voc.first);
    BOOST_REQUIRE_EQUAL(voc.second, deser_voc.second);

    auto bv2 = buf2.linearize();
    auto in2 = ser::as_input_stream(bv2);
    auto voc_view = ser::deserialize(in2, std::type_identity<ser::writable_vectors_of_compounds_view>());

    auto first_range = voc_view.first();
    auto first_view = std::vector<ser::writable_simple_compound_view>(first_range.begin(), first_range.end());
    BOOST_REQUIRE_EQUAL(vec1.size(), first_view.size());
    for (size_t i = 0; i < first_view.size(); i++) {
        auto fv = first_view[i];
        SCYLLA_ASSERT(vec1[i].foo == fv.foo());
        BOOST_REQUIRE_EQUAL(vec1[i].foo, first_view[i].foo());
        BOOST_REQUIRE_EQUAL(vec1[i].bar, first_view[i].bar());
    }

    auto second_range = voc_view.second().vector();
    auto second_view = std::vector<simple_compound>(second_range.begin(), second_range.end());
    BOOST_REQUIRE_EQUAL(vec2.size(), second_view.size());
    for (size_t i = 0; i < second_view.size(); i++) {
        BOOST_REQUIRE_EQUAL(vec2[i], second_view[i]);
    }
}

BOOST_AUTO_TEST_CASE(test_variant)
{
    std::vector<simple_compound> vec = {
        { 1, 2 },
        { 3, 4 },
        { 5, 6 },
        { 7, 8 },
        { 9, 10 },
    };

    simple_compound sc = { 0xdeadbeef, 0xbadc0ffe };
    simple_compound sc2 = { 0x12344321, 0x56788765 };

    bytes_ostream buf;
    ser::writer_of_writable_variants<bytes_ostream> wowv(buf);
    auto second_writer = std::move(wowv).write_id(17).write_first_simple_compound(sc).start_second_writable_vector().start_vector();
    for (auto&& v : vec) {
        second_writer.add_vector(v);
    }
    auto third_writer = std::move(second_writer).end_vector().end_writable_vector().start_third_writable_final_simple_compound();
    std::move(third_writer).write_foo(sc2.foo).write_bar(sc2.bar).end_writable_final_simple_compound().end_writable_variants();
    BOOST_REQUIRE_EQUAL(buf.size(), 120);

    auto bv = buf.linearize();
    auto in = ser::as_input_stream(bv);
    auto wv_view = ser::deserialize(in, std::type_identity<ser::writable_variants_view>());
    BOOST_REQUIRE_EQUAL(wv_view.id(), 17);

    struct expect_compound : boost::static_visitor<simple_compound> {
        simple_compound operator()(ser::writable_vector_view&) const {
            throw std::runtime_error("got writable_vector, expected simple_compound");
        }
        simple_compound operator()(simple_compound& sc) const {
            return sc;
        }
        simple_compound operator()(ser::writable_final_simple_compound_view&) const {
            throw std::runtime_error("got writable_final_simple_compound, expected simple_compound");
        }
        simple_compound operator()(ser::unknown_variant_type&) const {
            throw std::runtime_error("unknown type, expected simple_compound");
        }
    };
    auto v1 = wv_view.first();
    auto&& compound = boost::apply_visitor(expect_compound(), v1);
    BOOST_REQUIRE_EQUAL(compound, sc);

    struct expect_vector : boost::static_visitor<std::vector<simple_compound>> {
        std::vector<simple_compound> operator()(ser::writable_vector_view& wvv) const {
            auto range = wvv.vector();
            return std::vector<simple_compound>(range.begin(), range.end());
        }
        std::vector<simple_compound> operator()(simple_compound&) const {
            throw std::runtime_error("got simple_compound, expected writable_vector");
        }
        std::vector<simple_compound> operator()(ser::writable_final_simple_compound_view&) const {
            throw std::runtime_error("got writable_final_simple_compound, expected writable_vector");
        }
        std::vector<simple_compound> operator()(ser::unknown_variant_type&) const {
            throw std::runtime_error("unknown type, expected writable_vector");
        }
    };

    auto v2 = wv_view.second();
    auto&& vector = boost::apply_visitor(expect_vector(), v2);
    BOOST_REQUIRE_EQUAL(vector, vec);

    struct expect_writable_compound : boost::static_visitor<simple_compound> {
        simple_compound operator()(ser::writable_vector_view&) const {
            throw std::runtime_error("got writable_vector, expected writable_final_simple_compound");
        }
        simple_compound operator()(simple_compound&) const {
            throw std::runtime_error("got simple_compound, expected writable_final_simple_compound");
        }
        simple_compound operator()(ser::writable_final_simple_compound_view& scv) const {
            return simple_compound { scv.foo(), scv.bar() };
        }
        simple_compound operator()(ser::unknown_variant_type&) const {
            throw std::runtime_error("unknown type, expected writable_final_simple_compound");
        }
    };
    auto v3 = wv_view.third();
    auto&& compound2 = boost::apply_visitor(expect_writable_compound(), v3);
    BOOST_REQUIRE_EQUAL(compound2, sc2);
}

BOOST_AUTO_TEST_CASE(test_compound_with_optional)
{
    simple_compound foo = { 0xdeadbeef, 0xbadc0ffe };
    simple_compound bar = { 0x12345678, 0x87654321 };

    compound_with_optional one = { foo, bar };

    bytes_ostream buf1;
    ser::serialize(buf1, one);
    BOOST_REQUIRE_EQUAL(buf1.size(), 29);

    auto bv1 = buf1.linearize();
    seastar::simple_input_stream in1(reinterpret_cast<const char*>(bv1.data()), bv1.size());
    auto deser_one = ser::deserialize(in1, std::type_identity<compound_with_optional>());
    BOOST_REQUIRE_EQUAL(one, deser_one);

    compound_with_optional two = { {}, foo };

    bytes_ostream buf2;
    ser::serialize(buf2, two);
    BOOST_REQUIRE_EQUAL(buf2.size(), 17);

    auto bv2 = buf2.linearize();
    seastar::simple_input_stream in2(reinterpret_cast<const char*>(bv2.data()), bv2.size());
    auto deser_two = ser::deserialize(in2, std::type_identity<compound_with_optional>());
    BOOST_REQUIRE_EQUAL(two, deser_two);
}

BOOST_AUTO_TEST_CASE(test_skip_does_not_deserialize)
{
    {
        non_final_composite_test_object x({1, 2});

        bytes_ostream buf;
        ser::serialize(buf, x);

        auto in = ser::as_input_stream(buf.linearize());
        auto prev = non_final_composite_test_object::construction_count;

        ser::skip(in, std::type_identity<non_final_composite_test_object>());

        BOOST_REQUIRE(prev == non_final_composite_test_object::construction_count);
    }

    {
        final_composite_test_object x({1, 2});

        bytes_ostream buf;
        ser::serialize(buf, x);

        auto in = ser::as_input_stream(buf.linearize());
        auto prev = final_composite_test_object::construction_count;

        ser::skip(in, std::type_identity<final_composite_test_object>());

        BOOST_REQUIRE(prev == final_composite_test_object::construction_count);
    }
}

BOOST_AUTO_TEST_CASE(test_empty_struct)
{
    bytes_ostream buf1;
    ser::serialize(buf1, empty_struct());

    auto in1 = ser::as_input_stream(buf1.linearize());
    ser::deserialize(in1, std::type_identity<empty_struct>());

    bytes_ostream buf2;
    ser::serialize(buf2, empty_final_struct());

    auto in2 = ser::as_input_stream(buf2.linearize());
    ser::deserialize(in2, std::type_identity<empty_final_struct>());
}

BOOST_AUTO_TEST_CASE(test_just_a_variant)
{
    bytes_ostream buf;
    ser::writer_of_just_a_variant(buf)
        .start_variant_writable_simple_compound()
            .write_foo(0x1234abcd)
            .write_bar(0x1111ffff)
        .end_writable_simple_compound()
    .end_just_a_variant();

    auto in = ser::as_input_stream(buf);
    auto view = ser::deserialize(in, std::type_identity<ser::just_a_variant_view>());
    bool fired = false;
    seastar::visit(view.variant(), [&] (ser::writable_simple_compound_view v) {
            fired = true;
            BOOST_CHECK_EQUAL(v.foo(), 0x1234abcd);
            BOOST_CHECK_EQUAL(v.bar(), 0x1111ffff);
        },
        [&] (simple_compound) { BOOST_FAIL("should not reach"); },
        [&] (ser::unknown_variant_type) { BOOST_FAIL("should not reach"); }
    );
    BOOST_CHECK(fired);

    buf = bytes_ostream();
    ser::writer_of_just_a_variant(buf)
        .write_variant_simple_compound(simple_compound { 0xaaaabbbb, 0xccccdddd })
    .end_just_a_variant();

    in = ser::as_input_stream(buf);
    view = ser::deserialize(in, std::type_identity<ser::just_a_variant_view>());
    fired = false;
    seastar::visit(view.variant(), [&] (simple_compound v) {
            fired = true;
            BOOST_CHECK_EQUAL(v.foo, 0xaaaabbbb);
            BOOST_CHECK_EQUAL(v.bar, 0xccccdddd);
        },
        [&] (ser::writable_simple_compound_view) { BOOST_FAIL("should not reach"); },
        [&] (ser::unknown_variant_type) { BOOST_FAIL("should not reach"); }
    );
    BOOST_CHECK(fired);
}

BOOST_AUTO_TEST_CASE(test_fragmented_write)
{
    for (auto [fragment_count, fragment_size] : {std::pair<size_t, size_t>{9, 1025}, {6, 8999}, {2, 29521}, {1, 60001}, {0, 0}}) {
        bytes_ostream buf;
        ser::serialize_fragmented(buf, fragment_generator(fragment_count, fragment_size));
        auto in = ser::as_input_stream(buf);
        bytes deserialized = ser::deserialize(in, std::type_identity<bytes>());
        BOOST_CHECK_EQUAL(deserialized, fragment_generator(fragment_count, fragment_size).to_bytes());
    }
}

BOOST_AUTO_TEST_CASE(test_const_template_arg)
{
    const_template_arg_test_object obj {
        .first = {
            simple_compound{ 0xdeadbeef, 0xbadc0ffe },
            simple_compound{ 0xbaaaaaad, 0xdeadc0de }
        }
    };

    bytes_ostream buf;
    ser::serialize(buf, obj);
    BOOST_REQUIRE_EQUAL(buf.size(), 40);

    auto in = ser::as_input_stream(buf);
    auto deser_obj = ser::deserialize(in, std::type_identity<const_template_arg_test_object>());
    BOOST_REQUIRE(obj == deser_obj);
}
