
.. _program_listing_file_cpp_include_torch_tensorrt_torch_tensorrt.h:

Program Listing for File torch_tensorrt.h
=========================================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_include_torch_tensorrt_torch_tensorrt.h>` (``cpp/include/torch_tensorrt/torch_tensorrt.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   /*
    * Copyright (c) NVIDIA Corporation.
    * All rights reserved.
    *
    * This library is licensed under the BSD-style license found in the
    * LICENSE file in the root directory of this source tree.
    */

   #pragma once

   #include <cuda_runtime.h>
   #include <iostream>
   #include <memory>
   #include <set>
   #include <string>
   #include <vector>
   #include "torch/custom_class.h"

   #include "torch_tensorrt/macros.h"

   // Just include the .h?
   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace torch {
   namespace jit {
   struct Graph;
   struct Module;
   } // namespace jit
   } // namespace torch

   namespace c10 {
   enum class DeviceType : int8_t;
   enum class ScalarType : int8_t;
   template <class>
   class ArrayRef;
   } // namespace c10

   namespace nvinfer1 {
   class IInt8Calibrator;
   }
   #endif // DOXYGEN_SHOULD_SKIP_THIS

   namespace torch_tensorrt {
   class DataType {
    public:
     enum Value : int8_t {
       kFloat,
       kHalf,
       kChar,
       kInt,
       kBool,
       kUnknown
     };

     DataType() = default;
     constexpr DataType(Value t) : value(t) {}
     TORCHTRT_API DataType(c10::ScalarType t);
     operator Value() const {
       return value;
     }
     explicit operator bool() = delete;
     constexpr bool operator==(DataType other) const {
       return value == other.value;
     }
     constexpr bool operator==(DataType::Value other) const {
       return value == other;
     }
     constexpr bool operator!=(DataType other) const {
       return value != other.value;
     }
     constexpr bool operator!=(DataType::Value other) const {
       return value != other;
     }

    private:
     friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const DataType& dtype);
     Value value;
   };

   struct Device {
     class DeviceType {
      public:
       enum Value : int8_t {
         kGPU,
         kDLA,
       };

       DeviceType() = default;
       constexpr DeviceType(Value t) : value(t) {}
       DeviceType(c10::DeviceType t);
       operator Value() const {
         return value;
       }
       explicit operator bool() = delete;
       constexpr bool operator==(DeviceType other) const {
         return value == other.value;
       }
       constexpr bool operator!=(DeviceType other) const {
         return value != other.value;
       }

      private:
       Value value;
     };

     DeviceType device_type;

     /*
      * Target gpu id
      */
     int64_t gpu_id;

     /*
      * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
      */
     int64_t dla_core;

     bool allow_gpu_fallback;

     Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
   };

   enum class EngineCapability : int8_t {
     kSTANDARD,
     kSAFETY,
     kDLA_STANDALONE,
   };

   class TensorFormat {
    public:
     enum Value : int8_t {
       kContiguous,
       kChannelsLast,
       kUnknown,
     };

     TensorFormat() = default;
     constexpr TensorFormat(Value t) : value(t) {}
     TORCHTRT_API TensorFormat(at::MemoryFormat t);
     operator Value() const {
       return value;
     }
     explicit operator bool() = delete;
     constexpr bool operator==(TensorFormat other) const {
       return value == other.value;
     }
     constexpr bool operator==(TensorFormat::Value other) const {
       return value == other;
     }
     constexpr bool operator!=(TensorFormat other) const {
       return value != other.value;
     }
     constexpr bool operator!=(TensorFormat::Value other) const {
       return value != other;
     }

    private:
     friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
     Value value;
   };

   struct Input : torch::CustomClassHolder {
     std::vector<int64_t> min_shape;
     std::vector<int64_t> opt_shape;
     std::vector<int64_t> max_shape;
     std::vector<int64_t> shape;
     DataType dtype;
     TensorFormat format;

     Input() {}
     TORCHTRT_API Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(
         std::vector<int64_t> min_shape,
         std::vector<int64_t> opt_shape,
         std::vector<int64_t> max_shape,
         TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(
         std::vector<int64_t> min_shape,
         std::vector<int64_t> opt_shape,
         std::vector<int64_t> max_shape,
         DataType dtype,
         TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> min_shape,
         c10::ArrayRef<int64_t> opt_shape,
         c10::ArrayRef<int64_t> max_shape,
         TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> min_shape,
         c10::ArrayRef<int64_t> opt_shape,
         c10::ArrayRef<int64_t> max_shape,
         DataType dtype,
         TensorFormat format = TensorFormat::kContiguous);

     TORCHTRT_API Input(at::Tensor tensor);

    private:
     friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const Input& input);
     bool input_is_dynamic;
   };

   struct GraphInputs {
     torch::jit::IValue input_signature; // nested Input, full input spec
     std::vector<Input> inputs; // flatten input spec
   };

   TORCHTRT_API std::string get_build_info();

   TORCHTRT_API void dump_build_info();

   TORCHTRT_API void set_device(const int gpu_id);

   namespace torchscript {
   struct CompileSpec {
     TORCHTRT_API CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);

     TORCHTRT_API CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

     TORCHTRT_API CompileSpec(std::vector<Input> inputs);

     TORCHTRT_API CompileSpec(torch::jit::IValue input_signature);
     // Defaults should reflect TensorRT defaults for BuilderConfig

     GraphInputs graph_inputs;
     std::set<DataType> enabled_precisions = {DataType::kFloat};

     bool disable_tf32 = false;

     bool sparse_weights = false;

     bool refit = false;

     bool debug = false;

     bool truncate_long_and_double = false;

     Device device;

     EngineCapability capability = EngineCapability::kSTANDARD;

     uint64_t num_avg_timing_iters = 1;

     uint64_t workspace_size = 0;

     uint64_t dla_sram_size = 1048576;

     uint64_t dla_local_dram_size = 1073741824;

     uint64_t dla_global_dram_size = 536870912;

     nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;

     bool require_full_compilation = false;

     uint64_t min_block_size = 3;

     std::vector<std::string> torch_executed_ops;

     std::vector<std::string> torch_executed_modules;
   };

   TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);

   TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);

   TORCHTRT_API std::string convert_method_to_trt_engine(
       const torch::jit::Module& module,
       std::string method_name,
       CompileSpec info);

   TORCHTRT_API torch::jit::Module embed_engine_in_new_module(
       const std::string& engine,
       Device device,
       const std::vector<std::string>& input_binding_names = std::vector<std::string>(),
       const std::vector<std::string>& output_binding_names = std::vector<std::string>());
   } // namespace torchscript
   } // namespace torch_tensorrt
