#include "core/conversion/evaluators/eval_util.h"
#include <ATen/ATen.h>
#include "ATen/InitialTensorOptions.h"
#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/jit_type.h"
#include "c10/util/irange.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace evaluators {

nvinfer1::ITensor* index_layer(
    ConversionCtx* ctx,
    const torch::jit::Node* n,
    nvinfer1::ITensor* input_tensor,
    int64_t index) {
  // index to access needs to be an at::Tensor
  at::Tensor indices = torch::tensor({index}).to(torch::kI32);
  auto indices_out = converters::tensor_to_const(ctx, indices);

  auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
  TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
  auto indexed_tensor = gather_layer->getOutput(0);
  return indexed_tensor;
}

c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
  LOG_DEBUG("Using dynamic version of aten::size evaluator");
  auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
  auto input_dims = in->getDimensions();
  LOG_DEBUG("Input dimensions: " << input_dims);
  nvinfer1::ITensor* shape_1d_tensor = torch_tensorrt::core::conversion::converters::getShapeOutput(
      ctx, in, std::string(util::node_info(n) + "_dynamic_shape_layer_cast").c_str());
  if (n->inputs().size() != 1) {
    auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
    auto dim = args.at(n->input(1)).unwrapToInt();
    // Handle negative axis by refering to nbDims of input Tensor
    dim = dim < 0 ? dim + maxDim : dim;
    LOG_DEBUG("Dimension to select: " << dim);
    shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
    LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());

    auto tensor_holder = TensorContainer();
    tensor_holder.hold_tensor(shape_1d_tensor);
    auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));

    return shape_1d_ivalue;

  } else {
    auto input_size = c10::impl::GenericList(c10::AnyType::get());
    // Only express the dynamic dimension with a shape layer output.
    // The static dimensions are preserved in the input size.
    for (int32_t i = 0; i < input_dims.nbDims; i++) {
      if (input_dims.d[i] == -1) {
        auto dynamic_dim_tensor = index_layer(ctx, n, shape_1d_tensor, i);
        auto dynamic_dim_holder = TensorContainer();
        dynamic_dim_holder.hold_tensor(dynamic_dim_tensor);
        auto dynamic_dim_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(dynamic_dim_holder)));
        input_size.emplace_back(std::move(dynamic_dim_ivalue));
      } else {
        input_size.emplace_back(input_dims.d[i]);
      }
    }
    return c10::IValue(input_size);
  }
}

int64_t normalizeIndex(int64_t idx, int64_t list_size) {
  if (idx < 0) {
    // Handle negative indexing
    idx = list_size + idx;
  }
  return idx;
}

// TODO: Switch back to PyTorch canonical implimentation
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
  if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {
    return c10::nullopt;
  }
  const torch::jit::Node* node = v->node();
  const c10::TypePtr& type = v->type();

  c10::Symbol attr_value = c10::Symbol::fromDomainAndUnqualString(c10::attr::value.domainString(), "value");

  if (type->isSubtypeOf(c10::TensorType::get())) {
    return node->t(attr_value);
  } else if (type->isSubtypeOf(c10::BoolType::get())) {
    return (bool)node->i(attr_value);
  } else if (type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(attr_value) == torch::jit::AttributeKind::i) {
    return node->i(attr_value);
  } else if (type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(attr_value) == torch::jit::AttributeKind::f) {
    return node->f(attr_value);
  } else if (type->isSubtypeOf(c10::ListType::ofInts())) {
    try {
      const auto& is = node->is(attr_value);
      return is;
    } catch (const std::exception& ex) {
      const auto& ival = node->ival(attr_value);
      return ival;
    }
  } else if (type->isSubtypeOf(c10::ListType::ofFloats())) {
    try {
      const auto& fs = node->fs(attr_value);
      return fs;
    } catch (const std::exception& ex) {
      const auto& ival = node->ival(attr_value);
      return ival;
    }
  } else if (type->isSubtypeOf(c10::ListType::ofBools())) {
    const auto bs = c10::fmap<bool>(node->is(attr_value));
    return bs;
  } else if (type->isSubtypeOf(c10::ListType::ofTensors())) {
    try {
      const auto& ts = node->ts(attr_value);
      return ts;
    } catch (const std::exception& ex) {
      const auto& ival = node->ival(attr_value);
      return ival;
    }
  } else if (type->isSubtypeOf(c10::ListType::ofStrings())) {
    try {
      const auto& ss = node->ss(attr_value);
      auto vals = c10::impl::GenericList(c10::StringType::get());
      for (const auto& str : ss) {
        vals.push_back(str);
      }
      return vals;
    } catch (const std::exception& ex) {
      const auto& ival = node->ival(attr_value);
      return ival;
    }
  } else if (type->cast<c10::ListType>() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) {
    const auto& list = node->ival(attr_value);
    TORCHTRT_ASSERT(list.isList(), "Is not a list");
    return list;
  } else if (type->cast<c10::DictType>() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) {
    const auto& dict = node->ival(attr_value);
    TORCHTRT_ASSERT(dict.isGenericDict(), "Is not a dict");
    return dict;
  } else if (type->cast<c10::TupleType>() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) {
    const auto& tup = node->ival(attr_value);
    TORCHTRT_ASSERT(tup.isTuple(), "Is not a tuple");
    return tup;
  } else if (type == c10::StringType::get()) {
    const auto& s = node->s(attr_value);
    return s;
  } else if (type == c10::DeviceObjType::get()) {
    auto d = c10::Device(node->s(attr_value));
    return d;
  } else if (node->mustBeNone()) {
    return torch::jit::IValue();
  } else {
    std::stringstream ss;
    ss << "constant literal not supported for: " << type->str();
    throw std::runtime_error(ss.str());
  }
}

void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
  if (!elem_type->isSubtypeOf(c10::NumberType::get()) && elem_type != c10::BoolType::get()) {
    std::stringstream error;
    error << "Input must be of ints, floats, or bools, "
          << "got " << elem_type->repr_str();
    // special case empty list torch.tensor([])
    if (elem_type->isSubtypeOf(c10::TensorType::get())) {
      if (empty_list) {
        error << "\nEmpty lists default to List[Tensor]. Add a variable "
                 "annotation to the assignment to create an empty list "
                 "of another type (torch.jit.annotate(List[T, []]) where T "
                 "is the type of elements in the list for Python 2)";
      }
    }
    TORCHTRT_THROW_ERROR(error.str());
  }
}

void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
  if (seq_size != n) {
    TORCHTRT_THROW_ERROR("Expected sequence of length " << n << " at dim " << dim << " (got " << seq_size << ")");
  }
}

// TODO: Conditionally enable truncation based on user setting
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) {
  // This function is basically same with the one in
  // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
  // won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion
  if (device == at::kCPU) {
    if (s.isFloatingPoint()) {
      LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kFloat in scalar_to_tensor_util ");
      return at::detail::scalar_tensor_static(s, at::kFloat, at::kCPU);
    } else if (s.isComplex()) {
      return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
    } else if (s.isBoolean()) {
      return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
    } else {
      AT_ASSERT(s.isIntegral(false));
      LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util ");
      return at::detail::scalar_tensor_static(s, at::kInt, at::kCPU);
    }
  }
  if (s.isFloatingPoint()) {
    LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kFloat in scalar_to_tensor_util ");
    return at::scalar_tensor(s, at::device(device).dtype(at::kFloat));
  } else if (s.isBoolean()) {
    return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
  } else if (s.isComplex()) {
    return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble));
  } else {
    AT_ASSERT(s.isIntegral(false));
    LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util ");
    return at::scalar_tensor(s, at::device(device).dtype(at::kInt));
  }
}

template <typename DTYPE>
void storeLastDimension(
    char* data,
    const std::vector<int64_t>& sizes,
    const c10::ArrayRef<int64_t>& strides,
    int64_t dim,
    int elementSize,
    at::ArrayRef<torch::jit::IValue> obj) {
  auto n = sizes[dim];
  auto seq_size = obj.size();
  checkSequenceSize(n, dim, seq_size);
  for (int64_t i = 0; i < n; i++) {
    *(DTYPE*)data = obj[i].to<DTYPE>();
    data += strides[dim] * elementSize;
  }
}

void storeLastDimensionFloat(
    char* data,
    const std::vector<int64_t>& sizes,
    const c10::ArrayRef<int64_t>& strides,
    int64_t dim,
    int elementSize,
    at::ArrayRef<torch::jit::IValue> obj) {
  auto n = sizes[dim];
  auto seq_size = obj.size();
  checkSequenceSize(n, dim, seq_size);
  for (int64_t i = 0; i < n; i++) {
    *(float*)data = static_cast<float>(obj[i].to<double>());
    data += strides[dim] * elementSize;
  }
}

void storeLastDimensionHalf(
    char* data,
    const std::vector<int64_t>& sizes,
    const c10::ArrayRef<int64_t>& strides,
    int64_t dim,
    int elementSize,
    at::ArrayRef<torch::jit::IValue> obj) {
  auto n = sizes[dim];
  auto seq_size = obj.size();
  checkSequenceSize(n, dim, seq_size);
  for (int64_t i = 0; i < n; i++) {
    *(at::Half*)data = at::convert<at::Half, double>(obj[i].to<double>());
    data += strides[dim] * elementSize;
  }
}

void recursiveStore(
    char* data,
    const std::vector<int64_t>& sizes,
    const c10::ArrayRef<int64_t>& strides,
    int64_t dim,
    int tenElementSize,
    const torch::jit::IValue& obj) {
  auto ndim = sizes.size();
  auto n = sizes[dim];
  auto seq = obj.toListRef();
  checkSequenceSize(n, dim, seq.size());
  if (dim + 1 < static_cast<long>(ndim)) {
    for (const auto i : c10::irange(n)) {
      recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]);
      data += strides[dim] * tenElementSize;
    }
  } else {
    if (obj.isIntList()) {
      storeLastDimension<int64_t>(data, sizes, strides, dim, tenElementSize, seq);
    } else if (obj.isBoolList()) {
      storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
    } else if (obj.isDoubleList()) {
      if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Double))) {
        storeLastDimension<double>(data, sizes, strides, dim, tenElementSize, seq);
      } else if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Float))) {
        storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
      } else if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Half))) {
        storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
      } else {
        TORCHTRT_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
      }
    } else {
      TORCHTRT_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
    }
  }
}

at::Tensor castTensorTo(at::Tensor self, const torch::jit::IValue& dtype, const torch::jit::IValue& device) {
  at::ScalarType scalar_type = dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
  c10::Device dev = device.isNone() ? self.device() : device.toDevice();
  if (scalar_type != self.scalar_type() || dev != self.device()) {
    self = self.to(dev, scalar_type);
  }
  return self;
}

std::vector<int64_t> compute_sizes(const torch::jit::IValue& seq) {
  std::vector<int64_t> sizes;
  auto seq_recur = seq.toList();
  while (true) {
    sizes.push_back(seq_recur.size());
    if (seq_recur.size() == 0 || !seq_recur.get(0).isList()) {
      break;
    }
    seq_recur = seq_recur.get(0).toList();
  }
  return sizes;
}

at::Tensor createTensorFromList(
    const torch::jit::IValue& data,
    const torch::jit::IValue& dtype,
    const torch::jit::IValue& device) {
  auto elem_type = data.type();
  /// Recurse down nested lists to find base type
  while (auto list_type = elem_type->cast<c10::ListType>()) {
    elem_type = list_type->getElementType();
  }
  /// Gets shape of tensor to be created
  auto sizes = compute_sizes(data);
  checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
  at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(*elem_type);
  if (initial_scalar_type == at::ScalarType::Double) {
    initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype());
  }

  auto tensor = at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));

  if (tensor.numel() != 0) {
    recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0, tensor.element_size(), data);
  }

  tensor = castTensorTo(tensor, dtype, device);
  auto default_type = at::typeMetaToScalarType(at::get_default_dtype());

  if (dtype.isNone() && tensor.scalar_type() != default_type && tensor.numel() == 0) {
    LOG_WARNING(
        "Creating a tensor from an empty "
        << elem_type->repr_str() << "list will create a tensor of default floating point type  (currently "
        << default_type << ") in python but a tensor of type " << elem_type->repr_str() << " in torchscript.\n"
        << "Pass in a dtype argument to ensure consistent behavior");
  }

  return tensor;
}

std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
  auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);

  // Input 2 is the dtype
  if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
    options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
  } else {
    auto tensor_var = args.at(n->input(0));
    if (tensor_var.isITensor()) {
      auto tensor = tensor_var.ITensor();
      options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
    } else {
      auto tensor = tensor_var.unwrapToTensor();
      options = options.dtype(tensor.dtype());
    }
  }
  return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
}

c10::optional<torch::jit::IValue> newTensorLikeImplementation(
    ConversionCtx* ctx,
    const torch::jit::Node* n,
    kwargs& args,
    const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
  auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
  auto tensor_var = args.at(n->input(0));

  if (tensor_var.isITensor()) {
    auto tensor = tensor_var.ITensor();
    auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
    options = options.dtype(dtype);
  } else {
    auto tensor = tensor_var.unwrapToTensor();
    options = options.dtype(tensor.dtype());
  }

  // Input 1 is the dtype
  if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
    options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
  }
  std::vector<int64_t> tensor_dims;
  if (tensor_var.isITensor()) {
    auto tensor = tensor_var.ITensor();
    tensor_dims = util::toVec(tensor->getDimensions());
  } else {
    auto tensor = tensor_var.unwrapToTensor();
    tensor_dims = tensor.sizes().vec();
  }
  if (ctx->settings.allow_shape_tensors && ctx->input_is_dynamic) {
    auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
    std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
    auto constant = tensor_builder(dims_vec, options);
    auto constant_itensor = converters::tensor_to_const(ctx, constant);
    // broadcast constant to output shape
    std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
    auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
    nvinfer1::ITensor* shape_output = torch_tensorrt::core::conversion::converters::getShapeOutput(
        ctx, self, std::string(util::node_info(n) + "_shape").c_str());
    // slice implements expand
    auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
    TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
    slice_layer->setInput(2, *shape_output);
    slice_layer->setName((util::node_info(n) + "_slice").c_str());
    auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
    LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
    return {};
  }
  return tensor_builder(tensor_dims, options);
}

} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
