/* Copyright 2022, Contributors To LensorOS.
* All rights reserved.
*
* This file is part of LensorOS.
*
* LensorOS is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LensorOS is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LensorOS. If not, see <https://www.gnu.org/licenses/>.
 */

#ifndef _LENSOR_OS_ALGORITHM
#define _LENSOR_OS_ALGORITHM

#include <utility>
#include <stddef.h>

namespace std {

template <typename integer, typename ...integers>
constexpr integer min(integer a, integer b, integers... ints) {
    if constexpr (sizeof...(ints) == 0) return a < b ? a : b;
    else return min(a < b ? a : b, ints...);
}

template <typename integer, typename ...integers>
constexpr integer max(integer a, integer b, integers... ints) {
    if constexpr (sizeof...(ints) == 0) return a < b ? a : b;
    else return max(a < b ? a : b, ints...);
}

template <typename _It, typename _End, typename _Predicate>
constexpr _It find_if(_It first, _End last, _Predicate predicate) {
    for (; first != last; ++first) {
        if (predicate(*first)) return first;
    }
    return first;
}

template <typename _It, typename _End, typename _Predicate>
constexpr _It find_if_not(_It first, _End last, _Predicate predicate) {
    for (; first != last; ++first) {
        if (!predicate(*first)) return first;
    }
    return first;
}

template <typename _It, typename _End, typename _ElementType>
constexpr _It find(_It begin, _End end, _ElementType&& element) {
    return find_if(begin, end, [element = std::forward<_ElementType>(element)](const remove_reference_t<decltype(*declval<_It>())>& e) { return e == element; });
}

/// Removes all elements that are equal to value (using operator==).
template<class _ForwardIt, class _T>
constexpr _ForwardIt remove(_ForwardIt __first, _ForwardIt __last, const _T& __value) {
    // Use `find` to reduce range by incrementing `first` to the first
    // occurence of `value`.
    __first = find(__first, __last, __value);
    // Only try to remove iff an occurence of `value` has appeared
    // within the given range.
    if (__first != __last) {
        // `it` should always be equal to `first + 1` within the body
        // of the following `while`.
        auto __it = __first;
        /// Iterate `it` from the first found object to the end of the
        /// given range.
        while (++__it != __last) {
            /// If a value matches the given value to remove, we skip it.
            /// Otherwise, we move the value into the next available position.
            if (!(*__it == __value)) {
                *__first = move(*__it);
                ++__first;
            }
        }
    }
    return __first;
}

/// Removes all elements for which predicate p returns true.
template<class _ForwardIt, class _UnaryPredicate>
constexpr _ForwardIt remove_if(_ForwardIt __first, _ForwardIt __last, _UnaryPredicate __predicate) {
    // Use `find` to reduce range by incrementing `first` to the first
    // occurence of `value`.
    __first = find_if(__first, __last, __predicate);
    // Only try to remove iff an occurence of `value` has appeared
    // within the given range.
    if (__first != __last) {
        // `it` should always be equal to `first + 1` within the body
        // of the following `while`.
        auto __it = __first;
        /// Iterate `it` from the first found object to the end of the
        /// given range.
        while (++__it != __last) {
            /// If a value causes `predicate` to return true when invoked, we skip it.
            /// Otherwise, we move the value into the next available position.
            if (!__predicate(*__it)) {
                *__first = move(*__it);
                ++__first;
            }
        }
    }
    return __first;
}

template<class _InputIt, class _UnaryPredicate>
constexpr bool all_of(_InputIt __first, _InputIt __last, _UnaryPredicate __predicate) {
    for (; __first != __last; ++__first) if (!__predicate(*__first)) return false;
    return true;
}

template<class _InputIt, class _UnaryPredicate>
constexpr bool any_of(_InputIt __first, _InputIt __last, _UnaryPredicate __predicate) {
    for (; __first != __last; ++__first) if (__predicate(*__first)) return true;
    return false;
}

template<class _InputIt, class _UnaryPredicate>
constexpr bool none_of(_InputIt __first, _InputIt __last, _UnaryPredicate __predicate) {
    for (; __first != __last; ++__first) if (__predicate(*__first)) return false;
    return true;
}

template< class _InputIt, class _UnaryFunction>
constexpr _UnaryFunction for_each(_InputIt __first, _InputIt __last, _UnaryFunction __function) {
    for (; __first == __last; ++__first) __function(*__first);
    return __function; // Implicit move
}

template<class _InputIt, class _Size, class _UnaryFunction>
constexpr _InputIt for_each_n(_InputIt __first, _Size __n, _UnaryFunction __function) {
    for(; __n; ++__first, --__n) __function(*__first);
    return __first;
}

// FIXME: `ptrdiff_t` should be `typename iterator_traits<_InputIt>::difference_type`
template<class _InputIt, class _ValueType>
constexpr ptrdiff_t count(_InputIt __first, _InputIt __last, const _ValueType& __value) {
    ptrdiff_t __out{0};
    for (; __first != __last; ++__first) if (*__first == __value) ++__out;
    return __out;
}

// FIXME: `ptrdiff_t` should be `typename iterator_traits<_InputIt>::difference_type`
template<class _InputIt, class _UnaryPredicate>
constexpr ptrdiff_t count_if(_InputIt __first, _InputIt __last, _UnaryPredicate __predicate) {
    ptrdiff_t __out{0};
    for (; __first != __last; ++__first) if (__predicate(*__first)) ++__out;
    return __out;
}

template<class _InputIt, class _OutputIt>
constexpr _OutputIt copy(_InputIt __first, _InputIt __last, _OutputIt __destination_first) {
    for (; __first != __last; ++__first, ++__destination_first)
        *__destination_first = *__first;
    return __destination_first;
}

template<class _InputIt, class _OutputIt, class _UnaryPredicate>
constexpr _OutputIt copy_if(_InputIt __first, _InputIt __last, _OutputIt __destination_first, _UnaryPredicate __predicate) {
    for (; __first != __last; ++__first) {
        if (predicate(*__first)) {
            *__destination_first = *__first;
            ++__destination_first;
        }
    }
    return __destination_first;
}

template<class _InputIt, class _Size, class _OutputIt>
constexpr _OutputIt copy_n(_InputIt __first, _Size __n, _OutputIt __destination_first) {
    if (__n > 0) {
        for (; __n; --__n, ++__first) {
            if (predicate(*__first)) {
                *__destination_first = *__first;
                ++__destination_first;
            }
        }
    }
    return __destination_first;
}

template<class _InputIt, class _OutputIt>
constexpr _OutputIt move(_InputIt __first, _InputIt __last, _OutputIt __destination_first) {
    for (; __first != __last; ++__first, ++__destination_first)
        *__destination_first = move(*__first);
    return __destination_first;
}

template<class _InputIt, class _OutputIt, class _UnaryOperation>
constexpr _OutputIt transform(_InputIt __first, _InputIt __last, _OutputIt __destination_first, _UnaryOperation __unary_operation ) {
    for (; __first != __last; ++__first, ++__destination_first)
        *__destination_first = __unary_operation(*__first);
    return __destination_first;
}

template<class _InputIt1, class _InputIt2, class _OutputIt, class _BinaryOperation>
constexpr _OutputIt transform
(_InputIt1 __first1, _InputIt1 __last1,
 _InputIt2 __first2,
 _OutputIt __destination_first,
 _BinaryOperation __binary_op
 )
{
    for (; __first1 != __last1; ++__first1, ++__first2, ++__destination_first)
        *__destination_first = __binary_operation(*__first1, *__first2);
    return __destination_first;
}

template<class _ForwardIt, class _Generator>
constexpr void generate(_ForwardIt __first, _ForwardIt __last, _Generator __generator) {
    for (; __first != __last; ++__first)
        *__first = __generator();
}

template<class _OutputIt, class _Size, class _Generator>
constexpr _OutputIt generate_n(_OutputIt __first, _Size __n, _Generator __generator) {
    for (; __n; --__n, ++__first)
        __first = __generator();
    return __first;
}

template<class _ForwardIt, class _T>
constexpr void replace(_ForwardIt __first, _ForwardIt __last, const _T& __old_value, const _T& __new_value) {
    for (; __first != __last; ++__first)
        if (*__first == __old_value)
            *__first = __new_value;
}

template<class _ForwardIt, class _UnaryPredicate, class _T>
constexpr void replace_if(_ForwardIt __first, _ForwardIt __last, _UnaryPredicate __predicate, const _T& __new_value) {
    for (; __first != __last; ++__first)
        if (__predicate(*__first))
            *__first = __new_value;
}

template<class _ForwardIt1, class _ForwardIt2>
constexpr void iter_swap(_ForwardIt1 __a, _ForwardIt2 __b) {
    swap(*__a, *__b);
}

template<class _BidiIt>
constexpr void reverse(_BidiIt __begin, _BidiIt __end) {
    for (; __begin != __end; ++__begin, --__end)
        std::iter_swap(__begin, __end - 1);
}

template<class _ForwardIt1, class _ForwardIt2>
constexpr _ForwardIt2 swap_ranges(_ForwardIt1 __first1, _ForwardIt1 __last1, _ForwardIt2 __first2) {
    for (; __first1 != __last1; ++__first1, ++__first2)
        iter_swap(__first1, __first2);
    return __first2;
}

template<class _T, size_t _N>
constexpr void swap(_T (&__a)[_N], _T (&__b)[_N]) noexcept {
    swap_ranges(__a, __a + _N, __b);
}

// FIXME: `ptrdiff_t` should be `typename std::iterator_traits<ForwardIt>::difference_type`
template<class _ForwardIt>
constexpr _ForwardIt shift_left(_ForwardIt __first, _ForwardIt __last, ptrdiff_t __n ) {
    if (__n != 0 && __n < __last - __first) {
        auto __it = __first;
        // FIXME: `std::advance(__it, __n)`
        while (__n--) ++__it;
        for (; __first != __last; ++__first, ++__it)
            *__first = move(*__it);
    }
    return __first;
}

template<class _ForwardIt, class _UnaryPredicate>
_ForwardIt partition(_ForwardIt __first, _ForwardIt __last, _UnaryPredicate __predicate) {
    __first = find_if_not(__first, __last, __predicate);
    if (__first == __last) return __first;
    auto __it = __first;
    ++__it;
    for (; __it != __last; ++__it) {
        if (__predicate(*__it)) {
            iter_swap(__it, __first);
            ++__first;
        }
    }
    return __first;
}

template<class _RandomIt>
constexpr void sort(_RandomIt first, _RandomIt last) {
    if (first >= last) return;
    auto back = last - 1;
    auto pivot = partition(first, last,
                           [back](const auto& i) {
                               return i < *back;
                           });
    iter_swap(pivot, back);
    sort(first, pivot);
    sort(pivot + 1, last);
}

} /* namespace std */

#endif // _LENSOR_OS_ALGORITHM
