Skip to content

Commit f4219f7

Browse files
cehongwangChengzhe Xuperi044
authored
Mutable module improvement (#3394)
Co-authored-by: Chengzhe Xu <[email protected]> Co-authored-by: Dheeraj Peri <[email protected]>
1 parent b33f393 commit f4219f7

File tree

6 files changed

+515
-66
lines changed

6 files changed

+515
-66
lines changed

examples/dynamo/README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Model Zoo
2121
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
2222
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
2323
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
24-
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
24+
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)

examples/dynamo/mutable_torchtrt_module_example.py

+144-16
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
1212
1313
In this tutorial, we are going to walk through
14-
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15-
2. Save a Mutable Torch TensorRT Module
16-
3. Integration with Huggingface pipeline in LoRA use case
14+
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15+
2. Save a Mutable Torch TensorRT Module
16+
3. Integration with Huggingface pipeline in LoRA use case
17+
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1718
"""
1819

20+
# %%
1921
import numpy as np
2022
import torch
2123
import torch_tensorrt as torch_trt
@@ -63,16 +65,14 @@
6365
# Saving Mutable Torch TensorRT Module
6466
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6567

66-
# Currently, saving is only enabled for C++ runtime, not python runtime.
68+
# Currently, saving is only enabled when "use_python_runtime" = False in settings
6769
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
6870
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
6971

7072
# %%
7173
# Stable Diffusion with Huggingface
7274
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7375

74-
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75-
7676
from diffusers import DiffusionPipeline
7777

7878
with torch.no_grad():
@@ -83,33 +83,161 @@
8383
"immutable_weights": False,
8484
}
8585

86-
model_id = "runwayml/stable-diffusion-v1-5"
86+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
8787
device = "cuda:0"
8888

89-
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
90-
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
89+
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
90+
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
9191

92-
pipe = DiffusionPipeline.from_pretrained(
93-
model_id, revision="fp16", torch_dtype=torch.float16
94-
)
92+
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
9593
pipe.to(device)
9694

9795
# The only extra line you need
9896
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
99-
100-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
97+
BATCH = torch.export.Dim("BATCH", min=2, max=24)
98+
_HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
99+
_WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
100+
HEIGHT = 4 * _HEIGHT
101+
WIDTH = 4 * _WIDTH
102+
args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
103+
kwargs_dynamic_shapes = {
104+
"encoder_hidden_states": {0: BATCH},
105+
"added_cond_kwargs": {
106+
"text_embeds": {0: BATCH},
107+
"time_ids": {0: BATCH},
108+
},
109+
}
110+
pipe.unet.set_expected_dynamic_shape_range(
111+
args_dynamic_shapes, kwargs_dynamic_shapes
112+
)
113+
image = pipe(
114+
prompt,
115+
negative_prompt=negative,
116+
num_inference_steps=30,
117+
height=1024,
118+
width=768,
119+
num_images_per_prompt=2,
120+
).images[0]
101121
image.save("./without_LoRA_mutable.jpg")
102122

103123
# Standard Huggingface LoRA loading procedure
104124
pipe.load_lora_weights(
105125
"stablediffusionapi/load_lora_embeddings",
106-
weight_name="moxin.safetensors",
126+
weight_name="all-disney-princess-xl-lo.safetensors",
107127
adapter_name="lora1",
108128
)
109129
pipe.set_adapters(["lora1"], adapter_weights=[1])
110130
pipe.fuse_lora()
111131
pipe.unload_lora_weights()
112132

113133
# Refit triggered
114-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
134+
image = pipe(
135+
prompt,
136+
negative_prompt=negative,
137+
num_inference_steps=30,
138+
height=1024,
139+
width=1024,
140+
num_images_per_prompt=1,
141+
).images[0]
115142
image.save("./with_LoRA_mutable.jpg")
143+
144+
145+
# %%
146+
# Use Mutable Torch TensorRT module with dynamic shape
147+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148+
# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function
149+
# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
150+
# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
151+
# Note that you should exclude keyword arguments with value None as those will be filtered out.
152+
153+
154+
class Model(torch.nn.Module):
155+
def __init__(self):
156+
super().__init__()
157+
158+
def forward(self, a, b, c={}):
159+
x = torch.matmul(a, b)
160+
x = torch.matmul(c["a"], c["b"].T)
161+
print(c["b"][0])
162+
x = 2 * c["b"]
163+
return x
164+
165+
166+
device = "cuda:0"
167+
model = Model().eval().to(device)
168+
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
169+
kwargs = {
170+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
171+
}
172+
dim_0 = torch.export.Dim("dim", min=1, max=50)
173+
dim_1 = torch.export.Dim("dim", min=1, max=50)
174+
dim_2 = torch.export.Dim("dim2", min=1, max=50)
175+
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
176+
kwarg_dynamic_shapes = {
177+
"c": {
178+
"a": {},
179+
"b": {0: dim_2},
180+
}, # a's shape does not change so we give it an empty dict
181+
}
182+
# Export the model first with custom dynamic shape constraints
183+
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
184+
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
185+
# Compile
186+
model(*inputs, **kwargs)
187+
# Change input shape
188+
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
189+
kwargs_2 = {
190+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
191+
}
192+
# Run without recompiling
193+
model(*inputs_2, **kwargs_2)
194+
195+
# %%
196+
# Use Mutable Torch TensorRT module with persistent cache
197+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
198+
# Leveraging engine caching, we are able to shortcut the engine compilation and save much time.
199+
import os
200+
201+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
202+
203+
model = models.resnet18(pretrained=True).eval().to("cuda")
204+
205+
times = []
206+
start = torch.cuda.Event(enable_timing=True)
207+
end = torch.cuda.Event(enable_timing=True)
208+
209+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
210+
model = torch_trt.MutableTorchTensorRTModule(
211+
model,
212+
use_python_runtime=True,
213+
enabled_precisions={torch.float},
214+
debug=True,
215+
min_block_size=1,
216+
immutable_weights=False,
217+
cache_built_engines=True,
218+
reuse_cached_engines=True,
219+
engine_cache_size=1 << 30, # 1GB
220+
)
221+
222+
223+
def remove_timing_cache(path=TIMING_CACHE_PATH):
224+
if os.path.exists(path):
225+
os.remove(path)
226+
227+
228+
remove_timing_cache()
229+
230+
for i in range(4):
231+
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
232+
233+
start.record()
234+
model(*inputs) # Recompile
235+
end.record()
236+
torch.cuda.synchronize()
237+
times.append(start.elapsed_time(end))
238+
239+
print("----------------dynamo_compile----------------")
240+
print("Without engine caching, used:", times[0], "ms")
241+
print("With engine caching used:", times[1], "ms")
242+
print("With engine caching used:", times[2], "ms")
243+
print("With engine caching used:", times[3], "ms")

py/torch_tensorrt/dynamo/_refit.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,12 @@ def refit_module_weights(
395395
try:
396396
weight_name_map = compiled_submodule.weight_name_map
397397
except AttributeError:
398-
logger.warning(
399-
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
400-
)
398+
if not isinstance(
399+
compiled_submodule, torch.fx.graph_module.GraphModule
400+
):
401+
logger.warning(
402+
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
403+
)
401404
if not weight_name_map:
402405
use_weight_map_cache = False
403406
logger.warning(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None:
375375

376376
@staticmethod
377377
def find_weight(
378-
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
378+
weight_name: str,
379+
np_map: dict[str, Any],
380+
state_dict: dict[str, Any],
381+
device: torch.device,
379382
) -> str:
380383
"""
381384
We need to build map from engine weight name to state_dict weight name.
@@ -385,19 +388,21 @@ def find_weight(
385388
np_map: the map from weight name to np values in INetworkDefinition
386389
state_dict: state of the graph module
387390
"""
388-
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
391+
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
389392
for sd_w_name, sd_weight in state_dict.items():
390-
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
393+
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
391394
del state_dict[sd_w_name]
392395
return sd_w_name
393396
return ""
394397

395398
@staticmethod
396399
def check_weight_equal(
397-
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
400+
sd_weight: torch.tensor,
401+
network_weight: Union[torch.Tensor, np.ndarray],
402+
device: torch.device,
398403
) -> Any:
399404
if not isinstance(network_weight, torch.Tensor):
400-
network_weight = torch.from_numpy(network_weight).cuda()
405+
network_weight = torch.from_numpy(network_weight).to(device)
401406
try:
402407
return sd_weight.shape == network_weight.shape and torch.all(
403408
torch.abs(sd_weight - network_weight) < 0.01
@@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None:
530535
# There is no direct connection in batch_norm layer. So skip it
531536
pass
532537
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
533-
sd[sd_weight_name], np_map[engine_weight_name]
538+
sd[sd_weight_name], np_map[engine_weight_name], torch_device
534539
):
535540
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
536-
engine_weight_name, np_map, sd
541+
engine_weight_name, np_map, sd, torch_device
537542
)
538543
if (
539544
weight_name_map[engine_weight_name] != ""

0 commit comments

Comments
 (0)