1414
1515from collections .abc import Mapping
1616from contextlib import nullcontext
17- from math import gcd
17+ from math import gcd , lcm
1818from typing import Any
1919
2020import torch
@@ -129,25 +129,33 @@ def validate_te_fp8_hidden_size(te_fp8_config: Any, hidden_size: int) -> None:
129129 )
130130
131131
132- def get_te_fp8_bshd_sequence_multiple (batch_size : int ) -> int :
133- """Return the minimal sequence-length multiple so B*T is divisible by 8 ."""
132+ def get_te_fp8_bshd_sequence_multiple (batch_size : int , tp_size : int = 1 ) -> int :
133+ """Return the minimal BSHD sequence multiple for local TE FP8 Linear inputs ."""
134134 if batch_size <= 0 :
135135 raise ValueError (f"batch_size must be positive; got { batch_size } ." )
136- return 8 // gcd (batch_size , 8 )
136+ if tp_size <= 0 :
137+ raise ValueError (f"tp_size must be positive; got { tp_size } ." )
138+
139+ fp8_multiple = (8 * tp_size ) // gcd (batch_size , 8 * tp_size )
140+ return lcm (tp_size , fp8_multiple )
137141
138142
139143def maybe_pad_bshd_inputs_for_te_fp8 (
140144 te_fp8_config : Any ,
141145 input_embeds : torch .Tensor ,
142146 attention_mask : torch .Tensor | None ,
143147 llm_kwargs : Mapping [str , Any ] | None = None ,
148+ * ,
149+ tp_size : int = 1 ,
144150) -> tuple [torch .Tensor , torch .Tensor | None , dict [str , Any ], int ]:
145151 """Pad BSHD LLM inputs for TE FP8 and return the original sequence length.
146152
147153 TE FP8 Linear requires the product of all input dimensions except the last
148- to be divisible by 8 and the last dimension to be divisible by 16. For
149- BSHD inputs this means ``B * T`` must be divisible by 8. Padding is appended
150- on the sequence dimension and can be trimmed from logits after the LLM.
154+ to be divisible by 8 and the last dimension to be divisible by 16. With
155+ BSHD sequence parallelism, local TE Linear inputs see ``B * T / TP`` rows,
156+ so padding must keep ``T`` divisible by ``TP`` and ``B * T / TP`` divisible
157+ by 8. Padding is appended on the sequence dimension and can be trimmed from
158+ logits after the LLM.
151159 """
152160 llm_kwargs = dict (llm_kwargs or {})
153161 if input_embeds .dim () != 3 :
@@ -159,7 +167,7 @@ def maybe_pad_bshd_inputs_for_te_fp8(
159167 batch_size , seq_len , hidden_size = input_embeds .shape
160168 validate_te_fp8_hidden_size (te_fp8_config , hidden_size )
161169
162- seq_multiple = get_te_fp8_bshd_sequence_multiple (batch_size )
170+ seq_multiple = get_te_fp8_bshd_sequence_multiple (batch_size , tp_size = tp_size )
163171 pad = (- seq_len ) % seq_multiple
164172 if pad == 0 :
165173 return input_embeds , attention_mask , llm_kwargs , original_seq_len
0 commit comments