// Copyright (C) 2024-2025  ilobilo

#pragma once

#include <type_traits>
#include <cstddef>
#include <cstdint>
#include <limits>

// Based on code from the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

#pragma GCC system_header

namespace std
{
    namespace detail
    {
        template<typename Sseq, typename Engine>
        struct is_seed_sequence
        {
            static constexpr const bool value = !is_convertible<Sseq, typename Engine::result_type>::value && !is_same<remove_cv_t<Sseq>, Engine>::value;
        };

        template<typename UType, UType X, size_t R>
        struct log2_imp;

        template<unsigned long long X, size_t R>
        struct log2_imp<unsigned long long, X, R>
        {
            static const size_t value = X & ((unsigned long long)(1) << R) ? R : log2_imp<unsigned long long, X, R - 1>::value;
        };

        template<unsigned long long X>
        struct log2_imp<unsigned long long, X, 0>
        {
            static const size_t value = 0;
        };

        template<size_t R>
        struct log2_imp<unsigned long long, 0, R>
        {
            static const size_t value = R + 1;
        };

        template<__uint128_t X, size_t R>
        struct log2_imp<__uint128_t, X, R>
        {
            static const size_t value = (X >> 64) ? (64 + log2_imp<unsigned long long, (X >> 64), 63>::value) : log2_imp<unsigned long long, X, 63>::value;
        };

        template<typename UType, UType X>
        struct log2
        {
            static const size_t value = log2_imp<typename conditional<sizeof(UType) <= sizeof(unsigned long long), unsigned long long, __uint128_t>::type, X, sizeof(UType) * __CHAR_BIT__ - 1>::value;
        };
    } // namespace detail

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    class mersenne_twister_engine;

    template<typename Type, size_t WP, size_t NP, size_t MP, size_t RP,
        Type AP, size_t UP, Type DP, size_t SP,
        Type BP, size_t TP, Type CP, size_t LP, Type FP>
    bool operator==(const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &x,
                    const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &y);

    template<typename Type, size_t WP, size_t NP, size_t MP, size_t RP,
        Type AP, size_t UP, Type DP, size_t SP,
        Type BP, size_t TP, Type CP, size_t LP, Type FP>
    bool operator!=(const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &x,
                    const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &y);

    template<typename Type, size_t w, size_t n, size_t m, size_t r, Type a, size_t u, Type d, size_t s, Type b, size_t t, Type c, size_t l, Type f>
    class mersenne_twister_engine
    {
        public:
        // types
        using result_type = Type;

        private:
        result_type _x[n];
        size_t _i;

        static_assert(0 < m, "mersenne_twister_engine invalid parameters");
        static_assert(m <= n, "mersenne_twister_engine invalid parameters");
        static constexpr const result_type DT = numeric_limits<result_type>::digits;
        static_assert(w <= DT, "mersenne_twister_engine invalid parameters");
        static_assert(2 <= w, "mersenne_twister_engine invalid parameters");
        static_assert(r <= w, "mersenne_twister_engine invalid parameters");
        static_assert(u <= w, "mersenne_twister_engine invalid parameters");
        static_assert(s <= w, "mersenne_twister_engine invalid parameters");
        static_assert(t <= w, "mersenne_twister_engine invalid parameters");
        static_assert(l <= w, "mersenne_twister_engine invalid parameters");

        public:
        static constexpr const result_type Min = 0;
        static constexpr const result_type Max = w == DT ? result_type(~0) : (result_type(1) << w) - result_type(1);
        static_assert(Min < Max, "mersenne_twister_engine invalid parameters");
        static_assert(a <= Max, "mersenne_twister_engine invalid parameters");
        static_assert(b <= Max, "mersenne_twister_engine invalid parameters");
        static_assert(c <= Max, "mersenne_twister_engine invalid parameters");
        static_assert(d <= Max, "mersenne_twister_engine invalid parameters");
        static_assert(f <= Max, "mersenne_twister_engine invalid parameters");

        // engine characteristics
        static constexpr const size_t word_size = w;
        static constexpr const size_t state_size = n;
        static constexpr const size_t shift_size = m;
        static constexpr const size_t mask_bits = r;
        static constexpr const result_type xor_mask = a;
        static constexpr const size_t tempering_u = u;
        static constexpr const result_type tempering_d = d;
        static constexpr const size_t tempering_s = s;
        static constexpr const result_type tempering_b = b;
        static constexpr const size_t tempering_t = t;
        static constexpr const result_type tempering_c = c;
        static constexpr const size_t tempering_l = l;
        static constexpr const result_type initialization_multiplier = f;

        static constexpr result_type min() { return Min; }
        static constexpr result_type max() { return Max; }
        static constexpr const result_type default_seed = 5489u;

        // constructors and seeding functions
        mersenne_twister_engine() : mersenne_twister_engine(default_seed) { }

        explicit mersenne_twister_engine(result_type sd)
        {
            seed(sd);
        }

        template<typename Sseq>
        explicit mersenne_twister_engine(Sseq &q, typename enable_if<detail::is_seed_sequence<Sseq, mersenne_twister_engine>::value>::type * = 0)
        {
            seed(q);
        }
        void seed(result_type sd = default_seed);
        template<typename Sseq>
        typename enable_if<detail::is_seed_sequence<Sseq, mersenne_twister_engine>::value, void>::type seed(Sseq &q)
        {
            seed(q, integral_constant<unsigned, 1 + (w - 1) / 32>());
        }

        result_type operator()();

        void discard(unsigned long long z)
        {
            for (; z; --z)
                operator()();
        }

        template<typename UType, size_t WP, size_t NP, size_t MP, size_t RP,
            UType AP, size_t UP, UType DP,  size_t SP,
            UType BP, size_t TP, UType CP,size_t LP, UType FP>
        friend bool operator==(const mersenne_twister_engine<UType, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &x,
                            const mersenne_twister_engine<UType, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &y);

        template<typename UType, size_t WP, size_t NP, size_t MP, size_t RP,
            UType AP, size_t UP, UType DP, size_t SP,
            UType BP, size_t TP, UType CP, size_t LP, UType FP>
        friend bool operator!=(const mersenne_twister_engine<UType, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &x,
                            const mersenne_twister_engine<UType, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &y);

        private:
        template<typename Sseq>
        void seed(Sseq &q, integral_constant<unsigned, 1>);
        template<typename Sseq>
        void seed(Sseq &q, integral_constant<unsigned, 2>);

        template<size_t count>
        static typename enable_if<count<w, result_type>::type lshift(result_type x)
        {
            return (x << count) & Max;
        }

        template<size_t count>
        static typename enable_if<(count >= w), result_type>::type lshift(result_type)
        {
            return result_type(0);
        }

        template<size_t count>
        static typename enable_if<count<DT, result_type>::type rshift(result_type x)
        {
            return x >> count;
        }

        template<size_t count>
        static typename enable_if<(count >= DT), result_type>::type rshift(result_type)
        {
            return result_type(0);
        }
    };

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::word_size;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::state_size;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::shift_size;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::mask_bits;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const typename mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::result_type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::xor_mask;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_u;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const typename mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::result_type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_d;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_s;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const typename mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::result_type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_b;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_t;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const typename mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::result_type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_c;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const size_t mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::tempering_l;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const typename mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::result_type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::initialization_multiplier;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    constexpr const typename mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::result_type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::default_seed;

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    void mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::seed(result_type sd) __attribute__((__no_sanitize__("unsigned-integer-overflow")))
    { // w >= 2
        _x[0] = sd & Max;
        for (size_t i = 1; i < n; ++i)
            _x[i] = (f * (_x[i - 1] ^ rshift<w - 2>(_x[i - 1])) + i) & Max;
        _i = 0;
    }

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    template<typename Sseq>
    void mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::seed(Sseq &q, integral_constant<unsigned, 1>)
    {
        const unsigned k = 1;
        uint32_t ar[n * k];
        q.generate(ar, ar + n * k);
        for (size_t i = 0; i < n; ++i)
            _x[i] = static_cast<result_type>(ar[i] & Max);
        const result_type mask = r == DT ? result_type(~0) : (result_type(1) << r) - result_type(1);
        _i = 0;
        if ((_x[0] & ~mask) == 0)
        {
            for (size_t i = 1; i < n; ++i)
                if (_x[i] != 0)
                    return;
            _x[0] = result_type(1) << (w - 1);
        }
    }

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    template<typename Sseq>
    void mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::seed(Sseq &q, integral_constant<unsigned, 2>)
    {
        const unsigned k = 2;
        uint32_t ar[n * k];
        q.generate(ar, ar + n * k);
        for (size_t i = 0; i < n; ++i)
            _x[i] = static_cast<result_type>((ar[2 * i] + ((uint64_t)ar[2 * i + 1] << 32)) & Max);
        const result_type mask = r == DT ? result_type(~0) : (result_type(1) << r) - result_type(1);
        _i = 0;
        if ((_x[0] & ~mask) == 0)
        {
            for (size_t i = 1; i < n; ++i)
                if (_x[i] != 0)
                    return;
            _x[0] = result_type(1) << (w - 1);
        }
    }

    template<typename Type, size_t w, size_t n, size_t m, size_t r,
        Type a, size_t u, Type d, size_t s,
        Type b, size_t t, Type c, size_t l, Type f>
    Type mersenne_twister_engine<Type, w, n, m, r, a, u, d, s, b, t, c, l, f>::operator()()
    {
        const size_t j = (_i + 1) % n;
        const result_type mask = r == DT ? result_type(~0) : (result_type(1) << r) - result_type(1);
        const result_type YP = (_x[_i] & ~mask) | (_x[j] & mask);
        const size_t k = (_i + m) % n;
        _x[_i] = _x[k] ^ rshift<1>(YP) ^ (a * (YP & 1));
        result_type z = _x[_i] ^ (rshift<u>(_x[_i]) & d);
        _i = j;
        z ^= lshift<s>(z) & b;
        z ^= lshift<t>(z) & c;
        return z ^ rshift<l>(z);
    }

    template<typename Type, size_t WP, size_t NP, size_t MP, size_t RP,
        Type AP, size_t UP, Type DP, size_t SP,
        Type BP, size_t TP, Type CP, size_t LP, Type FP>
    bool operator==(const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &x,
                    const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &y)
    {
        if (x._i == y._i)
            return equal(x._x, x._x + NP, y._x);
        if (x._i == 0 || y._i == 0)
        {
            size_t j = min(NP - x._i, NP - y._i);
            if (!equal(x._x + x._i, x._x + x._i + j, y._x + y._i))
                return false;
            if (x._i == 0)
                return equal(x._x + j, x._x + NP, y._x);
            return equal(x._x, x._x + (NP - j), y._x + j);
        }
        if (x._i < y._i)
        {
            size_t j = NP - y._i;
            if (!equal(x._x + x._i, x._x + (x._i + j), y._x + y._i))
                return false;
            if (!equal(x._x + (x._i + j), x._x + NP, y._x))
                return false;
            return equal(x._x, x._x + x._i, y._x + (NP - (x._i + j)));
        }
        size_t j = NP - x._i;
        if (!equal(y._x + y._i, y._x + (y._i + j), x._x + x._i))
            return false;
        if (!equal(y._x + (y._i + j), y._x + NP, x._x))
            return false;
        return equal(y._x, y._x + y._i, x._x + (NP - (y._i + j)));
    }

    template<typename Type, size_t WP, size_t NP, size_t MP, size_t RP,
        Type AP, size_t UP, Type DP, size_t SP,
        Type BP, size_t TP, Type CP, size_t LP, Type FP>
    inline bool operator!=(const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &x,
                           const mersenne_twister_engine<Type, WP, NP, MP, RP, AP, UP, DP, SP, BP, TP, CP, LP, FP> &y)
    {
        return !(x == y);
    }

    using mt19937 = mersenne_twister_engine<uint_fast32_t, 32, 624, 397, 31,
        0x9908B0DF, 11, 0xFFFFFFFF,
        7, 0x9D2C5680,
        15, 0xEFC60000,
        18, 1812433253>;
    using mt19937_64 = mersenne_twister_engine<uint_fast64_t, 64, 312, 156, 31,
        0xB5026F5AA96619E9ULL, 29, 0x5555555555555555ULL,
        17, 0x71D67FFFEDA60000ULL,
        37, 0xFFF7EEE000000000ULL,
        43, 6364136223846793005ULL>;
} // namespace std