Skip to content

Commit ffecdc9

Browse files
committed
Update tests to reflect change in overlapping logic
Signed-off-by: Deepak Narayanan <[email protected]>
1 parent 59f5560 commit ffecdc9

File tree

2 files changed

+113
-88
lines changed

2 files changed

+113
-88
lines changed

tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
13
import contextlib
24
from typing import Optional
35

@@ -169,15 +171,20 @@ def test_grad_sync(
169171
)
170172
!= 0
171173
):
172-
# With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/data_parallel_word_size
173-
# When average_in_collective=False, the grad data is always first scaled by 1/data_parallel_word_size and then summed by AR/RS
174-
# when use_distributed_optimizer=True, only for rank=0 param_and_grad_buffer.grad_data[0] is updated, for other ranks
175-
# another shard of grad_data is updated while param_and_grad_buffer.grad_data[0] is unchanged (=1/data_parallel_word_size)
174+
# With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to
175+
# 1/data_parallel_word_size.
176+
# When average_in_collective=False, the grad data is always first scaled by
177+
# 1/data_parallel_word_size and then summed by AR/RS.
178+
# When use_distributed_optimizer=True, only for rank=0,
179+
# param_and_grad_buffer.grad_data[0] is updated. For other ranks another shard of
180+
# grad_data is updated while param_and_grad_buffer.grad_data[0] is unchanged
181+
# (=1/data_parallel_word_size).
176182
non_ep_expected_grad_data_value_after_collective /= (
177183
parallel_state.get_data_parallel_world_size()
178184
)
179185
if ep_size > 1:
180-
# For MoE models with exper parallelism, each expert will receive tokens from EPxETP times batches, such that the expert gradient will be EPxETP times after backward,
186+
# For MoE models with exper parallelism, each expert will receive tokens from EPxETP
187+
# times batches, such that the expert gradient will be EPxETP times after backward,
181188
# and the expected gradient after collective should be 1.0 as same as dense params.
182189
ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size))
183190
ep_expected_grad_data_value_after_collective = 1
@@ -186,67 +193,76 @@ def test_grad_sync(
186193
and (not average_in_collective)
187194
and parallel_state.get_expert_data_parallel_rank(partial_expert_data_parallel=True) != 0
188195
):
189-
# With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/EDP
190-
# When average_in_collective=False, the grad data is always first scaled by expert_data_parallel_size and then summed by AR/RS
191-
# after SUM collective in expert_data_group, the scale will be 1.0.
196+
# With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/EDP.
197+
# When average_in_collective=False, the grad data is always first scaled by
198+
# expert_data_parallel_size and then summed by AR/RS.
199+
# After SUM collective in expert_data_group, the scale will be 1.0.
192200
ep_expected_grad_data_value_after_collective /= (
193201
parallel_state.get_expert_data_parallel_world_size()
194202
)
195203

196204
params = list(model.parameters())
197205
map_bucket_to_last_param_idx = {}
198-
for i, param in enumerate(params):
199-
if not (param in param_to_bucket_group):
200-
# it means this parameter is not on this device, skip
201-
continue
202-
bucket_group = param_to_bucket_group[param]
203-
if bucket_group in map_bucket_to_last_param_idx:
204-
param_idx = map_bucket_to_last_param_idx[bucket_group] + 1
205-
else:
206-
param_idx = 0
207-
map_bucket_to_last_param_idx[bucket_group] = param_idx
208-
209-
register_grad_sync_context = (
210-
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
211-
)
212-
finish_grad_sync_context = contextlib.nullcontext()
213-
if (
214-
param_idx < (len(bucket_group.params) - 1)
215-
and overlap_grad_reduce
216-
and num_distributed_optimizer_instances == 1
217-
):
218-
# Can't finish grad sync until all params have been registered ready.
219-
finish_grad_sync_context = pytest.raises(AssertionError)
220-
221-
with register_grad_sync_context:
222-
bucket_group.register_grad_ready(param)
223-
with finish_grad_sync_context:
224-
# When overlap_grad_reduce is True, this should throw an assertion error until all
225-
# params in the model have registered their grad above.
226-
# When overlap_grad_reduce is False, the collective is forced through.
227-
bucket_group.finish_grad_sync()
228-
229-
if bucket_group in non_ep_bucket_groups:
230-
expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
231-
else:
232-
expected_grad_data_value = ep_expected_grad_data_value_after_collective
233-
# Before gradient sync, the gradient value should keep original.
234-
if overlap_grad_reduce and param_idx < (len(bucket_group.params) - 1):
235-
if bucket_group in non_ep_bucket_groups:
236-
expected_grad_data_value = 1
206+
for iteration in range(2):
207+
for i, param in enumerate(params):
208+
if not (param in param_to_bucket_group):
209+
# it means this parameter is not on this device, skip
210+
continue
211+
bucket_group = param_to_bucket_group[param]
212+
if bucket_group in map_bucket_to_last_param_idx:
213+
param_idx = map_bucket_to_last_param_idx[bucket_group] + 1
237214
else:
238-
expected_grad_data_value = ep_size * etp_size
215+
param_idx = 0
216+
map_bucket_to_last_param_idx[bucket_group] = param_idx
217+
218+
register_grad_sync_context = (
219+
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
220+
)
221+
finish_grad_sync_context = contextlib.nullcontext()
222+
if (
223+
param_idx < (len(bucket_group.params) - 1)
224+
and overlap_grad_reduce
225+
and num_distributed_optimizer_instances == 1
226+
):
227+
# Can't finish grad sync until all params have been registered ready.
228+
finish_grad_sync_context = pytest.raises(AssertionError)
229+
230+
with register_grad_sync_context:
231+
bucket_group.register_grad_ready(param)
232+
# Don't call finish_grad_sync() multiple times in the first iteration when
233+
# golden_per_param_grad_ready_counts is being populated.
234+
if iteration == 0 and i < (len(params) - 1):
235+
continue
236+
with finish_grad_sync_context:
237+
# When overlap_grad_reduce is True, this should throw an assertion error until all
238+
# params in the model have registered their grad above.
239+
# When overlap_grad_reduce is False, the collective is forced through.
240+
bucket_group.finish_grad_sync()
239241

240-
if bucket_group in non_ep_bucket_groups:
241-
assert non_ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
242-
else:
243-
assert ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
242+
if bucket_group in non_ep_bucket_groups:
243+
expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
244+
else:
245+
expected_grad_data_value = ep_expected_grad_data_value_after_collective
246+
# Before gradient sync, the gradient value should keep original.
247+
if overlap_grad_reduce and param_idx < (len(bucket_group.params) - 1):
248+
if bucket_group in non_ep_bucket_groups:
249+
expected_grad_data_value = 1
250+
else:
251+
expected_grad_data_value = ep_size * etp_size
244252

245-
if not overlap_grad_reduce:
246-
# Reset grad_data for subsequent collectives.
247253
if bucket_group in non_ep_bucket_groups:
248-
non_ep_param_and_grad_buffer.grad_data.data.fill_(1.0)
254+
assert non_ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
249255
else:
250-
ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size))
256+
assert ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
257+
258+
if not overlap_grad_reduce:
259+
# Reset grad_data for subsequent collectives.
260+
if bucket_group in non_ep_bucket_groups:
261+
non_ep_param_and_grad_buffer.grad_data.data.fill_(1.0)
262+
else:
263+
ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size))
264+
265+
# Call reset to set .is_first_batch to False.
266+
bucket_group.reset()
251267

252268
Utils.destroy_model_parallel()

tests/unit_tests/distributed/test_param_and_grad_buffer.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
13
import contextlib
24
import math
35
from typing import Optional
@@ -164,7 +166,6 @@ def _pad_param_if_needed(numel_unpadded):
164166
@pytest.mark.parametrize("overlap_grad_reduce", [False, True])
165167
@pytest.mark.parametrize("average_in_collective", [False, True])
166168
@pytest.mark.parametrize("num_distributed_optimizer_instances", [1, 2])
167-
# @pytest.mark.flaky
168169
def test_grad_sync(
169170
use_distributed_optimizer: bool,
170171
overlap_grad_reduce: bool,
@@ -216,36 +217,44 @@ def test_grad_sync(
216217
expected_grad_data_value_after_collective /= parallel_state.get_data_parallel_world_size()
217218

218219
params = list(model.parameters())
219-
for i, param in enumerate(params):
220-
assert param in param_to_bucket_group
221-
bucket_group = param_to_bucket_group[param]
222-
register_grad_sync_context = (
223-
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
224-
)
225-
finish_grad_sync_context = contextlib.nullcontext()
226-
if (
227-
i < (len(params) - 1)
228-
and overlap_grad_reduce
229-
and num_distributed_optimizer_instances == 1
230-
):
231-
# Can't finish grad sync until all params have been registered ready.
232-
finish_grad_sync_context = pytest.raises(AssertionError)
233-
234-
with register_grad_sync_context:
235-
bucket_group.register_grad_ready(param)
236-
with finish_grad_sync_context:
237-
# When overlap_grad_reduce is True, this should throw an assertion error until all
238-
# params in the model have registered their grad above.
239-
# When overlap_grad_reduce is False, the collective is forced through.
240-
bucket_group.finish_grad_sync()
241-
242-
expected_grad_data_value = expected_grad_data_value_after_collective
243-
if overlap_grad_reduce and i < (len(params) - 1):
244-
expected_grad_data_value = 1
245-
assert param_and_grad_buffer.grad_data[0] == expected_grad_data_value
246-
247-
if not overlap_grad_reduce:
248-
# Reset grad_data for subsequent collectives.
249-
param_and_grad_buffer.grad_data.data.fill_(1.0)
220+
for iteration in range(2):
221+
for i, param in enumerate(params):
222+
assert param in param_to_bucket_group
223+
bucket_group = param_to_bucket_group[param]
224+
register_grad_sync_context = (
225+
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
226+
)
227+
finish_grad_sync_context = contextlib.nullcontext()
228+
if (
229+
i < (len(params) - 1)
230+
and overlap_grad_reduce
231+
and num_distributed_optimizer_instances == 1
232+
):
233+
# Can't finish grad sync until all params have been registered ready.
234+
finish_grad_sync_context = pytest.raises(AssertionError)
235+
236+
with register_grad_sync_context:
237+
bucket_group.register_grad_ready(param)
238+
# Don't call finish_grad_sync() multiple times in the first iteration when
239+
# golden_per_param_grad_ready_counts is being populated.
240+
if iteration == 0 and i < (len(params) - 1):
241+
continue
242+
with finish_grad_sync_context:
243+
# When overlap_grad_reduce is True, this should throw an assertion error until all
244+
# params in the model have registered their grad above.
245+
# When overlap_grad_reduce is False, the collective is forced through.
246+
bucket_group.finish_grad_sync()
247+
248+
expected_grad_data_value = expected_grad_data_value_after_collective
249+
if overlap_grad_reduce and i < (len(params) - 1):
250+
expected_grad_data_value = 1
251+
assert param_and_grad_buffer.grad_data[0] == expected_grad_data_value
252+
253+
if not overlap_grad_reduce:
254+
# Reset grad_data for subsequent collectives.
255+
param_and_grad_buffer.grad_data.data.fill_(1.0)
256+
257+
# Call reset to set .is_first_batch to False.
258+
bucket_group.reset()
250259

251260
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)