
.. _program_listing_file_cpp_include_torch_tensorrt_ptq.h:

Program Listing for File ptq.h
==============================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_include_torch_tensorrt_ptq.h>` (``cpp/include/torch_tensorrt/ptq.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   /*
    * Copyright (c) NVIDIA Corporation.
    * All rights reserved.
    *
    * This library is licensed under the BSD-style license found in the
    * LICENSE file in the root directory of this source tree.
    */
   #pragma once

   #include <fstream>
   #include <iostream>
   #include <iterator>
   #include <memory>
   #include <sstream>
   #include <string>
   #include <vector>

   #include "NvInfer.h"
   #include "torch/torch.h"
   #include "torch_tensorrt/logging.h"

   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace nvinfer1 {
   class IInt8Calibrator;
   class IInt8EntropyCalibrator2;
   } // namespace nvinfer1

   namespace torch_tensorrt {
   namespace ptq {
   bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
   }
   } // namespace torch_tensorrt
   #endif // DOXYGEN_SHOULD_SKIP_THIS

   namespace torch_tensorrt {
   namespace ptq {

   template <typename Algorithm, typename DataLoaderUniquePtr>
   class Int8Calibrator : Algorithm {
     using DataLoader = typename DataLoaderUniquePtr::element_type;
     using Batch = typename DataLoader::super::BatchType;

    public:
     Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
         : dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) {
       for (auto batch : *dataloader_) {
         batched_data_.push_back(batch.data);
       }
       it_ = batched_data_.begin();
     }

     int getBatchSize() const noexcept override {
       // HACK: Torch-TensorRT only uses explict batch sizing, INT8 Calibrator does not
       // work when reporting the batch size here and having explicity batching.
       // So we just report batch size 1 (warnings will still be printed out).
       return 1;
       // return static_cast<int>(dataloader_->options().batch_size);
     }

     bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
       if (it_ != batched_data_.end()) {
         auto status = get_batch_impl(bindings, names, nbBindings, *it_);
         it_ = ++it_;
         return status;
       } else {
         // Reset iterator if incase calibrator is going to be used again
         it_ = batched_data_.begin();
         return false;
       }
     }

     const void* readCalibrationCache(size_t& length) noexcept override {
       if (use_cache_) {
         std::stringstream ss;
         ss << "Reading Calibration Cache from " << cache_file_path_;
         logging::log(logging::Level::kINFO, ss.str());

         cache_.clear();
         std::ifstream input(cache_file_path_, std::ios::binary);
         input >> std::noskipws;
         if (input.good()) {
           std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(cache_));
           logging::log(logging::Level::kDEBUG, "Cache read");
         }
         length = cache_.size();
         return length ? cache_.data() : nullptr;
       }
       return nullptr;
     }

     void writeCalibrationCache(const void* cache, size_t length) noexcept override {
       std::ofstream cache_file(cache_file_path_, std::ios::binary);
       cache_file.write(reinterpret_cast<const char*>(cache), length);
       std::stringstream ss;
       ss << "Saved Calibration Cache to " << cache_file_path_;
       logging::log(logging::Level::kINFO, ss.str());
     }

     operator nvinfer1::IInt8Calibrator*() {
       return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
     }

    private:
     DataLoader* dataloader_;
     const std::string& cache_file_path_;
     size_t cache_size_ = 0;
     bool use_cache_;
     std::vector<char> cache_;
     std::vector<torch::Tensor> batched_data_;
     std::vector<torch::Tensor>::iterator it_;
   };

   template <typename Algorithm>
   class Int8CacheCalibrator : Algorithm {
    public:
     Int8CacheCalibrator(const std::string& cache_file_path) : cache_file_path_(cache_file_path) {}

     int getBatchSize() const noexcept override {
       // HACK: Torch-TensorRT only uses explict batch sizing, INT8 Calibrator does not
       // work when reporting the batch size here and having explicity batching.
       // So we just report batch size 1 (warnings will still be printed out).
       return 1;
     }

     bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
       return false;
     }

     const void* readCalibrationCache(size_t& length) noexcept override {
       std::stringstream ss;
       ss << "Reading Calibration Cache from " << cache_file_path_;
       logging::log(logging::Level::kINFO, ss.str());

       cache_.clear();
       std::ifstream input(cache_file_path_, std::ios::binary);
       input >> std::noskipws;
       if (input.good()) {
         std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(cache_));
         logging::log(logging::Level::kDEBUG, "Cache read");
       }
       length = cache_.size();
       return length ? cache_.data() : nullptr;
     }

     void writeCalibrationCache(const void* cache, size_t length) noexcept override {
       std::ofstream cache_file(cache_file_path_, std::ios::binary);
       cache_file.write(reinterpret_cast<const char*>(cache), length);
       std::stringstream ss;
       ss << "Saved Calibration Cache to " << cache_file_path_;
       logging::log(logging::Level::kINFO, ss.str());
     }

     operator nvinfer1::IInt8Calibrator*() {
       return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
     }

    private:
     const std::string& cache_file_path_;
     size_t cache_size_ = 0;
     std::vector<char> cache_;
   };

   template <typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
   TORCHTRT_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(
       DataLoader dataloader,
       const std::string& cache_file_path,
       bool use_cache) {
     return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
   }

   template <typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
   TORCHTRT_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
     return Int8CacheCalibrator<Algorithm>(cache_file_path);
   }

   } // namespace ptq
   } // namespace torch_tensorrt
