-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Open
Labels
Description
Bug Description
When resuming training from a checkpoint using bf16 and the Muon optimizer, a RuntimeError occurs due to a dtype mismatch.
- Model parameters and gradients are in
bf16. - Optimizer state (
momentum_buffer) is loaded from the checkpoint asfp32. - The mismatch happens when Muon tries to apply updates (e.g.,
lerp_) betweenfp32momentum buffers andbf16gradients.
Minimal Reproducible Example
import torch
import os
import deepspeed
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def train_step(model_engine, x, y):
output = model_engine(x)
loss = ((output - y) ** 2).mean()
model_engine.backward(loss)
model_engine.step()
hidden_size = 64
out_size = 1
dtype = torch.bfloat16
# Setup dummy data and model
x = torch.randn(16, hidden_size, dtype=dtype).cuda()
y = torch.randn(16, out_size, dtype=dtype).cuda()
model = torch.nn.Linear(hidden_size, out_size)
ds_config = {
"bf16": {"enabled": True},
"zero_optimization": {"stage": 2},
"optimizer": {"type": "Muon", "params": {"lr": 1e-3}},
"zero_allow_untested_optimizer": True,
"train_batch_size": 4,
"train_micro_batch_size_per_gpu": 1
}
model_engine, optimizer, _, _ = deepspeed.initialize(model=model, config=ds_config)
train_step(model_engine, x, y)
model_engine.save_checkpoint("./test_checkpoint")
model_engine.load_checkpoint("./test_checkpoint")
# Resume training -> trigger error
train_step(model_engine, x, y)Run command:
deepspeed --include localhost:0,1,2,3 example.pyError trace:
[rank2]: torch._dynamo.exc.TorchRuntimeError: Failed running call_method lerp_(*(FakeTensor(..., device='cuda:2', size=(1, 64)), FakeTensor(..., device='cuda:2', size=(1, 64), dtype=torch.bfloat16), 0.050000000000000044), **{}):
[rank2]: expected dtype torch.float32 for `end`, but got dtype torch.bfloat16
[rank2]: from user code:
[rank2]: File "/usr/local/lib/python3.12/dist-packages/deepspeed/runtime/zero/muon/original_muon.py", line 72, in torch_dynamo_resume_in_muon_update_at_71
[rank2]: momentum.lerp_(grad, 1 - beta)
Environment
- GPU:
4 H100 - Docker image:
nvidia/pytorch:25.03-py3 - torch:
2.7.0a0+7c8ec84dab.nv25.3 - deepspeed:
0.18.3
My Workaround
Debugging showed that momentum_buffer is loaded as an fp32 tensor, which conflicts with bf16 gradients in muon_update. I found that manually converting the buffer to bf16 right after load_checkpoint() fixes the crash:
for tensor_key, values_dict in optimizer.optimizer.state.items():
for key, tensor_value in values_dict.items():
if key == "momentum_buffer":
values_dict[key] = tensor_value.to(dtype=torch.bfloat16)This resolves the error, but I'm unsure if forcing momentum_buffer to bf16 is numerically stable or the intended behavior, given that optimizer states are typically kept in fp32. Opening this issue to find the proper fix.