// Copyright (C) 2024-2025  ilobilo

#pragma once

#include_next <functional>
#include <utility>
#include <cstddef>
#include <cassert>

namespace std
{
    template<typename Func>
    class function;

    template<typename Ret, typename ...Args>
    class function<Ret (Args...)>
    {
        private:
        static inline constexpr std::size_t stack_size = sizeof(void *) * 2;

        class callable
        {
            public:
            virtual ~callable() = default;
            virtual Ret invoke(Args &&...) = 0;
        };

        template<typename Func>
        class [[gnu::aligned(16)]] callablefunc : public callable
        {
            public:
            callablefunc(const Func &t) : t_ { t } { }
            ~callablefunc() override = default;

            Ret invoke(Args &&...args) override
            {
                return t_(std::forward<Args>(args)...);
            }

            private:
            Func t_;
        };

        union {
            std::byte _stack[stack_size];
            callable *_heap;
        };
        enum class st_type { none, stack, heap };
        st_type _st_type = st_type::none;

        callable *get_callable()
        {
            assert(this->_st_type != st_type::none);
            return this->_st_type == st_type::stack ? const_cast<callable *>(reinterpret_cast<const callable *>(_stack)) : _heap;
        }

        public:
        function() : _heap { nullptr }, _st_type { st_type::none } { };

        template<typename Func>
        function(Func t)
        {
            if (sizeof(Func) <= stack_size)
            {
                _st_type = st_type::stack;
                new (_stack) callablefunc<Func>(t);
            }
            else
            {
                _st_type = st_type::heap;
                _heap = new callablefunc<Func>(t);
            }
        }

        ~function() { clear(); }

        void clear()
        {
            if (_st_type == st_type::heap && _heap != nullptr)
                delete _heap;

            _heap = nullptr;
            _st_type = st_type::none;
        }

        template<typename Func>
        function &operator=(Func t)
        {
            if (_heap != nullptr)
            {
                delete _heap;
                _heap = nullptr;
            }

            if (sizeof(Func) <= stack_size)
            {
                _st_type = st_type::stack;
                new (_stack) callablefunc<Func>(t);
            }
            else
            {
                _st_type = st_type::heap;
                _heap = new callablefunc<Func>(t);
            }

            return *this;
        }

        Ret operator()(Args ...args)
        {
            assert(_st_type != st_type::none);
            return get_callable()->invoke(std::forward<Args>(args)...);
        }

        explicit operator bool() const
        {
            return _st_type != st_type::none;
        }
    };
} // namespace std