// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "toolchain/sem_ir/builtin_function_kind.h"

#include <utility>

#include "toolchain/sem_ir/file.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::SemIR {

// A function that validates that a builtin was declared properly.
using ValidateFn = auto(const File& sem_ir, llvm::ArrayRef<TypeId> arg_types,
                        TypeId return_type) -> bool;

namespace {
// Information about a builtin function.
struct BuiltinInfo {
  llvm::StringLiteral name;
  ValidateFn* validate;
};

// The maximum number of type parameters any builtin needs.
constexpr int MaxTypeParams = 2;

// State used when validating a builtin signature that persists between
// individual checks.
struct ValidateState {
  // The type values of type parameters in the builtin signature. Invalid if
  // either no value has been deduced yet or the parameter is not used.
  TypeId type_params[MaxTypeParams] = {TypeId::Invalid, TypeId::Invalid};
};

// Constraint that a type is generic type parameter `I` of the builtin,
// satisfying `TypeConstraint`. See ValidateSignature for details.
template <int I, typename TypeConstraint>
struct TypeParam {
  static_assert(I >= 0 && I < MaxTypeParams);

  static auto Check(const File& sem_ir, ValidateState& state, TypeId type_id)
      -> bool {
    if (state.type_params[I].is_valid() && type_id != state.type_params[I]) {
      return false;
    }
    if (!TypeConstraint::Check(sem_ir, state, type_id)) {
      return false;
    }
    state.type_params[I] = type_id;
    return true;
  }
};

// Constraint that a type is a specific builtin. See ValidateSignature for
// details.
template <const InstId& BuiltinId>
struct BuiltinType {
  static auto Check(const File& sem_ir, ValidateState& /*state*/,
                    TypeId type_id) -> bool {
    return sem_ir.types().GetInstId(type_id) == BuiltinId;
  }
};

// Constraint that the function has no return.
struct NoReturn {
  static auto Check(const File& sem_ir, ValidateState& /*state*/,
                    TypeId type_id) -> bool {
    auto tuple = sem_ir.types().TryGetAs<SemIR::TupleType>(type_id);
    if (!tuple) {
      return false;
    }
    return sem_ir.type_blocks().Get(tuple->elements_id).empty();
  }
};

// Constraint that a type is `bool`.
using Bool = BuiltinType<BoolType::SingletonInstId>;

// Constraint that requires the type to be an integer type.
struct AnyInt {
  static auto Check(const File& sem_ir, ValidateState& state, TypeId type_id)
      -> bool {
    if (BuiltinType<IntLiteralType::SingletonInstId>::Check(sem_ir, state,
                                                            type_id)) {
      return true;
    }
    return sem_ir.types().Is<IntType>(type_id);
  }
};

// Constraint that requires the type to be a float type.
struct AnyFloat {
  static auto Check(const File& sem_ir, ValidateState& state, TypeId type_id)
      -> bool {
    if (BuiltinType<LegacyFloatType::SingletonInstId>::Check(sem_ir, state,
                                                             type_id)) {
      return true;
    }
    return sem_ir.types().Is<FloatType>(type_id);
  }
};

// Checks that the specified type matches the given type constraint.
template <typename TypeConstraint>
auto Check(const File& sem_ir, ValidateState& state, TypeId type_id) -> bool {
  while (type_id.is_valid()) {
    // Allow a type that satisfies the constraint.
    if (TypeConstraint::Check(sem_ir, state, type_id)) {
      return true;
    }

    // Also allow a class type that adapts a matching type.
    auto class_type = sem_ir.types().TryGetAs<ClassType>(type_id);
    if (!class_type) {
      break;
    }
    type_id = sem_ir.classes()
                  .Get(class_type->class_id)
                  .GetAdaptedType(sem_ir, class_type->specific_id);
  }
  return false;
}

// Constraint that requires the type to be the type type.
using Type = BuiltinType<TypeType::SingletonInstId>;

}  // namespace

// Validates that this builtin has a signature matching the specified signature.
//
// `SignatureFnType` is a C++ function type that describes the signature that is
// expected for this builtin. For example, `auto (AnyInt, AnyInt) -> AnyInt`
// specifies that the builtin takes values of two integer types and returns a
// value of a third integer type. Types used within the signature should provide
// a `Check` function that validates that the Carbon type is expected:
//
//   auto Check(const File&, ValidateState&, TypeId) -> bool;
//
// To constrain that the same type is used in multiple places in the signature,
// `TypeParam<I, T>` can be used. For example:
//
//   auto (TypeParam<0, AnyInt>, AnyInt) -> TypeParam<0, AnyInt>
//
// describes a builtin that takes two integers, and whose return type matches
// its first parameter type. For convenience, typedefs for `TypeParam<I, T>`
// are used in the descriptions of the builtins.
template <typename SignatureFnType>
static auto ValidateSignature(const File& sem_ir,
                              llvm::ArrayRef<TypeId> arg_types,
                              TypeId return_type) -> bool {
  using SignatureTraits = llvm::function_traits<SignatureFnType*>;
  ValidateState state;

  // Must have expected number of arguments.
  if (arg_types.size() != SignatureTraits::num_args) {
    return false;
  }

  // Argument types must match.
  if (![&]<size_t... Indexes>(std::index_sequence<Indexes...>) {
        return ((Check<typename SignatureTraits::template arg_t<Indexes>>(
                    sem_ir, state, arg_types[Indexes])) &&
                ...);
      }(std::make_index_sequence<SignatureTraits::num_args>())) {
    return false;
  }

  // Result type must match.
  if (!Check<typename SignatureTraits::result_t>(sem_ir, state, return_type)) {
    return false;
  }

  return true;
}

// Descriptions of builtin functions follow. For each builtin, a corresponding
// `BuiltinInfo` constant is declared describing properties of that builtin.
namespace BuiltinFunctionInfo {

// Convenience name used in the builtin type signatures below for a first
// generic type parameter that is constrained to be an integer type.
using IntT = TypeParam<0, AnyInt>;

// Convenience name used in the builtin type signatures below for a second
// generic type parameter that is constrained to be an integer type.
using IntU = TypeParam<1, AnyInt>;

// Convenience name used in the builtin type signatures below for a first
// generic type parameter that is constrained to be an float type.
using FloatT = TypeParam<0, AnyFloat>;

// Not a builtin function.
constexpr BuiltinInfo None = {"", nullptr};

// Prints an argument.
constexpr BuiltinInfo PrintInt = {"print.int",
                                  ValidateSignature<auto(AnyInt)->NoReturn>};

// Returns the `Core.IntLiteral` type.
constexpr BuiltinInfo IntLiteralMakeType = {"int_literal.make_type",
                                            ValidateSignature<auto()->Type>};

// Returns the `iN` type.
// TODO: Should we use a more specific type as the type of the bit width?
constexpr BuiltinInfo IntMakeTypeSigned = {
    "int.make_type_signed", ValidateSignature<auto(AnyInt)->Type>};

// Returns the `uN` type.
constexpr BuiltinInfo IntMakeTypeUnsigned = {
    "int.make_type_unsigned", ValidateSignature<auto(AnyInt)->Type>};

// Returns float types, such as `f64`. Currently only supports `f64`.
constexpr BuiltinInfo FloatMakeType = {"float.make_type",
                                       ValidateSignature<auto(AnyInt)->Type>};

// Returns the `bool` type.
constexpr BuiltinInfo BoolMakeType = {"bool.make_type",
                                      ValidateSignature<auto()->Type>};

// Converts between integer types, with a diagnostic if the value doesn't fit.
constexpr BuiltinInfo IntConvertChecked = {
    "int.convert_checked", ValidateSignature<auto(AnyInt)->AnyInt>};

// "int.snegate": integer negation.
constexpr BuiltinInfo IntSNegate = {"int.snegate",
                                    ValidateSignature<auto(IntT)->IntT>};

// "int.sadd": integer addition.
constexpr BuiltinInfo IntSAdd = {"int.sadd",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.ssub": integer subtraction.
constexpr BuiltinInfo IntSSub = {"int.ssub",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.smul": integer multiplication.
constexpr BuiltinInfo IntSMul = {"int.smul",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.sdiv": integer division.
constexpr BuiltinInfo IntSDiv = {"int.sdiv",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.smod": integer modulo.
constexpr BuiltinInfo IntSMod = {"int.smod",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.unegate": unsigned integer negation.
constexpr BuiltinInfo IntUNegate = {"int.unegate",
                                    ValidateSignature<auto(IntT)->IntT>};

// "int.uadd": unsigned integer addition.
constexpr BuiltinInfo IntUAdd = {"int.uadd",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.usub": unsigned integer subtraction.
constexpr BuiltinInfo IntUSub = {"int.usub",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.umul": unsigned integer multiplication.
constexpr BuiltinInfo IntUMul = {"int.umul",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.udiv": unsigned integer division.
constexpr BuiltinInfo IntUDiv = {"int.udiv",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.mod": integer modulo.
constexpr BuiltinInfo IntUMod = {"int.umod",
                                 ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.complement": integer bitwise complement.
constexpr BuiltinInfo IntComplement = {"int.complement",
                                       ValidateSignature<auto(IntT)->IntT>};

// "int.and": integer bitwise and.
constexpr BuiltinInfo IntAnd = {"int.and",
                                ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.or": integer bitwise or.
constexpr BuiltinInfo IntOr = {"int.or",
                               ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.xor": integer bitwise xor.
constexpr BuiltinInfo IntXor = {"int.xor",
                                ValidateSignature<auto(IntT, IntT)->IntT>};

// "int.left_shift": integer left shift.
constexpr BuiltinInfo IntLeftShift = {
    "int.left_shift", ValidateSignature<auto(IntT, IntU)->IntT>};

// "int.left_shift": integer right shift.
constexpr BuiltinInfo IntRightShift = {
    "int.right_shift", ValidateSignature<auto(IntT, IntU)->IntT>};

// "int.eq": integer equality comparison.
constexpr BuiltinInfo IntEq = {"int.eq",
                               ValidateSignature<auto(IntT, IntT)->Bool>};

// "int.neq": integer non-equality comparison.
constexpr BuiltinInfo IntNeq = {"int.neq",
                                ValidateSignature<auto(IntT, IntT)->Bool>};

// "int.less": integer less than comparison.
constexpr BuiltinInfo IntLess = {"int.less",
                                 ValidateSignature<auto(IntT, IntT)->Bool>};

// "int.less_eq": integer less than or equal comparison.
constexpr BuiltinInfo IntLessEq = {"int.less_eq",
                                   ValidateSignature<auto(IntT, IntT)->Bool>};

// "int.greater": integer greater than comparison.
constexpr BuiltinInfo IntGreater = {"int.greater",
                                    ValidateSignature<auto(IntT, IntT)->Bool>};

// "int.greater_eq": integer greater than or equal comparison.
constexpr BuiltinInfo IntGreaterEq = {
    "int.greater_eq", ValidateSignature<auto(IntT, IntT)->Bool>};

// "float.negate": float negation.
constexpr BuiltinInfo FloatNegate = {"float.negate",
                                     ValidateSignature<auto(FloatT)->FloatT>};

// "float.add": float addition.
constexpr BuiltinInfo FloatAdd = {
    "float.add", ValidateSignature<auto(FloatT, FloatT)->FloatT>};

// "float.sub": float subtraction.
constexpr BuiltinInfo FloatSub = {
    "float.sub", ValidateSignature<auto(FloatT, FloatT)->FloatT>};

// "float.mul": float multiplication.
constexpr BuiltinInfo FloatMul = {
    "float.mul", ValidateSignature<auto(FloatT, FloatT)->FloatT>};

// "float.div": float division.
constexpr BuiltinInfo FloatDiv = {
    "float.div", ValidateSignature<auto(FloatT, FloatT)->FloatT>};

// "float.eq": float equality comparison.
constexpr BuiltinInfo FloatEq = {"float.eq",
                                 ValidateSignature<auto(FloatT, FloatT)->Bool>};

// "float.neq": float non-equality comparison.
constexpr BuiltinInfo FloatNeq = {
    "float.neq", ValidateSignature<auto(FloatT, FloatT)->Bool>};

// "float.less": float less than comparison.
constexpr BuiltinInfo FloatLess = {
    "float.less", ValidateSignature<auto(FloatT, FloatT)->Bool>};

// "float.less_eq": float less than or equal comparison.
constexpr BuiltinInfo FloatLessEq = {
    "float.less_eq", ValidateSignature<auto(FloatT, FloatT)->Bool>};

// "float.greater": float greater than comparison.
constexpr BuiltinInfo FloatGreater = {
    "float.greater", ValidateSignature<auto(FloatT, FloatT)->Bool>};

// "float.greater_eq": float greater than or equal comparison.
constexpr BuiltinInfo FloatGreaterEq = {
    "float.greater_eq", ValidateSignature<auto(FloatT, FloatT)->Bool>};

}  // namespace BuiltinFunctionInfo

CARBON_DEFINE_ENUM_CLASS_NAMES(BuiltinFunctionKind) = {
#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
  BuiltinFunctionInfo::Name.name,
#include "toolchain/sem_ir/builtin_function_kind.def"
};

// Returns the builtin function kind with the given name, or None if the name
// is unknown.
auto BuiltinFunctionKind::ForBuiltinName(llvm::StringRef name)
    -> BuiltinFunctionKind {
#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
  if (name == BuiltinFunctionInfo::Name.name) {   \
    return BuiltinFunctionKind::Name;             \
  }
#include "toolchain/sem_ir/builtin_function_kind.def"
  return BuiltinFunctionKind::None;
}

auto BuiltinFunctionKind::IsValidType(const File& sem_ir,
                                      llvm::ArrayRef<TypeId> arg_types,
                                      TypeId return_type) const -> bool {
  static constexpr ValidateFn* ValidateFns[] = {
#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
  BuiltinFunctionInfo::Name.validate,
#include "toolchain/sem_ir/builtin_function_kind.def"
  };
  return ValidateFns[AsInt()](sem_ir, arg_types, return_type);
}

auto BuiltinFunctionKind::IsCompTimeOnly() const -> bool {
  return *this == BuiltinFunctionKind::IntConvertChecked;
}

}  // namespace Carbon::SemIR
