Skip to content

Commit 17e203f

Browse files
committed
Fix TE FP8 padding with tensor parallelism
1 parent 329f23a commit 17e203f

3 files changed

Lines changed: 57 additions & 9 deletions

File tree

nemo/collections/speechlm2/models/salm_automodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,13 @@ def forward(
179179
te_fp8_config = (automodel_backend_config or {}).get("te_fp8", None)
180180
original_seq_len = input_embeds.shape[1] if input_embeds.dim() == 3 else input_embeds.shape[0]
181181
if cache is None and llm_kwargs.get("qkv_format", None) != "thd":
182+
tp_size = self.device_mesh["tp"].size() if self._use_tp else 1
182183
input_embeds, attention_mask, llm_kwargs, original_seq_len = maybe_pad_bshd_inputs_for_te_fp8(
183184
te_fp8_config,
184185
input_embeds,
185186
attention_mask,
186187
llm_kwargs,
188+
tp_size=tp_size,
187189
)
188190
with te_fp8_context(automodel_backend_config):
189191
out = self.llm(
@@ -462,7 +464,7 @@ def test_step(self, *args: Any, **kwargs: Any):
462464

463465
def backward(self, *args, **kwargs):
464466
self._setup_moe_fsdp_sync()
465-
with loss_parallel(), te_fp8_context(self.cfg.get("automodel_backend", None)):
467+
with loss_parallel():
466468
super().backward(*args, **kwargs)
467469

468470
def on_before_zero_grad(self, optimizer) -> None:

nemo/collections/speechlm2/parts/fp8.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from collections.abc import Mapping
1616
from contextlib import nullcontext
17-
from math import gcd
17+
from math import gcd, lcm
1818
from typing import Any
1919

2020
import torch
@@ -129,25 +129,33 @@ def validate_te_fp8_hidden_size(te_fp8_config: Any, hidden_size: int) -> None:
129129
)
130130

131131

132-
def get_te_fp8_bshd_sequence_multiple(batch_size: int) -> int:
133-
"""Return the minimal sequence-length multiple so B*T is divisible by 8."""
132+
def get_te_fp8_bshd_sequence_multiple(batch_size: int, tp_size: int = 1) -> int:
133+
"""Return the minimal BSHD sequence multiple for local TE FP8 Linear inputs."""
134134
if batch_size <= 0:
135135
raise ValueError(f"batch_size must be positive; got {batch_size}.")
136-
return 8 // gcd(batch_size, 8)
136+
if tp_size <= 0:
137+
raise ValueError(f"tp_size must be positive; got {tp_size}.")
138+
139+
fp8_multiple = (8 * tp_size) // gcd(batch_size, 8 * tp_size)
140+
return lcm(tp_size, fp8_multiple)
137141

138142

139143
def maybe_pad_bshd_inputs_for_te_fp8(
140144
te_fp8_config: Any,
141145
input_embeds: torch.Tensor,
142146
attention_mask: torch.Tensor | None,
143147
llm_kwargs: Mapping[str, Any] | None = None,
148+
*,
149+
tp_size: int = 1,
144150
) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, Any], int]:
145151
"""Pad BSHD LLM inputs for TE FP8 and return the original sequence length.
146152
147153
TE FP8 Linear requires the product of all input dimensions except the last
148-
to be divisible by 8 and the last dimension to be divisible by 16. For
149-
BSHD inputs this means ``B * T`` must be divisible by 8. Padding is appended
150-
on the sequence dimension and can be trimmed from logits after the LLM.
154+
to be divisible by 8 and the last dimension to be divisible by 16. With
155+
BSHD sequence parallelism, local TE Linear inputs see ``B * T / TP`` rows,
156+
so padding must keep ``T`` divisible by ``TP`` and ``B * T / TP`` divisible
157+
by 8. Padding is appended on the sequence dimension and can be trimmed from
158+
logits after the LLM.
151159
"""
152160
llm_kwargs = dict(llm_kwargs or {})
153161
if input_embeds.dim() != 3:
@@ -159,7 +167,7 @@ def maybe_pad_bshd_inputs_for_te_fp8(
159167
batch_size, seq_len, hidden_size = input_embeds.shape
160168
validate_te_fp8_hidden_size(te_fp8_config, hidden_size)
161169

162-
seq_multiple = get_te_fp8_bshd_sequence_multiple(batch_size)
170+
seq_multiple = get_te_fp8_bshd_sequence_multiple(batch_size, tp_size=tp_size)
163171
pad = (-seq_len) % seq_multiple
164172
if pad == 0:
165173
return input_embeds, attention_mask, llm_kwargs, original_seq_len

tests/collections/speechlm2/test_fp8.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,25 @@ def test_maybe_pad_bshd_inputs_for_te_fp8_noops_without_te_fp8():
178178
assert original_seq_len == 5
179179

180180

181+
@pytest.mark.parametrize(
182+
("batch_size", "tp_size", "expected_multiple"),
183+
[
184+
(1, 1, 8),
185+
(2, 1, 4),
186+
(16, 4, 4),
187+
(1, 4, 32),
188+
(2, 4, 16),
189+
(8, 4, 4),
190+
],
191+
)
192+
def test_get_te_fp8_bshd_sequence_multiple_accounts_for_tp(batch_size, tp_size, expected_multiple):
193+
multiple = fp8.get_te_fp8_bshd_sequence_multiple(batch_size, tp_size=tp_size)
194+
195+
assert multiple == expected_multiple
196+
assert multiple % tp_size == 0
197+
assert (batch_size * multiple // tp_size) % 8 == 0
198+
199+
181200
def test_maybe_pad_bshd_inputs_for_te_fp8_pads_sequence_tensors():
182201
input_embeds = torch.ones(2, 5, 16)
183202
attention_mask = torch.ones(2, 5, dtype=torch.bool)
@@ -200,6 +219,25 @@ def test_maybe_pad_bshd_inputs_for_te_fp8_pads_sequence_tensors():
200219
assert (llm_kwargs["position_ids"][:, 5:] == 0).all()
201220

202221

222+
def test_maybe_pad_bshd_inputs_for_te_fp8_accounts_for_tp():
223+
input_embeds = torch.ones(16, 5, 16)
224+
attention_mask = torch.ones(16, 5, dtype=torch.bool)
225+
226+
padded, padded_mask, llm_kwargs, original_seq_len = fp8.maybe_pad_bshd_inputs_for_te_fp8(
227+
DictConfig({"recipe": "block"}),
228+
input_embeds,
229+
attention_mask,
230+
tp_size=4,
231+
)
232+
233+
assert original_seq_len == 5
234+
assert padded.shape == (16, 8, 16)
235+
assert padded.shape[1] % 4 == 0
236+
assert (padded.shape[0] * padded.shape[1] // 4) % 8 == 0
237+
assert padded_mask.shape == (16, 8)
238+
assert llm_kwargs == {}
239+
240+
203241
def test_te_fp8_hidden_size_validation():
204242
te_fp8_config = DictConfig({"recipe": "block"})
205243

0 commit comments

Comments
 (0)