Skip to content

fix(lora): use float32 intermediate buffer in fused MoE LoRA to prevent bf16 precision loss#38686

Draft
prsabahrami wants to merge 1 commit intovllm-project:mainfrom
prsabahrami:fix/bf16-moe-lora-precision
Draft

fix(lora): use float32 intermediate buffer in fused MoE LoRA to prevent bf16 precision loss#38686
prsabahrami wants to merge 1 commit intovllm-project:mainfrom
prsabahrami:fix/bf16-moe-lora-precision

Conversation

@prsabahrami
Copy link
Copy Markdown

Summary

Fix bf16 precision loss in the fused MoE LoRA path that causes hallucinated/incoherent output when merging LoRA adapters into MoE models (e.g., openai/gpt-oss-120b).

Root Cause Analysis

Two precision bugs in vllm/lora/ops/triton_ops/fused_moe_lora_op.py:

Bug 1 (Primary): Intermediate buffer uses output dtype instead of float32

The _fused_moe_lora() function allocates the intermediate buffer between the shrink (lora_a) and expand (lora_b) Triton kernels with dtype=output.dtype:

# BEFORE (buggy):
a_intermediate_cache1 = torch.zeros(
    intermediate_cache_shape,
    dtype=output.dtype,  # bf16 for bf16 models!
    device=device,
)

For bf16 models, this means the float32 accumulator result from the shrink kernel gets truncated to bfloat16 before being passed to the expand kernel.

The non-MoE LoRA path (punica_gpu.py:add_lora_linear) correctly uses torch.float32:

buffer = torch.empty(..., dtype=torch.float32, device=x.device)
# Comment: "We set the buffer to be float32 by default"

This precision loss compounds across 128 experts x 36 layers in gpt-oss, causing hallucinated output.

Bug 2 (Secondary): Hardcoded tl.bfloat16 cast in Triton kernel

The fused MoE LoRA kernel hardcodes tl.bfloat16 for dot product operand casting:

# BEFORE (buggy):
accumulator += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))

This:

  1. Discards precision from the now-float32 intermediate buffer in the expand path
  2. Incorrectly handles fp16 models by casting weights to the wrong dtype

Fix

# AFTER (fix for Bug 1):
a_intermediate_cache1 = torch.zeros(
    intermediate_cache_shape,
    dtype=torch.float32,  # Match non-MoE LoRA path
    device=device,
)

# AFTER (fix for Bug 2):
dot_dtype = c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))

Reproduction Steps

  1. Load openai/gpt-oss-120b in bf16 precision with vLLM
  2. Load a LoRA adapter (e.g., jeeejeee/gpt-oss-20b-lora-adapter-text2sql)
  3. Merge the adapter into the base model weights
  4. Generate text — output is hallucinated/incoherent

The merge operation completes without error, making this a silent precision bug.

Before/After Comparison

Before (with bug): The intermediate buffer between shrink and expand kernels is bf16. The float32 accumulator from hidden_states @ lora_a gets truncated to bf16, and these errors compound across 128 experts per layer and 36 layers. Result: hallucinated, incoherent text.

After (with fix): The intermediate buffer is float32 (matching the non-MoE path in punica_gpu.py). Full precision is preserved through the shrink-expand pipeline. Result: coherent, correct output.

Test Coverage

Added tests/lora/test_moe_lora_bf16_precision.py with three tests:

  1. test_fused_moe_lora_intermediate_buffer_is_float32 — Verifies the intermediate buffer uses float32
  2. test_fused_moe_lora_kernel_no_hardcoded_bfloat16 — Verifies no hardcoded tl.bfloat16 casts
  3. test_bf16_precision_loss_in_matmul_chain — Demonstrates the numeric impact of bf16 intermediate truncation in a LoRA-like matmul chain

Files Changed

  • vllm/lora/ops/triton_ops/fused_moe_lora_op.py — Fix intermediate buffer dtype and kernel cast
  • tests/lora/test_moe_lora_bf16_precision.py — New test file (3 tests)

…nt bf16 precision loss

The fused MoE LoRA path in _fused_moe_lora() allocated the intermediate
buffer between the shrink (lora_a) and expand (lora_b) kernels with
dtype=output.dtype. For bf16 models like gpt-oss-120b, this means the
float32 accumulator result from the shrink kernel gets truncated to
bfloat16 before being passed to the expand kernel.

This precision loss compounds across 128 experts and 36 layers in MoE
models, causing hallucinated/incoherent output after LoRA adapter
merging -- even though the merge operation itself completes without
error.

The non-MoE LoRA path in punica_gpu.py correctly uses float32 for
the intermediate buffer, and this fix aligns the MoE path to match.

Additionally, the Triton kernel hardcoded tl.bfloat16 for dot product
operand casting, which:
1. Discards precision from the now-float32 intermediate buffer in the
   expand kernel path
2. Incorrectly handles fp16 models by casting to the wrong dtype

This is fixed by using c_ptr.dtype.element_ty, which adapts to the
actual output dtype of the kernel.

Fixes: hallucinated output when merging LoRA adapters into MoE models
in bf16 precision.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses precision loss in the fused MoE LoRA implementation by transitioning the intermediate buffer from bfloat16 to float32 and removing hardcoded bfloat16 casts in the Triton kernel. A comprehensive test suite is included to verify the fix and demonstrate the impact of precision loss. Feedback highlights a critical issue where using the intermediate buffer's float32 type for dot product inputs in the shrink path leads to incorrect pointer casting and performance degradation. It is recommended to dynamically determine the 16-bit compute dtype based on the kernel phase to ensure correct memory access and Tensor Core acceleration.

Comment on lines +389 to +390
dot_dtype = c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using c_ptr.dtype.element_ty as the dot_dtype is problematic because c_ptr refers to the intermediate buffer in the shrink path, which is now float32. This has two major consequences:

  1. Incorrect Pointer Casting: Although not shown in this diff, line 296 casts the weight pointer cur_b_ptr using c_ptr.dtype.element_ty. In the shrink path, this will cast it to float32*, causing tl.load to read 4 bytes per element from memory containing 2-byte weights (bf16/fp16), leading to corrupted data and potential out-of-bounds access.
  2. Performance/Compatibility: Casting tl.dot inputs to float32 (when dot_dtype is float32) will bypass Tensor Core acceleration on most GPUs and may fail if the Triton backend doesn't support fp32 inputs for the dot op.

You should determine the 16-bit compute dtype based on whether the kernel is in the shrink or expand phase (e.g., using IS_PRIMARY). Note that you must also update line 296 to use this correct dtype for the weight pointer casting.

Suggested change
dot_dtype = c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))
# Use the 16-bit model/weight dtype for the dot product inputs.
# In the shrink path (IS_PRIMARY=True), c_ptr is the fp32 intermediate buffer,
# so we use a_ptr's dtype. In the expand path, c_ptr is the 16-bit output.
dot_dtype = a_ptr.dtype.element_ty if IS_PRIMARY else c_ptr.dtype.element_ty
accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant