-
-
Notifications
You must be signed in to change notification settings - Fork 12k
[Bugfix] Fix NaN issue for Triton FusedMoE LoRA #30585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a critical bug where NaN values could appear in the attention output. The root cause is that the output tensor, allocated with torch.empty(), was not fully initialized, and subsequent attention operations only filled a portion of it up to num_actual_tokens. The added line output[num_actual_tokens:].fill_(0) effectively zeros out the remaining uninitialized part of the tensor, preventing any garbage values or NaNs from propagating. This is a robust and necessary fix. This same pattern of not zeroing out the padded portion of the output tensor may exist in other attention backends, and it would be beneficial to audit them for similar issues to ensure consistent behavior across the system.
|
Great find thanks a lot @xyang16! cc: @jeejeelee, @robertgshaw2-redhat, @varun-sundar-rabindranath |
|
Hi @xyang16, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @xyang16, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
851c552 to
b1e7f12
Compare
|
@jeejeelee - lmk if this looks okay to you |
|
I was able to reproduce this error on my A5500 hardware by running However, instead of applying the fix here, I applied the fix from #30650 and after that all of these tests passed. So, these two PRs and code paths are at least related if not directly intertwined. |
|
I believe there are two separate things in play here. The change to With that said, I tried testing just the changes in this PR on an H100 that's using FLASH_ATTN and Triton MXFP4 kernels and am still seeing the infinite generation: Logs from the failed test showing FLASH_ATTN and Triton MXFP4 backend in use: |
|
@bbrowning Thanks for helping investigate this!
Is there going to be other reason other than skipping custom_all_reduce would make
Yes, I have a note that says: This PR doesn't address the NaN caused by FULL_AND_PIECEWISE cudagraph mode, see #29539 (comment), so need to set cudagraph_mode to PIECEWISE + this PR to make it work. |
Signed-off-by: Xin Yang <[email protected]>
6234130 to
6287d37
Compare
Signed-off-by: Xin Yang <[email protected]>
Purpose
This PR is to fix NaN issue with fused_moe_lora. Currently running
test_gptoss_tp.pywill fail. We found it's because NaN values in tensor causing withtriton_kernelshaving wronggather_indx.outputtensor is created asoutput = torch.empty(output_shape, dtype=output_dtype, device=query.device)in here, which was changed fromoutput = torch.zeros(output_shape, dtype=output_dtype, device=query.device)by remove attn output view kernel #26680.torch.empty()allocates uninitialized memory and may contain NaNs. Then intorch.ops.vllm.unified_attention_with_outputit writesoutput[:num_actual_tokens]here, so output tensor is filled onlynum_actual_tokensrows. We see the case whereoutputshape [4, 64, 64],num_actual_tokensis 3, the last row has NaNs.triton_kernelswill have wronggather_indxif the last few rows have NaNs, below is an example:attn_outputhave a NaN in the last row:hidden_statesis NaNs, after select_experts, the last row oftopk_idsis all 0s:Note: This step also exposes some problem in fused_topk. Because I see
torch.topkis able to generate uniquetopk_ids[1, 0, 2, 3] for NaN tensor. So it would be better for fused_topk to be able to handle NaN tensor and generate uniquetopk_ids, instead of [0, 0, 0, 0]. This could avoid the problem in the later steps.topk_idsto have unique values in each row, instead of [0, 0, 0, 0].gather_indx.gather_indxshould have row index for all 16 rows, but row index 0, 4, 8 are replaced with -1. This means row 0, 4, 8 will be missing in the first matmul_ogs (This will impact non-LoRA as well).gather_indx, the first row inintermediate_cache1also having NaN values.b_intermediate_cache1is added intooutput(outputisintermediate_cache1in the previous step).b_intermediate_cache1is having NaNs in last few rows becausehidden_statesis not reordered,outputis having NaNs in the first row and last rows because it's reordered. So it's adding wrong rows and nowoutputhaving more NaN rows.So this PR fixes the NaN issues by:
output[num_actual_tokens:]with 0 to avoid NaNs in the tensor.output[num_actual_tokens:]with 0, because NaN values in tensor might cause unexpected behaviors in somewhere else too. So please let me know how you think.intermediate_cache1back to make sure LoRA weights is added to the correct rows here.Note: This PR doesn't address the NaN caused by FULL_AND_PIECEWISE cudagraph mode, see #29539 (comment), so need to set cudagraph_mode to PIECEWISE + this PR to make it work.
Test Plan
Tests passed.
Run
test_gptoss_tp.pywith modification:Main:
PR:
Garbage output fixed. Tests passed.
Accuracy Testing
Marlin:
Triton:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.cc @robertgshaw2-redhat @jeejeelee