Skip to content

[BUG] Dtype mismatch (bf16 vs fp32) when resuming Muon optimizer from checkpoint #7746

@literid

Description

@literid

Bug Description

When resuming training from a checkpoint using bf16 and the Muon optimizer, a RuntimeError occurs due to a dtype mismatch.

  • Model parameters and gradients are in bf16.
  • Optimizer state (momentum_buffer) is loaded from the checkpoint as fp32.
  • The mismatch happens when Muon tries to apply updates (e.g., lerp_) between fp32 momentum buffers and bf16 gradients.

Minimal Reproducible Example

import torch
import os
import deepspeed

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def train_step(model_engine, x, y):
    output = model_engine(x)
    loss = ((output - y) ** 2).mean()
    model_engine.backward(loss)
    model_engine.step()

hidden_size = 64
out_size = 1
dtype = torch.bfloat16

# Setup dummy data and model
x = torch.randn(16, hidden_size, dtype=dtype).cuda()
y = torch.randn(16, out_size, dtype=dtype).cuda()
model = torch.nn.Linear(hidden_size, out_size)

ds_config = {
    "bf16": {"enabled": True},
    "zero_optimization": {"stage": 2},
    "optimizer": {"type": "Muon", "params": {"lr": 1e-3}},
    "zero_allow_untested_optimizer": True,
    "train_batch_size": 4,
    "train_micro_batch_size_per_gpu": 1
}

model_engine, optimizer, _, _ = deepspeed.initialize(model=model, config=ds_config)


train_step(model_engine, x, y)

model_engine.save_checkpoint("./test_checkpoint")
model_engine.load_checkpoint("./test_checkpoint")

# Resume training -> trigger error
train_step(model_engine, x, y)

Run command:

deepspeed --include localhost:0,1,2,3 example.py

Error trace:

[rank2]: torch._dynamo.exc.TorchRuntimeError: Failed running call_method lerp_(*(FakeTensor(..., device='cuda:2', size=(1, 64)), FakeTensor(..., device='cuda:2', size=(1, 64), dtype=torch.bfloat16), 0.050000000000000044), **{}):
[rank2]: expected dtype torch.float32 for `end`, but got dtype torch.bfloat16

[rank2]: from user code:
[rank2]:    File "/usr/local/lib/python3.12/dist-packages/deepspeed/runtime/zero/muon/original_muon.py", line 72, in torch_dynamo_resume_in_muon_update_at_71
[rank2]:     momentum.lerp_(grad, 1 - beta)

Environment

  • GPU: 4 H100
  • Docker image: nvidia/pytorch:25.03-py3
  • torch: 2.7.0a0+7c8ec84dab.nv25.3
  • deepspeed: 0.18.3

My Workaround

Debugging showed that momentum_buffer is loaded as an fp32 tensor, which conflicts with bf16 gradients in muon_update. I found that manually converting the buffer to bf16 right after load_checkpoint() fixes the crash:

for tensor_key, values_dict in optimizer.optimizer.state.items():
    for key, tensor_value in values_dict.items():
        if key == "momentum_buffer":
            values_dict[key] = tensor_value.to(dtype=torch.bfloat16)

This resolves the error, but I'm unsure if forcing momentum_buffer to bf16 is numerically stable or the intended behavior, given that optimizer states are typically kept in fp32. Opening this issue to find the proper fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions