Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions tests/lora/test_moe_lora_bf16_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that the fused MoE LoRA intermediate buffer uses float32
to prevent precision loss that causes hallucinated output.

This test verifies the fix for the bf16 precision bug in
vllm/lora/ops/triton_ops/fused_moe_lora_op.py where the intermediate
buffer between the shrink (lora_a) and expand (lora_b) kernels was
incorrectly allocated with output.dtype (bf16) instead of float32.
"""

import pytest
import torch

from vllm.platforms import current_platform


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Requires CUDA GPU",
)
def test_fused_moe_lora_intermediate_buffer_is_float32():
"""Verify the intermediate buffer in fused MoE LoRA uses float32.

The non-MoE LoRA path (punica_gpu.py) explicitly uses float32 for the
intermediate buffer between shrink and expand operations. The fused MoE
path should do the same to prevent precision loss that compounds across
experts and layers, leading to hallucinated outputs in MoE models.
"""
# Import the function that creates the intermediate buffer
from vllm.lora.ops.triton_ops.fused_moe_lora_op import _fused_moe_lora

import inspect
source = inspect.getsource(_fused_moe_lora)

# Verify the intermediate buffer uses float32
assert "dtype=torch.float32" in source, (
"fused_moe_lora intermediate buffer must use torch.float32 "
"to match the non-MoE LoRA path and prevent precision loss. "
"Found dtype=output.dtype which causes bf16 truncation."
)


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Requires CUDA GPU",
)
def test_fused_moe_lora_kernel_no_hardcoded_bfloat16():
"""Verify the fused MoE LoRA kernel does not hardcode tl.bfloat16.

The kernel should use the output element type for dot product casting
rather than hardcoding bfloat16, which would:
1. Fail to handle fp16 models correctly
2. Discard precision from the float32 intermediate buffer
"""
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
_fused_moe_lora_kernel,
)

import inspect
source = inspect.getsource(_fused_moe_lora_kernel.fn)

# The kernel should NOT contain hardcoded tl.bfloat16 casts
assert "a.to(tl.bfloat16)" not in source, (
"fused_moe_lora_kernel should not hardcode tl.bfloat16. "
"Use c_ptr.dtype.element_ty to handle all dtypes correctly."
)
assert "b.to(tl.bfloat16)" not in source, (
"fused_moe_lora_kernel should not hardcode tl.bfloat16. "
"Use c_ptr.dtype.element_ty to handle all dtypes correctly."
)


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Requires CUDA GPU",
)
def test_bf16_precision_loss_in_matmul_chain():
"""Demonstrate that bf16 intermediate truncation causes precision loss.

This test shows the numeric impact of the bug: when a float32 matmul
result is truncated to bf16 before a second matmul, the final result
diverges from the float32 reference, especially for small LoRA ranks
typical in MoE models.
"""
torch.manual_seed(42)
device = "cuda"
num_tokens = 32
hidden_size = 2880 # gpt_oss hidden size
rank = 8 # typical LoRA rank

# Simulate LoRA computation: output = (hidden @ lora_a) @ lora_b
hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16,
device=device)
lora_a = torch.randn(hidden_size, rank, dtype=torch.bfloat16,
device=device) * 0.01
lora_b = torch.randn(rank, hidden_size, dtype=torch.bfloat16,
device=device) * 0.01

# Reference: full float32 intermediate (correct behavior)
intermediate_f32 = torch.matmul(hidden.float(), lora_a.float())
result_f32 = torch.matmul(intermediate_f32, lora_b.float()).bfloat16()

# Bug path: bf16 intermediate (truncated)
intermediate_bf16 = torch.matmul(
hidden.float(), lora_a.float()
).bfloat16()
result_bf16 = torch.matmul(
intermediate_bf16.float(), lora_b.float()
).bfloat16()

# Compute relative error
abs_diff = (result_f32.float() - result_bf16.float()).abs()
rel_error = abs_diff / (result_f32.float().abs() + 1e-8)
max_rel_error = rel_error.max().item()
mean_rel_error = rel_error.mean().item()

print(f"Max relative error from bf16 intermediate: {max_rel_error:.6f}")
print(f"Mean relative error from bf16 intermediate: {mean_rel_error:.6f}")

# The bf16 intermediate should produce measurably different results
# A typical threshold: relative error > 0.001 (0.1%) indicates
# non-trivial precision loss
assert max_rel_error > 0.0, (
"Expected some precision difference between float32 and bf16 "
"intermediate paths"
)
# Note: for a single layer this might be small, but it compounds
# across 128 experts x 36 layers in gpt_oss
14 changes: 12 additions & 2 deletions vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,11 @@ def _fused_moe_lora_kernel(
# compiler may infer different types for a and b when merging
# if/else branches (TMA desc path returns fp32, tl.load returns
# the pointer's element type).
accumulator += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
# Use the output element type rather than hardcoding bfloat16
# so that fp16 models are handled correctly and precision from
# a float32 intermediate buffer is not discarded prematurely.
dot_dtype = c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))
Comment on lines +389 to +390
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using c_ptr.dtype.element_ty as the dot_dtype is problematic because c_ptr refers to the intermediate buffer in the shrink path, which is now float32. This has two major consequences:

  1. Incorrect Pointer Casting: Although not shown in this diff, line 296 casts the weight pointer cur_b_ptr using c_ptr.dtype.element_ty. In the shrink path, this will cast it to float32*, causing tl.load to read 4 bytes per element from memory containing 2-byte weights (bf16/fp16), leading to corrupted data and potential out-of-bounds access.
  2. Performance/Compatibility: Casting tl.dot inputs to float32 (when dot_dtype is float32) will bypass Tensor Core acceleration on most GPUs and may fail if the Triton backend doesn't support fp32 inputs for the dot op.

You should determine the 16-bit compute dtype based on whether the kernel is in the shrink or expand phase (e.g., using IS_PRIMARY). Note that you must also update line 296 to use this correct dtype for the weight pointer casting.

Suggested change
dot_dtype = c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))
# Use the 16-bit model/weight dtype for the dot product inputs.
# In the shrink path (IS_PRIMARY=True), c_ptr is the fp32 intermediate buffer,
# so we use a_ptr's dtype. In the expand path, c_ptr is the 16-bit output.
dot_dtype = a_ptr.dtype.element_ty if IS_PRIMARY else c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))


if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
Expand Down Expand Up @@ -772,7 +776,13 @@ def _fused_moe_lora(

a_intermediate_cache1 = torch.zeros(
intermediate_cache_shape,
dtype=output.dtype,
# Use float32 for the intermediate buffer to avoid precision loss
# when accumulating lora_a results, matching the non-MoE LoRA path
# in punica_gpu.py which explicitly uses float32. Using output.dtype
# (e.g. bfloat16) causes the float32 accumulator in the shrink
# kernel to be truncated, and the resulting precision loss compounds
# across experts and layers, leading to hallucinated outputs.
dtype=torch.float32,
device=device,
)

Expand Down
Loading