
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/_rendered_examples/dynamo/engine_caching_bert_example.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials__rendered_examples_dynamo_engine_caching_bert_example.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials__rendered_examples_dynamo_engine_caching_bert_example.py:


.. _engine_caching_bert_example:

Engine Caching (BERT)
=======================

Small caching example on BERT.

.. GENERATED FROM PYTHON SOURCE LINES 10-76

.. code-block:: python


    import numpy as np
    import torch
    import torch_tensorrt
    from engine_caching_example import remove_timing_cache
    from transformers import BertModel

    np.random.seed(0)
    torch.manual_seed(0)

    model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
    inputs = [
        torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
        torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    ]


    def compile_bert(iterations=3):
        times = []
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        # The 1st iteration is to measure the compilation time without engine caching
        # The 2nd and 3rd iterations are to measure the compilation time with engine caching.
        # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
        # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
        for i in range(iterations):
            # remove timing cache and reset dynamo for engine caching messurement
            remove_timing_cache()
            torch._dynamo.reset()

            if i == 0:
                cache_built_engines = False
                reuse_cached_engines = False
            else:
                cache_built_engines = True
                reuse_cached_engines = True

            start.record()
            compilation_kwargs = {
                "use_python_runtime": False,
                "enabled_precisions": {torch.float},
                "truncate_double": True,
                "debug": False,
                "min_block_size": 1,
                "immutable_weights": False,
                "cache_built_engines": cache_built_engines,
                "reuse_cached_engines": reuse_cached_engines,
                "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
                "engine_cache_size": 1 << 30,  # 1GB
            }
            optimized_model = torch.compile(
                model,
                backend="torch_tensorrt",
                options=compilation_kwargs,
            )
            optimized_model(*inputs)
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))

        print("-----compile bert-----> compilation time:\n", times, "milliseconds")


    if __name__ == "__main__":
        compile_bert()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)


.. _sphx_glr_download_tutorials__rendered_examples_dynamo_engine_caching_bert_example.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: engine_caching_bert_example.py <engine_caching_bert_example.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: engine_caching_bert_example.ipynb <engine_caching_bert_example.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
