Skip to content

Commit c25890e

Browse files
committed
Added converter registration
1 parent a4ff6bb commit c25890e

File tree

5 files changed

+300
-9
lines changed

5 files changed

+300
-9
lines changed

examples/apps/flux-demo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import gradio as gr
66
import modelopt.torch.quantization as mtq
7+
import register_sdpa
78
import torch
89
import torch_tensorrt
910
from diffusers import FluxPipeline
@@ -152,9 +153,6 @@ def load_lora(path):
152153
print("Refitting Finished!")
153154

154155

155-
load_lora("/home/TensorRT/examples/apps/NGRVNG.safetensors")
156-
157-
158156
# Create Gradio interface
159157
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
160158
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

examples/apps/register_sdpa.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import copy
2+
import logging
3+
import operator
4+
from typing import Callable, Sequence, Tuple
5+
6+
import torch
7+
from sdpa_converter import *
8+
from torch_tensorrt.dynamo._settings import CompilationSettings
9+
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
10+
from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS
11+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
12+
_aten_lowering_pass,
13+
)
14+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
15+
clean_up_graph_after_modifications,
16+
)
17+
18+
logger = logging.getLogger(__name__)
19+
20+
# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
21+
# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
22+
# TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default)
23+
# TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_efficient_attention.default)
24+
# TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default)
25+
26+
REPLACEABLE_ATEN_OPS = {
27+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
28+
torch.ops.aten._scaled_dot_product_flash_attention.default,
29+
}
30+
31+
32+
@_aten_lowering_pass
33+
def replace_variants_of_sdpa(
34+
gm: torch.fx.GraphModule, settings: CompilationSettings
35+
) -> torch.fx.GraphModule:
36+
"""Replace scaled_dot_product_attention with an equivalent
37+
implementation which can be accurately converted to TRT
38+
"""
39+
attn_mask = None
40+
is_causal = True
41+
for node in gm.graph.nodes:
42+
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
43+
if (
44+
node.target
45+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
46+
):
47+
breakpoint()
48+
if len(node.args) == 7:
49+
(
50+
query,
51+
key,
52+
value,
53+
attn_bias,
54+
compute_log_sumexp,
55+
dropout_p,
56+
is_causal,
57+
) = node.args
58+
elif len(node.args) == 5:
59+
query, key, value, attn_mask, is_causal = node.args
60+
dropout_p = 0.0
61+
else:
62+
raise ValueError(
63+
f"Unexpected number of arguments for {node.target} in the graph"
64+
)
65+
elif (
66+
node.target
67+
== torch.ops.aten._scaled_dot_product_flash_attention.default
68+
):
69+
if len(node.args) == 6:
70+
query, key, value, dropout_p, is_causal, return_debug_mask = (
71+
node.args
72+
)
73+
elif len(node.args) == 3:
74+
query, key, value = node.args
75+
dropout_p = 0.0
76+
is_causal = True
77+
else:
78+
raise ValueError(
79+
f"Unexpected number of arguments for {node.target} in the graph"
80+
)
81+
if attn_mask is not None:
82+
logger.warning(
83+
f"We do not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration."
84+
)
85+
86+
modified_input_args = (query, key, value, None, dropout_p, is_causal)
87+
88+
# Create a new node with torch.nn.functional.scaled_dot_product_attention
89+
# The input args is (query, key, value, is_causal). kwargs has scale
90+
with gm.graph.inserting_after(node):
91+
new_node = gm.graph.call_function(
92+
torch.nn.functional.scaled_dot_product_attention,
93+
args=modified_input_args,
94+
kwargs={"scale": node.kwargs.get("scale", None)},
95+
)
96+
97+
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
98+
new_node.meta = copy.copy(node.meta)
99+
# Check if there's a getitem node following this attention node
100+
for user in list(node.users):
101+
if user.op == "call_function" and user.target == operator.getitem:
102+
# If the getitem is extracting the first element (the output tensor)
103+
if user.args[1] == 0:
104+
# Replace all uses of the getitem with the new attention node
105+
user.replace_all_uses_with(new_node)
106+
new_node.meta["val"] = new_node.meta["val"][0]
107+
# Replace all uses of the original node with the new node
108+
node.replace_all_uses_with(new_node)
109+
110+
gm.graph.erase_node(node)
111+
112+
# Clean up the graph
113+
clean_up_graph_after_modifications(gm)
114+
115+
logger.info(
116+
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
117+
)
118+
return gm

examples/apps/sdpa_converter.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import logging
2+
import math
3+
from typing import Any, Dict, Optional, Tuple, Union
4+
5+
import numpy as np
6+
import tensorrt as trt
7+
import torch
8+
import torch_tensorrt
9+
from torch.fx.node import Target
10+
from torch_tensorrt._enums import dtype
11+
from torch_tensorrt.dynamo.conversion import impl
12+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
13+
from torch_tensorrt.dynamo.conversion.converter_utils import (
14+
SourceIR,
15+
cast_trt_tensor,
16+
get_trt_tensor,
17+
)
18+
from torch_tensorrt.fx.types import TRTTensor
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def tril(
24+
ctx: ConversionContext,
25+
target: Union[Target, str],
26+
source_ir: Optional[SourceIR],
27+
name: str,
28+
row: TRTTensor,
29+
col: TRTTensor,
30+
) -> TRTTensor:
31+
row_arange_tensor = impl.arange.arange(
32+
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
33+
)
34+
row_reshape_tensor = impl.shuffle.reshape(
35+
ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]
36+
)
37+
38+
col_arange_tensor = impl.arange.arange(
39+
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
40+
)
41+
col_reshape_tensor = impl.shuffle.reshape(
42+
ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]
43+
)
44+
45+
mask = impl.elementwise.ge(
46+
ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor
47+
)
48+
return mask
49+
50+
51+
@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
52+
torch.nn.functional.scaled_dot_product_attention,
53+
enabled=True,
54+
supports_dynamic_shapes=True,
55+
)
56+
def scaled_dot_product_attention(
57+
ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
58+
target: Target,
59+
args: Tuple[Any, ...],
60+
kwargs: Dict[str, Any],
61+
name: str,
62+
) -> TRTTensor:
63+
# TODO: Handle attn_mask and is_causal arguments in the future
64+
query, key, value, attn_mask, dropout_p, is_causal = args
65+
logger.info(
66+
"Ignoring attn_mask and is_causal arguments provided by the original graph. "
67+
"This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True "
68+
"and for generate phase, is_causal=False since we pass only 1 input token at a time"
69+
)
70+
71+
# TODO: remove this once we have a better way to handle the causal mask
72+
scale = kwargs.get("scale", None)
73+
source_ir = SourceIR.ATEN
74+
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
75+
mm = impl.matmul.matrix_multiply(
76+
ctx,
77+
target,
78+
source_ir,
79+
name + "_mm",
80+
query,
81+
key,
82+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
83+
)
84+
if scale is None:
85+
scale = query.shape[-1]
86+
if scale < 0:
87+
# dynamic shape
88+
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
89+
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
90+
else:
91+
# static shape
92+
sqrt_scaled = math.sqrt(scale)
93+
scaled = impl.elementwise.div(
94+
ctx,
95+
target,
96+
source_ir,
97+
name + "_scale",
98+
mm,
99+
sqrt_scaled,
100+
)
101+
else:
102+
scaled = impl.elementwise.mul(
103+
ctx,
104+
target,
105+
source_ir,
106+
name + "_scale",
107+
mm,
108+
scale,
109+
)
110+
111+
# If is_causal is True, we need to generate a causal mask
112+
if is_causal:
113+
L, S = query.shape[-2], key.shape[-2]
114+
if L >= 0 and S >= 0:
115+
# static shape
116+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
117+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
118+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
119+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
120+
else:
121+
# if any of the L or S is dynamic shape
122+
if L < 0:
123+
L = impl.shape.shape(
124+
ctx, target, source_ir, name + "_shape_0", query, 2
125+
)
126+
if S < 0:
127+
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
128+
129+
# generate the mask tensor
130+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
131+
132+
temp_mask = impl.unary.logical_not(
133+
ctx, target, source_ir, name + "_logical_not", tril_tensor
134+
)
135+
temp_mask_casted = cast_trt_tensor(
136+
ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir
137+
)
138+
one_minus_temp_mask = impl.elementwise.sub(
139+
ctx,
140+
target,
141+
source_ir,
142+
name + "_one_minus_temp_mask",
143+
1.0,
144+
temp_mask_casted,
145+
)
146+
attn_bias = impl.unary.log(
147+
ctx, target, source_ir, name + "_log", one_minus_temp_mask
148+
)
149+
150+
scaled_add_attn_bias = impl.elementwise.add(
151+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
152+
)
153+
else:
154+
scaled_add_attn_bias = scaled
155+
156+
# Create a if condition to check if is_causal is True
157+
if isinstance(is_causal, TRTTensor):
158+
if_layer = ctx.net.add_if_conditional()
159+
condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled
160+
if_layer.set_condition(condition)
161+
output_layer = if_layer.add_output(true_branch, false_branch)
162+
scaled_add_attn_bias = output_layer.get_output(0)
163+
164+
softmax = impl.normalization.softmax(
165+
ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
166+
)
167+
out = impl.matmul.matrix_multiply(
168+
ctx,
169+
target,
170+
source_ir,
171+
name + "_out",
172+
softmax,
173+
value,
174+
)
175+
176+
return out

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._decomposition_groups import (
2+
TORCH_TRT_DECOMPOSITIONS,
23
torch_disabled_decompositions,
34
torch_enabled_decompositions,
45
)

tools/perf/Flux/flux_perf.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636
settings = {
3737
"strict": False,
3838
"allow_complex_guards_as_runtime_asserts": True,
39-
"enabled_precisions": {torch.float32},
39+
"enabled_precisions": {torch.float16},
4040
"truncate_double": True,
4141
"min_block_size": 1,
42-
"use_fp32_acc": True,
43-
"use_explicit_typing": True,
4442
"debug": False,
4543
"use_python_runtime": True,
4644
"immutable_weights": False,
@@ -74,12 +72,12 @@ def generate_image(prompt, inference_step, batch_size=1, benchmark=False, iterat
7472
# Warmup
7573
generate_image(["Test"], 20)
7674
print("Benchmark Original PyTorch Module Latency (bfloat16)")
77-
for batch_size in range(1, 9):
75+
for batch_size in range(1, 3):
7876
generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3)
7977

8078
pipe.to(torch.float16)
8179
print("Benchmark Original PyTorch Module Latency (float16)")
82-
for batch_size in range(1, 9):
80+
for batch_size in range(1, 3):
8381
generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3)
8482

8583
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
@@ -92,6 +90,6 @@ def generate_image(prompt, inference_step, batch_size=1, benchmark=False, iterat
9290
print("Time Elapse compilation:", end - start)
9391
print()
9492
print("Benchmark TRT Accelerated Latency")
95-
for batch_size in range(1, 9):
93+
for batch_size in range(1, 3):
9694
generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3)
9795
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)