|
11 | 11 | The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
|
12 | 12 |
|
13 | 13 | In this tutorial, we are going to walk through
|
14 |
| -1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18 |
15 |
| -2. Save a Mutable Torch TensorRT Module |
16 |
| -3. Integration with Huggingface pipeline in LoRA use case |
| 14 | + 1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18 |
| 15 | + 2. Save a Mutable Torch TensorRT Module |
| 16 | + 3. Integration with Huggingface pipeline in LoRA use case |
| 17 | + 4. Usage of dynamic shape with Mutable Torch TensorRT Module |
17 | 18 | """
|
18 | 19 |
|
| 20 | +# %% |
19 | 21 | import numpy as np
|
20 | 22 | import torch
|
21 | 23 | import torch_tensorrt as torch_trt
|
|
63 | 65 | # Saving Mutable Torch TensorRT Module
|
64 | 66 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
65 | 67 |
|
66 |
| -# Currently, saving is only enabled for C++ runtime, not python runtime. |
| 68 | +# Currently, saving is only enabled when "use_python_runtime" = False in settings |
67 | 69 | torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
|
68 | 70 | reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
|
69 | 71 |
|
70 | 72 | # %%
|
71 | 73 | # Stable Diffusion with Huggingface
|
72 | 74 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
73 | 75 |
|
74 |
| -# The LoRA checkpoint is from https://civitai.com/models/12597/moxin |
75 |
| - |
76 | 76 | from diffusers import DiffusionPipeline
|
77 | 77 |
|
78 | 78 | with torch.no_grad():
|
|
83 | 83 | "immutable_weights": False,
|
84 | 84 | }
|
85 | 85 |
|
86 |
| - model_id = "runwayml/stable-diffusion-v1-5" |
| 86 | + model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
87 | 87 | device = "cuda:0"
|
88 | 88 |
|
89 |
| - prompt = "house in forest, shuimobysim, wuchangshuo, best quality" |
90 |
| - negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2)," |
| 89 | + prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
| 90 | + negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
91 | 91 |
|
92 |
| - pipe = DiffusionPipeline.from_pretrained( |
93 |
| - model_id, revision="fp16", torch_dtype=torch.float16 |
94 |
| - ) |
| 92 | + pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
95 | 93 | pipe.to(device)
|
96 | 94 |
|
97 | 95 | # The only extra line you need
|
98 | 96 | pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
|
99 |
| - |
100 |
| - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
| 97 | + BATCH = torch.export.Dim("BATCH", min=2, max=24) |
| 98 | + _HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32) |
| 99 | + _WIDTH = torch.export.Dim("_WIDTH", min=16, max=32) |
| 100 | + HEIGHT = 4 * _HEIGHT |
| 101 | + WIDTH = 4 * _WIDTH |
| 102 | + args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {}) |
| 103 | + kwargs_dynamic_shapes = { |
| 104 | + "encoder_hidden_states": {0: BATCH}, |
| 105 | + "added_cond_kwargs": { |
| 106 | + "text_embeds": {0: BATCH}, |
| 107 | + "time_ids": {0: BATCH}, |
| 108 | + }, |
| 109 | + } |
| 110 | + pipe.unet.set_expected_dynamic_shape_range( |
| 111 | + args_dynamic_shapes, kwargs_dynamic_shapes |
| 112 | + ) |
| 113 | + image = pipe( |
| 114 | + prompt, |
| 115 | + negative_prompt=negative, |
| 116 | + num_inference_steps=30, |
| 117 | + height=1024, |
| 118 | + width=768, |
| 119 | + num_images_per_prompt=2, |
| 120 | + ).images[0] |
101 | 121 | image.save("./without_LoRA_mutable.jpg")
|
102 | 122 |
|
103 | 123 | # Standard Huggingface LoRA loading procedure
|
104 | 124 | pipe.load_lora_weights(
|
105 | 125 | "stablediffusionapi/load_lora_embeddings",
|
106 |
| - weight_name="moxin.safetensors", |
| 126 | + weight_name="all-disney-princess-xl-lo.safetensors", |
107 | 127 | adapter_name="lora1",
|
108 | 128 | )
|
109 | 129 | pipe.set_adapters(["lora1"], adapter_weights=[1])
|
110 | 130 | pipe.fuse_lora()
|
111 | 131 | pipe.unload_lora_weights()
|
112 | 132 |
|
113 | 133 | # Refit triggered
|
114 |
| - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
| 134 | + image = pipe( |
| 135 | + prompt, |
| 136 | + negative_prompt=negative, |
| 137 | + num_inference_steps=30, |
| 138 | + height=1024, |
| 139 | + width=1024, |
| 140 | + num_images_per_prompt=1, |
| 141 | + ).images[0] |
115 | 142 | image.save("./with_LoRA_mutable.jpg")
|
| 143 | + |
| 144 | + |
| 145 | +# %% |
| 146 | +# Use Mutable Torch TensorRT module with dynamic shape |
| 147 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 148 | +# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function |
| 149 | +# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list. |
| 150 | +# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input. |
| 151 | +# Note that you should exclude keyword arguments with value None as those will be filtered out. |
| 152 | + |
| 153 | + |
| 154 | +class Model(torch.nn.Module): |
| 155 | + def __init__(self): |
| 156 | + super().__init__() |
| 157 | + |
| 158 | + def forward(self, a, b, c={}): |
| 159 | + x = torch.matmul(a, b) |
| 160 | + x = torch.matmul(c["a"], c["b"].T) |
| 161 | + print(c["b"][0]) |
| 162 | + x = 2 * c["b"] |
| 163 | + return x |
| 164 | + |
| 165 | + |
| 166 | +device = "cuda:0" |
| 167 | +model = Model().eval().to(device) |
| 168 | +inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device)) |
| 169 | +kwargs = { |
| 170 | + "c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)}, |
| 171 | +} |
| 172 | +dim_0 = torch.export.Dim("dim", min=1, max=50) |
| 173 | +dim_1 = torch.export.Dim("dim", min=1, max=50) |
| 174 | +dim_2 = torch.export.Dim("dim2", min=1, max=50) |
| 175 | +args_dynamic_shapes = ({1: dim_1}, {0: dim_0}) |
| 176 | +kwarg_dynamic_shapes = { |
| 177 | + "c": { |
| 178 | + "a": {}, |
| 179 | + "b": {0: dim_2}, |
| 180 | + }, # a's shape does not change so we give it an empty dict |
| 181 | +} |
| 182 | +# Export the model first with custom dynamic shape constraints |
| 183 | +model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1) |
| 184 | +model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) |
| 185 | +# Compile |
| 186 | +model(*inputs, **kwargs) |
| 187 | +# Change input shape |
| 188 | +inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device)) |
| 189 | +kwargs_2 = { |
| 190 | + "c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)}, |
| 191 | +} |
| 192 | +# Run without recompiling |
| 193 | +model(*inputs_2, **kwargs_2) |
| 194 | + |
| 195 | +# %% |
| 196 | +# Use Mutable Torch TensorRT module with persistent cache |
| 197 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 198 | +# Leveraging engine caching, we are able to shortcut the engine compilation and save much time. |
| 199 | +import os |
| 200 | + |
| 201 | +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH |
| 202 | + |
| 203 | +model = models.resnet18(pretrained=True).eval().to("cuda") |
| 204 | + |
| 205 | +times = [] |
| 206 | +start = torch.cuda.Event(enable_timing=True) |
| 207 | +end = torch.cuda.Event(enable_timing=True) |
| 208 | + |
| 209 | +example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) |
| 210 | +model = torch_trt.MutableTorchTensorRTModule( |
| 211 | + model, |
| 212 | + use_python_runtime=True, |
| 213 | + enabled_precisions={torch.float}, |
| 214 | + debug=True, |
| 215 | + min_block_size=1, |
| 216 | + immutable_weights=False, |
| 217 | + cache_built_engines=True, |
| 218 | + reuse_cached_engines=True, |
| 219 | + engine_cache_size=1 << 30, # 1GB |
| 220 | +) |
| 221 | + |
| 222 | + |
| 223 | +def remove_timing_cache(path=TIMING_CACHE_PATH): |
| 224 | + if os.path.exists(path): |
| 225 | + os.remove(path) |
| 226 | + |
| 227 | + |
| 228 | +remove_timing_cache() |
| 229 | + |
| 230 | +for i in range(4): |
| 231 | + inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")] |
| 232 | + |
| 233 | + start.record() |
| 234 | + model(*inputs) # Recompile |
| 235 | + end.record() |
| 236 | + torch.cuda.synchronize() |
| 237 | + times.append(start.elapsed_time(end)) |
| 238 | + |
| 239 | +print("----------------dynamo_compile----------------") |
| 240 | +print("Without engine caching, used:", times[0], "ms") |
| 241 | +print("With engine caching used:", times[1], "ms") |
| 242 | +print("With engine caching used:", times[2], "ms") |
| 243 | +print("With engine caching used:", times[3], "ms") |
0 commit comments