// Copyright (C) 2024-2025  ilobilo

#pragma once

#include_next <memory>
#include <smart_ptr>
#include <new>

namespace std
{
    template<typename Alloc>
    struct allocator_traits;

    namespace detail
    {
        template<typename Type>
        struct make_unique_t { using single_object = unique_ptr<Type>; };

        template<typename Type>
        struct make_unique_t<Type[]> { using array = unique_ptr<Type[]>; };

        template<typename Type, std::size_t Bound>
        struct make_unique_t<Type[Bound]> { struct invalid_type { }; };

        template<typename Type>
        using unique_ptr_t = typename make_unique_t<Type>::single_object;

        template<typename Type>
        using unique_ptr_array_t = typename make_unique_t<Type>::array;

        template<typename Type>
        using invalid_make_unique_t = typename make_unique_t<Type>::invalid_type;
    } // namespace detail

    template<typename Type, typename ...Args>
    inline constexpr detail::unique_ptr_t<Type> make_unique(Args &&...args)
    {
        return unique_ptr<Type> { new Type(forward<Args>(args)...) };
    }

    template<typename Type>
    inline constexpr detail::unique_ptr_array_t<Type> make_unique(std::size_t num)
    {
        return unique_ptr<Type> { new remove_extent_t<Type>[num]() };
    }

    template<typename Type, typename ...Args>
    inline constexpr detail::invalid_make_unique_t<Type> make_unique(Args &&...) = delete;

    template<typename Allocator>
    struct deallocator
    {
        using value_type = typename Allocator::value_type;
        Allocator _alloc;
        std::size_t _size;

        constexpr deallocator() noexcept = default;
        constexpr explicit deallocator(const Allocator &alloc, std::size_t size)
            : _alloc { alloc }, _size { size } { }

        constexpr void operator()(value_type *ptr)
        {
            using Traits = allocator_traits<Allocator>;
            for (std::ptrdiff_t i = _size - 1; i >= 0; --i)
                Traits::destroy(_alloc, ptr + i);

            Traits::deallocate(_alloc, ptr, _size);
        }
    };

    template<typename Allocator, typename U>
    using Rebind = typename allocator_traits<Allocator>::template rebind_alloc<U>;

    template<typename Type, typename Allocator, typename ...Args> requires (!is_array_v<Type>)
    inline constexpr unique_ptr<Type, deallocator<Rebind<Allocator, Type>>> allocate_unique(const Allocator &allocator, Args &&...args)
    {
        using Alloc = Rebind<Allocator, Type>;
        using Traits = allocator_traits<Alloc>;

        Alloc alloc { allocator };
        Type *ptr = Traits::allocate(alloc, 1);
        Traits::construct(alloc, ptr, forward<Args>(args)...);

        deallocator<Alloc> dealloc { alloc, 1 };
        return unique_ptr<Type, deallocator<Alloc>> { ptr, dealloc };
    }

    template<typename Type, typename Allocator> requires (is_array_v<Type> && extent_v<Type> == 0)
    inline constexpr unique_ptr<Type, deallocator<Rebind<Allocator, remove_extent_t<Type>>>> allocate_unique(const Allocator &allocator, std::size_t size)
    {
        using Elem = remove_extent_t<Type>;
        using Alloc = Rebind<Allocator, Elem>;
        using Traits = allocator_traits<Alloc>;

        Alloc alloc { allocator };
        Elem *ptr = Traits::allocate(alloc, size);
        for (std::size_t i = 0; i < size; ++i)
            Traits::construct(alloc, ptr + i);

        deallocator<Alloc> dealloc { alloc, size };
        return unique_ptr<Type, deallocator<Alloc>> { ptr, dealloc };
    }

    template<typename Type>
    struct allocator
    {
        using value_type = Type;
        using pointer = Type *;
        using const_pointer = const Type *;
        using reference = Type &;
        using const_reference = const Type &;
        using size_type = std::size_t;
        using difference_type = std::ptrdiff_t;

        constexpr allocator() noexcept = default;

        template<typename U>
        constexpr allocator(const allocator<U> &other) noexcept { }

        [[nodiscard]] constexpr Type *allocate(std::size_t count)
        {
            return static_cast<Type *>(::operator new(count * sizeof(Type)));
        }

        constexpr void deallocate(Type *ptr, std::size_t count)
        {
            ::operator delete(ptr, count * sizeof(Type));
        }

        friend constexpr bool operator==(const allocator &, const allocator &) noexcept
        {
            return true;
        }
    };
} // namespace std