/* Copyright 2022, Contributors To LensorOS.
 * All rights reserved.
 *
 * This file is part of LensorOS.
 *
 * LensorOS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * LensorOS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with LensorOS. If not, see <https://www.gnu.org/licenses
 */

#ifndef _LENSOROS_UNORDERED_MAP
#define _LENSOROS_UNORDERED_MAP

#include <stddef.h>
#include <stdint.h>

#include <functional>
#include <vector>
#include <utility>

namespace std {

template<
    class _Key,
    class _T,
    class _Hash = std::hash<_Key>,
    class _KeyEqual = std::equal_to<_Key>>
struct __unordered_map_impl {
    /// All read only operations, swap, std::swap NEVER invalidate iterators.
    /// clear, rehash, reserve, operator= ALWAYS invalidate iterators.
    /// insert, emplace, emplace_hint, operator[] MAY invalidate iterators.
    /// erase invalidates iterators ONLY to the element erased.

    using key_type = _Key;
    using mapped_type = _T;
    using value_type = std::pair<const _Key, _T>;
    using size_type = std::size_t;
    using difference_type = ptrdiff_t;
    using hasher = _Hash;
    using key_equal = _KeyEqual;
    using reference = value_type&;
    using const_reference = const value_type&;

    hasher __hash_function;
    key_equal __keys_equal;

    // FIXME: Technically an unordered map doesn't need a doubly linked list, just singly.
    struct __bucket {
        bool __occupied { false };
        value_type __keyval;
        __bucket* __prev { nullptr };
        __bucket* __next { nullptr };

        /// [A]<->[B]<->[C]
        /// B.insert_after(F);
        /// [A]<->[B]<->[F]<->[C]
        void insert_after(__bucket* __new_bucket) {
            /// In total, four things need to happen.
            ///   F.prev = B
            ///   B.next.prev = F
            ///   F.next = B.next
            ///   B.next = F
            /// In doing so, all four connections are handled; the two old ones
            /// (B<->C) and the two new ones (both sides of F).

            /// [A]<->[B]<->[C]  [F]
            /// [A]<->[B]<->[C]  [B]<-[F]
            __new_bucket->__prev = this;

            /// [A]<->[B]<->[C]  [B]<-[F]
            /// [A]<->[B]->[C]  [B]<-[F]<-[C]
            if (__next) __next->__prev = __new_bucket;

            /// [A]<->[B]->[C]  [B]<-[F]<-[C]
            /// [A]<->[B]->[C]  [B]<-[F]<->[C]
            __new_bucket->__next = __next;

            /// [A]<->[B]->[C]  [B]<-[F]<-[C]
            /// [A]<->[B]  [B]<->[F]<->[C]
            /// [A]<->[B]<->[F]<->[C]
            __next = __new_bucket;
        }

        void insert_at_end(__bucket* __new_bucket) {
            __new_bucket->__next = nullptr;
            auto __it = this;
            while (__it.__next) __it = __it.__next;
            __it.__next = __new_bucket;
            __new_bucket->__prev = __it;
        }

        /// [A]<->[B]<->[C]
        /// B.insert_before(F);
        /// [A]<->[F]<->[B]<->[C]
        void insert_before(__bucket* __new_bucket) {
            /// In total, four things need to happen.
            ///   F.prev = B.prev
            ///   B.prev.next = F
            ///   B.prev = F
            ///   F.next = B
            /// In doing so, all four connections are handled; the two old ones
            /// (A<->B) and the two new ones (both sides of F).

            /// [A]<->[B]<->[C]  [F]
            /// [A]<->[B]<->[C]  [A]<-[F]
            __new_bucket->__prev = __prev;

            /// [A]<->[B]<->[C]  [A]<-[F]
            /// [A]<-[B]<->[C]  [A]<->[F]
            if (__prev) __prev->__next = __new_bucket;

            /// [A]<->[B]<->[C]  [A]<-[F]
            /// [A]<->[F]<-[B]<->[C]
            __prev = __new_bucket;

            /// [A]<->[B]<->[C]  [A]<-[F]
            /// [A]<->[F]<-[B]<->[C]
            __new_bucket->__next = this;
        }

        void insert_at_beginning(__bucket* __new_bucket) {
            __new_bucket->__prev = nullptr;
            auto __it = this;
            while (__it.__prev) __it = __it.__prev;
            __it.__prev = __new_bucket;
            __new_bucket->__next = __it;
        }

        /// Frees `__n` nodes later than this node in the linked list,
        /// or all of them if it's value is zero.
        ///
        /// [A]<->[B]<->[C]
        /// B.erase_after();
        /// [A]<->[B]
        void erase_after(size_type __n = 0) {
            if (!__n) __n = SIZE_MAX;
            _T* __it = nullptr;
            _T* __next_it = __next;
            __next = nullptr;
            while ((__it = __next_it) && __n--) {
                __next_it = __it.__next;
                delete __it;
            }
        }

        /// Frees `__n` nodes earlier than this node in the linked
        /// list, or all of them if it's value is zero.
        ///
        /// [A]<->[B]<->[C]
        /// B.erase_before();
        /// [B]<->[C]
        void erase_before(size_type __n = 0) {
            if (!__n) __n = SIZE_MAX;
            _T* __it = nullptr;
            _T* __next_it = __prev;
            __prev = nullptr;
            while ((__it = __next_it) && __n--) {
                __next_it = __it.__prev;
                delete __it;
            }
        }

        /// Removes this node from a list but DOES NOT FREE itself.
        ///
        /// [A]<->[B]<->[C]
        /// B.erase();
        /// [A]<->[C]
        void erase() {
            if (__prev) __prev->__next = __next;
            if (__next) __next->prev = __prev;
        }

        size_type length() {
            size_type __out { 0 };
            for (__bucket* __it = this; __it; __it = __it->__next)
                ++__out;
            return __out;
        }
    };

    std::vector<__bucket> __data { 512 };

    // Amount of elements stored in the map.
    size_type __count { 0 };

    struct iterator {
        __bucket* __buckets { nullptr };
        size_type __count { 0 };
        size_type __index { 0 };
        __bucket* __current { nullptr };

        void clear() {
            __buckets = nullptr;
            __count = 0;
            __index = 0;
            __current = nullptr;
        }

        void __update_to_next_occupied() {
            // Find first occupied bucket and set ourselves to that.
            // If no occupied buckets exist, clear the iterator.
            const size_type __last_index = __count - 1;
            while (__index < __last_index) {
                __current = __buckets + ++__index;
                if (__current->__occupied) break;
            }
            // If we are at the end and we still haven't found an
            // occupied bucket, there are none left.
            if (!__current->__occupied && __index == __last_index)
                clear();
        }

        iterator(__bucket* __bucket_array, size_type __bucket_count)
            : __buckets(__bucket_array), __count(__bucket_count), __current(__bucket_array) {
            __update_to_next_occupied();
        }

        iterator(__bucket* __bucket_array, size_type __bucket_count, size_type __i)
            : __buckets(__bucket_array), __count(__bucket_count), __index(__i) {
            __current = __buckets + __index;
        }

        iterator(__bucket* __bucket_array, size_type __bucket_count, size_type __i, __bucket* __b)
            : __buckets(__bucket_array), __count(__bucket_count), __index(__i), __current(__b) {}

        // Prefix increment operator
        iterator& operator++ () {
            if (__current->__next) __current = __current->__next;
            else __update_to_next_occupied();
            return *this;
        }

        bool operator== (const iterator& __other) const {
            return __buckets == __other.__buckets
                && __count == __other.__count
                && __index == __other.__index
                && __current == __other.__current;
        }
        bool operator!= (const iterator& __other) const {
            return !(operator==(__other));
        }

        value_type& operator* () {
            return __current->__keyval;
        }
    };

    iterator begin() {
        return { __data.data(), __data.capacity() };
    }

    iterator end() { return { nullptr, 0, 0, nullptr }; }

    [[nodiscard]]
    bool empty() const noexcept {
        return __count == 0;
    }

    size_type size() const noexcept {
        return __count;
    }

    void clear() noexcept {
        for (auto& __b : __data) {
            __b.erase_after();
        }
        __data.clear();
        __count = 0;
    }

    [[nodiscard]]
    size_type __get_index(const key_type& __key) {
        if (!__data.capacity()) __data.reserve(512);
        size_type __hash = __hash_function(__key);
        return __hash % __data.capacity();
    }

    [[nodiscard]]
    __bucket* __get_matching_bucket(size_type __index, const key_type& __key) const {
        __bucket* __b = __data.data() + __index;
        while (__b && !__keys_equal(__key, __b->__key.first))
            __b = __b->__next;
        return __b;
    }
    [[nodiscard]]
    __bucket* __get_matching_bucket(const key_type& __key) const {
        size_type __index = __get_index(__key);
        __bucket* __b = __data.data() + __index;
        while (__b && !__keys_equal(__key, __b->__key))
            __b = __b->__next;
        return __b;
    }

    std::pair<iterator, bool> insert(value_type&& __value) {
        auto __key = __value.first;
        size_type __index = __get_index(__key);
        auto __make_it = [&](__bucket* __b, bool __inserted = true) -> std::pair<iterator, bool> {
            return { { __data.data(), __data.capacity(), __index, __b }, __inserted };
        };
        __bucket* __b = __data.data() + __index;
        if (__b->__occupied) {
            __bucket* __last_it = nullptr;
            __bucket* __it = __b;
            while (__it) {
                if (__keys_equal(__key, __it->__keyval.first))
                    return __make_it(__it, false);
                __last_it = __it;
            }
            // If we got through all of the pairs in the bucket without finding a
            // matching key, it's time to add this key/value pair to this bucket.
            auto* __b = new __bucket{ true, __key, __value.second };
            __last_it->insert_after(__b);
            return __make_it(__last_it->__next);
        }
        // Bucket is entirely unoccupied; simply place key and value into this
        // bucket, set it as occupied, and move on.
        __b->__keyval = __value;
        __b->__occupied = true;
        return __make_it(__b);
    }

    mapped_type& operator[] (const key_type& __key) {
        auto __it = insert({ __key, {} });
        return __it.first.__current->__keyval.second;
    }
    mapped_type& operator[] (key_type&& __key) {
        return operator[](__key);
    }

    /// Return number of elements matching key (either 1 or 0).
    size_type count(const key_type& __key) const {
        size_type __index = __get_index(__key);
        __bucket* __b = __data.data() + __index;
        if (!__b->__occupied) return 0;
        while (__b) if (__keys_equal(__key, __b->__key)) return 1;
        return 0;
    }

    iterator find(const key_type& __key) {
        size_type __index = __get_index(__key);
        auto __make_it = [&](__bucket* __b) -> iterator {
            return { __data.data(), __data.capacity(), __index, __b };
        };
        __bucket* __b = __data.data() + __index;
        if (!__b->__occupied) return {};
        while (__b) if (__keys_equal(__key, __b->__key)) return __make_it(__b);
        return {};
    }

    bool contains(const key_type& __key) const {
        return count(__key) != 0;
    }

    size_type bucket_size(size_type __n) const {
        __bucket* __b = __data.data() + __n;
        if (!__b->__occupied) return 0;
        return __b->length();
    }
    size_type bucket_count() const {
        return __data.count();
    }
    // FIXME: Is this correct?
    size_type max_bucket_count() const {
        return SIZE_MAX;
    }

    size_type bucket(const key_type& __key) {
        return __get_index(__key);
    }

    hasher hash_function() { return __hash_function; }
    key_equal key_eq() { return __keys_equal; }
};

template<
    class _Key,
    class _T,
    class _Hash = std::hash<_Key>,
    class _KeyEqual = std::equal_to<_Key>>
using unordered_map = __unordered_map_impl<_Key, _T, _Hash, _KeyEqual>;

} // namespace std

#endif /* _LENSOROS_UNORDERED_MAP */
