
.. _program_listing_file_cpp_include_torch_tensorrt_torch_tensorrt.h:

Program Listing for File torch_tensorrt.h
=========================================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_include_torch_tensorrt_torch_tensorrt.h>` (``cpp/include/torch_tensorrt/torch_tensorrt.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   /*
    * Copyright (c) NVIDIA Corporation.
    * All rights reserved.
    *
    * This library is licensed under the BSD-style license found in the
    * LICENSE file in the root directory of this source tree.
    */
   
   #pragma once
   
   #include <cuda_runtime.h>
   #include <iostream>
   #include <memory>
   #include <set>
   #include <string>
   #include <vector>
   #include "torch/custom_class.h"
   
   #include "torch_tensorrt/macros.h"
   
   // Just include the .h?
   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace torch {
   namespace jit {
   struct Graph;
   struct Module;
   } // namespace jit
   } // namespace torch
   
   namespace c10 {
   enum class DeviceType : int8_t;
   enum class ScalarType : int8_t;
   template <class>
   class ArrayRef;
   } // namespace c10
   
   namespace nvinfer1 {
   class IInt8Calibrator;
   }
   #endif // DOXYGEN_SHOULD_SKIP_THIS
   
   namespace torch_tensorrt {
   class DataType {
    public:
     enum Value : int8_t {
       kLong,
       kDouble,
       kFloat,
       kHalf,
       kChar,
       kInt,
       kBool,
       kUnknown
     };
   
     DataType() = default;
     constexpr DataType(Value t) : value(t) {}
     TORCHTRT_API DataType(c10::ScalarType t);
     operator Value() const {
       return value;
     }
     explicit operator bool() = delete;
     constexpr bool operator==(DataType other) const {
       return value == other.value;
     }
     constexpr bool operator==(DataType::Value other) const {
       return value == other;
     }
     constexpr bool operator!=(DataType other) const {
       return value != other.value;
     }
     constexpr bool operator!=(DataType::Value other) const {
       return value != other;
     }
   
    private:
     friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const DataType& dtype);
     Value value;
   };
   
   struct Device {
     class DeviceType {
      public:
       enum Value : int8_t {
         kGPU,
         kDLA,
       };
   
       DeviceType() = default;
       constexpr DeviceType(Value t) : value(t) {}
       DeviceType(c10::DeviceType t);
       operator Value() const {
         return value;
       }
       explicit operator bool() = delete;
       constexpr bool operator==(DeviceType other) const {
         return value == other.value;
       }
       constexpr bool operator!=(DeviceType other) const {
         return value != other.value;
       }
   
      private:
       Value value;
     };
   
     DeviceType device_type;
   
     /*
      * Target gpu id
      */
     int64_t gpu_id;
   
     /*
      * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
      */
     int64_t dla_core;
   
     bool allow_gpu_fallback;
   
     Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
   };
   
   enum class EngineCapability : int8_t {
     kSTANDARD,
     kSAFETY,
     kDLA_STANDALONE,
   };
   
   class TensorFormat {
    public:
     enum Value : int8_t {
       kContiguous,
       kChannelsLast,
       kUnknown,
     };
   
     TensorFormat() = default;
     constexpr TensorFormat(Value t) : value(t) {}
     TORCHTRT_API TensorFormat(at::MemoryFormat t);
     operator Value() const {
       return value;
     }
     explicit operator bool() = delete;
     constexpr bool operator==(TensorFormat other) const {
       return value == other.value;
     }
     constexpr bool operator==(TensorFormat::Value other) const {
       return value == other;
     }
     constexpr bool operator!=(TensorFormat other) const {
       return value != other.value;
     }
     constexpr bool operator!=(TensorFormat::Value other) const {
       return value != other;
     }
   
    private:
     friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
     Value value;
   };
   
   struct Input : torch::CustomClassHolder {
     std::vector<int64_t> min_shape;
     std::vector<int64_t> opt_shape;
     std::vector<int64_t> max_shape;
     std::vector<int64_t> shape;
     DataType dtype;
     TensorFormat format;
     std::vector<double> tensor_domain;
   
     Input() {}
     TORCHTRT_API Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         std::vector<int64_t> shape,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         std::vector<int64_t> shape,
         DataType dtype,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> shape,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> shape,
         DataType dtype,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         std::vector<int64_t> min_shape,
         std::vector<int64_t> opt_shape,
         std::vector<int64_t> max_shape,
         TensorFormat format = TensorFormat::kContiguous);
     TORCHTRT_API Input(
         std::vector<int64_t> min_shape,
         std::vector<int64_t> opt_shape,
         std::vector<int64_t> max_shape,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         std::vector<int64_t> min_shape,
         std::vector<int64_t> opt_shape,
         std::vector<int64_t> max_shape,
         DataType dtype,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         std::vector<int64_t> min_shape,
         std::vector<int64_t> opt_shape,
         std::vector<int64_t> max_shape,
         DataType dtype,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> min_shape,
         c10::ArrayRef<int64_t> opt_shape,
         c10::ArrayRef<int64_t> max_shape,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> min_shape,
         c10::ArrayRef<int64_t> opt_shape,
         c10::ArrayRef<int64_t> max_shape,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> min_shape,
         c10::ArrayRef<int64_t> opt_shape,
         c10::ArrayRef<int64_t> max_shape,
         DataType dtype,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(
         c10::ArrayRef<int64_t> min_shape,
         c10::ArrayRef<int64_t> opt_shape,
         c10::ArrayRef<int64_t> max_shape,
         DataType dtype,
         std::vector<double> tensor_domain,
         TensorFormat format = TensorFormat::kContiguous);
   
     TORCHTRT_API Input(at::Tensor tensor);
   
    private:
     friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const Input& input);
     bool input_is_dynamic;
   };
   
   struct GraphInputs {
     torch::jit::IValue input_signature; // nested Input, full input spec
     std::vector<Input> inputs; // flatten input spec
   };
   
   TORCHTRT_API std::string get_build_info();
   
   TORCHTRT_API void dump_build_info();
   
   TORCHTRT_API void set_device(const int gpu_id);
   
   namespace torchscript {
   struct CompileSpec {
     TORCHTRT_API CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
   
     TORCHTRT_API CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
   
     TORCHTRT_API CompileSpec(std::vector<Input> inputs);
   
     TORCHTRT_API CompileSpec(torch::jit::IValue input_signature);
     // Defaults should reflect TensorRT defaults for BuilderConfig
   
     GraphInputs graph_inputs;
     std::set<DataType> enabled_precisions = {DataType::kFloat};
   
     bool disable_tf32 = false;
   
     bool sparse_weights = false;
   
     bool refit = false;
   
     bool debug = false;
   
     bool truncate_long_and_double = false;
   
     bool allow_shape_tensors = false;
   
     Device device;
   
     EngineCapability capability = EngineCapability::kSTANDARD;
   
     uint64_t num_avg_timing_iters = 1;
   
     uint64_t workspace_size = 0;
   
     uint64_t dla_sram_size = 1048576;
   
     uint64_t dla_local_dram_size = 1073741824;
   
     uint64_t dla_global_dram_size = 536870912;
   
     nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
   
     bool require_full_compilation = false;
   
     uint64_t min_block_size = 3;
   
     std::vector<std::string> torch_executed_ops;
   
     std::vector<std::string> torch_executed_modules;
   };
   
   TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);
   
   TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);
   
   TORCHTRT_API std::string convert_method_to_trt_engine(
       const torch::jit::Module& module,
       std::string method_name,
       CompileSpec info);
   
   TORCHTRT_API torch::jit::Module embed_engine_in_new_module(
       const std::string& engine,
       Device device,
       const std::vector<std::string>& input_binding_names = std::vector<std::string>(),
       const std::vector<std::string>& output_binding_names = std::vector<std::string>());
   } // namespace torchscript
   } // namespace torch_tensorrt
