
.. _program_listing_file_cpp_include_trtorch_trtorch.h:

Program Listing for File trtorch.h
==================================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_include_trtorch_trtorch.h>` (``cpp/include/trtorch/trtorch.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   /*
    * Copyright (c) NVIDIA Corporation.
    * All rights reserved.
    *
    * This library is licensed under the BSD-style license found in the
    * LICENSE file in the root directory of this source tree.
    */

   #pragma once

   #include <cuda_runtime.h>
   #include <iostream>
   #include <memory>
   #include <set>
   #include <string>
   #include <vector>

   // Just include the .h?
   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace torch {
   namespace jit {
   struct Graph;
   struct Module;
   } // namespace jit
   } // namespace torch

   namespace c10 {
   enum class DeviceType : int8_t;
   enum class ScalarType : int8_t;
   template <class>
   class ArrayRef;
   } // namespace c10

   namespace nvinfer1 {
   class IInt8Calibrator;
   }
   #endif // DOXYGEN_SHOULD_SKIP_THIS

   #include "trtorch/macros.h"
   namespace trtorch {
   struct TRTORCH_API CompileSpec {
     class TRTORCH_API DataType {
      public:
       enum Value : int8_t {
         kFloat,
         kHalf,
         kChar,
         kInt,
         kBool,
         kUnknown
       };

       DataType() = default;
       constexpr DataType(Value t) : value(t) {}
       DataType(c10::ScalarType t);
       operator Value() const {
         return value;
       }
       explicit operator bool() = delete;
       constexpr bool operator==(DataType other) const {
         return value == other.value;
       }
       constexpr bool operator==(DataType::Value other) const {
         return value == other;
       }
       constexpr bool operator!=(DataType other) const {
         return value != other.value;
       }
       constexpr bool operator!=(DataType::Value other) const {
         return value != other;
       }

      private:
       friend std::ostream& operator<<(std::ostream& os, const DataType& dtype);
       Value value;
     };

     /*
      * Setting data structure for Target device
      */
     struct Device {
       class DeviceType {
        public:
         enum Value : int8_t {
           kGPU,
           kDLA,
         };

         DeviceType() = default;
         constexpr DeviceType(Value t) : value(t) {}
         DeviceType(c10::DeviceType t);
         operator Value() const {
           return value;
         }
         explicit operator bool() = delete;
         constexpr bool operator==(DeviceType other) const {
           return value == other.value;
         }
         constexpr bool operator!=(DeviceType other) const {
           return value != other.value;
         }

        private:
         Value value;
       };

       DeviceType device_type;

       /*
        * Target gpu id
        */
       int64_t gpu_id;

       /*
        * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
        */
       int64_t dla_core;

       bool allow_gpu_fallback;

       Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
     };

     enum class EngineCapability : int8_t {
       kSTANDARD,
       kSAFETY,
       kDLA_STANDALONE,
     };

     class TRTORCH_API TensorFormat {
      public:
       enum Value : int8_t {
         kContiguous,
         kChannelsLast,
         kUnknown,
       };

       TensorFormat() = default;
       constexpr TensorFormat(Value t) : value(t) {}
       TensorFormat(at::MemoryFormat t);
       operator Value() const {
         return value;
       }
       explicit operator bool() = delete;
       constexpr bool operator==(TensorFormat other) const {
         return value == other.value;
       }
       constexpr bool operator==(TensorFormat::Value other) const {
         return value == other;
       }
       constexpr bool operator!=(TensorFormat other) const {
         return value != other.value;
       }
       constexpr bool operator!=(TensorFormat::Value other) const {
         return value != other;
       }

      private:
       friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
       Value value;
     };

     struct TRTORCH_API Input {
       std::vector<int64_t> min_shape;
       std::vector<int64_t> opt_shape;
       std::vector<int64_t> max_shape;
       std::vector<int64_t> shape;
       DataType dtype;
       TensorFormat format;

       Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

       Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

       Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

       Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

       Input(
           std::vector<int64_t> min_shape,
           std::vector<int64_t> opt_shape,
           std::vector<int64_t> max_shape,
           TensorFormat format = TensorFormat::kContiguous);

       Input(
           std::vector<int64_t> min_shape,
           std::vector<int64_t> opt_shape,
           std::vector<int64_t> max_shape,
           DataType dtype,
           TensorFormat format = TensorFormat::kContiguous);

       Input(
           c10::ArrayRef<int64_t> min_shape,
           c10::ArrayRef<int64_t> opt_shape,
           c10::ArrayRef<int64_t> max_shape,
           TensorFormat format = TensorFormat::kContiguous);

       Input(
           c10::ArrayRef<int64_t> min_shape,
           c10::ArrayRef<int64_t> opt_shape,
           c10::ArrayRef<int64_t> max_shape,
           DataType dtype,
           TensorFormat format = TensorFormat::kContiguous);

       bool get_explicit_set_dtype() {
         return explicit_set_dtype;
       }

      private:
       friend std::ostream& operator<<(std::ostream& os, const Input& input);
       bool input_is_dynamic;
       bool explicit_set_dtype;
     };

     struct TRTORCH_API InputRange {
       std::vector<int64_t> min;
       std::vector<int64_t> opt;
       std::vector<int64_t> max;
       [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
           std::vector<int64_t> opt);
       [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
           c10::ArrayRef<int64_t> opt);
       [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
           std::vector<int64_t> min,
           std::vector<int64_t> opt,
           std::vector<int64_t> max);
       [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
           c10::ArrayRef<int64_t> min,
           c10::ArrayRef<int64_t> opt,
           c10::ArrayRef<int64_t> max);
     };

     struct TRTORCH_API TorchFallback {
       bool enabled = false;

       uint64_t min_block_size = 1;

       std::vector<std::string> forced_fallback_ops;

       std::vector<std::string> forced_fallback_modules;

       TorchFallback() = default;

       TorchFallback(bool enabled) : enabled(enabled) {}

       TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
     };

     [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
         std::vector<InputRange> input_ranges)
         : input_ranges(std::move(input_ranges)) {}
     CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);

     CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

     CompileSpec(std::vector<Input> inputs) : inputs(std::move(inputs)) {}

     // Defaults should reflect TensorRT defaults for BuilderConfig

     std::vector<Input> inputs;

     [[deprecated(
         "trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
         vector<InputRange>
             input_ranges;

     [[deprecated(
         "trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
         op_precision = DataType::kFloat;

     std::set<DataType> enabled_precisions = {DataType::kFloat};

     bool disable_tf32 = false;

     bool sparse_weights = false;

     bool refit = false;

     bool debug = false;

     bool truncate_long_and_double = false;

     bool strict_types = false;

     Device device;

     TorchFallback torch_fallback;

     EngineCapability capability = EngineCapability::kSTANDARD;

     uint64_t num_min_timing_iters = 2;
     uint64_t num_avg_timing_iters = 1;

     uint64_t workspace_size = 0;

     uint64_t max_batch_size = 0;

     nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
   };

   TRTORCH_API std::string get_build_info();

   TRTORCH_API void dump_build_info();

   TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);

   TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info);

   TRTORCH_API std::string ConvertGraphToTRTEngine(
       const torch::jit::Module& module,
       std::string method_name,
       CompileSpec info);

   TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec::Device device);

   TRTORCH_API void set_device(const int gpu_id);

   } // namespace trtorch
