/* Copyright 2022, Contributors To LensorOS.
 * All rights reserved.
 *
 * This file is part of LensorOS.
 *
 * LensorOS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * LensorOS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with LensorOS. If not, see <https://www.gnu.org/licenses
 */

#ifndef LENSOROS_VECTOR_
#define LENSOROS_VECTOR_

#ifdef __kernel__
#   include "memory/heap.h"
#endif

#include <algorithm>
#include <iterator>
#include <memory>
#include <type_traits>
#include <utility>

namespace std {

template <
    typename _T,
    typename __index_type = size_t,
    typename = _Requires<_NotRef<_T>>
>
struct __vector {
    using value_type = _T;
    using pointer = _T*;
    using const_pointer = const _T*;
    using reference = _T&;
    using const_reference = const _T&;
    using size_type = size_t;
    using difference_type = ptrdiff_t;
    using iterator = _T*;
    using const_iterator = const _T*;
    using reverse_iterator = std::reverse_iterator<iterator>;
    using const_reverse_iterator = std::reverse_iterator<const_iterator>;

private:
    value_type* __ptr{};
    size_type __sz{};
    size_type __cap{};

public:
    /// =======================================================================
    ///  Constructors.
    /// =======================================================================
    /// Construct an empty vector.
    __vector() noexcept = default;

    /// Construct a __vector with a given size.
    explicit __vector(size_type __n) : __sz(__n), __cap(__n) { __ptr = new value_type[__n]; }

    /// Construct a __vector with __n copies of __val.
    __vector(size_type __n, const value_type& __val) : __sz(__n), __cap(__n) {
        __ptr = new value_type[__n];
        for (size_type i = 0; i < __n; ++i) { __ptr[i] = __val; }
    }

    /// Construct a __vector with the contents of the range [__first, __last).
    template<typename _It>
    __vector(_It __first, _It __last) : __sz(__last - __first), __cap(__last - __first) {
        __ptr = new value_type[__sz];
        for (size_type i = 0; i < __sz; ++i) { __ptr[i] = *__first++; }
    }

    /// =======================================================================
    ///  Copy and move constructors, assignment operators, and destructor.
    /// =======================================================================
    __vector(const __vector& __other) : __sz(__other.__sz), __cap(__other.__cap) {
        __ptr = new value_type[__sz];
        for (size_type i = 0; i < __sz; ++i) { __ptr[i] = __other.__ptr[i]; }
    }

    __vector(__vector&& __other) noexcept : __ptr(__other.__ptr), __sz(__other.__sz), __cap(__other.__cap) {
        __other.__ptr = nullptr;
        __other.__sz = 0;
        __other.__cap = 0;
    }

    __vector& operator=(const __vector& __other) {
        if (this == addressof(__other)) { return *this; }
        __sz = __other.__sz;
        __cap = __other.__cap;
        __ptr = new value_type[__sz];
        for (size_type i = 0; i < __sz; ++i) { __ptr[i] = __other.__ptr[i]; }
        return *this;
    }

    __vector& operator=(__vector&& __other) noexcept {
        if (this == addressof(__other)) { return *this; }
        __ptr = __other.__ptr;
        __sz = __other.__sz;
        __cap = __other.__cap;
        __other.__ptr = nullptr;
        __other.__sz = 0;
        __other.__cap = 0;
        return *this;
    }

    /// Destructor.
    ~__vector() { delete[] __ptr; }

    /// =======================================================================
    ///  Iterators.
    /// =======================================================================
    iterator begin() noexcept { return __ptr; }
    const_iterator begin() const noexcept { return __ptr; }

    iterator end() noexcept { return __ptr + __sz; }
    const_iterator end() const noexcept { return __ptr + __sz; }

    reverse_iterator rbegin() noexcept { return reverse_iterator(end()); }
    const_reverse_iterator rbegin() const noexcept { return const_reverse_iterator(end()); }

    reverse_iterator rend() noexcept { return reverse_iterator(begin()); }
    const_reverse_iterator rend() const noexcept { return const_reverse_iterator(begin()); }

    const_iterator cbegin() const noexcept { return __ptr; }
    const_iterator cend() const noexcept { return __ptr + __sz; }
    const_reverse_iterator crbegin() const noexcept { return const_reverse_iterator(end()); }
    const_reverse_iterator crend() const noexcept { return const_reverse_iterator(begin()); }

    /// =======================================================================
    ///  Size and capacity.
    /// =======================================================================
    [[nodiscard]] bool empty() const noexcept { return __sz == 0; }
    size_type size() const noexcept { return __sz; }
    size_type max_size() const noexcept { return SIZE_MAX / sizeof(value_type); }
    size_type capacity() const noexcept { return __cap; }

    void resize(size_type __n) {
        if (__n > __cap) { reserve(__n); }
        __sz = __n;
    }

    void resize(size_type __n, const value_type& __val) {
        if (__n > __cap) { reserve(__n); }
        if (__n > __sz) {
            for (size_type i = __sz; i < __n; ++i) { __ptr[i] = __val; }
        }
        __sz = __n;
    }

    void reserve(size_type __n) {
        if (__n > max_size()) { /* TODO: Crash horribly */ }
        if (__n > __cap) {
            auto __new_ptr = new value_type[__n];
            for (size_type i = 0; i < __sz; ++i) { __new_ptr[i] = std::move(__ptr[i]); }
            delete[] __ptr;
            __ptr = __new_ptr;
            __cap = __n;
        }
    }

    void shrink_to_fit() {
        if (__sz < __cap) {
            auto __new_ptr = new value_type[__sz];
            for (size_type i = 0; i < __sz; ++i) { __new_ptr[i] = std::move(__ptr[i]); }
            delete[] __ptr;
            __ptr = __new_ptr;
            __cap = __sz;
        }
    }

    /// =======================================================================
    ///  Element access.
    /// =======================================================================
    reference operator[](__index_type __n) noexcept { return __ptr[static_cast<size_t>(__n)]; }
    const_reference operator[](__index_type __n) const noexcept { return __ptr[static_cast<size_t>(__n)]; }
    reference at(__index_type __n) { return __ptr[static_cast<size_t>(__n)]; }
    const_reference at(__index_type __n) const { return __ptr[static_cast<size_t>(__n)]; }
    reference front() noexcept { return __ptr[0]; }
    const_reference front() const noexcept { return __ptr[0]; }
    reference back() noexcept { return __ptr[__sz - 1]; }
    const_reference back() const noexcept { return __ptr[__sz - 1]; }

    /// =======================================================================
    ///  Data access.
    /// =======================================================================
    value_type* data() noexcept { return __ptr; }
    const value_type* data() const noexcept { return __ptr; }

    /// =======================================================================
    ///  Pushing, popping, and emplacing.
    /// =======================================================================
    template <typename ..._Args>
    reference emplace_back(_Args&&... __args) {
        if (__sz == __cap) { reserve(__cap == 0 ? 1 : __cap * 2); }
        __ptr[__sz++] = value_type(forward<_Args>(__args)...);
        return __ptr[__sz - 1];
    }

    void push_back(const value_type& __val) {
        if (__sz == __cap) { reserve(__cap == 0 ? 1 : __cap * 2); }
        __ptr[__sz++] = __val;
    }

    void push_back(value_type&& __val) {
        if (__sz == __cap) { reserve(__cap == 0 ? 1 : __cap * 2); }
        __ptr[__sz++] = move(__val);
    }

    void pop_back() { --__sz; }
    void clear() { __sz = 0; }

    /// =======================================================================
    ///  Inserting and erasing.
    /// =======================================================================
    iterator insert(const_iterator __pos, value_type&& __val) {
        if (__pos > end() || __pos < begin()) { /* TODO: Crash horribly */ }
        const auto __idx = __pos - __ptr;
        if (__sz == __cap) { reserve(__cap == 0 ? 1 : __cap * 2); }
        for (difference_type i = __sz; i > __idx; --i) { __ptr[i] = __ptr[i - 1]; }
        __ptr[__idx] = move(__val);
        ++__sz;
        return __ptr + __idx;
    }

    iterator insert(const_iterator __pos, const value_type& __val) {
        return insert(__pos, value_type(__val));
    }

    template<typename _It>
    iterator insert(const_iterator __pos, _It __first, _It __last) {
        if (__pos > end() || __pos < begin()) { /* TODO: Crash horribly */ }
        const auto __idx = __pos - __ptr;
        const auto __n = __last - __first;
        if (__sz + __n > __cap) { reserve(max(__sz + __n, __cap == 0 ? 1 : __cap * 2)); }

        for (size_type i = __sz + __n - 1; i >= __idx + __n; --i) { __ptr[i] = __ptr[i - __n]; }
        for (size_type i = 0; i < __n; ++i) { __ptr[__idx + i] = __first[i]; }

        __sz += __n;
        return __ptr + __idx;
    }

    iterator erase(const_iterator __first, const_iterator __last) {
        if (__first > end() || __first < begin() || __last > end() || __last < begin()) { /* TODO: Crash horribly */ }
        const auto __idx = __first - __ptr;
        const auto __n = __last - __first;
        for (size_type i = __idx; i < __sz - __n; ++i) { __ptr[i] = __ptr[i + __n]; }
        __sz -= __n;
        return __ptr + __idx;
    }

    iterator erase(const_iterator __pos) { return erase(__pos, __pos + 1); }
};


template
<typename _ElementType,
 typename __index_type = size_t,
 class _ComparedType
 >
constexpr typename __vector<_ElementType, __index_type>::size_type
erase(__vector<_ElementType, __index_type>& __container, const _ComparedType& __value) {
    size_t __size = __container.size();
    __container.erase(remove(__container.begin(), __container.end(), __value), __container.end());
    return __size - __container.size();
}

template
<typename _ElementType,
 typename __index_type = size_t,
 class _Predicate
 >
constexpr typename __vector<_ElementType, __index_type>::size_type
erase_if(__vector<_ElementType, __index_type>& __container, _Predicate __predicate) {
    size_t __size = __container.size();
    __container.erase(remove_if(__container.begin(), __container.end(), __predicate), __container.end());
    return __size - __container.size();
}


template <typename _T>
using vector = __vector<_T, size_t>;

} // namespace std

#endif // LENSOROS_VECTOR_
