/*
 * Copyright 2015-present ScyllaDB
 */

/*
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#define BOOST_TEST_MODULE core

#include <boost/test/unit_test.hpp>
#include <vector>
#include <random>
#include <ranges>

#include "utils/dynamic_bitset.hh"

BOOST_AUTO_TEST_CASE(test_set_clear_test) {
    utils::dynamic_bitset bits(178);
    for (size_t i = 0; i < 178; i++) {
        BOOST_REQUIRE(!bits.test(i));
    }

    for (size_t i = 0; i < 178; i += 2) {
        bits.set(i);
    }

    for (size_t i = 0; i < 178; i++) {
        if (i % 2) {
            BOOST_REQUIRE(!bits.test(i));
        } else {
            BOOST_REQUIRE(bits.test(i));
        }
    }

    for (size_t i = 0; i < 178; i += 4) {
        bits.clear(i);
    }

    for (size_t i = 0; i < 178; i++) {
        if (i % 2 || i % 4 == 0) {
            BOOST_REQUIRE(!bits.test(i));
        } else {
            BOOST_REQUIRE(bits.test(i));
        }
    }
}

BOOST_AUTO_TEST_CASE(test_find_first_next) {
    utils::dynamic_bitset bits(178);
    for (size_t i = 0; i < 178; i++) {
        BOOST_REQUIRE(!bits.test(i));
    }
    BOOST_REQUIRE_EQUAL(bits.find_first_set(), utils::dynamic_bitset::npos);

    for (size_t i = 0; i < 178; i += 2) {
        bits.set(i);
    }

    size_t i = bits.find_first_set();
    BOOST_REQUIRE_EQUAL(i, 0);
    do {
        auto j = bits.find_next_set(i);
        BOOST_REQUIRE_EQUAL(i + 2, j);
        i = j;
    } while (i < 176);
    BOOST_REQUIRE_EQUAL(bits.find_next_set(i), utils::dynamic_bitset::npos);

    for (size_t i = 0; i < 178; i += 4) {
        bits.clear(i);
    }

    i = bits.find_first_set();
    BOOST_REQUIRE_EQUAL(i, 2);
    do {
        auto j = bits.find_next_set(i);
        BOOST_REQUIRE_EQUAL(i + 4, j);
        i = j;
    } while (i < 174);
    BOOST_REQUIRE_EQUAL(bits.find_next_set(i), utils::dynamic_bitset::npos);

}

BOOST_AUTO_TEST_CASE(test_find_last_prev) {
    utils::dynamic_bitset bits(178);
    for (size_t i = 0; i < 178; i++) {
        BOOST_REQUIRE(!bits.test(i));
    }
    BOOST_REQUIRE_EQUAL(bits.find_last_set(), utils::dynamic_bitset::npos);

    for (size_t i = 0; i < 178; i += 2) {
        bits.set(i);
    }

    size_t i = bits.find_last_set();
    BOOST_REQUIRE_EQUAL(i, 176);

    for (size_t i = 0; i < 178; i += 4) {
        bits.clear(i);
    }

    i = bits.find_last_set();
    BOOST_REQUIRE_EQUAL(i, 174);
}

static void test_random_ops(size_t size, std::default_random_engine& re ) {
    // BOOST_REQUIRE and friends are very slow, just use regular throws instead.
    auto require = [] (bool b) {
        if (!b) {
            throw 0;
        }
    };
    auto require_equal = [&] (const auto& a, const auto& b) {
        require(a == b);
    };

    utils::dynamic_bitset db{size};
    std::vector<bool> bv(size, false);
    std::uniform_int_distribution<size_t> global_op_dist(0, size-1);
    std::uniform_int_distribution<size_t> bit_dist(0, size-1);
    std::uniform_int_distribution<int> global_op_selection_dist(0, 1);
    std::uniform_int_distribution<int> single_op_selection_dist(0, 5);
    auto is_set = [&] (size_t i) -> bool {
        return bv[i];
    };
    size_t limit = std::log(size) * 1000;
    for (size_t i = 0; i != limit; ++i) {
        if (global_op_dist(re) == 0) {
            // perform a global operation
            switch (global_op_selection_dist(re)) {
            case 0:
                for (size_t j = 0; j != size; ++j) {
                    db.clear(j);
                    bv[j] = false;
                }
                break;
            case 1:
                for (size_t j = 0; j != size; ++j) {
                    db.set(j);
                    bv[j] = true;
                }
                break;
            }
        } else {
            // perform a single-bit operation
            switch (single_op_selection_dist(re)) {
            case 0: {
                auto bit = bit_dist(re);
                db.set(bit);
                bv[bit] = true;
                break;
            }
            case 1: {
                auto bit = bit_dist(re);
                db.clear(bit);
                bv[bit] = false;
                break;
            }
            case 2: {
                auto bit = bit_dist(re);
                bool dbb = db.test(bit);
                bool bvb = bv[bit];
                require_equal(dbb, bvb);
                break;
            }
            case 3: {
                auto bit = bit_dist(re);
                auto next = db.find_next_set(bit);
                if (next == db.npos) {
                    require(!std::ranges::any_of(std::views::iota(bit+1, size), is_set));
                } else {
                    require(!std::ranges::any_of(std::views::iota(bit+1, next), is_set));
                    require(is_set(next));
                }
                break;            }
            case 4: {
                auto next = db.find_first_set();
                if (next == db.npos) {
                    require(!std::ranges::any_of(std::views::iota(0u, size), is_set));
                } else {
                    require(!std::ranges::any_of(std::views::iota(0u, next), is_set));
                    require(is_set(next));
                }
                break;
            }
            case 5: {
                auto next = db.find_last_set();
                if (next == db.npos) {
                    require(!std::ranges::any_of(std::views::iota(0u, size), is_set));
                } else {
                    require(!std::ranges::any_of(std::views::iota(next + 1, size), is_set));
                    require(is_set(next));
                }
                break;
            }
            }
        }
    }
}


BOOST_AUTO_TEST_CASE(test_random_operations) {
    std::random_device rd;
    std::default_random_engine re(rd());
    for (auto size : { 1, 63, 64, 65, 2000, 4096-65, 4096-64, 4096-63, 4096-1, 4096, 4096+1, 262144-1, 262144, 262144+1}) {
        BOOST_CHECK_NO_THROW(test_random_ops(size, re));
    }
}
