Skip to content

Commit ce02358

Browse files
Add parallel attention (#2630)
<!-- .github/pull_request_template.md --> ## 📌 Description Add a `parallel_attention` module to FlashInfer that enables distributed attention computation using **Ulysses** (all-to-all head parallelism) and **Ring** (P2P KV exchange with online softmax merging) strategies, or a combination of both. ### New files - **`parallel_attention.py`** — `ParallelAttention` class: the main entry point that wraps any registered attention backend and applies Ulysses/Ring parallelism transparently via decorators. - **`parallel_config.py`** — Configuration classes: - `AttnParallelConfig`: singleton that manages `ulysses_size`, `ring_size`, device mesh creation, and process group accessors. - `UnevenCPConfig`: handles uneven context parallelism where the total sequence length is not divisible by `world_size`. - `VarlenCPConfig`: handles variable-length (ragged) batching where multiple sequences of different lengths are packed together. - **`parallel_wrapper.py`** — Decorator implementations: - `ulysses_wrapper`: performs all-to-all communication to split heads across ranks, calls the inner function, then reverses the all-to-all. - `ring_wrapper`: implements ring attention with P2P KV exchange and online softmax correction across ring steps. - Helper functions: `all_to_all`, `ulysses_a2a_in/out`, `ring_fwd_out_correction`, `ring_fwd_softmax_lse_correction`, `ring_attn_p2p_communicate`. - **`attention_ops.py`** — `AttentionOpManager` registry with decorator-based backend registration. Includes `FlashAttn3` as the first registered backend. - **`utils.py`** — Utility functions: `convert_qkv_layout`, `convert_output_layout`, `split_varlen_input`. - **`__init__.py`** — Package API re-exports. ### Tests - **`tests/attention/test_parallel_attention.py`** — Pytest-based test suite covering: - Combined Ulysses + Ring attention (`test_attn_parallel`) - Uneven context parallelism (`test_uneven_attn_parallel`) - Ulysses-only varlen attention (`test_ulysses_varlen_attn_parallel`) - Ring-only varlen attention (`test_ring_varlen_attn_parallel`) - Parametrized over `tensor_layout` (`"HND"` / `"NHD"`) ### Key design decisions - **Backend-agnostic**: any attention function can be registered via `@AttentionOpManager.register_attn("name")` and used with parallel wrappers. - **Decorator-based parallelism**: `@ulysses_wrapper` and `@ring_wrapper` are composable decorators — they can be stacked or used independently. - **No causal support yet**: `is_causal=True` raises `NotImplementedError`. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Parallel attention framework enabling distributed inference across multiple GPUs * Ulysses and Ring parallelism strategies to optimize throughput and reduce latency * Multiple pluggable attention backends with automatic kernel selection based on hardware * Variable-length sequence handling for flexible batch processing in distributed settings * Comprehensive utilities for tensor layout conversion and distributed sequence management * **Tests** * Added comprehensive distributed test suite for parallel attention scenarios <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Sam (Kesen Li) <lsam@nvidia.com>
1 parent 06cb1b7 commit ce02358

8 files changed

Lines changed: 1937 additions & 0 deletions

File tree

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .parallel_attention import ParallelAttention as ParallelAttention
2+
from .parallel_config import UnevenCPConfig as UnevenCPConfig
3+
from .parallel_config import VarlenCPConfig as VarlenCPConfig
4+
from .utils import split_varlen_input as split_varlen_input
5+
from .utils import ulysses_varlen_config as ulysses_varlen_config
6+
from .utils import ring_varlen_config as ring_varlen_config
7+
from .utils import uneven_cp_config as uneven_cp_config
8+
from .utils import get_parallel_groups as get_parallel_groups
9+
10+
__all__ = [
11+
"ParallelAttention",
12+
"UnevenCPConfig",
13+
"VarlenCPConfig",
14+
"split_varlen_input",
15+
"ulysses_varlen_config",
16+
"ring_varlen_config",
17+
"uneven_cp_config",
18+
"get_parallel_groups",
19+
]
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import logging
2+
3+
import math
4+
import torch
5+
6+
from .utils import (
7+
convert_output_layout,
8+
convert_qkv_layout,
9+
)
10+
11+
logger = logging.getLogger(__name__)
12+
13+
try:
14+
import flash_attn_interface
15+
except ImportError:
16+
flash_attn_interface = None
17+
18+
from flashinfer.prefill import fmha_varlen
19+
20+
21+
class AttentionOpManager:
22+
_attn_registry: dict[str, type] = {}
23+
24+
@classmethod
25+
def op_type(cls):
26+
return "attention"
27+
28+
@classmethod
29+
def set_attn_config(cls, **kwargs):
30+
for key, value in kwargs.items():
31+
if hasattr(cls, key):
32+
setattr(cls, key, value)
33+
else:
34+
raise AttributeError(f"'{cls.__name__}' has no attribute '{key}'")
35+
36+
@classmethod
37+
def register_attn(cls, attn_type):
38+
def decorator(attn_class):
39+
# Register the attention class
40+
cls._attn_registry[attn_type] = attn_class
41+
return attn_class
42+
43+
return decorator
44+
45+
@classmethod
46+
def get_impl(cls, name=None):
47+
if name is None:
48+
name = cls.attn_type
49+
attn_class = cls._attn_registry.get(name)
50+
if attn_class is None:
51+
raise ValueError(f"Attention function {name} not found in registry")
52+
return attn_class() # Create and return an instance
53+
54+
@classmethod
55+
def get_registered_types(cls):
56+
return list(cls._attn_registry.keys())
57+
58+
59+
@AttentionOpManager.register_attn("flash-attn3")
60+
class FlashAttn3:
61+
def __call__(
62+
self,
63+
query,
64+
key,
65+
value,
66+
attn_mask=None,
67+
is_causal=False,
68+
return_lse=False,
69+
tensor_layout="HND",
70+
cur_rank_cu_seqlens_q=None,
71+
cur_rank_cu_seqlens_k=None,
72+
cur_rank_max_seqlen_q=0,
73+
cur_rank_max_seqlen_k=0,
74+
**kwargs,
75+
):
76+
if flash_attn_interface is None:
77+
raise ImportError("FlashAttn3 is not installed")
78+
79+
if tensor_layout not in ["HND", "NHD"]:
80+
raise NotImplementedError("Tensor layout not supported for FlashAttn3")
81+
82+
if tensor_layout == "HND":
83+
query, key, value = convert_qkv_layout(
84+
query, key, value, src_layout="HND", dst_layout="NHD"
85+
)
86+
87+
if attn_mask is not None:
88+
raise NotImplementedError("FlashAttn3 does not support attn_mask yet")
89+
90+
# FA3 only supports float16 and bfloat16
91+
origin_dtype = query.dtype
92+
if query.dtype not in [torch.float16, torch.bfloat16]:
93+
query = query.to(torch.float16)
94+
key = key.to(torch.float16)
95+
value = value.to(torch.float16)
96+
97+
if cur_rank_cu_seqlens_q is None:
98+
query = torch.unsqueeze(query, dim=0)
99+
key = torch.unsqueeze(key, dim=0)
100+
value = torch.unsqueeze(value, dim=0)
101+
output = flash_attn_interface.flash_attn_func(
102+
q=query,
103+
k=key,
104+
v=value,
105+
softmax_scale=None,
106+
causal=is_causal,
107+
qv=None,
108+
q_descale=None,
109+
k_descale=None,
110+
v_descale=None,
111+
window_size=(-1, -1),
112+
attention_chunk=0,
113+
softcap=0.0,
114+
num_splits=1,
115+
pack_gqa=None,
116+
deterministic=False,
117+
sm_margin=0,
118+
return_attn_probs=return_lse,
119+
)
120+
121+
if isinstance(output, tuple):
122+
lse = torch.squeeze(output[1], dim=0)
123+
output = torch.squeeze(output[0], dim=0)
124+
output = (output, lse)
125+
else:
126+
output = torch.squeeze(output, dim=0)
127+
128+
else:
129+
output = flash_attn_interface.flash_attn_varlen_func(
130+
q=query,
131+
k=key,
132+
v=value,
133+
cu_seqlens_q=cur_rank_cu_seqlens_q,
134+
cu_seqlens_k=cur_rank_cu_seqlens_k,
135+
max_seqlen_q=cur_rank_max_seqlen_q,
136+
max_seqlen_k=cur_rank_max_seqlen_k,
137+
seqused_q=None,
138+
seqused_k=None,
139+
softmax_scale=None,
140+
causal=is_causal,
141+
qv=None,
142+
q_descale=None,
143+
k_descale=None,
144+
v_descale=None,
145+
window_size=(-1, -1),
146+
attention_chunk=0,
147+
softcap=0.0,
148+
num_splits=1,
149+
pack_gqa=None,
150+
deterministic=False,
151+
sm_margin=0,
152+
return_attn_probs=return_lse,
153+
)
154+
155+
lse = None
156+
if isinstance(output, tuple):
157+
lse = output[1]
158+
output = output[0]
159+
160+
if tensor_layout == "HND":
161+
output = convert_output_layout(output, src_layout="NHD", dst_layout="HND")
162+
163+
if tensor_layout == "NHD" and lse is not None:
164+
lse = lse.permute(1, 0)
165+
166+
if output.dtype != origin_dtype:
167+
output = output.to(origin_dtype)
168+
169+
if return_lse:
170+
assert lse is not None, "lse is not returned by FlashAttn3"
171+
return output, lse
172+
else:
173+
return output
174+
175+
176+
@AttentionOpManager.register_attn("cutlass")
177+
class CutlassFmha:
178+
def __call__(
179+
self,
180+
query,
181+
key,
182+
value,
183+
attn_mask=None,
184+
is_causal=False,
185+
return_lse=False,
186+
tensor_layout="HND",
187+
cur_rank_cu_seqlens_q=None,
188+
cur_rank_cu_seqlens_k=None,
189+
cur_rank_max_seqlen_q=0,
190+
cur_rank_max_seqlen_k=0,
191+
**kwargs,
192+
):
193+
if tensor_layout not in ["HND", "NHD"]:
194+
raise NotImplementedError("Tensor layout not supported for CutlassFmha")
195+
196+
if tensor_layout == "HND":
197+
query, key, value = convert_qkv_layout(
198+
query, key, value, src_layout="HND", dst_layout="NHD"
199+
)
200+
201+
if attn_mask is not None:
202+
raise NotImplementedError("CutlassFmha does not support attn_mask yet")
203+
204+
# CutlassFmha only supports float16 and bfloat16
205+
origin_dtype = query.dtype
206+
if query.dtype not in [torch.float16, torch.bfloat16]:
207+
query = query.to(torch.float16)
208+
key = key.to(torch.float16)
209+
value = value.to(torch.float16)
210+
211+
if cur_rank_cu_seqlens_q is None:
212+
qo_segment_offsets = torch.tensor(
213+
[0, query.shape[0]], device=query.device, dtype=torch.int32
214+
)
215+
kv_segment_offsets = torch.tensor(
216+
[0, key.shape[0]], device=key.device, dtype=torch.int32
217+
)
218+
max_qo_len = query.shape[0]
219+
else:
220+
qo_segment_offsets = cur_rank_cu_seqlens_q
221+
kv_segment_offsets = cur_rank_cu_seqlens_k
222+
max_qo_len = cur_rank_max_seqlen_q
223+
224+
output = fmha_varlen(
225+
query,
226+
key,
227+
value,
228+
qo_segment_offsets=qo_segment_offsets,
229+
kv_segment_offsets=kv_segment_offsets,
230+
max_qo_len=max_qo_len,
231+
causal=is_causal,
232+
sm_scale=1.0 / math.sqrt(query.size(-1)),
233+
return_lse=return_lse,
234+
)
235+
236+
lse = None
237+
if isinstance(output, tuple):
238+
lse = output[1]
239+
output = output[0]
240+
241+
if tensor_layout == "HND":
242+
output = convert_output_layout(output, src_layout="NHD", dst_layout="HND")
243+
if lse is not None:
244+
lse = lse.permute(1, 0)
245+
246+
if output.dtype != origin_dtype:
247+
output = output.to(origin_dtype)
248+
249+
if return_lse:
250+
assert lse is not None, "lse is not returned by cutlass fmha"
251+
return output, lse
252+
else:
253+
return output
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
3+
import torch
4+
5+
from .attention_ops import AttentionOpManager
6+
from .parallel_config import UnevenCPConfig, VarlenCPConfig
7+
from .parallel_wrapper import ring_wrapper, ulysses_wrapper
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class ParallelAttention:
13+
"""Runs an attention backend with Ulysses and/or Ring parallelism.
14+
15+
Wraps any registered attention implementation (see :class:`AttentionOpManager`)
16+
and transparently applies Ulysses (all-to-all head splitting) and Ring
17+
(P2P KV exchange with online softmax merging) parallelism via decorators.
18+
19+
Args:
20+
attn_type: Name of the registered attention backend (e.g. ``"flash-attn3"``).
21+
ulysses_group: Ulysses process group.
22+
ring_group: Ring process group.
23+
uneven_cp_config: Configuration for uneven context parallelism where
24+
sequence lengths are not evenly divisible across ranks.
25+
varlen_cp_config: Configuration for variable-length context parallelism
26+
where multiple sequences of different lengths are packed together.
27+
fuse_qkv: If ``True``, fuse Q/K/V into a single all-to-all communication
28+
in Ulysses parallelism (reduces 3 NCCL calls to 1).
29+
30+
Example::
31+
32+
config = AttnParallelConfig()
33+
config.set_config(ulysses_size=2, ring_size=2)
34+
attn = ParallelAttention(
35+
attn_type="flash-attn3",
36+
ulysses_group=ulysses_group,
37+
ring_group=ring_group,
38+
)
39+
output = attn.run(query, key, value, tensor_layout="HND")
40+
"""
41+
42+
def __init__(
43+
self,
44+
attn_type: str,
45+
ulysses_group: torch.distributed.ProcessGroup,
46+
ring_group: torch.distributed.ProcessGroup,
47+
uneven_cp_config: UnevenCPConfig = None,
48+
varlen_cp_config: VarlenCPConfig = None,
49+
fuse_qkv: bool = False,
50+
):
51+
self.attn_type = attn_type
52+
self.attn_impl = AttentionOpManager.get_impl(attn_type)
53+
self.ulysses_group = ulysses_group
54+
self.ring_group = ring_group
55+
self.uneven_cp_config = uneven_cp_config
56+
self.varlen_cp_config = varlen_cp_config
57+
self.fuse_qkv = fuse_qkv
58+
59+
@ulysses_wrapper
60+
@ring_wrapper
61+
def run(
62+
self,
63+
query,
64+
key,
65+
value,
66+
tensor_layout,
67+
attn_mask=None,
68+
is_causal=False,
69+
return_lse=False,
70+
cur_rank_cu_seqlens_q=None,
71+
cur_rank_cu_seqlens_k=None,
72+
cur_rank_max_seqlen_q=0,
73+
cur_rank_max_seqlen_k=0,
74+
**kwargs,
75+
):
76+
"""Run parallel attention on the local rank's portion of Q/K/V.
77+
78+
The Ulysses and Ring wrappers transparently handle communication
79+
before and after this method is called.
80+
81+
Args:
82+
query: Query tensor, shape ``[H, S, D]`` (HND) or ``[S, H, D]`` (NHD).
83+
key: Key tensor, same layout as query.
84+
value: Value tensor, same layout as query.
85+
tensor_layout: ``"HND"`` or ``"NHD"``.
86+
attn_mask: Optional attention mask (not yet supported).
87+
is_causal: Whether to apply causal masking (not yet supported).
88+
return_lse: Must be ``False``; internally managed by ring wrapper.
89+
cur_rank_cu_seqlens_q/ cur_rank_cu_seqlens_k/
90+
cur_rank_max_seqlen_q/ cur_rank_max_seqlen_k:
91+
please do not set this manually. This will be set by the parallel wrapper.
92+
The sequence lengths should be set in the uneven_cp_config or varlen_cp_config.
93+
**kwargs: Additional arguments forwarded to the attention backend.
94+
95+
Returns:
96+
torch.Tensor: Attention output for the local rank, same layout as input.
97+
"""
98+
if is_causal:
99+
raise NotImplementedError(
100+
"parallel attention does not support causal attention right now"
101+
)
102+
103+
attn_inputs = {
104+
"query": query,
105+
"key": key,
106+
"value": value,
107+
"tensor_layout": tensor_layout,
108+
"attn_mask": attn_mask,
109+
"is_causal": is_causal,
110+
"return_lse": return_lse,
111+
"cur_rank_cu_seqlens_q": cur_rank_cu_seqlens_q,
112+
"cur_rank_cu_seqlens_k": cur_rank_cu_seqlens_k,
113+
"cur_rank_max_seqlen_q": cur_rank_max_seqlen_q,
114+
"cur_rank_max_seqlen_k": cur_rank_max_seqlen_k,
115+
}
116+
117+
return self.attn_impl(**attn_inputs, **kwargs)

0 commit comments

Comments
 (0)