Skip to content

[BUG] DistributedOptimizer issues in Mixed-Precision: Stale model_param_group_index_map and .copy_ failure during get/load parameter_state #2777

@HelloWorld686

Description

@HelloWorld686

Describe the bug

During mixed-precision training (BF16 & FP32, FP8 is optional), a RuntimeError (size mismatch) is triggered when saving a checkpoint via get_parameter_state_dp_zero.

Image

Based on our analysis, the hardcoded parameter reordering in _build_model_and_main_param_groups within DistributedOptimizer causes the self.model_param_group_index_map to become out of sync with the actual optimizer.param_groups.

Steps/Code to reproduce bug

  1. Before building the DDP model and optimizer,use Float16Module to wrap the model for BF16 training, but manually promote certain modules (both params and inputs) to FP32.
  2. Train for several steps.
  3. Call get_parameter_state_dp_zero in DistributedOptimizer to collect optimizer states, which triggers the size mismatch error.

Root Cause Analysis

  1. Initial Map Construction: In __init__, self.model_param_group_index_map is first constructed via _build_optimizer_group_ranges. This map records the position (group_index, group_order) of param in param_groups

  2. Hardcoded Reordering: Subsequently, _build_model_and_main_param_groups reorders the parameters within each group (placing native FP32 shards at the front and main parameter shards converted from FP16/BF16 at the back) and updates optimizer.param_groups accordingly:

Image
  1. Index Invalidation: The model_param_group_index_map is not updated after this reordering. Consequently, downstream functions like _get_main_param_and_optimizer_states retrieve the wrong Tensors using stale group_order indices, leading to shape mismatches during buffer copy operations.

Additional question

  1. What is the design motivation behind this specific reordering (grouping by DType)?
  2. What is the recommended best practice to fix this: disabling the reordering to maintain discovery order consistency, or explicitly updating the model_param_group_index_map after the reordering is performed?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions