/* Copyright 2022, Contributors To LensorOS.
 * All rights reserved.
 *
 * This file is part of LensorOS.
 *
 * LensorOS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * LensorOS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with LensorOS. If not, see <https://www.gnu.org/licenses
 */

#ifndef _LENSOROS_FUNCTIONAL_H
#define _LENSOROS_FUNCTIONAL_H

#include <stddef.h>
#include <stdint.h>

#include <bit>

namespace std {

namespace __detail {

constexpr inline uint32_t __hash_32bit(uint32_t __x) {
    __x = ((__x >> 16) ^ __x) * 0x45d9f3b;
    __x = ((__x >> 16) ^ __x) * 0x45d9f3b;
    __x = (__x >> 16) ^ __x;
    return __x;
}
constexpr inline uint64_t __hash_64bit(uint64_t __x) {
    __x = (__x ^ (__x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
    __x = (__x ^ (__x >> 27)) * UINT64_C(0x94d049bb133111eb);
    __x = __x ^ (__x >> 31);
    return __x;
}
inline size_t __hash_ptr(void* __ptr) {
    static_assert(sizeof(void*) == 8, "Sorry, std::hash requires an 8-byte address size");
    return __hash_64bit(uint64_t(__ptr));
}

} // namespace __detail

template <class _Key>
struct hash;

template<>
struct hash<std::nullptr_t> {
    using argument_type = std::nullptr_t;
    using result_type = size_t;
    result_type operator() (argument_type __key) const noexcept {
        return std::__detail::__hash_ptr((void*)__key);
    }
};

template <class _T>
struct hash<_T*> {
    using argument_type = _T*;
    using result_type = size_t;
    result_type operator() (argument_type __key) const noexcept {
        return std::__detail::__hash_ptr(__key);
    }
};

#define _IntegerHashImpl(__t) template<> struct hash<__t> {             \
    using argument_type = __t;                                          \
    using result_type = size_t;                                         \
    result_type operator() (argument_type __key) const noexcept {       \
        if constexpr (sizeof(argument_type) <= 4) return std::__detail::__hash_32bit(__key); \
        if constexpr (sizeof(argument_type) <= 8) return std::__detail::__hash_64bit(__key); \
        static_assert(sizeof(argument_type) <= 8, "Unsupported "#__t" size, sorry"); \
    }                                                                   \
    }

_IntegerHashImpl(int);
_IntegerHashImpl(unsigned int);
_IntegerHashImpl(short);
_IntegerHashImpl(unsigned short);
_IntegerHashImpl(long);
_IntegerHashImpl(unsigned long);
_IntegerHashImpl(long long);
_IntegerHashImpl(unsigned long long);
_IntegerHashImpl(bool);
_IntegerHashImpl(char);
_IntegerHashImpl(signed char);
_IntegerHashImpl(unsigned char);
_IntegerHashImpl(char8_t);
_IntegerHashImpl(char16_t);
_IntegerHashImpl(char32_t);
_IntegerHashImpl(wchar_t);

template<> struct hash<float> {
    using argument_type = float;
    using result_type = size_t;
    result_type operator() (argument_type __key) const noexcept {
        static_assert(sizeof(argument_type) == 4, "Sorry, float must be 4 bytes for current implementation of std::hash");
        return std::__detail::__hash_32bit(std::bit_cast<uint32_t>(__key));
    }
};
template<> struct hash<double> {
    using argument_type = double;
    using result_type = size_t;
    result_type operator() (argument_type __key) const noexcept {
        static_assert(sizeof(argument_type) == 8, "Sorry, double must be 8 bytes for current implementation of std::hash");
        return std::__detail::__hash_64bit(std::bit_cast<uint64_t>(__key));
    }
};
// TODO
template<> struct hash<long double>;


#define _DefineBinaryOperatorFunctionObject(name, op) template <class _T = void> \
    struct name {                                                       \
        constexpr _T operator() (const _T& __lhs, const _T& __rhs) const { \
            return __lhs op __rhs;                                      \
        };                                                              \
    }

#define _DefineBinaryComparisonFunctionObject(name, op) template <class _T = void> \
    struct name {                                                       \
        constexpr bool operator() (const _T& __lhs, const _T& __rhs) const { \
            return __lhs op __rhs;                                      \
        };                                                              \
    }

#define _DefineUnaryOperatorFunctionObject(name, op) template <class _T = void> \
    struct name {                                                       \
        constexpr _T operator() (const _T& __value) const {             \
            return op __value;                                          \
        };                                                              \
    }


_DefineBinaryOperatorFunctionObject(plus, +);
_DefineBinaryOperatorFunctionObject(minus, -);
_DefineBinaryOperatorFunctionObject(multiplies, *);
_DefineBinaryOperatorFunctionObject(divides, /);
_DefineBinaryOperatorFunctionObject(modulus, %);
_DefineUnaryOperatorFunctionObject(negate, -);

_DefineBinaryComparisonFunctionObject(equal_to, ==);
_DefineBinaryComparisonFunctionObject(not_equal_to, !=);
_DefineBinaryComparisonFunctionObject(greater, >);
_DefineBinaryComparisonFunctionObject(lesser, <);
_DefineBinaryComparisonFunctionObject(greater_equal, >=);
_DefineBinaryComparisonFunctionObject(lesser_equal, <=);

_DefineBinaryOperatorFunctionObject(logical_and, &&);
_DefineBinaryOperatorFunctionObject(logical_or, ||);
_DefineUnaryOperatorFunctionObject(logical_not, !);

_DefineBinaryOperatorFunctionObject(bit_and, &);
_DefineBinaryOperatorFunctionObject(bit_or, |);
_DefineBinaryOperatorFunctionObject(bit_xor, ^);
_DefineUnaryOperatorFunctionObject(bit_not, ~);


} // namespace std

#endif // _LENSOROS_FUNCTIONAL_H
