1+ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
13import contextlib
24from typing import Optional
35
@@ -169,15 +171,20 @@ def test_grad_sync(
169171 )
170172 != 0
171173 ):
172- # With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/data_parallel_word_size
173- # When average_in_collective=False, the grad data is always first scaled by 1/data_parallel_word_size and then summed by AR/RS
174- # when use_distributed_optimizer=True, only for rank=0 param_and_grad_buffer.grad_data[0] is updated, for other ranks
175- # another shard of grad_data is updated while param_and_grad_buffer.grad_data[0] is unchanged (=1/data_parallel_word_size)
174+ # With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to
175+ # 1/data_parallel_word_size.
176+ # When average_in_collective=False, the grad data is always first scaled by
177+ # 1/data_parallel_word_size and then summed by AR/RS.
178+ # When use_distributed_optimizer=True, only for rank=0,
179+ # param_and_grad_buffer.grad_data[0] is updated. For other ranks another shard of
180+ # grad_data is updated while param_and_grad_buffer.grad_data[0] is unchanged
181+ # (=1/data_parallel_word_size).
176182 non_ep_expected_grad_data_value_after_collective /= (
177183 parallel_state .get_data_parallel_world_size ()
178184 )
179185 if ep_size > 1 :
180- # For MoE models with exper parallelism, each expert will receive tokens from EPxETP times batches, such that the expert gradient will be EPxETP times after backward,
186+ # For MoE models with exper parallelism, each expert will receive tokens from EPxETP
187+ # times batches, such that the expert gradient will be EPxETP times after backward,
181188 # and the expected gradient after collective should be 1.0 as same as dense params.
182189 ep_param_and_grad_buffer .grad_data .data .fill_ (float (ep_size * etp_size ))
183190 ep_expected_grad_data_value_after_collective = 1
@@ -186,67 +193,76 @@ def test_grad_sync(
186193 and (not average_in_collective )
187194 and parallel_state .get_expert_data_parallel_rank (partial_expert_data_parallel = True ) != 0
188195 ):
189- # With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/EDP
190- # When average_in_collective=False, the grad data is always first scaled by expert_data_parallel_size and then summed by AR/RS
191- # after SUM collective in expert_data_group, the scale will be 1.0.
196+ # With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/EDP.
197+ # When average_in_collective=False, the grad data is always first scaled by
198+ # expert_data_parallel_size and then summed by AR/RS.
199+ # After SUM collective in expert_data_group, the scale will be 1.0.
192200 ep_expected_grad_data_value_after_collective /= (
193201 parallel_state .get_expert_data_parallel_world_size ()
194202 )
195203
196204 params = list (model .parameters ())
197205 map_bucket_to_last_param_idx = {}
198- for i , param in enumerate (params ):
199- if not (param in param_to_bucket_group ):
200- # it means this parameter is not on this device, skip
201- continue
202- bucket_group = param_to_bucket_group [param ]
203- if bucket_group in map_bucket_to_last_param_idx :
204- param_idx = map_bucket_to_last_param_idx [bucket_group ] + 1
205- else :
206- param_idx = 0
207- map_bucket_to_last_param_idx [bucket_group ] = param_idx
208-
209- register_grad_sync_context = (
210- contextlib .nullcontext () if overlap_grad_reduce else pytest .raises (AssertionError )
211- )
212- finish_grad_sync_context = contextlib .nullcontext ()
213- if (
214- param_idx < (len (bucket_group .params ) - 1 )
215- and overlap_grad_reduce
216- and num_distributed_optimizer_instances == 1
217- ):
218- # Can't finish grad sync until all params have been registered ready.
219- finish_grad_sync_context = pytest .raises (AssertionError )
220-
221- with register_grad_sync_context :
222- bucket_group .register_grad_ready (param )
223- with finish_grad_sync_context :
224- # When overlap_grad_reduce is True, this should throw an assertion error until all
225- # params in the model have registered their grad above.
226- # When overlap_grad_reduce is False, the collective is forced through.
227- bucket_group .finish_grad_sync ()
228-
229- if bucket_group in non_ep_bucket_groups :
230- expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
231- else :
232- expected_grad_data_value = ep_expected_grad_data_value_after_collective
233- # Before gradient sync, the gradient value should keep original.
234- if overlap_grad_reduce and param_idx < (len (bucket_group .params ) - 1 ):
235- if bucket_group in non_ep_bucket_groups :
236- expected_grad_data_value = 1
206+ for iteration in range (2 ):
207+ for i , param in enumerate (params ):
208+ if not (param in param_to_bucket_group ):
209+ # it means this parameter is not on this device, skip
210+ continue
211+ bucket_group = param_to_bucket_group [param ]
212+ if bucket_group in map_bucket_to_last_param_idx :
213+ param_idx = map_bucket_to_last_param_idx [bucket_group ] + 1
237214 else :
238- expected_grad_data_value = ep_size * etp_size
215+ param_idx = 0
216+ map_bucket_to_last_param_idx [bucket_group ] = param_idx
217+
218+ register_grad_sync_context = (
219+ contextlib .nullcontext () if overlap_grad_reduce else pytest .raises (AssertionError )
220+ )
221+ finish_grad_sync_context = contextlib .nullcontext ()
222+ if (
223+ param_idx < (len (bucket_group .params ) - 1 )
224+ and overlap_grad_reduce
225+ and num_distributed_optimizer_instances == 1
226+ ):
227+ # Can't finish grad sync until all params have been registered ready.
228+ finish_grad_sync_context = pytest .raises (AssertionError )
229+
230+ with register_grad_sync_context :
231+ bucket_group .register_grad_ready (param )
232+ # Don't call finish_grad_sync() multiple times in the first iteration when
233+ # golden_per_param_grad_ready_counts is being populated.
234+ if iteration == 0 and i < (len (params ) - 1 ):
235+ continue
236+ with finish_grad_sync_context :
237+ # When overlap_grad_reduce is True, this should throw an assertion error until all
238+ # params in the model have registered their grad above.
239+ # When overlap_grad_reduce is False, the collective is forced through.
240+ bucket_group .finish_grad_sync ()
239241
240- if bucket_group in non_ep_bucket_groups :
241- assert non_ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
242- else :
243- assert ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
242+ if bucket_group in non_ep_bucket_groups :
243+ expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
244+ else :
245+ expected_grad_data_value = ep_expected_grad_data_value_after_collective
246+ # Before gradient sync, the gradient value should keep original.
247+ if overlap_grad_reduce and param_idx < (len (bucket_group .params ) - 1 ):
248+ if bucket_group in non_ep_bucket_groups :
249+ expected_grad_data_value = 1
250+ else :
251+ expected_grad_data_value = ep_size * etp_size
244252
245- if not overlap_grad_reduce :
246- # Reset grad_data for subsequent collectives.
247253 if bucket_group in non_ep_bucket_groups :
248- non_ep_param_and_grad_buffer .grad_data . data . fill_ ( 1.0 )
254+ assert non_ep_param_and_grad_buffer .grad_data [ 0 ] == expected_grad_data_value
249255 else :
250- ep_param_and_grad_buffer .grad_data .data .fill_ (float (ep_size * etp_size ))
256+ assert ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
257+
258+ if not overlap_grad_reduce :
259+ # Reset grad_data for subsequent collectives.
260+ if bucket_group in non_ep_bucket_groups :
261+ non_ep_param_and_grad_buffer .grad_data .data .fill_ (1.0 )
262+ else :
263+ ep_param_and_grad_buffer .grad_data .data .fill_ (float (ep_size * etp_size ))
264+
265+ # Call reset to set .is_first_batch to False.
266+ bucket_group .reset ()
251267
252268 Utils .destroy_model_parallel ()
0 commit comments