Skip to content

🐛 [Bug] Engine caching and loading from cache #4226

@asfiyab-nvidia

Description

@asfiyab-nvidia

Bug Description

2 bugs observed

  1. Engine is not cached in the directory supplied to engine_cache_dir
  2. Engine loading from cache fails

To Reproduce

Run the below script:

"""
Usage:
    1. Run on DLFW 26.03 container
    2. Install Torch-TensorRT nightly wheel to get fused_rms_norm support: python -m pip install --pre torch-tensorrt==2.12.0.dev20260420 --index-url https://download.pytorch.org/whl/nightly/cu130 --extra-index-url https://pypi.org/simple --force-reinstall
    3. Apply patch from https://github.com/asfiyab-nvidia/PyTorch-TensorRT/commit/9cc2803dddc749c957332c779d10772116c2f8ff to Torch-TensorRT installation
    4. Force-reinstall TensorRT, step 2 breaks tensorrt import: pip3 install --upgrade tensorrt --force-reinstall
    5. Install other dependecies: pip3 install --upgrade diffusers transformers accelerate
    4. python run_flux.py
"""

import time
import torch


import torch_tensorrt as torchtrt
from diffusers import FluxPipeline
from transformers import CLIPTextModel, T5EncoderModel
from torch_tensorrt.dynamo._engine_cache import DiskEngineCache


MODEL_ID = ""


def load_pipeline() -> FluxPipeline:
    print(f"Loading {MODEL_ID} ...")
    t0 = time.time()
    text_encoder = CLIPTextModel.from_pretrained(
        f"{MODEL_ID}/text_encoder", torch_dtype=torch.bfloat16
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
        f"{MODEL_ID}/text_encoder_2", torch_dtype=torch.bfloat16
    )
    pipe = FluxPipeline.from_pretrained(
        MODEL_ID,
        text_encoder=text_encoder,
        text_encoder_2=text_encoder_2,
        torch_dtype=torch.bfloat16,
        use_safetensors=False,
    ).to("cuda")
    print(f"  loaded in {time.time() - t0:.1f}s")
    return pipe


def compile_transformer(pipe: FluxPipeline) -> torch.export.ExportedProgram:
    print("\nExporting transformer ...")
    transformer = pipe.transformer

    dummy_inputs = {
        "hidden_states": torch.randn(1, 4096, 64, dtype=torch.bfloat16, device="cuda"),
        "encoder_hidden_states": torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda"),
        "pooled_projections": torch.randn(1, 768, dtype=torch.bfloat16, device="cuda"),
        "timestep": torch.randn(1, dtype=torch.bfloat16, device="cuda"),
        "guidance": torch.randn(1, dtype=torch.float32, device="cuda"),
        "img_ids": torch.randn(4096, 3, dtype=torch.bfloat16, device="cuda"),
        "txt_ids": torch.randn(512, 3, dtype=torch.bfloat16, device="cuda"),
    }

    ep_transformer = torch.export.export(
        transformer, (), kwargs=dummy_inputs, strict=True,
    )

    ep_transformer.graph_module.print_readable()

    with torchtrt.dynamo.Debugger(
        log_level="debug",
        logging_dir="logs",
        engine_builder_monitor=False,
        save_layer_info=True,
        save_engine_profile=True,
    ):
        my_cache = DiskEngineCache(
            engine_cache_dir="./engine_cache_dir",
            engine_cache_size=20 * 1024**3,  # 20 GB
        )
        compiled_transformer = torchtrt.dynamo.compile(
                ep_transformer,
                inputs=dummy_inputs,
                truncate_double=True,
                require_full_compilation=True,
                cache_built_engines=True,
                reuse_cached_engines=True,
                immutable_weights=False,
                engine_cache=my_cache,
            )

    return compiled_transformer

def main() -> None:
    pipe = load_pipeline()
    _ = compile_transformer(pipe)


if __name__ == "__main__":
    main()

Expected behavior

  1. blob.bin to be saved in ./engine_cache_dir
  2. re-running the script should trigger loading the engine from cache but I run into the error below:
compiled_transformer = torchtrt.dynamo.compile(
                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 772, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 1118, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 348, in convert_module
    serialized_interpreter_result = interpret_module_to_result(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 254, in interpret_module_to_result
    serialized_interpreter_result = pull_cached_engine(
                                    ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 169, in pull_cached_engine
    _refit_single_trt_engine_with_gm(
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/_features.py", line 177, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_refit.py", line 209, in _refit_single_trt_engine_with_gm
    assert len(missing_weights) == 0 and len(unset_weights) == 0, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Fast refitting failed due to incomplete mapping (502 missing, 1274 unset)

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.12.0.dev20260420
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12
  • CUDA version:
  • GPU models and configuration: H100
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions