Skip to content

Commit e2dc3ee

Browse files
wukong1992shaominhwchen2017
authored
Fix, bf16 optimizer remove dup loop (#7054)
bf16 with moe refresh optimizer state from bf16 ckpt will raise IndexError: list index out of range Signed-off-by: shaomin <[email protected]> Co-authored-by: shaomin <[email protected]> Co-authored-by: Hongwei Chen <[email protected]>
1 parent d98204b commit e2dc3ee

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

deepspeed/runtime/bf16_optimizer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,10 @@ def state_dict(self):
472472

473473
# Restore base optimizer fp32 weights bfloat16 weights
474474
def _restore_from_bit16_weights(self):
475-
for i, group in enumerate(self.bf16_groups):
475+
for i, (bf16_partitions,
476+
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
476477
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
477-
for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
478-
fp32_partition.data.copy_(bf16_partitions[partition_id].data)
478+
fp32_partition.data.copy_(bf16_partitions[partition_id].data)
479479

480480
def refresh_fp32_params(self):
481481
self._restore_from_bit16_weights()

0 commit comments

Comments
 (0)