/* Copyright 2022, Contributors To LensorOS.
 * All rights reserved.
 *
 * This file is part of LensorOS.
 *
 * LensorOS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * LensorOS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with LensorOS. If not, see <https://www.gnu.org/licenses
 */

#ifndef _LENSOROS_EXTENSIONS_SPARSE_VECTOR
#define _LENSOROS_EXTENSIONS_SPARSE_VECTOR

#include <type_traits>
#include <vector>

namespace std {

/// Like std::vector, but erasing is O(1) and doesn't invalidate indices.
template <typename _T, auto __invalid, typename index_type = size_t>
class __sparse_vector_base {
protected:
    /// For now, we only allow __invalid to be something we can easily store.
    static_assert(sizeof(__invalid) <= sizeof(void*));
    static_assert(alignof(decltype(__invalid)) <= alignof(void*));

    vector<_T> __els;
    vector<size_t> __idxs;

    __sparse_vector_base() = default;
    __sparse_vector_base(const __sparse_vector_base& other)
        : __els(other.__els),
          __idxs(other.__idxs)
    {}

    struct __sentinel {};

    template <typename _P, typename = _Requires<_Pointer<_P>>>
    struct __iterator {
        _P __ptr{};
        _P __end{};

        __iterator() = default;
        __iterator(_P __ptr_, _P __end_) : __ptr(__ptr_), __end(__end_) {
            while (__ptr < __end && *__ptr == static_cast<_T>(__invalid)) { ++__ptr; }
        }

        __iterator& operator++() {
            do ++__ptr;
            while (__ptr < __end && *__ptr == static_cast<_T>(__invalid));
            return *this;
        }

        __iterator operator+(size_t __n) const {
            __iterator __it;
            __it.__ptr = __ptr + __n;
            __it.__end = __end;
            return __it;
        }

        __iterator operator-(size_t __n) const {
            __iterator __it;
            __it.__ptr = __ptr - __n;
            __it.__end = __end;
            return __it;
        }

        ptrdiff_t operator-(const __iterator& __other) const {
            if (__end != __other.__end) { /** TODO: Crash horribly **/ }
            return __ptr - __other.__ptr;
        }

        bool operator==(const __iterator& __other) const { return __ptr == __other.__ptr; }
        bool operator==(const __sentinel&) const { return __ptr == __end; }

        bool operator!=(const __iterator& __rhs) const { return __ptr != __rhs.__ptr; }
        bool operator!=(const __sentinel&) const { return __ptr != __end; }

        _T& operator*() { return *__ptr; }
        _T* operator->() { return __ptr; }
    };

    template <typename _P, typename = _Requires<_Pointer<_P>>>
    struct __kv_iterator {
        __iterator<_P> __it;
        _P __start{};

        __kv_iterator() = default;
        __kv_iterator(_P __ptr, _P __end) : __it(__ptr, __end), __start(__ptr) {}

        __kv_iterator& operator++() {
            ++__it;
            return *this;
        }

        bool operator!=(const __kv_iterator& __rhs) const { return __it != __rhs.__it; }
        bool operator!=(const __sentinel&) const { return __it != __sentinel{}; }

        pair<index_type, _T&> operator*() {
            return {static_cast<index_type>(__it.__ptr - __start), *__it};
        }
    };

    struct __kv_iterator_helper {
        __sparse_vector_base& __sv;
        __kv_iterator_helper(__sparse_vector_base& __sv) : __sv(__sv) {}

        __kv_iterator<decltype(__sv.__els.begin())> begin() {
            return __kv_iterator<decltype(__sv.__els.begin())>{__sv.__els.begin(), __sv.__els.end()};
        }
        __sentinel end() { return {}; }
    };

public:
    /// =======================================================================
    ///  Iterators.
    /// =======================================================================
    __iterator<decltype(__els.begin())> begin() noexcept { return __iterator<decltype(__els.begin())>{__els.begin(), __els.end()}; }
    __iterator<decltype(__els.begin())> begin() const noexcept { return __iterator<decltype(__els.begin())>{__els.begin(), __els.end()}; }
    __sentinel end() noexcept { return {}; }
    __sentinel end() const noexcept { return {}; }

    /// =======================================================================
    ///  Manipulating elements.
    /// =======================================================================
    /// Insert a new element into the vector.
    ///
    /// If the element is equal to the invalid element, no insertion is performed.
    ///
    /// \param __el The element to insert.
    /// \return A pair containing the index of the inserted element, if it was inserted,
    ///         and a bool that is true iff an element was inserted.
    template <typename _U>
    std::pair<index_type, bool> push_back(_U&& __el) {
        if (__el == static_cast<_T>(__invalid)) { return {static_cast<index_type>(0), false}; }
        auto [__idx, _] = allocate();
        __els[size_t(__idx)] = std::forward<_U>(__el);
        return {static_cast<index_type>(__idx), true};
    }

    /// Allocate a new element in the vector.
    ///
    /// This default-constructs a free position if it is unused, or allocates a new
    /// element at the end of the vector if all positions are used.
    ///
    /// \return A pair containing the index of the allocated element, and a bool that
    ///         is true iff an element was allocated.
    std::pair<index_type, bool> allocate() {
        if (!__idxs.empty()) {
            auto __idx = __idxs.back();
            __idxs.pop_back();
            return {index_type(__idx), false};
        } else {
            __els.push_back(static_cast<_T>(__invalid));
            return {index_type(__els.size() - 1), true};
        }
    }

    /// Place element `__el` at index `__idx` within sparse vector.
    bool replace(index_type __idx, _T __el) {
        const auto __i = static_cast<size_t>(__idx);
        if (__i > __els.size()) return false;
        if (__els[__i] == static_cast<_T>(__invalid)) std::erase_if(__idxs, [__i](auto __j) { return __j == __i; });
        __els[__i] = std::move(__el);
        return true;
    }

    bool erase(index_type __idx) {
        const auto __i = static_cast<size_t>(__idx);
        if (__i > __els.size()) { return false; }
        if (__els[__i] == static_cast<_T>(__invalid)) { return false; }

        __els[__i] = static_cast<_T>(__invalid);
        __idxs.push_back(__i);
        return true;
    }

    bool erase(__iterator<decltype(__els.begin())> __it) {
        if (__it == end()) { return false; }
        return erase(static_cast<index_type>(__it - begin()));
    }

    /// Clear and delete all elements.
    void clear_and_delete() {
        __els.clear();
        __idxs.clear();
        __els.shrink_to_fit();
    }

    __kv_iterator_helper pairs() noexcept { return __kv_iterator_helper{*this}; }

    /// =======================================================================
    ///  Miscellaneous
    /// =======================================================================
    /// Since the vector has holes, its size may not be the number of elements
    /// it contains; rather it is the maximum number of elements it may contain.
    size_t allocated_size() const noexcept { return __els.size(); }

    /// Defragment the vector and free unused memory.
    /// This will remove all holes in the vector, but also invalidate all indices.
    void compact() {
        /// Find the first hole.
        size_t __hole = 0;
        while (__hole < __els.size() && __els[__hole] != __invalid) { __hole++; }

        /// If there are no holes, we're done.
        if (__hole == __els.size()) { return; }

        /// Find the first non-hole after the hole.
        size_t __el = __hole + 1;
        while (__el < __els.size() && __els[__el] == __invalid) { __el++; }

        /// If the hole is at the end, we're done.
        if (__el == __els.size()) {
        __shrink:
            __els.resize(__hole);
            __els.shrink_to_fit();
            __idxs.clear();
            return;
        }

        /// Move elements into the hole from the right; skip over any holes.
        while (__el < __els.size()) {
            if (__els[__el] != __invalid) {
                __els[__hole] = move(__els[__el]);
                ++__hole;
            }
            ++__el;
        }

        /// Shrink the vector.
        goto __shrink;
    }
};

template <typename _T, auto __invalid, typename index_type = size_t>
struct sparse_vector : __sparse_vector_base<_T, __invalid, index_type> {
    using __base = __sparse_vector_base<_T, __invalid, index_type>;
    using __base::__els;


    /// =======================================================================
    ///  Accessing elements.
    /// =======================================================================
    _T* operator[](index_type __idx) noexcept {
        const auto __i = static_cast<size_t>(__idx);
        if (__i >= __els.size()) { return nullptr; }
        return __els.data() + __i;
    }

    const _T* operator[](index_type __idx) const noexcept {
        const auto __i = static_cast<size_t>(__idx);
        if (__i >= __els.size()) { return nullptr; }
        return __els.data() + __i;
    }
};

template <typename _T, typename index_type>
struct sparse_vector<std::unique_ptr<_T>, nullptr, index_type>
    : __sparse_vector_base<std::unique_ptr<_T>, nullptr, index_type> {
    using __base = __sparse_vector_base<std::unique_ptr<_T>, nullptr, index_type>;
    using __base::__els;

    static std::unique_ptr<_T> __invalid_value;

    /// =======================================================================
    ///  Accessing elements.
    /// =======================================================================
    std::unique_ptr<_T>& operator[](index_type __idx) noexcept {
        const auto __i = static_cast<size_t>(__idx);
        if (__i >= __els.size()) { return __invalid_value; }
        return __els[__i];
    }

    const std::unique_ptr<_T>& operator[](index_type __idx) const noexcept {
        const auto __i = static_cast<size_t>(__idx);
        if (__i >= __els.size()) { return __invalid_value; }
        return __els[__i];
    }
};

template <typename _T, typename index_type>
struct sparse_vector<std::shared_ptr<_T>, nullptr, index_type>
    : __sparse_vector_base<std::shared_ptr<_T>, nullptr, index_type> {
    using __base = __sparse_vector_base<std::shared_ptr<_T>, nullptr, index_type>;
    using __base::__els;

    /// =======================================================================
    ///  Accessing elements.
    /// =======================================================================
    std::shared_ptr<_T> operator[](index_type __idx) const noexcept {
        const auto __i = static_cast<size_t>(__idx);
        if (__i >= __els.size()) { return std::shared_ptr<_T>{nullptr}; }
        return __els[__i];
    }
};

} // namespace std

#endif // _LENSOROS_EXTENSIONS_SPARSE_VECTOR
