-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
When enabling bf16 mixed precision in DeepSpeed, all nn.Parameters are forcibly cast to torch.bfloat16 during initialization.
There is currently no supported way to keep specific parameters in float32.
This causes runtime dtype mismatch errors in modules that intentionally require fp32 computation (e.g., MoE routers or other numerically sensitive control-flow logic).
To Reproduce
- Define a module with a parameter explicitly initialized in
float32:
class Ernie4_5_VL_MoeMoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32))
self.moe_statics = Ernie4_5_VL_MoeMoeStatics(config)
self.top_k = config.moe_k
self.norm_min = config.moe_norm_min
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
# key point #
router_logits = F.linear(hidden_states.float(), self.weight)
# # # # # # #
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(hidden_states.dtype)
return router_logits, selected_experts, routing_weights- Enable bf16 in DeepSpeed config:
{
"bf16": {
"enabled": true
}
}- In the forward pass, disable autocast and perform fp32 computation:
with torch.cuda.amp.autocast(enabled=False):
out = F.linear(hidden_states.float(), self.weight)-
Launch training with DeepSpeed.
-
Observe the following error:
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
This happens because self.weight is cast to bfloat16 by DeepSpeed despite being initialized as float32.
Expected behavior
It should be possible to prevent specific parameters from being cast to bf16, or to opt out individual parameters or modules from DeepSpeed’s mixed-precision parameter casting.
At minimum, an official mechanism or documented workaround for parameter-level precision control would be helpful for numerically sensitive modules such as MoE routers.
ds_report output
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
Screenshots
System info (please complete the following information):
- OS: Ubuntu 20.04
- GPU count and types: 8 × NVIDIA A100 80GB
- Interconnects: NVLink + InfiniBand
- Python version: 3.10
- PyTorch version: 2.x
- DeepSpeed version: 0.x.x
- CUDA version: 12.x
Launcher context
Launched using the deepspeed launcher.
Docker context
Not using Docker. Running in a bare-metal environment.
Additional context
This behavior is especially problematic for MoE router implementations, where fp32 computation is often required for numerical stability (softmax + top-k routing).
Currently, the only workaround is to explicitly cast parameters back to float32 inside the forward pass, which works but feels like a workaround rather than a supported solution.