Skip to content

Commit d039292

Browse files
committed
Update tests to reflect change in overlapping logic
1 parent 59f5560 commit d039292

File tree

2 files changed

+95
-80
lines changed

2 files changed

+95
-80
lines changed

tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -195,58 +195,66 @@ def test_grad_sync(
195195

196196
params = list(model.parameters())
197197
map_bucket_to_last_param_idx = {}
198-
for i, param in enumerate(params):
199-
if not (param in param_to_bucket_group):
200-
# it means this parameter is not on this device, skip
201-
continue
202-
bucket_group = param_to_bucket_group[param]
203-
if bucket_group in map_bucket_to_last_param_idx:
204-
param_idx = map_bucket_to_last_param_idx[bucket_group] + 1
205-
else:
206-
param_idx = 0
207-
map_bucket_to_last_param_idx[bucket_group] = param_idx
208-
209-
register_grad_sync_context = (
210-
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
211-
)
212-
finish_grad_sync_context = contextlib.nullcontext()
213-
if (
214-
param_idx < (len(bucket_group.params) - 1)
215-
and overlap_grad_reduce
216-
and num_distributed_optimizer_instances == 1
217-
):
218-
# Can't finish grad sync until all params have been registered ready.
219-
finish_grad_sync_context = pytest.raises(AssertionError)
220-
221-
with register_grad_sync_context:
222-
bucket_group.register_grad_ready(param)
223-
with finish_grad_sync_context:
224-
# When overlap_grad_reduce is True, this should throw an assertion error until all
225-
# params in the model have registered their grad above.
226-
# When overlap_grad_reduce is False, the collective is forced through.
227-
bucket_group.finish_grad_sync()
228-
229-
if bucket_group in non_ep_bucket_groups:
230-
expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
231-
else:
232-
expected_grad_data_value = ep_expected_grad_data_value_after_collective
233-
# Before gradient sync, the gradient value should keep original.
234-
if overlap_grad_reduce and param_idx < (len(bucket_group.params) - 1):
235-
if bucket_group in non_ep_bucket_groups:
236-
expected_grad_data_value = 1
198+
for iteration in range(2):
199+
for i, param in enumerate(params):
200+
if not (param in param_to_bucket_group):
201+
# it means this parameter is not on this device, skip
202+
continue
203+
bucket_group = param_to_bucket_group[param]
204+
if bucket_group in map_bucket_to_last_param_idx:
205+
param_idx = map_bucket_to_last_param_idx[bucket_group] + 1
237206
else:
238-
expected_grad_data_value = ep_size * etp_size
207+
param_idx = 0
208+
map_bucket_to_last_param_idx[bucket_group] = param_idx
209+
210+
register_grad_sync_context = (
211+
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
212+
)
213+
finish_grad_sync_context = contextlib.nullcontext()
214+
if (
215+
param_idx < (len(bucket_group.params) - 1)
216+
and overlap_grad_reduce
217+
and num_distributed_optimizer_instances == 1
218+
):
219+
# Can't finish grad sync until all params have been registered ready.
220+
finish_grad_sync_context = pytest.raises(AssertionError)
239221

240-
if bucket_group in non_ep_bucket_groups:
241-
assert non_ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
242-
else:
243-
assert ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
222+
with register_grad_sync_context:
223+
bucket_group.register_grad_ready(param)
224+
# Don't call finish_grad_sync() multiple times in the first iteration when
225+
# golden_per_param_grad_ready_counts is being populated.
226+
if iteration == 0 and i < (len(params) - 1):
227+
continue
228+
with finish_grad_sync_context:
229+
# When overlap_grad_reduce is True, this should throw an assertion error until all
230+
# params in the model have registered their grad above.
231+
# When overlap_grad_reduce is False, the collective is forced through.
232+
bucket_group.finish_grad_sync()
244233

245-
if not overlap_grad_reduce:
246-
# Reset grad_data for subsequent collectives.
247234
if bucket_group in non_ep_bucket_groups:
248-
non_ep_param_and_grad_buffer.grad_data.data.fill_(1.0)
235+
expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
249236
else:
250-
ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size))
237+
expected_grad_data_value = ep_expected_grad_data_value_after_collective
238+
# Before gradient sync, the gradient value should keep original.
239+
if overlap_grad_reduce and param_idx < (len(bucket_group.params) - 1):
240+
if bucket_group in non_ep_bucket_groups:
241+
expected_grad_data_value = 1
242+
else:
243+
expected_grad_data_value = ep_size * etp_size
244+
245+
if bucket_group in non_ep_bucket_groups:
246+
assert non_ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
247+
else:
248+
assert ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value
249+
250+
if not overlap_grad_reduce:
251+
# Reset grad_data for subsequent collectives.
252+
if bucket_group in non_ep_bucket_groups:
253+
non_ep_param_and_grad_buffer.grad_data.data.fill_(1.0)
254+
else:
255+
ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size))
256+
257+
# Call reset to set .is_first_batch to False.
258+
bucket_group.reset()
251259

252260
Utils.destroy_model_parallel()

tests/unit_tests/distributed/test_param_and_grad_buffer.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def _pad_param_if_needed(numel_unpadded):
164164
@pytest.mark.parametrize("overlap_grad_reduce", [False, True])
165165
@pytest.mark.parametrize("average_in_collective", [False, True])
166166
@pytest.mark.parametrize("num_distributed_optimizer_instances", [1, 2])
167-
# @pytest.mark.flaky
168167
def test_grad_sync(
169168
use_distributed_optimizer: bool,
170169
overlap_grad_reduce: bool,
@@ -216,36 +215,44 @@ def test_grad_sync(
216215
expected_grad_data_value_after_collective /= parallel_state.get_data_parallel_world_size()
217216

218217
params = list(model.parameters())
219-
for i, param in enumerate(params):
220-
assert param in param_to_bucket_group
221-
bucket_group = param_to_bucket_group[param]
222-
register_grad_sync_context = (
223-
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
224-
)
225-
finish_grad_sync_context = contextlib.nullcontext()
226-
if (
227-
i < (len(params) - 1)
228-
and overlap_grad_reduce
229-
and num_distributed_optimizer_instances == 1
230-
):
231-
# Can't finish grad sync until all params have been registered ready.
232-
finish_grad_sync_context = pytest.raises(AssertionError)
233-
234-
with register_grad_sync_context:
235-
bucket_group.register_grad_ready(param)
236-
with finish_grad_sync_context:
237-
# When overlap_grad_reduce is True, this should throw an assertion error until all
238-
# params in the model have registered their grad above.
239-
# When overlap_grad_reduce is False, the collective is forced through.
240-
bucket_group.finish_grad_sync()
241-
242-
expected_grad_data_value = expected_grad_data_value_after_collective
243-
if overlap_grad_reduce and i < (len(params) - 1):
244-
expected_grad_data_value = 1
245-
assert param_and_grad_buffer.grad_data[0] == expected_grad_data_value
246-
247-
if not overlap_grad_reduce:
248-
# Reset grad_data for subsequent collectives.
249-
param_and_grad_buffer.grad_data.data.fill_(1.0)
218+
for iteration in range(2):
219+
for i, param in enumerate(params):
220+
assert param in param_to_bucket_group
221+
bucket_group = param_to_bucket_group[param]
222+
register_grad_sync_context = (
223+
contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError)
224+
)
225+
finish_grad_sync_context = contextlib.nullcontext()
226+
if (
227+
i < (len(params) - 1)
228+
and overlap_grad_reduce
229+
and num_distributed_optimizer_instances == 1
230+
):
231+
# Can't finish grad sync until all params have been registered ready.
232+
finish_grad_sync_context = pytest.raises(AssertionError)
233+
234+
with register_grad_sync_context:
235+
bucket_group.register_grad_ready(param)
236+
# Don't call finish_grad_sync() multiple times in the first iteration when
237+
# golden_per_param_grad_ready_counts is being populated.
238+
if iteration == 0 and i < (len(params) - 1):
239+
continue
240+
with finish_grad_sync_context:
241+
# When overlap_grad_reduce is True, this should throw an assertion error until all
242+
# params in the model have registered their grad above.
243+
# When overlap_grad_reduce is False, the collective is forced through.
244+
bucket_group.finish_grad_sync()
245+
246+
expected_grad_data_value = expected_grad_data_value_after_collective
247+
if overlap_grad_reduce and i < (len(params) - 1):
248+
expected_grad_data_value = 1
249+
assert param_and_grad_buffer.grad_data[0] == expected_grad_data_value
250+
251+
if not overlap_grad_reduce:
252+
# Reset grad_data for subsequent collectives.
253+
param_and_grad_buffer.grad_data.data.fill_(1.0)
254+
255+
# Call reset to set .is_first_batch to False.
256+
bucket_group.reset()
250257

251258
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)