/* Copyright 2022, Contributors To LensorOS.
 * All rights reserved.
 *
 * This file is part of LensorOS.
 *
 * LensorOS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * LensorOS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with LensorOS. If not, see <https://www.gnu.org/licenses
 */

#ifndef LENSOROS_MEMORY_HH
#define LENSOROS_MEMORY_HH

#include <atomic>
#include <new>
#include <type_traits>
#include <utility>

namespace std {

template <typename _T>
constexpr _T* addressof(_T& __val) noexcept {
    return reinterpret_cast<_T*>(
        &const_cast<char&>(reinterpret_cast<const volatile char&>(__val))
    );
}

template <typename _T>
const _T* addressof(const _T&&) = delete;

template <typename _T>
class unique_ptr {
    _T* __ptr{};

    void __assign(_T* __p) noexcept {
        if (__ptr != __p) {
            delete __ptr;
            __ptr = __p;
        }
    }

public:
    unique_ptr() = default;
    explicit unique_ptr(_T* __ptr) : __ptr(__ptr) {}
    unique_ptr(nullptr_t) noexcept : unique_ptr() {}

    /// Copying a unique_ptr would kind of defeat the purpose of it.
    unique_ptr(const unique_ptr&) = delete;
    unique_ptr& operator=(const unique_ptr&) = delete;

    /// We need to be able to move it tho.
    unique_ptr(unique_ptr&& __other) noexcept : __ptr(__other.__ptr) { __other.__ptr = nullptr; }
    unique_ptr& operator=(unique_ptr&& __other) noexcept {
        __assign(__other.__ptr);
        __other.__ptr = nullptr;
        return *this;
    }

    /// We can also asssign nullptr_t to this.
    unique_ptr& operator=(nullptr_t) noexcept {
        __assign(nullptr);
        return *this;
    }

    ~unique_ptr() { delete __ptr; }

    _T* get() const noexcept { return __ptr; }
    _T* release() noexcept {
        auto __tmp = __ptr;
        __ptr = nullptr;
        return __tmp;
    }
    _T& operator*() const noexcept { return *__ptr; }
    _T* operator->() const noexcept { return __ptr; }
    explicit operator bool() const noexcept { return __ptr != nullptr; }

    bool operator==(const unique_ptr& __rhs) const noexcept { return __ptr == __rhs.__ptr; }
    friend bool operator==(const unique_ptr& __rhs, const nullptr_t) noexcept { return __rhs.__ptr == nullptr; }
    friend bool operator==(const nullptr_t, const unique_ptr& __rhs) noexcept { return __rhs.__ptr == nullptr; }
};

template <typename _T, typename... _Args>
unique_ptr<_T> make_unique(_Args&&... __args) {
    return unique_ptr<_T>(new _T{std::forward<_Args>(__args)...});
}

template <typename _T> class shared_ptr;

class __shared_ptr_ctrl_block {
    template <typename _T>
    friend class shared_ptr;

    template <typename _T>
    friend struct weak_ptr;

    template <typename _T, typename... _Args>
    friend shared_ptr<_T> make_shared(_Args&&... __args);

    void *__obj = nullptr;
    void (*__deleter)(void*, bool) = [](auto, auto) {};
    std::atomic<size_t> __ref_count{1};
    std::atomic<size_t> __weak_count{1};
    bool __jointly_allocated{};

    explicit __shared_ptr_ctrl_block() noexcept = default;
    explicit __shared_ptr_ctrl_block(bool __jointly_allocated) noexcept : __jointly_allocated(__jointly_allocated) {}
    ~__shared_ptr_ctrl_block() noexcept = default;

    __shared_ptr_ctrl_block(const __shared_ptr_ctrl_block& __other) = delete;
    __shared_ptr_ctrl_block& operator=(const __shared_ptr_ctrl_block& __other) = delete;
    __shared_ptr_ctrl_block(__shared_ptr_ctrl_block&& __other) = delete;
    __shared_ptr_ctrl_block& operator=(__shared_ptr_ctrl_block&& __other) = delete;

    void __add_weak() noexcept { atomic_fetch_add(&__weak_count, size_t(1)); }

    /// Increment the reference count and returns the new value. This also
    /// increments the weak count such that the control block is not destroyed
    /// until after we have called the object's destructor.
    void __add_shared() noexcept {
        atomic_fetch_add(&__ref_count, size_t(1));
        __add_weak();
    }

    /// Decrement the weak reference count and delete the control block if it was 1.
    void __release_weak() noexcept {
        /// If we're the last weak reference, we need to delete the control block.
        auto __count = atomic_fetch_sub(&__weak_count, size_t(1));
        if (__count == 1) {
            /// If the control block was jointly allocated then we need to call
            /// our destructor before delete[]’ing the memory shared between us
            /// and the object. Otherwise, the object has already been deleted
            /// and we can just delete ourselves.
            if (__jointly_allocated) {
                this->~__shared_ptr_ctrl_block();
                delete[] reinterpret_cast<char*>(this);
            } else delete this;
        }
    }

    void __release_shared() {
        /// Decrement the refcount on the object and destroy it if it was 1.
        auto __count = atomic_fetch_sub(&__ref_count, size_t(1));
        if (__count == 1) __deleter(__obj, __jointly_allocated);

        /// Release the weak reference.
        __release_weak();
    }

    size_t __use_count() const noexcept { return atomic_load(&__ref_count); }
};

template <typename _T>
class shared_ptr {
    _T* __ptr{};
    __shared_ptr_ctrl_block* __ctrl{};

    template <typename _U, typename... _Args>
    friend shared_ptr<_U> make_shared(_Args&&... __args);

    template <typename _U>
    friend class shared_ptr;

    template <typename _U>
    friend struct weak_ptr;

    void __delete() {
        if (__ctrl) __ctrl->__release_shared();
    }

    template <typename _U>
    static auto __deleter_for() noexcept {
        return [](void* __raw_obj, bool __jointly_allocated) {
            auto __obj = reinterpret_cast<_U*>(__raw_obj);
            if (__jointly_allocated) __obj->~_U();
            else delete __obj;
        };
    }

    /// Should only ever be used internally.
    shared_ptr(_T* __ptr, __shared_ptr_ctrl_block* __ctrl) noexcept : __ptr(__ptr), __ctrl(__ctrl) {}

public:
    using element_type = _T;

    shared_ptr() noexcept = default;
    shared_ptr(nullptr_t) noexcept : shared_ptr() {}

    /// Wrap a raw pointer in a shared_ptr.
    explicit shared_ptr(_T* __ptr) : __ptr(__ptr), __ctrl(new __shared_ptr_ctrl_block{}) {
        __ctrl->__obj = __ptr;
        __ctrl->__deleter = __deleter_for<_T>();
    }

    /// Copy constructor.
    shared_ptr(const shared_ptr& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        if (__ctrl) { __ctrl->__add_shared(); }
    }

    /// Move constructor.
    shared_ptr(shared_ptr&& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
    }

    /// Aliasing constructor.
    template <typename _U>
    shared_ptr(const shared_ptr<_U>& __other, _T* __ptr) noexcept : __ptr(__ptr), __ctrl(__other.__ctrl) {
        if (__ctrl) { __ctrl->__add_shared(); }
    }

    /// Aliasing move constructor.
    template <typename _U>
    shared_ptr(shared_ptr<_U>&& __other, _T* __ptr) noexcept : __ptr(__ptr), __ctrl(__other.__ctrl) {
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
    }

    /// Copy assignment.
    shared_ptr& operator=(const shared_ptr& __other) {
        if (this == &__other) { return *this; }
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        if (__ctrl) { __ctrl->__add_shared(); }
        return *this;
    }

    /// Move assignment.
    shared_ptr& operator=(shared_ptr&& __other) {
        if (this == &__other) { return *this; }
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
        return *this;
    }

    /// Zero out the pointer.
    shared_ptr& operator=(nullptr_t) {
        __delete();
        __ptr = nullptr;
        __ctrl = nullptr;
        return *this;
    }

    /// Decrement the reference count when this is destroyed.
    ~shared_ptr() noexcept { __delete(); }

    /// Get the reference count.
    size_t use_count() const noexcept {
        if (__ctrl) return __ctrl->__use_count();
        else return 0;
    }

    _T* get() const noexcept { return __ptr; }
    _T& operator*() const noexcept { return *__ptr; }
    _T* operator->() const noexcept { return __ptr; }
    explicit operator bool() const noexcept { return __ptr; }

    bool operator==(const shared_ptr& __rhs) const noexcept { return __ptr == __rhs.__ptr; }
    friend bool operator==(const shared_ptr& __rhs, const nullptr_t) noexcept { return __rhs.__ptr == nullptr; }
    friend bool operator==(const nullptr_t, const shared_ptr& __rhs) noexcept { return __rhs.__ptr == nullptr; }
};

/// Constructs a shared_ptr with the given arguments. This function is the prefferred
/// way to construct a shared_ptr, because it allocates the refcount and the object
/// in one allocation and next to each other on the heap.
template <typename _T, typename... _Args>
shared_ptr<_T> make_shared(_Args&&... __args) {
    /// Make sure that everything is aligned properly.
    static_assert(alignof(_T) <= alignof(max_align_t));
    static_assert(alignof(__shared_ptr_ctrl_block) <= alignof(max_align_t));

    /// Calculate the offset of the object from the start of the allocation.
    static constexpr size_t __obj_offset =
        sizeof(__shared_ptr_ctrl_block)
        + alignof(max_align_t)
        - sizeof(__shared_ptr_ctrl_block) % alignof(max_align_t);
    static_assert(__obj_offset % alignof(max_align_t) == 0);

    /// Construct the refcount and the object.
    auto __ptr = new char [__obj_offset + sizeof(_T)];
    auto __ctrl = new (__ptr) __shared_ptr_ctrl_block{};
    auto __obj = new (__ptr + __obj_offset) _T(forward<_Args>(__args)...);
    __ctrl->__jointly_allocated = true;
    __ctrl->__obj = __obj;
    __ctrl->__deleter = shared_ptr<_T>::template __deleter_for<_T>();

    /// Return the shared_ptr.
    return shared_ptr<_T>(__obj, __ctrl);
}

template <typename _T>
struct weak_ptr {
    using element_type = _T;

    template <typename _U>
    friend struct weak_ptr;

private:
    _T* __ptr{};
    __shared_ptr_ctrl_block* __ctrl{};

    void __delete() {
        if (__ctrl) __ctrl->__release_weak();
    }

public:
    constexpr weak_ptr() noexcept {}

    /// Copy a weak_ptr.
    weak_ptr(const weak_ptr& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        if (__ctrl) { __ctrl->__add_weak(); }
    }

    /// Copy a weak_ptr of a different type.
    template <typename _U>
    weak_ptr(const weak_ptr<_U>& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        if (__ctrl) { __ctrl->__add_weak(); }
    }

    /// Create a weak_ptr from a shared_ptr.
    template <typename _U>
    weak_ptr(const shared_ptr<_U>& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        if (__ctrl) { __ctrl->__add_weak(); }
    }

    /// Move a weak_ptr.
    weak_ptr(weak_ptr&& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
    }

    /// Move a weak_ptr of a different type.
    template <typename _U>
    weak_ptr(weak_ptr<_U>&& __other) noexcept : __ptr(__other.__ptr), __ctrl(__other.__ctrl) {
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
    }

    /// Destroy the weak_ptr.
    ~weak_ptr() noexcept { __delete(); }

    /// Copy assignment.
    weak_ptr& operator=(const weak_ptr& __other) noexcept {
        if (this == &__other) { return *this; }
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        if (__ctrl) { __ctrl->__add_weak(); }
        return *this;
    }

    /// Move assignment.
    weak_ptr& operator=(weak_ptr&& __other) noexcept {
        if (this == &__other) { return *this; }
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
        return *this;
    }

    /// Assign from a shared_ptr.
    template <typename _U>
    weak_ptr& operator=(const shared_ptr<_U>& __other) noexcept {
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        if (__ctrl) { __ctrl->__add_weak(); }
        return *this;
    }

    /// Assign from a weak_ptr of a different type.
    template <typename _U>
    weak_ptr& operator=(const weak_ptr<_U>& __other) noexcept {
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        return *this;
    }

    /// Move a weak_ptr of a different type.
    template <typename _U>
    weak_ptr& operator=(weak_ptr<_U>&& __other) noexcept {
        __delete();
        __ptr = __other.__ptr;
        __ctrl = __other.__ctrl;
        __other.__ptr = nullptr;
        __other.__ctrl = nullptr;
        return *this;
    }

    /// Reset the weak_ptr.
    void reset() noexcept {
        __delete();
        __ptr = nullptr;
        __ctrl = nullptr;
    }

    /// Get the number of shared_ptrs that point to the same object.
    long use_count() const noexcept {
        if (__ctrl) { return __ctrl->__use_count(); }
        return 0;
    }

    /// Check if the weak_ptr is expired.
    bool expired() const noexcept {
        if (__ctrl) { return __ctrl->__use_count() == 0; }
        return true;
    }

    /// Convert to a shared_ptr.
    shared_ptr<_T> lock() const noexcept {
        /// If there is no control block, then we don’t even need to try.
        if (!__ctrl) { return {}; }

        /// Get the current refcount. If, at any point, it is zero, then
        /// the object has been deleted and we should return an empty
        /// shared_ptr.
        size_t __ref_count;
        do {
            __ref_count = __ctrl->__use_count();
            if (__ref_count == 0) { return {}; }
        }

        /// Try to increment it. If we fail, then we try again. If we succeed,
        /// then we can create a new shared_ptr.
        while (!atomic_compare_exchange_weak(&__ctrl->__ref_count, &__ref_count, __ref_count + 1));

        /// Don’t forget to add another weak reference for the new shared_ptr.
        /// This is safe because if we get here, then we hold a weak reference
        /// to it and are therefore allowed to increment it.
        __ctrl->__add_weak();

        /// Return the shared_ptr. This constructor does not do any incrementing.
        return shared_ptr<_T>(__ptr, __ctrl);
    }
};

template <typename _T, typename _U>
shared_ptr<_T> static_pointer_cast(const shared_ptr<_U>& __other) noexcept {
    auto ptr = static_cast<typename shared_ptr<_T>::element_type*>(__other.get());
    return shared_ptr<_T>(__other, ptr);
}

template <typename _T, typename _U>
shared_ptr<_T> static_pointer_cast(shared_ptr<_U>&& __other) noexcept {
    auto ptr = static_cast<typename shared_ptr<_T>::element_type*>(__other.get());
    return shared_ptr<_T>(std::move(__other), ptr);
}

template <typename _T, typename _U>
shared_ptr<_T> dynamic_pointer_cast(const shared_ptr<_U>& __other) noexcept {
    auto ptr = dynamic_cast<typename shared_ptr<_T>::element_type*>(__other.get());
    return shared_ptr<_T>(__other, ptr);
}

template <typename _T, typename _U>
shared_ptr<_T> dynamic_pointer_cast(shared_ptr<_U>&& __other) noexcept {
    auto ptr = dynamic_cast<typename shared_ptr<_T>::element_type*>(__other.get());
    return shared_ptr<_T>(std::move(__other), ptr);
}

template <typename _T, typename _U>
shared_ptr<_T> reinterpret_pointer_cast(const shared_ptr<_U>& __other) noexcept {
    auto ptr = reinterpret_cast<typename shared_ptr<_T>::element_type*>(__other.get());
    return shared_ptr<_T>(__other, ptr);
}

template <typename _T, typename _U>
shared_ptr<_T> reinterpret_pointer_cast(shared_ptr<_U>&& __other) noexcept {
    auto ptr = reinterpret_cast<typename shared_ptr<_T>::element_type*>(__other.get());
    return shared_ptr<_T>(std::move(__other), ptr);
}

} // namespace std

#endif // LENSOROS_MEMORY_HH
