
.. _program_listing_file_cpp_api_include_trtorch_trtorch.h:

Program Listing for File trtorch.h
==================================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_api_include_trtorch_trtorch.h>` (``cpp/api/include/trtorch/trtorch.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   /*
    * Copyright (c) NVIDIA Corporation.
    * All rights reserved.
    *
    * This library is licensed under the BSD-style license found in the
    * LICENSE file in the root directory of this source tree.
    */

   #pragma once

   #include <string>
   #include <vector>
   #include <memory>

   // Just include the .h?
   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace torch {
   namespace jit {
   struct Graph;
   struct Module;
   } // namespace jit
   } // namespace torch

   namespace c10 {
   enum class DeviceType : int16_t;
   enum class ScalarType : int8_t;
   template <class>
   class ArrayRef;
   }

   namespace nvinfer1 {
   class IInt8EntropyCalibrator2;
   }
   #endif //DOXYGEN_SHOULD_SKIP_THIS

   #include "trtorch/macros.h"
   #include "trtorch/logging.h"
   #include "trtorch/ptq.h"
   namespace trtorch {
   struct TRTORCH_API ExtraInfo {
       struct TRTORCH_API InputRange {
           std::vector<int64_t> min;
           std::vector<int64_t> opt;
           std::vector<int64_t> max;
           InputRange(std::vector<int64_t> opt);
           InputRange(c10::ArrayRef<int64_t> opt);
           InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
           InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
       };

       class DataType {
       public:
           enum Value : int8_t {
               kFloat,
               kHalf,
               kChar,
           };

           DataType() = default;
           constexpr DataType(Value t) : value(t) {}
           DataType(c10::ScalarType t);
           operator Value() const  { return value; }
           explicit operator bool() = delete;
           constexpr bool operator==(DataType other) const { return value == other.value; }
           constexpr bool operator!=(DataType other) const { return value != other.value; }
       private:
           Value value;
       };

       class DeviceType {
       public:
           enum Value : int8_t {
               kGPU,
               kDLA,
           };

           DeviceType() = default;
           constexpr DeviceType(Value t) : value(t) {}
           DeviceType(c10::DeviceType t);
           operator Value() const { return value; }
           explicit operator bool() = delete;
           constexpr bool operator==(DeviceType other) const { return value == other.value; }
           constexpr bool operator!=(DeviceType other) const { return value != other.value; }
       private:
           Value value;
       };

       enum class EngineCapability : int8_t {
           kDEFAULT,
           kSAFE_GPU,
           kSAFE_DLA,
       };

       ExtraInfo(std::vector<InputRange> input_ranges)
           : input_ranges(std::move(input_ranges)) {}
       ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
       ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

       // Defaults should reflect TensorRT defaults for BuilderConfig

       std::vector<InputRange> input_ranges;

       DataType op_precision = DataType::kFloat;

       bool refit = false;

       bool debug = false;

       bool strict_types = false;

       bool allow_gpu_fallback = true;

       DeviceType device = DeviceType::kGPU;

       EngineCapability capability = EngineCapability::kDEFAULT;

       uint64_t num_min_timing_iters = 2;
       uint64_t num_avg_timing_iters = 1;

       uint64_t workspace_size = 0;

       uint64_t max_batch_size = 0;

       nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
   };

   TRTORCH_API std::string get_build_info();


   TRTORCH_API void dump_build_info();

   TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);

   TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, ExtraInfo info);

   TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);

   namespace ptq {
   template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
   TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
       return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
   }

   template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
   TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
       return Int8CacheCalibrator<Algorithm>(cache_file_path);
   }
   } // namespace ptq
   } // namespace trtorch
