#include <string>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
  auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
    auto g = std::make_shared<torch::jit::Graph>();
    torch::jit::parseIR(graph, g.get());

    auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
    auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});

    auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
    auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});

    params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
    auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});

    ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5));
  };
  const auto graph = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=3]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";
  const auto graph2 = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=-1]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";
  const auto graph3 = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=-2]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";

  TestATenStackPureTensorConvertsCorrectly(graph);
  TestATenStackPureTensorConvertsCorrectly(graph2);
  TestATenStackPureTensorConvertsCorrectly(graph3);
}

TEST(Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
  auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
    auto g = std::make_shared<torch::jit::Graph>();
    torch::jit::parseIR(graph, g.get());

    auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
    auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});

    auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
    auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});

    params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
    auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in1, in2});

    ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5));
  };
  const auto graph = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=1]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";
  const auto graph2 = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=-1]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";

  TestATenStackPureTensorConvertsCorrectly(graph);
  TestATenStackPureTensorConvertsCorrectly(graph2);
}

TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
  auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
    auto g = std::make_shared<torch::jit::Graph>();
    torch::jit::parseIR(graph, g.get());

    auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
    auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});

    auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
    auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});

    params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
    auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});

    ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5));
  };
  const auto graph = R"IR(
      graph(%0 : Tensor,
            %1 : Float(4, 4, 4, strides=[16, 4, 1])):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=1]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";
  const auto graph2 = R"IR(
      graph(%0 : Tensor,
            %1 : Float(4, 4, 4, strides=[16, 4, 1])):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=-1]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";
  const auto graph3 = R"IR(
      graph(%0 : Tensor,
            %1 : Float(4, 4, 4, strides=[16, 4, 1])):
        %2 : Tensor[] = prim::ListConstruct(%0, %1)
        %3 : int = prim::Constant[value=-3]()
        %4 : Tensor = aten::stack(%2, %3)
        return (%4))IR";
  TestATenStackDiffTensorConvertsCorrectly(graph);
  TestATenStackDiffTensorConvertsCorrectly(graph2);
  TestATenStackDiffTensorConvertsCorrectly(graph3);
}
