fix(lora): use float32 intermediate buffer in fused MoE LoRA to prevent bf16 precision loss#38686
fix(lora): use float32 intermediate buffer in fused MoE LoRA to prevent bf16 precision loss#38686prsabahrami wants to merge 1 commit intovllm-project:mainfrom
Conversation
…nt bf16 precision loss The fused MoE LoRA path in _fused_moe_lora() allocated the intermediate buffer between the shrink (lora_a) and expand (lora_b) kernels with dtype=output.dtype. For bf16 models like gpt-oss-120b, this means the float32 accumulator result from the shrink kernel gets truncated to bfloat16 before being passed to the expand kernel. This precision loss compounds across 128 experts and 36 layers in MoE models, causing hallucinated/incoherent output after LoRA adapter merging -- even though the merge operation itself completes without error. The non-MoE LoRA path in punica_gpu.py correctly uses float32 for the intermediate buffer, and this fix aligns the MoE path to match. Additionally, the Triton kernel hardcoded tl.bfloat16 for dot product operand casting, which: 1. Discards precision from the now-float32 intermediate buffer in the expand kernel path 2. Incorrectly handles fp16 models by casting to the wrong dtype This is fixed by using c_ptr.dtype.element_ty, which adapts to the actual output dtype of the kernel. Fixes: hallucinated output when merging LoRA adapters into MoE models in bf16 precision.
There was a problem hiding this comment.
Code Review
This pull request addresses precision loss in the fused MoE LoRA implementation by transitioning the intermediate buffer from bfloat16 to float32 and removing hardcoded bfloat16 casts in the Triton kernel. A comprehensive test suite is included to verify the fix and demonstrate the impact of precision loss. Feedback highlights a critical issue where using the intermediate buffer's float32 type for dot product inputs in the shrink path leads to incorrect pointer casting and performance degradation. It is recommended to dynamically determine the 16-bit compute dtype based on the kernel phase to ensure correct memory access and Tensor Core acceleration.
| dot_dtype = c_ptr.dtype.element_ty | ||
| accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype)) |
There was a problem hiding this comment.
Using c_ptr.dtype.element_ty as the dot_dtype is problematic because c_ptr refers to the intermediate buffer in the shrink path, which is now float32. This has two major consequences:
- Incorrect Pointer Casting: Although not shown in this diff, line 296 casts the weight pointer
cur_b_ptrusingc_ptr.dtype.element_ty. In the shrink path, this will cast it tofloat32*, causingtl.loadto read 4 bytes per element from memory containing 2-byte weights (bf16/fp16), leading to corrupted data and potential out-of-bounds access. - Performance/Compatibility: Casting
tl.dotinputs tofloat32(whendot_dtypeisfloat32) will bypass Tensor Core acceleration on most GPUs and may fail if the Triton backend doesn't supportfp32inputs for thedotop.
You should determine the 16-bit compute dtype based on whether the kernel is in the shrink or expand phase (e.g., using IS_PRIMARY). Note that you must also update line 296 to use this correct dtype for the weight pointer casting.
| dot_dtype = c_ptr.dtype.element_ty | |
| accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype)) | |
| # Use the 16-bit model/weight dtype for the dot product inputs. | |
| # In the shrink path (IS_PRIMARY=True), c_ptr is the fp32 intermediate buffer, | |
| # so we use a_ptr's dtype. In the expand path, c_ptr is the 16-bit output. | |
| dot_dtype = a_ptr.dtype.element_ty if IS_PRIMARY else c_ptr.dtype.element_ty | |
| accumulator += tl.dot(a.to(dot_dtype), b.to(dot_dtype)) |
Summary
Fix bf16 precision loss in the fused MoE LoRA path that causes hallucinated/incoherent output when merging LoRA adapters into MoE models (e.g.,
openai/gpt-oss-120b).Root Cause Analysis
Two precision bugs in
vllm/lora/ops/triton_ops/fused_moe_lora_op.py:Bug 1 (Primary): Intermediate buffer uses output dtype instead of float32
The
_fused_moe_lora()function allocates the intermediate buffer between the shrink (lora_a) and expand (lora_b) Triton kernels withdtype=output.dtype:For bf16 models, this means the float32 accumulator result from the shrink kernel gets truncated to bfloat16 before being passed to the expand kernel.
The non-MoE LoRA path (
punica_gpu.py:add_lora_linear) correctly usestorch.float32:This precision loss compounds across 128 experts x 36 layers in gpt-oss, causing hallucinated output.
Bug 2 (Secondary): Hardcoded
tl.bfloat16cast in Triton kernelThe fused MoE LoRA kernel hardcodes
tl.bfloat16for dot product operand casting:This:
Fix
Reproduction Steps
openai/gpt-oss-120bin bf16 precision with vLLMjeeejeee/gpt-oss-20b-lora-adapter-text2sql)The merge operation completes without error, making this a silent precision bug.
Before/After Comparison
Before (with bug): The intermediate buffer between shrink and expand kernels is bf16. The float32 accumulator from
hidden_states @ lora_agets truncated to bf16, and these errors compound across 128 experts per layer and 36 layers. Result: hallucinated, incoherent text.After (with fix): The intermediate buffer is float32 (matching the non-MoE path in
punica_gpu.py). Full precision is preserved through the shrink-expand pipeline. Result: coherent, correct output.Test Coverage
Added
tests/lora/test_moe_lora_bf16_precision.pywith three tests:test_fused_moe_lora_intermediate_buffer_is_float32— Verifies the intermediate buffer uses float32test_fused_moe_lora_kernel_no_hardcoded_bfloat16— Verifies no hardcoded tl.bfloat16 caststest_bf16_precision_loss_in_matmul_chain— Demonstrates the numeric impact of bf16 intermediate truncation in a LoRA-like matmul chainFiles Changed
vllm/lora/ops/triton_ops/fused_moe_lora_op.py— Fix intermediate buffer dtype and kernel casttests/lora/test_moe_lora_bf16_precision.py— New test file (3 tests)