Skip to content

Commit 2a83dd5

Browse files
committed
add backwards pass for non-deepEP version
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent e7db783 commit 2a83dd5

File tree

5 files changed

+325
-43
lines changed

5 files changed

+325
-43
lines changed

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,54 @@ class _AllToAllHandle:
949949
output_split_sizes: list[int] | None = None
950950

951951

952+
class _DifferentiableAllToAll(torch.autograd.Function):
953+
"""Differentiable wrapper around dist.all_to_all_single.
954+
955+
The forward pass performs the standard all-to-all communication.
956+
The backward pass reverses the communication direction (swapping
957+
input/output split sizes) so that gradients flow correctly.
958+
"""
959+
960+
@staticmethod
961+
def forward(
962+
ctx,
963+
input: torch.Tensor,
964+
output_split_sizes: list[int],
965+
input_split_sizes: list[int],
966+
group: dist.ProcessGroup,
967+
) -> torch.Tensor:
968+
"""Perform all-to-all forward and save sizes for backward."""
969+
ctx.input_split_sizes = input_split_sizes
970+
ctx.output_split_sizes = output_split_sizes
971+
ctx.group = group
972+
output = torch.empty(
973+
sum(output_split_sizes),
974+
input.shape[1],
975+
device=input.device,
976+
dtype=input.dtype,
977+
)
978+
dist.all_to_all_single(output, input.contiguous(), output_split_sizes, input_split_sizes, group=group)
979+
return output
980+
981+
@staticmethod
982+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]:
983+
"""Reverse all-to-all: swap input and output split sizes."""
984+
grad_input = torch.empty(
985+
sum(ctx.input_split_sizes),
986+
grad_output.shape[1],
987+
device=grad_output.device,
988+
dtype=grad_output.dtype,
989+
)
990+
dist.all_to_all_single(
991+
grad_input,
992+
grad_output.contiguous(),
993+
ctx.input_split_sizes,
994+
ctx.output_split_sizes,
995+
group=ctx.group,
996+
)
997+
return grad_input, None, None, None
998+
999+
9521000
class AllToAllTokenDispatcher:
9531001
"""TokenDispatcher using NCCL all-to-all for expert-parallel communication.
9541002
@@ -998,10 +1046,7 @@ def dispatch(
9981046
# Compute m_splits: number of tokens per expert
9991047
m_splits_tensor = torch.bincount(selected_experts.reshape(-1), minlength=self.num_experts).int()
10001048

1001-
if self.ep_size > 1:
1002-
assert self._ep_group is not None, (
1003-
"EP group must be set via set_ep_group() before dispatch when ep_size > 1"
1004-
)
1049+
if self._ep_group is not None:
10051050
ep_group = self._ep_group
10061051

10071052
# Token counts per expert, reshaped to [ep_size, num_local_experts]
@@ -1016,14 +1061,10 @@ def dispatch(
10161061
output_split_sizes = recv_counts.sum(dim=1).tolist()
10171062
local_m_splits = recv_counts.sum(dim=0).int().tolist()
10181063

1019-
# Dispatch tokens to expert-owning ranks
1020-
recv_tokens = torch.empty(
1021-
sum(output_split_sizes),
1022-
self.hidden_size,
1023-
device=permuted_hidden.device,
1024-
dtype=permuted_hidden.dtype,
1064+
# Dispatch tokens to expert-owning ranks (differentiable)
1065+
recv_tokens = _DifferentiableAllToAll.apply(
1066+
permuted_hidden, output_split_sizes, input_split_sizes, ep_group
10251067
)
1026-
dist.all_to_all_single(recv_tokens, permuted_hidden, output_split_sizes, input_split_sizes, group=ep_group)
10271068

10281069
# Sort received tokens by local expert index.
10291070
# After all_to_all layout is [src0_exp0..src0_expL, src1_exp0..src1_expL, ...].
@@ -1060,21 +1101,14 @@ def combine(self, expert_output: torch.Tensor, handle: _AllToAllHandle) -> torch
10601101
Returns:
10611102
Combined output tensor of shape ``[N, H]`` with routing weights applied.
10621103
"""
1063-
if self.ep_size > 1:
1104+
if self._ep_group is not None:
10641105
assert handle.unsort_indices is not None
1065-
# Unsort back to source-rank-grouped order and reverse all_to_all
1066-
combined = torch.empty(
1067-
sum(handle.input_split_sizes),
1068-
self.hidden_size,
1069-
device=expert_output.device,
1070-
dtype=expert_output.dtype,
1071-
)
1072-
dist.all_to_all_single(
1073-
combined,
1106+
# Unsort back to source-rank-grouped order and reverse all_to_all (differentiable)
1107+
combined = _DifferentiableAllToAll.apply(
10741108
expert_output[handle.unsort_indices],
10751109
handle.input_split_sizes,
10761110
handle.output_split_sizes,
1077-
group=self._ep_group,
1111+
self._ep_group,
10781112
)
10791113
else:
10801114
combined = expert_output

bionemo-recipes/models/mixtral/tests/test_ep.py

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@
4040
)
4141

4242

43-
@requires_multi_gpu
44-
def test_ep2_matches_ep1(unused_tcp_port):
45-
"""Test that EP=2 produces the same logits as EP=1."""
43+
def _run_torchrun(test_name: str, port: int):
44+
"""Run a named test worker via torchrun with 2 GPUs."""
4645
cmd = [
4746
"torchrun",
4847
"--nproc_per_node=2",
4948
"--rdzv-backend=c10d",
50-
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
49+
f"--rdzv-endpoint=localhost:{port}",
5150
str(Path(__file__).resolve()),
51+
test_name,
5252
]
5353
result = subprocess.run(
5454
cmd,
@@ -62,7 +62,19 @@ def test_ep2_matches_ep1(unused_tcp_port):
6262
if result.returncode != 0:
6363
print(f"STDOUT:\n{result.stdout}")
6464
print(f"STDERR:\n{result.stderr}")
65-
pytest.fail(f"EP equivalence test failed with exit code {result.returncode}")
65+
pytest.fail(f"EP {test_name} test failed with exit code {result.returncode}")
66+
67+
68+
@requires_multi_gpu
69+
def test_ep2_matches_ep1(unused_tcp_port):
70+
"""Test that EP=2 produces the same logits as EP=1."""
71+
_run_torchrun("forward", unused_tcp_port)
72+
73+
74+
@requires_multi_gpu
75+
def test_ep2_backward_matches_ep1(unused_tcp_port):
76+
"""Test that EP=2 backward pass produces the same gradients as EP=1."""
77+
_run_torchrun("backward", unused_tcp_port)
6678

6779

6880
# ---------------------------------------------------------------------------
@@ -85,18 +97,22 @@ def _distribute_state_dict(full_state_dict: dict, model: torch.nn.Module, device
8597
from torch.distributed.tensor import DTensor, distribute_tensor
8698

8799
distributed_state: dict = {}
88-
for key, value in model.state_dict().items():
89-
if key not in full_state_dict:
100+
# model.state_dict() filters _extra_state keys via the NVMixtralPreTrainedModel
101+
# override, so use nn.Module.state_dict to get the unfiltered dict that includes
102+
# TransformerEngine _extra_state entries required by load_state_dict(strict=True).
103+
for key, value in torch.nn.Module.state_dict(model).items():
104+
if key.endswith("_extra_state"):
105+
distributed_state[key] = value
106+
elif key not in full_state_dict:
90107
continue
91-
full_value = full_state_dict[key]
92-
if isinstance(value, DTensor):
108+
elif isinstance(value, DTensor):
93109
distributed_state[key] = distribute_tensor(
94-
full_value.to(device),
110+
full_state_dict[key].to(device),
95111
value.device_mesh,
96112
list(value.placements),
97113
)
98114
else:
99-
distributed_state[key] = full_value
115+
distributed_state[key] = full_state_dict[key]
100116
return distributed_state
101117

102118

@@ -183,5 +199,95 @@ def _run_ep_equivalence_test():
183199
torch.distributed.destroy_process_group()
184200

185201

202+
def _run_ep_backward_test():
203+
"""Worker function for the EP backward equivalence test.
204+
205+
1. Init distributed with 2 GPUs.
206+
2. Every rank creates an EP=1 model, runs forward+backward, saves gradients.
207+
3. Create EP=2 model with sharded weights, runs forward+backward, compares gradients.
208+
"""
209+
from torch.distributed.tensor.device_mesh import DeviceMesh
210+
211+
dist_config = DistributedConfig()
212+
device = torch.device(f"cuda:{dist_config.local_rank}")
213+
torch.cuda.set_device(device)
214+
torch.distributed.init_process_group(backend="nccl", device_id=device)
215+
ep_size = dist_config.world_size
216+
217+
# --- Phase 1: EP=1 reference (every rank computes independently) ---
218+
config_ep1 = create_small_mixtral_config(expert_parallel_size=1)
219+
torch.manual_seed(0)
220+
model_ep1 = NVMixtralForCausalLM(config_ep1).to(dtype=torch.bfloat16, device=device)
221+
222+
batch = get_dummy_batch(config_ep1.vocab_size, seq_len=32, batch_size=2, device=device)
223+
224+
outputs_ep1 = model_ep1(**batch)
225+
outputs_ep1.loss.backward()
226+
227+
ref_grads = {name: p.grad.detach().clone().cpu() for name, p in model_ep1.named_parameters() if p.grad is not None}
228+
loss_ep1 = outputs_ep1.loss.detach().clone().cpu()
229+
230+
full_state_dict = {k: v.clone().cpu() for k, v in model_ep1.state_dict().items()}
231+
del model_ep1, outputs_ep1
232+
torch.cuda.empty_cache()
233+
234+
# --- Phase 2: EP=2 distributed run ---
235+
config_ep2 = create_small_mixtral_config(expert_parallel_size=ep_size)
236+
torch.manual_seed(0)
237+
model_ep2 = NVMixtralForCausalLM(config_ep2).to(dtype=torch.bfloat16, device=device)
238+
239+
ep_mesh = DeviceMesh("cuda", list(range(ep_size)))
240+
ep_group = ep_mesh.get_group()
241+
model_ep2.model.set_ep_groups(ep_group, ep_mesh)
242+
243+
distributed_state = _distribute_state_dict(full_state_dict, model_ep2, device)
244+
model_ep2.load_state_dict(distributed_state, strict=True)
245+
246+
batch_cuda = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
247+
248+
outputs_ep2 = model_ep2(**batch_cuda)
249+
outputs_ep2.loss.backward()
250+
251+
test_grads = {}
252+
for name, p in model_ep2.named_parameters():
253+
if p.grad is not None:
254+
g = p.grad
255+
if hasattr(g, "full_tensor"):
256+
g = g.full_tensor()
257+
test_grads[name] = g.detach().clone().cpu()
258+
loss_ep2 = outputs_ep2.loss.detach().clone().cpu()
259+
260+
# --- Phase 3: Compare on rank 0 ---
261+
if dist_config.is_main_process():
262+
torch.testing.assert_close(
263+
loss_ep2,
264+
loss_ep1,
265+
atol=1e-3,
266+
rtol=1e-3,
267+
msg="EP=2 backward: loss does not match EP=1 loss",
268+
)
269+
270+
# All EP=1 parameters should have gradients in EP=2 as well
271+
for name in ref_grads:
272+
assert name in test_grads, f"EP=2 model missing gradient for {name}"
273+
274+
for name in ref_grads:
275+
torch.testing.assert_close(
276+
test_grads[name],
277+
ref_grads[name],
278+
atol=3e-2,
279+
rtol=3e-2,
280+
msg=f"EP=2 gradient mismatch for {name}",
281+
)
282+
283+
print("EP backward test PASSED: EP=2 gradients match EP=1")
284+
285+
torch.distributed.destroy_process_group()
286+
287+
186288
if __name__ == "__main__":
187-
_run_ep_equivalence_test()
289+
test_name = sys.argv[1] if len(sys.argv) > 1 else "forward"
290+
if test_name == "backward":
291+
_run_ep_backward_test()
292+
else:
293+
_run_ep_equivalence_test()

bionemo-recipes/models/mixtral/tests/test_fsdp_ep.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,22 @@ def _distribute_state_dict(full_state_dict: dict, model: torch.nn.Module, device
6262
from torch.distributed.tensor import DTensor, distribute_tensor
6363

6464
distributed_state: dict = {}
65-
for key, value in model.state_dict().items():
66-
if key not in full_state_dict:
65+
# model.state_dict() filters _extra_state keys via the NVMixtralPreTrainedModel
66+
# override, so use nn.Module.state_dict to get the unfiltered dict that includes
67+
# TransformerEngine _extra_state entries required by load_state_dict(strict=True).
68+
for key, value in torch.nn.Module.state_dict(model).items():
69+
if key.endswith("_extra_state"):
70+
distributed_state[key] = value
71+
elif key not in full_state_dict:
6772
continue
68-
full_value = full_state_dict[key]
69-
if isinstance(value, DTensor):
73+
elif isinstance(value, DTensor):
7074
distributed_state[key] = distribute_tensor(
71-
full_value.to(device),
75+
full_state_dict[key].to(device),
7276
value.device_mesh,
7377
list(value.placements),
7478
)
7579
else:
76-
distributed_state[key] = full_value
80+
distributed_state[key] = full_state_dict[key]
7781
return distributed_state
7882

7983

bionemo-recipes/models/mixtral/tests/test_hybrid_ep.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,16 @@ def _run_backward_test():
332332

333333
assert len(ref_grads) > 0, "AllToAll model produced no gradients"
334334
assert len(test_grads) > 0, "FusedTokenRouter model produced no gradients"
335+
336+
# Both AllToAll (via _DifferentiableAllToAll) and FusedTokenRouter provide
337+
# differentiable dispatch/combine, so all parameters should have gradients.
335338
assert ref_grads.keys() == test_grads.keys(), (
336339
f"Gradient key mismatch: "
337-
f"only in ref={ref_grads.keys() - test_grads.keys()}, "
338-
f"only in test={test_grads.keys() - ref_grads.keys()}"
340+
f"AllToAll-only={ref_grads.keys() - test_grads.keys()}, "
341+
f"Fused-only={test_grads.keys() - ref_grads.keys()}"
339342
)
340343

341344
for name in ref_grads:
342-
assert name in test_grads, f"FusedTokenRouter missing gradient for parameter: {name}"
343345
torch.testing.assert_close(
344346
test_grads[name],
345347
ref_grads[name],

0 commit comments

Comments
 (0)