Skip to content

Commit 413d2c0

Browse files
remove rope fusion option, do automatically on torch compile
ghstack-source-id: c37cc48 Pull-Request: #4055
1 parent 5ebd10d commit 413d2c0

File tree

10 files changed

+264
-470
lines changed

10 files changed

+264
-470
lines changed

benchmarks/prototype/attention/eval_flux_model.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@
4343
"fp8": True,
4444
"fp8_backend": AttentionBackend.FP8_FA3,
4545
},
46-
"fa4": {"flash_impl": "FA4", "fp8": False},
47-
"fa4_fp8": {
48-
"flash_impl": "FA4",
49-
"fp8": True,
50-
"fp8_backend": AttentionBackend.FP8_FA4,
51-
},
5246
}
5347

5448
IMAGE_SIZE = (512, 512) # (width, height) - resize for consistent LPIPS
@@ -68,7 +62,6 @@ def setup_backend(
6862
backend_name,
6963
compile_flag,
7064
orig_transformer,
71-
fuse_rope_using_torch_compile=False,
7265
):
7366
"""Set up a backend and return the flash_impl name."""
7467
cfg = BACKENDS[backend_name]
@@ -79,16 +72,8 @@ def setup_backend(
7972
pipe.transformer = apply_low_precision_attention(
8073
pipe.transformer,
8174
backend=cfg["fp8_backend"],
82-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
8375
)
84-
if fuse_rope_using_torch_compile:
85-
print(
86-
f"Compiling transformer with torch.compile ({backend_name}, FP8 backend)..."
87-
)
88-
pipe.transformer = torch.compile(
89-
pipe.transformer, backend=pipe.transformer.compile_backend
90-
)
91-
elif compile_flag:
76+
if compile_flag:
9277
print(f"Compiling transformer with torch.compile ({backend_name})...")
9378
pipe.transformer = torch.compile(pipe.transformer)
9479
return cfg["flash_impl"]
@@ -161,7 +146,6 @@ def run_benchmark(
161146
debug_prompt: Optional[str] = None,
162147
warmup_iters: int = 2,
163148
compile: bool = False,
164-
fuse_rope_using_torch_compile: bool = False,
165149
):
166150
"""Run the attention backend benchmark on FLUX.1-schnell."""
167151
compile_str = " + torch.compile" if compile else ""
@@ -218,7 +202,6 @@ def run_benchmark(
218202
baseline_backend,
219203
compile,
220204
orig_transformer,
221-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
222205
)
223206

224207
print(f"Warming up {baseline_backend} with {warmup_iters} iterations...")
@@ -282,7 +265,6 @@ def run_benchmark(
282265
test_backend,
283266
compile,
284267
orig_transformer,
285-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
286268
)
287269

288270
print(f"Warming up {test_backend} with {warmup_iters} iterations...")
@@ -422,11 +404,6 @@ def main():
422404
action="store_true",
423405
help="Wrap the model with torch.compile for both backends",
424406
)
425-
parser.add_argument(
426-
"--fuse_rope_using_torch_compile",
427-
action="store_true",
428-
help="Fuse RoPE into the FP8 kernel (compile path, off by default)",
429-
)
430407
parser.add_argument(
431408
"--height",
432409
type=int,
@@ -452,7 +429,6 @@ def main():
452429
debug_prompt=args.debug_prompt,
453430
warmup_iters=args.warmup_iters,
454431
compile=args.compile,
455-
fuse_rope_using_torch_compile=args.fuse_rope_using_torch_compile,
456432
)
457433

458434

benchmarks/prototype/attention/eval_llama3_model.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,6 @@
5757
"fp8_backend": AttentionBackend.FP8_FA3,
5858
"label": "FA3 FP8",
5959
},
60-
"fa4": {
61-
"flash_impl": "FA4",
62-
"fp8": False,
63-
"label": "FA4 BF16",
64-
},
65-
"fa4_fp8": {
66-
"flash_impl": "FA4",
67-
"fp8": True,
68-
"fp8_backend": AttentionBackend.FP8_FA4,
69-
"label": "FA4 FP8",
70-
},
7160
}
7261

7362
RANDOM_SEED = 42
@@ -115,9 +104,7 @@ def mask_strip_backend(gm, example_inputs):
115104
return torch.compile(model, backend=mask_strip_backend)
116105

117106

118-
def setup_backend(
119-
orig_model, backend_name, compile_flag, fuse_rope_using_torch_compile=False
120-
):
107+
def setup_backend(orig_model, backend_name, compile_flag):
121108
"""Set up a backend and return (model, flash_impl)."""
122109
cfg = BACKENDS[backend_name]
123110

@@ -129,14 +116,8 @@ def setup_backend(
129116
model = apply_low_precision_attention(
130117
orig_model,
131118
backend=cfg["fp8_backend"],
132-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
133119
)
134-
if fuse_rope_using_torch_compile:
135-
print(
136-
f" Compiling model with torch.compile ({backend_name}, FP8 backend)..."
137-
)
138-
model = torch.compile(model, backend=model.compile_backend)
139-
elif compile_flag:
120+
if compile_flag:
140121
print(f" Compiling model with torch.compile ({backend_name})...")
141122
model = torch.compile(model)
142123
return model, cfg["flash_impl"]
@@ -221,7 +202,6 @@ def run_benchmark(
221202
num_warmup: int = 3,
222203
seq_lengths: list[int] | None = None,
223204
compile: bool = False,
224-
fuse_rope_using_torch_compile: bool = False,
225205
):
226206
if seq_lengths is None:
227207
seq_lengths = DEFAULT_SEQ_LENGTHS
@@ -266,7 +246,6 @@ def run_benchmark(
266246
orig_model,
267247
baseline_backend,
268248
compile,
269-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
270249
)
271250
baseline_ppl = evaluate_perplexity(baseline_model, tokenizer, baseline_flash)
272251
print(f" {baseline_label} perplexity: {baseline_ppl:.2f}")
@@ -277,7 +256,6 @@ def run_benchmark(
277256
orig_model,
278257
test_backend,
279258
compile,
280-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
281259
)
282260
test_ppl = evaluate_perplexity(test_model, tokenizer, test_flash)
283261
print(f" {test_label} perplexity: {test_ppl:.2f}")
@@ -299,7 +277,6 @@ def run_benchmark(
299277
orig_model,
300278
baseline_backend,
301279
compile,
302-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
303280
)
304281
baseline_runtimes = {}
305282
for S in seq_lengths:
@@ -330,7 +307,6 @@ def run_benchmark(
330307
orig_model,
331308
test_backend,
332309
compile,
333-
fuse_rope_using_torch_compile=fuse_rope_using_torch_compile,
334310
)
335311
test_runtimes = {}
336312
for S in baseline_runtimes:
@@ -459,11 +435,6 @@ def main():
459435
action="store_true",
460436
help="Wrap the model with torch.compile (applies to non-FP8 backends)",
461437
)
462-
parser.add_argument(
463-
"--fuse_rope_using_torch_compile",
464-
action="store_true",
465-
help="Fuse RoPE into the FP8 kernel (compile path, off by default)",
466-
)
467438
args = parser.parse_args()
468439

469440
run_benchmark(
@@ -474,7 +445,6 @@ def main():
474445
num_warmup=args.num_warmup,
475446
seq_lengths=args.seq_lengths,
476447
compile=args.compile,
477-
fuse_rope_using_torch_compile=args.fuse_rope_using_torch_compile,
478448
)
479449

480450

test/prototype/attention/test_fp8_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def test_monkey_patch_model(self, dtype):
190190
fp8_model = apply_low_precision_attention(
191191
fp8_model,
192192
backend=AttentionBackend.FP8_FA3,
193-
fuse_rope_using_torch_compile=False,
194193
)
195194

196195
with torch.no_grad():
@@ -231,9 +230,8 @@ def test_rope_fusion_model(self, dtype):
231230
fp8_model = apply_low_precision_attention(
232231
fp8_model,
233232
backend=AttentionBackend.FP8_FA3,
234-
fuse_rope_using_torch_compile=True,
235233
)
236-
fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend)
234+
fp8_model = torch.compile(fp8_model)
237235

238236
with torch.no_grad():
239237
out_fp8 = fp8_model(x, cos, sin)

torchao/prototype/attention/api.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,19 @@ def _check_backend_available(backend: AttentionBackend) -> None:
6060
def apply_low_precision_attention(
6161
model: nn.Module,
6262
backend: Optional[AttentionBackend] = None,
63-
fuse_rope_using_torch_compile: bool = False,
6463
) -> nn.Module:
6564
"""Apply low-precision attention to a model.
6665
6766
Must be called before ``torch.compile``. KV caching should be
6867
disabled before calling (e.g., ``config.use_cache = False`` for
6968
HuggingFace models).
7069
71-
When ``fuse_rope_using_torch_compile=True``, the returned wrapper
72-
exposes a ``compile_backend`` attribute. You must compile with it to get
73-
the RoPE fusion::
70+
This replaces ``F.scaled_dot_product_attention`` with an FP8 SDPA
71+
for eager execution and sets a global pre-grad pass so that
72+
``torch.compile`` will automatically fuse RoPE where detected::
7473
75-
model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True)
76-
model = torch.compile(model, backend=model.compile_backend)
74+
model = apply_low_precision_attention(model)
75+
model = torch.compile(model) # RoPE fusion happens automatically
7776
"""
7877
if isinstance(model, _LowPrecisionAttentionWrapper):
7978
raise RuntimeError(
@@ -90,6 +89,6 @@ def apply_low_precision_attention(
9089
_check_backend_available(backend)
9190

9291
if backend == AttentionBackend.FP8_FA3:
93-
return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile)
92+
return setup_fp8_backend(model, "FA3")
9493

9594
raise ValueError(f"Unknown backend: {backend}")

torchao/prototype/attention/fp8_fa3/attention.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
_fp8_rope_sdpa,
2424
_fp8_sdpa,
2525
)
26+
from torchao.prototype.attention.shared_utils.custom_ops import (
27+
register_fp8_attention_ops,
28+
)
2629

2730
fp8_fa3_sdpa = partial(_fp8_sdpa, backend_name="FA3")
2831
fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__
@@ -33,3 +36,9 @@
3336
fp8_fa3_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__
3437
fp8_fa3_rope_sdpa.__name__ = "fp8_fa3_rope_sdpa"
3538
fp8_fa3_rope_sdpa.__qualname__ = "fp8_fa3_rope_sdpa"
39+
40+
_ops = register_fp8_attention_ops(
41+
backend_name="fa3",
42+
rope_sdpa_fn=fp8_fa3_rope_sdpa,
43+
sdpa_fn=fp8_fa3_sdpa,
44+
)

torchao/prototype/attention/fp8_fa3/fusion_pass.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

torchao/prototype/attention/shared_utils/custom_ops.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,9 @@
1212
tells the compiler to treat them as opaque nodes with known shapes/dtypes.
1313
"""
1414

15-
from functools import partial
1615
from typing import Callable, NamedTuple
1716

1817
import torch
19-
import torch._inductor.config as inductor_config
20-
import torch.nn as nn
21-
22-
from torchao.prototype.attention.shared_utils.fusion_utils import (
23-
detect_causal_mask,
24-
)
25-
from torchao.prototype.attention.shared_utils.fusion_utils import (
26-
rope_sdpa_fusion_pass as _shared_fusion_pass,
27-
)
2818

2919

3020
class RegisteredOps(NamedTuple):
@@ -120,42 +110,3 @@ def _sdpa_fake(
120110
fp8_sdpa_op = getattr(getattr(torch.ops, "torchao"), f"fp8_{backend}_sdpa").default
121111

122112
return RegisteredOps(rope_sdpa_op=rope_sdpa_op, fp8_sdpa_op=fp8_sdpa_op)
123-
124-
125-
def make_backend_fn(
126-
ops: RegisteredOps,
127-
backend_name: str,
128-
flash_impl_name: str,
129-
max_head_dim: int = 256,
130-
) -> Callable:
131-
"""Return a ``make_fp8_backend(model, fuse_rope_using_torch_compile)`` function for a backend."""
132-
133-
def make_fp8_backend(
134-
model: nn.Module,
135-
fuse_rope_using_torch_compile: bool,
136-
) -> Callable:
137-
from torch._inductor.compile_fx import compile_fx
138-
139-
strip_causal_mask = detect_causal_mask(model, flash_impl_name=flash_impl_name)
140-
141-
pass_fn = partial(
142-
_shared_fusion_pass,
143-
rope_sdpa_op=ops.rope_sdpa_op,
144-
fp8_sdpa_op=ops.fp8_sdpa_op,
145-
max_head_dim=max_head_dim,
146-
backend_name=backend_name,
147-
fuse_rope=fuse_rope_using_torch_compile,
148-
strip_causal_mask=strip_causal_mask,
149-
)
150-
151-
def fp8_attention_backend(gm, example_inputs):
152-
old_pass = inductor_config.pre_grad_custom_pass
153-
inductor_config.pre_grad_custom_pass = pass_fn
154-
try:
155-
return compile_fx(gm, example_inputs)
156-
finally:
157-
inductor_config.pre_grad_custom_pass = old_pass
158-
159-
return fp8_attention_backend
160-
161-
return make_fp8_backend

0 commit comments

Comments
 (0)