
.. _program_listing_file_cpp_api_include_trtorch_trtorch.h:

Program Listing for File trtorch.h
==================================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_api_include_trtorch_trtorch.h>` (``cpp/api/include/trtorch/trtorch.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   /*
    * Copyright (c) NVIDIA Corporation.
    * All rights reserved.
    *
    * This library is licensed under the BSD-style license found in the
    * LICENSE file in the root directory of this source tree.
    */

   #pragma once

   #include <cuda_runtime.h>
   #include <memory>
   #include <string>
   #include <vector>

   // Just include the .h?
   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace torch {
   namespace jit {
   struct Graph;
   struct Module;
   } // namespace jit
   } // namespace torch

   namespace c10 {
   enum class DeviceType : int16_t;
   enum class ScalarType : int8_t;
   template <class>
   class ArrayRef;
   } // namespace c10

   namespace nvinfer1 {
   class IInt8Calibrator;
   }
   #endif // DOXYGEN_SHOULD_SKIP_THIS

   #include "trtorch/macros.h"
   namespace trtorch {
   struct TRTORCH_API CompileSpec {
     struct TRTORCH_API InputRange {
       std::vector<int64_t> min;
       std::vector<int64_t> opt;
       std::vector<int64_t> max;
       InputRange(std::vector<int64_t> opt);
       InputRange(c10::ArrayRef<int64_t> opt);
       InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
       InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
     };

     class TRTORCH_API DataType {
      public:
       enum Value : int8_t {
         kFloat,
         kHalf,
         kChar,
       };

       DataType() = default;
       constexpr DataType(Value t) : value(t) {}
       DataType(c10::ScalarType t);
       operator Value() const {
         return value;
       }
       explicit operator bool() = delete;
       constexpr bool operator==(DataType other) const {
         return value == other.value;
       }
       constexpr bool operator==(DataType::Value other) const {
         return value == other;
       }
       constexpr bool operator!=(DataType other) const {
         return value != other.value;
       }
       constexpr bool operator!=(DataType::Value other) const {
         return value != other;
       }

      private:
       Value value;
     };

     enum class EngineCapability : int8_t {
       kDEFAULT,
       kSAFE_GPU,
       kSAFE_DLA,
     };

     CompileSpec(std::vector<InputRange> input_ranges) : input_ranges(std::move(input_ranges)) {}
     CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
     CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

     // Defaults should reflect TensorRT defaults for BuilderConfig

     std::vector<InputRange> input_ranges;

     DataType op_precision = DataType::kFloat;

     bool disable_tf32 = false;

     bool refit = false;

     bool debug = false;

     bool strict_types = false;

     /*
      * Setting data structure for Target device
      */
     struct Device {
       class DeviceType {
        public:
         enum Value : int8_t {
           kGPU,
           kDLA,
         };

         DeviceType() = default;
         constexpr DeviceType(Value t) : value(t) {}
         DeviceType(c10::DeviceType t);
         operator Value() const {
           return value;
         }
         explicit operator bool() = delete;
         constexpr bool operator==(DeviceType other) const {
           return value == other.value;
         }
         constexpr bool operator!=(DeviceType other) const {
           return value != other.value;
         }

        private:
         Value value;
       };

       DeviceType device_type;

       /*
        * Target gpu id
        */
       int64_t gpu_id;

       /*
        * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
        */
       int64_t dla_core;

       bool allow_gpu_fallback;

       Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
     };

     /*
      * Target Device
      */
     Device device;

     EngineCapability capability = EngineCapability::kDEFAULT;

     uint64_t num_min_timing_iters = 2;
     uint64_t num_avg_timing_iters = 1;

     uint64_t workspace_size = 0;

     uint64_t max_batch_size = 0;

     nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
   };

   TRTORCH_API std::string get_build_info();

   TRTORCH_API void dump_build_info();

   TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);

   TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info);

   TRTORCH_API std::string ConvertGraphToTRTEngine(
       const torch::jit::Module& module,
       std::string method_name,
       CompileSpec info);
   TRTORCH_API void set_device(const int gpu_id);

   } // namespace trtorch
