
.. _program_listing_file_cpp_api_include_trtorch_ptq.h:

Program Listing for File ptq.h
==============================

|exhale_lsh| :ref:`Return to documentation for file <file_cpp_api_include_trtorch_ptq.h>` (``cpp/api/include/trtorch/ptq.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   #pragma once

   #include <string>
   #include <vector>
   #include <memory>
   #include <iostream>
   #include <sstream>

   #ifndef DOXYGEN_SHOULD_SKIP_THIS
   namespace nvinfer1 {
   class IInt8Calibrator;
   class IInt8EntropyCalibrator2;
   }

   namespace torch {
   namespace data {
   template<typename Example>
   class Iterator;
   }
   }
   #endif //DOXYGEN_SHOULD_SKIP_THIS

   namespace trtorch {
   namespace ptq {

   template<typename Algorithm, typename DataLoaderUniquePtr>
   class Int8Calibrator : Algorithm {
       using DataLoader = typename DataLoaderUniquePtr::element_type;
       using Batch = typename DataLoader::super::BatchType;
   public:
       Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
         : dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}

       int getBatchSize() const override {
           // HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
           // work when reporting the batch size here and having explicity batching.
           // So we just report batch size 1 (warnings will still be printed out).
           return 1;
           //return static_cast<int>(dataloader_->options().batch_size);
       }

       bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
           // HACK: doesnt seem like the first try in the initializer list works
           if (! it_created_) {
               it_ = dataloader_->begin();
               it_created_ = true;
           }

           if (it_ == dataloader_->end()) {
               return false;
           }

           auto batch = *it_;

           for (int i = 0; i < nbBindings; i++) {
               auto data = batch.data;
               data = data.to(at::kCUDA).contiguous();
               bindings[i] = data.data_ptr();
           }

           it_ = ++it_;
           return true;
       }

       const void* readCalibrationCache(size_t& length) override {
           if (use_cache_) {
               std::stringstream ss;
               ss << "Reading Calibration Cache from " << cache_file_path_;
               logging::log(logging::Level::kINFO, ss.str());
               cache_.clear();
               std::ifstream cache_file(cache_file_path_, std::ios::binary);
               cache_file >> std::noskipws;
               if (cache_file.good()) {
                   std::copy(std::istream_iterator<char>(cache_file),
                               std::istream_iterator<char>(),
                               std::back_inserter(cache_));
                   ss << "Cache read";
                   logging::log(logging::Level::kDEBUG, ss.str());
               }
               cache_size_ = cache_.size();
               return cache_size_ ? cache_.data() : nullptr;
           }
           return nullptr;
       }

       void writeCalibrationCache(const void* cache, size_t length) override {
           std::ofstream cache_file(cache_file_path_, std::ios::binary);
           cache_file.write(reinterpret_cast<const char*>(cache), length);
           std::stringstream ss;
           ss << "Saved Calibration Cache to " << cache_file_path_;
           logging::log(logging::Level::kINFO, ss.str());
       }

       operator nvinfer1::IInt8Calibrator* () {
           return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
       }

   private:
       DataLoader* dataloader_;
       torch::data::Iterator<Batch> it_;
       const std::string& cache_file_path_;
       size_t cache_size_ = 0;
       bool use_cache_;
       std::vector<char> cache_;
       bool it_created_ = false;
   };

   template<typename Algorithm>
   class Int8CacheCalibrator : Algorithm {
   public:
       Int8CacheCalibrator(const std::string& cache_file_path)
         : cache_file_path_(cache_file_path) {}

       int getBatchSize() const override {
           // HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
           // work when reporting the batch size here and having explicity batching.
           // So we just report batch size 1 (warnings will still be printed out).
           return 1;
       }

       bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
           return false;
       }

       const void* readCalibrationCache(size_t& length) override {
           std::stringstream ss;
           ss << "Reading Calibration Cache from " << cache_file_path_;
           logging::log(logging::Level::kINFO, ss.str());
           cache_.clear();
           std::ifstream cache_file;
           cache_file.open(cache_file_path_, std::ios::in | std::ios::binary);
           cache_file.unsetf(std::ios::skipws);
           cache_file.seekg(0, std::ios::beg);
           cache_.reserve(cache_file.tellg());
           cache_file.seekg(0, std::ios::beg);
           if (cache_file.good()) {
               std::cout << "Trying to read cache" << std::endl;
               std::copy(std::istreambuf_iterator<char>(cache_file),
                           std::istreambuf_iterator<char>(),
                           std::back_inserter(cache_));
               ss << "Cache read";
               logging::log(logging::Level::kDEBUG, ss.str());
           }
           cache_size_ = cache_.size();
           return cache_size_ ? cache_.data() : nullptr;
       }


       void writeCalibrationCache(const void* cache, size_t length) override {
           std::ofstream cache_file(cache_file_path_, std::ios::binary);
           cache_file.write(reinterpret_cast<const char*>(cache), length);
           std::stringstream ss;
           ss << "Saved Calibration Cache to " << cache_file_path_;
           logging::log(logging::Level::kINFO, ss.str());
       }

       operator nvinfer1::IInt8Calibrator* () {
           return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
       }

   private:
       const std::string& cache_file_path_;
       size_t cache_size_ = 0;
       std::vector<char> cache_;
   };

   } // namespace ptq
   } // namespace trtorch
