File tree 1 file changed +7
-5
lines changed
1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -187,16 +187,18 @@ def reset_step(self) -> None:
187
187
f"{ [p .ds_summary for p in self .__inflight_param_registry .keys ()]} " )
188
188
189
189
if not self .is_complete_trace (): # not self.trace_complete:
190
- # Make sure that recorded parameter and submodule orders are
191
- # identical across ranks
190
+ # Make sure that recorded submodule orders are identical across ranks
192
191
assert_ints_same_as_other_ranks ([m .id for m in self .__submodule_order ])
193
- assert_ints_same_as_other_ranks ([p .param .ds_id for p in self .__param_order ])
194
- assert_ints_same_as_other_ranks (
195
- [p .step_id_last_used_at for p in self .__param_order ])
196
192
197
193
if self .is_record_trace ():
198
194
# Successfully recorded a trace
199
195
self .construct_parameter_trace_from_module_trace ()
196
+ # Make sure that recorded parameter orders are identical across ranks
197
+ assert_ints_same_as_other_ranks (
198
+ [p .param .ds_id for p in self .__param_order ])
199
+ assert_ints_same_as_other_ranks (
200
+ [p .step_id_last_used_at for p in self .__param_order ])
201
+
200
202
self .__submodule_order = tuple (self .__submodule_order ) # freeze
201
203
self .__param_order = tuple (self .__param_order ) # freeze
202
204
self .__trace_mode = ZeRoTraceMode .COMPLETE
You can’t perform that action at this time.
0 commit comments