Skip to content

Commit aeda7f9

Browse files
inkcherrytjruwase
andauthored
Fix invalid check of recorded parameter orders in zero stage3. (#2550)
Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent ffcf384 commit aeda7f9

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

deepspeed/runtime/zero/partitioned_param_coordinator.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -187,16 +187,18 @@ def reset_step(self) -> None:
187187
f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")
188188

189189
if not self.is_complete_trace(): # not self.trace_complete:
190-
# Make sure that recorded parameter and submodule orders are
191-
# identical across ranks
190+
# Make sure that recorded submodule orders are identical across ranks
192191
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
193-
assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order])
194-
assert_ints_same_as_other_ranks(
195-
[p.step_id_last_used_at for p in self.__param_order])
196192

197193
if self.is_record_trace():
198194
# Successfully recorded a trace
199195
self.construct_parameter_trace_from_module_trace()
196+
# Make sure that recorded parameter orders are identical across ranks
197+
assert_ints_same_as_other_ranks(
198+
[p.param.ds_id for p in self.__param_order])
199+
assert_ints_same_as_other_ranks(
200+
[p.step_id_last_used_at for p in self.__param_order])
201+
200202
self.__submodule_order = tuple(self.__submodule_order) # freeze
201203
self.__param_order = tuple(self.__param_order) # freeze
202204
self.__trace_mode = ZeRoTraceMode.COMPLETE

0 commit comments

Comments
 (0)