Skip to content

Commit 7ab637e

Browse files
peri044Dheeraj Peri
and
Dheeraj Peri
authored
fix: Split addmm nodes to not cast bias for FP32 accumulation and flux example fixes. (#3395)
Co-authored-by: Dheeraj Peri <[email protected]>
1 parent a3db469 commit 7ab637e

File tree

7 files changed

+112
-12
lines changed

7 files changed

+112
-12
lines changed

examples/dynamo/torch_export_flux_dev.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
1010
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
1111
12-
Install the following dependencies before compilation
12+
To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ page) and install the following dependencies
1313
1414
.. code-block:: python
1515
16-
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"
16+
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3"
1717
1818
There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
1919
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
@@ -38,11 +38,10 @@
3838
"black-forest-labs/FLUX.1-dev",
3939
torch_dtype=torch.float16,
4040
)
41-
pipe.to(DEVICE).to(torch.float16)
41+
4242
# Store the config and transformer backbone
4343
config = pipe.transformer.config
44-
backbone = pipe.transformer
45-
44+
backbone = pipe.transformer.to(DEVICE)
4645

4746
# %%
4847
# Export the backbone using torch.export
@@ -63,6 +62,8 @@
6362
"txt_ids": {0: SEQ_LEN},
6463
"img_ids": {0: IMG_ID},
6564
"guidance": {0: BATCH},
65+
"joint_attention_kwargs": {},
66+
"return_dict": None,
6667
}
6768
# The guidance factor is of type torch.float32
6869
dummy_inputs = {
@@ -79,6 +80,8 @@
7980
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
8081
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
8182
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
83+
"joint_attention_kwargs": {},
84+
"return_dict": False,
8285
}
8386
# This will create an exported program which is going to be compiled with Torch-TensorRT
8487
ep = _export(
@@ -116,8 +119,11 @@
116119
# ---------------------------
117120
# Release the GPU memory occupied by the exported program and the pipe.transformer
118121
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
119-
backbone.to("cpu")
122+
120123
del ep
124+
backbone.to("cpu")
125+
pipe.to(DEVICE)
126+
torch.cuda.empty_cache()
121127
pipe.transformer = trt_gm
122128
pipe.transformer.config = config
123129

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .fuse_prims_broadcast import fuse_prims_broadcast
1111
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1212
from .pass_manager import DynamoPassManager
13-
from .remove_assert_scalar import remove_assert_scalar
13+
from .remove_assert_nodes import remove_assert_nodes
1414
from .remove_detach import remove_detach
1515
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1616
from .repair_input_as_output import repair_input_as_output
@@ -27,7 +27,7 @@
2727
replace_max_pool_with_indices,
2828
lower_scaled_dot_product_attention,
2929
view_to_reshape,
30-
remove_assert_scalar,
30+
remove_assert_nodes,
3131
accumulate_fp32_matmul,
3232
]
3333
)

py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,54 @@
99
logger = logging.getLogger(__name__)
1010

1111

12+
def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
target = torch.ops.aten.addmm.default
14+
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
15+
for addmm_node in addmm_nodes:
16+
bias, mat1, mat2 = addmm_node.all_input_nodes
17+
beta = addmm_node.kwargs.get("beta")
18+
alpha = addmm_node.kwargs.get("alpha")
19+
20+
with gm.graph.inserting_before(addmm_node):
21+
mm_node = gm.graph.call_function(
22+
torch.ops.aten.mm.default,
23+
args=(mat1, mat2),
24+
)
25+
if alpha:
26+
mm_node = gm.graph.call_function(
27+
torch.ops.aten.mul.Tensor,
28+
args=(mm_node, alpha),
29+
)
30+
31+
if beta:
32+
bias = gm.graph.call_function(
33+
torch.ops.aten.mul.Tensor,
34+
args=(bias, beta),
35+
)
36+
add_node = gm.graph.call_function(
37+
torch.ops.aten.add.Tensor,
38+
args=(bias, mm_node),
39+
)
40+
41+
addmm_node.replace_all_uses_with(add_node, propagate_meta=True)
42+
gm.graph.erase_node(addmm_node)
43+
44+
return gm
45+
46+
1247
def accumulate_fp32_matmul(
1348
gm: torch.fx.GraphModule, settings: CompilationSettings
1449
) -> torch.fx.GraphModule:
15-
"""Replace a matmul layer with fp32 accumulation nodes"""
50+
"""Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution."""
1651
if settings.use_fp32_acc:
1752
matmul_targets = [
1853
torch.ops.aten.mm.default,
1954
torch.ops.aten.bmm.default,
20-
torch.ops.aten.addmm.default,
2155
]
2256

57+
# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
58+
split_addmm_nodes(gm)
59+
2360
matmul_nodes = [
2461
node for node in gm.graph.nodes if node.target in matmul_targets
2562
]

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def constant_fold(
5151
gm.graph.erase_node(node)
5252

5353
gm = clean_up_graph_after_modifications(gm)
54+
# Delete the constant folder instance which holds GPU memory
55+
del cf
5456

5557
logger.debug(f"Graph after constant folding:\n{gm.graph}")
5658

py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
logger = logging.getLogger(__name__)
1010

1111

12-
def remove_assert_scalar(
12+
def remove_assert_nodes(
1313
gm: torch.fx.GraphModule, settings: CompilationSettings
1414
) -> torch.fx.GraphModule:
1515
"""Remove assert_scalar ops in the graph"""
1616
count = 0
1717
for node in gm.graph.nodes:
1818
if (
1919
node.target == torch.ops.aten._assert_scalar.default
20-
or node == torch.ops.aten._assert_tensor_metadata.default
20+
or node.target == torch.ops.aten._assert_tensor_metadata.default
2121
):
2222
gm.graph.erase_node(node)
2323
count += 1

py/torch_tensorrt/dynamo/utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import gc
34
import logging
45
import warnings
56
from dataclasses import fields, replace
@@ -30,6 +31,7 @@
3031
DYNAMIC_DIM = -1
3132
RTOL = 5e-3
3233
ATOL = 5e-3
34+
CPU_DEVICE = "cpu"
3335

3436

3537
class Frameworks(Enum):
@@ -81,6 +83,17 @@ class Frameworks(Enum):
8183
}
8284

8385

86+
def delete_module(module: torch.fx.GraphModule) -> None:
87+
"""
88+
This is a helper function to delete the instance of module. We first move it to CPU and then
89+
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
90+
"""
91+
module.to(CPU_DEVICE)
92+
del module
93+
torch.cuda.empty_cache()
94+
gc.collect()
95+
96+
8497
def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
8598
"""Parses a user-provided input argument regarding Python runtime
8699

tests/py/dynamo/lowering/test_aten_lowering_passes.py

+42
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,48 @@ def forward(self, input, weight):
269269
)
270270
torch._dynamo.reset()
271271

272+
def test_fp32_acc_for_addmm(self):
273+
class FP32Acc(torch.nn.Module):
274+
def forward(self, input, mat1, mat2):
275+
out = torch.ops.aten.addmm.default(input, mat1, mat2, beta=20, alpha=2)
276+
return out
277+
278+
inputs = [
279+
torch.rand((3, 5)).cuda(),
280+
torch.rand((3, 4)).cuda(),
281+
torch.rand((4, 5)).cuda(),
282+
]
283+
284+
fx_graph = torch.fx.symbolic_trace(FP32Acc())
285+
expected_ops = {
286+
torch.ops.aten._to_copy.default,
287+
torch.ops.aten.mm.default,
288+
torch.ops.aten.add.Tensor,
289+
}
290+
unexpected_ops = {}
291+
292+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
293+
fx_graph,
294+
inputs,
295+
expected_ops=expected_ops,
296+
unexpected_ops=unexpected_ops,
297+
min_block_size=1,
298+
use_fp32_acc=True,
299+
)
300+
301+
self.assertEqual(
302+
len(unexpected_ops_seen),
303+
0,
304+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
305+
)
306+
307+
self.assertEqual(
308+
len(expected_ops_unseen),
309+
0,
310+
f"The following expected ops were not encountered: {expected_ops_unseen}",
311+
)
312+
torch._dynamo.reset()
313+
272314

273315
class TestLowerEfficientAttention(TestCase):
274316
def test_lower_efficient_attention(self):

0 commit comments

Comments
 (0)