#pragma once

#include <ATen/ATen.h>
#include <string>
#include <vector>
#include "ATen/Tensor.h"
#include "core/ir/ir.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/ir/irparser.h"

const float ATOL = 5e-3;
const float RTOL = 5e-3;
const float COSINE_THRESHOLD = 0.99f;
const float THRESHOLD_E5 = 1e-5;

namespace torch_tensorrt {
namespace tests {
namespace util {

bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float threshold = COSINE_THRESHOLD);

bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = ATOL, float rtol = RTOL);

bool sameShape(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor);

bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);

void pointwise_test_helper(
    std::string graph_ir,
    bool singleInput,
    bool dynamicInput = false,
    std::vector<int64_t> shape1 = {5},
    std::vector<int64_t> shape2 = {5},
    bool negative_input = false,
    at::ScalarType type1 = at::kFloat,
    at::ScalarType type2 = at::kFloat);

std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs);

// Runs an arbitrary JIT graph and returns results
std::vector<at::Tensor> RunGraph(
    std::shared_ptr<torch::jit::Graph>& g,
    core::ir::StaticParams& named_params,
    std::vector<at::Tensor> inputs);

// Runs an arbitrary JIT graph by converting it to TensorRT and running
// inference and returns results
std::vector<at::Tensor> RunGraphEngine(
    std::shared_ptr<torch::jit::Graph>& g,
    core::ir::StaticParams& named_params,
    std::vector<at::Tensor> inputs,
    nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT);

// Runs an arbitrary JIT graph with dynamic input sizes by converting it to
// TensorRT and running inference and returns results
std::vector<at::Tensor> RunGraphEngineDynamic(
    std::shared_ptr<torch::jit::Graph>& g,
    core::ir::StaticParams& named_params,
    std::vector<at::Tensor> inputs,
    bool dynamic_batch = false,
    bool allow_shape_tensors = false);

// Run the forward method of a module and return results
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);

// Convert the forward module to a TRT engine and return results
std::vector<at::Tensor> RunModuleForwardAsEngine(torch::jit::Module& mod, std::vector<at::Tensor> inputs);

// Runs evaluatable graphs through the compiler evaluator library and returns results
std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::vector<torch::jit::IValue> inputs);

// Runs evaluatable graphs through the JIT interpreter and returns results
std::vector<torch::jit::IValue> EvaluateGraphJIT(
    std::shared_ptr<torch::jit::Graph>& g,
    std::vector<torch::jit::IValue> inputs);
} // namespace util
} // namespace tests
} // namespace torch_tensorrt
