Add device-side NaN check instrumentation with E8M0 scale factor detection#5
Open
kahyunnam wants to merge 2 commits intoTomerBN-Nvidia:ultra-rlfrom
Open
Add device-side NaN check instrumentation with E8M0 scale factor detection#5kahyunnam wants to merge 2 commits intoTomerBN-Nvidia:ultra-rlfrom
kahyunnam wants to merge 2 commits intoTomerBN-Nvidia:ultra-rlfrom
Conversation
Adds opt-in NaN detection at the C++ binding layer for mxfp8_quantize, mxfp8_gemm, and fused MoE kernels. The checks use device-side kernels that get captured into CUDA graphs, enabling NaN detection during graph replay — addressing the debugging gap where host-side checks (like torch.isnan) don't execute during CUDA graph replay. Controlled by environment variables: - FLASHINFER_NAN_CHECK=1: enables input/output NaN scanning (device printf) - FLASHINFER_NAN_CHECK_TRAP=1: additionally calls __trap() on first NaN Zero overhead when disabled (env var not set). Made-with: Cursor
Per the OCP MX spec, E8M0 byte value 255 (0xFF) encodes NaN, not 2^128. When weight scale factors contain 0xFF, the GEMM produces NaN output via 0 * inf = NaN (since 2^128 overflows FP32 to infinity). This adds checks on UE8M0 scale factor tensors at: - mxfp8_quantize output scales - mxfp8_gemm input scales (mat1Scale, mat2Scale) - CUTLASS fused MoE input scale factors (input_sf) - TRT-LLM fused MoE scale factors (hidden_states_scale, weight scales) The message distinguishes E8M0 NaN from FP8 data NaN: [FLASHINFER_NAN_CHECK] E8M0 NaN scale (0xFF) in "mxfp8_gemm:mat2Scale[e8m0]" ... Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
📌 Description
Usage
Set env variable:
Example output:
Summary
Adds opt-in device-side NaN detection at the C++ binding layer for mxfp8_quantize, mxfp8_gemm, and fused MoE kernels
Checks both FP8 data tensors (E4M3 NaN = 0x7F) and UE8M0 scale factor tensors (E8M0 NaN = 0xFF per OCP MX spec)
Uses device-side kernels that get captured into CUDA graphs, enabling NaN detection during graph replay
Zero overhead when disabled (env var not set)
Motivation
Debugging NaN issues in the MXFP8 pipeline reported in nvbug/6084190 and nvbug/6062022. Analysis of Oleg's tensor dump showed that weight_scale contained 1,056 bytes with value 0xFF, which per the OCP MX spec encodes NaN in E8M0 format. The GEMM then computes 0 * inf = NaN since 2^(255-127) overflows FP32 to infinity.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes