"""
Usage:
1. Run on DLFW 26.03 container
2. Install Torch-TensorRT nightly wheel to get fused_rms_norm support: python -m pip install --pre torch-tensorrt==2.12.0.dev20260420 --index-url https://download.pytorch.org/whl/nightly/cu130 --extra-index-url https://pypi.org/simple --force-reinstall
3. Apply patch from https://github.com/asfiyab-nvidia/PyTorch-TensorRT/commit/9cc2803dddc749c957332c779d10772116c2f8ff to Torch-TensorRT installation
4. Force-reinstall TensorRT, step 2 breaks tensorrt import: pip3 install --upgrade tensorrt --force-reinstall
5. Install other dependecies: pip3 install --upgrade diffusers transformers accelerate
4. python run_flux.py
"""
import time
import torch
import torch_tensorrt as torchtrt
from diffusers import FluxPipeline
from transformers import CLIPTextModel, T5EncoderModel
from torch_tensorrt.dynamo._engine_cache import DiskEngineCache
MODEL_ID = ""
def load_pipeline() -> FluxPipeline:
print(f"Loading {MODEL_ID} ...")
t0 = time.time()
text_encoder = CLIPTextModel.from_pretrained(
f"{MODEL_ID}/text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
f"{MODEL_ID}/text_encoder_2", torch_dtype=torch.bfloat16
)
pipe = FluxPipeline.from_pretrained(
MODEL_ID,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16,
use_safetensors=False,
).to("cuda")
print(f" loaded in {time.time() - t0:.1f}s")
return pipe
def compile_transformer(pipe: FluxPipeline) -> torch.export.ExportedProgram:
print("\nExporting transformer ...")
transformer = pipe.transformer
dummy_inputs = {
"hidden_states": torch.randn(1, 4096, 64, dtype=torch.bfloat16, device="cuda"),
"encoder_hidden_states": torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda"),
"pooled_projections": torch.randn(1, 768, dtype=torch.bfloat16, device="cuda"),
"timestep": torch.randn(1, dtype=torch.bfloat16, device="cuda"),
"guidance": torch.randn(1, dtype=torch.float32, device="cuda"),
"img_ids": torch.randn(4096, 3, dtype=torch.bfloat16, device="cuda"),
"txt_ids": torch.randn(512, 3, dtype=torch.bfloat16, device="cuda"),
}
ep_transformer = torch.export.export(
transformer, (), kwargs=dummy_inputs, strict=True,
)
ep_transformer.graph_module.print_readable()
with torchtrt.dynamo.Debugger(
log_level="debug",
logging_dir="logs",
engine_builder_monitor=False,
save_layer_info=True,
save_engine_profile=True,
):
my_cache = DiskEngineCache(
engine_cache_dir="./engine_cache_dir",
engine_cache_size=20 * 1024**3, # 20 GB
)
compiled_transformer = torchtrt.dynamo.compile(
ep_transformer,
inputs=dummy_inputs,
truncate_double=True,
require_full_compilation=True,
cache_built_engines=True,
reuse_cached_engines=True,
immutable_weights=False,
engine_cache=my_cache,
)
return compiled_transformer
def main() -> None:
pipe = load_pipeline()
_ = compile_transformer(pipe)
if __name__ == "__main__":
main()
compiled_transformer = torchtrt.dynamo.compile(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 772, in compile
trt_gm = compile_module(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 1118, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 348, in convert_module
serialized_interpreter_result = interpret_module_to_result(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 254, in interpret_module_to_result
serialized_interpreter_result = pull_cached_engine(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 169, in pull_cached_engine
_refit_single_trt_engine_with_gm(
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/_features.py", line 177, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_refit.py", line 209, in _refit_single_trt_engine_with_gm
assert len(missing_weights) == 0 and len(unset_weights) == 0, (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Fast refitting failed due to incomplete mapping (502 missing, 1274 unset)
Bug Description
2 bugs observed
engine_cache_dirTo Reproduce
Run the below script:
Expected behavior
blob.binto be saved in./engine_cache_dirEnvironment
conda,pip,libtorch, source): pipAdditional context