@@ -195,58 +195,66 @@ def test_grad_sync(
195195
196196 params = list (model .parameters ())
197197 map_bucket_to_last_param_idx = {}
198- for i , param in enumerate (params ):
199- if not (param in param_to_bucket_group ):
200- # it means this parameter is not on this device, skip
201- continue
202- bucket_group = param_to_bucket_group [param ]
203- if bucket_group in map_bucket_to_last_param_idx :
204- param_idx = map_bucket_to_last_param_idx [bucket_group ] + 1
205- else :
206- param_idx = 0
207- map_bucket_to_last_param_idx [bucket_group ] = param_idx
208-
209- register_grad_sync_context = (
210- contextlib .nullcontext () if overlap_grad_reduce else pytest .raises (AssertionError )
211- )
212- finish_grad_sync_context = contextlib .nullcontext ()
213- if (
214- param_idx < (len (bucket_group .params ) - 1 )
215- and overlap_grad_reduce
216- and num_distributed_optimizer_instances == 1
217- ):
218- # Can't finish grad sync until all params have been registered ready.
219- finish_grad_sync_context = pytest .raises (AssertionError )
220-
221- with register_grad_sync_context :
222- bucket_group .register_grad_ready (param )
223- with finish_grad_sync_context :
224- # When overlap_grad_reduce is True, this should throw an assertion error until all
225- # params in the model have registered their grad above.
226- # When overlap_grad_reduce is False, the collective is forced through.
227- bucket_group .finish_grad_sync ()
228-
229- if bucket_group in non_ep_bucket_groups :
230- expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
231- else :
232- expected_grad_data_value = ep_expected_grad_data_value_after_collective
233- # Before gradient sync, the gradient value should keep original.
234- if overlap_grad_reduce and param_idx < (len (bucket_group .params ) - 1 ):
235- if bucket_group in non_ep_bucket_groups :
236- expected_grad_data_value = 1
198+ for iteration in range (2 ):
199+ for i , param in enumerate (params ):
200+ if not (param in param_to_bucket_group ):
201+ # it means this parameter is not on this device, skip
202+ continue
203+ bucket_group = param_to_bucket_group [param ]
204+ if bucket_group in map_bucket_to_last_param_idx :
205+ param_idx = map_bucket_to_last_param_idx [bucket_group ] + 1
237206 else :
238- expected_grad_data_value = ep_size * etp_size
207+ param_idx = 0
208+ map_bucket_to_last_param_idx [bucket_group ] = param_idx
209+
210+ register_grad_sync_context = (
211+ contextlib .nullcontext () if overlap_grad_reduce else pytest .raises (AssertionError )
212+ )
213+ finish_grad_sync_context = contextlib .nullcontext ()
214+ if (
215+ param_idx < (len (bucket_group .params ) - 1 )
216+ and overlap_grad_reduce
217+ and num_distributed_optimizer_instances == 1
218+ ):
219+ # Can't finish grad sync until all params have been registered ready.
220+ finish_grad_sync_context = pytest .raises (AssertionError )
239221
240- if bucket_group in non_ep_bucket_groups :
241- assert non_ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
242- else :
243- assert ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
222+ with register_grad_sync_context :
223+ bucket_group .register_grad_ready (param )
224+ # Don't call finish_grad_sync() multiple times in the first iteration when
225+ # golden_per_param_grad_ready_counts is being populated.
226+ if iteration == 0 and i < (len (params ) - 1 ):
227+ continue
228+ with finish_grad_sync_context :
229+ # When overlap_grad_reduce is True, this should throw an assertion error until all
230+ # params in the model have registered their grad above.
231+ # When overlap_grad_reduce is False, the collective is forced through.
232+ bucket_group .finish_grad_sync ()
244233
245- if not overlap_grad_reduce :
246- # Reset grad_data for subsequent collectives.
247234 if bucket_group in non_ep_bucket_groups :
248- non_ep_param_and_grad_buffer . grad_data . data . fill_ ( 1.0 )
235+ expected_grad_data_value = non_ep_expected_grad_data_value_after_collective
249236 else :
250- ep_param_and_grad_buffer .grad_data .data .fill_ (float (ep_size * etp_size ))
237+ expected_grad_data_value = ep_expected_grad_data_value_after_collective
238+ # Before gradient sync, the gradient value should keep original.
239+ if overlap_grad_reduce and param_idx < (len (bucket_group .params ) - 1 ):
240+ if bucket_group in non_ep_bucket_groups :
241+ expected_grad_data_value = 1
242+ else :
243+ expected_grad_data_value = ep_size * etp_size
244+
245+ if bucket_group in non_ep_bucket_groups :
246+ assert non_ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
247+ else :
248+ assert ep_param_and_grad_buffer .grad_data [0 ] == expected_grad_data_value
249+
250+ if not overlap_grad_reduce :
251+ # Reset grad_data for subsequent collectives.
252+ if bucket_group in non_ep_bucket_groups :
253+ non_ep_param_and_grad_buffer .grad_data .data .fill_ (1.0 )
254+ else :
255+ ep_param_and_grad_buffer .grad_data .data .fill_ (float (ep_size * etp_size ))
256+
257+ # Call reset to set .is_first_batch to False.
258+ bucket_group .reset ()
251259
252260 Utils .destroy_model_parallel ()
0 commit comments