Skip to content

Add device-side NaN check instrumentation with E8M0 scale factor detection#5

Open
kahyunnam wants to merge 2 commits intoTomerBN-Nvidia:ultra-rlfrom
kahyunnam:e8m0-nan-check-ultra-rl
Open

Add device-side NaN check instrumentation with E8M0 scale factor detection#5
kahyunnam wants to merge 2 commits intoTomerBN-Nvidia:ultra-rlfrom
kahyunnam:e8m0-nan-check-ultra-rl

Conversation

@kahyunnam
Copy link
Copy Markdown

📌 Description

Usage

Set env variable:

FLASHINFER_NAN_CHECK=1 python your_script.py

Example output:

[FLASHINFER_NAN_CHECK] NaN detected in "mxfp8_quantize:input[bf16]" at index 2548 (of 4096)
[FLASHINFER_NAN_CHECK] E8M0 NaN scale (0xFF) in "mxfp8_gemm:mat2Scale[e8m0]" at index 126 (of 524288)

Summary

Adds opt-in device-side NaN detection at the C++ binding layer for mxfp8_quantize, mxfp8_gemm, and fused MoE kernels
Checks both FP8 data tensors (E4M3 NaN = 0x7F) and UE8M0 scale factor tensors (E8M0 NaN = 0xFF per OCP MX spec)
Uses device-side kernels that get captured into CUDA graphs, enabling NaN detection during graph replay
Zero overhead when disabled (env var not set)

Motivation

Debugging NaN issues in the MXFP8 pipeline reported in nvbug/6084190 and nvbug/6062022. Analysis of Oleg's tensor dump showed that weight_scale contained 1,056 bytes with value 0xFF, which per the OCP MX spec encodes NaN in E8M0 format. The GEMM then computes 0 * inf = NaN since 2^(255-127) overflows FP32 to infinity.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Adds opt-in NaN detection at the C++ binding layer for mxfp8_quantize,
mxfp8_gemm, and fused MoE kernels. The checks use device-side kernels
that get captured into CUDA graphs, enabling NaN detection during graph
replay — addressing the debugging gap where host-side checks (like
torch.isnan) don't execute during CUDA graph replay.

Controlled by environment variables:
- FLASHINFER_NAN_CHECK=1: enables input/output NaN scanning (device printf)
- FLASHINFER_NAN_CHECK_TRAP=1: additionally calls __trap() on first NaN

Zero overhead when disabled (env var not set).

Made-with: Cursor
Per the OCP MX spec, E8M0 byte value 255 (0xFF) encodes NaN, not 2^128.
When weight scale factors contain 0xFF, the GEMM produces NaN output via
0 * inf = NaN (since 2^128 overflows FP32 to infinity).

This adds checks on UE8M0 scale factor tensors at:
- mxfp8_quantize output scales
- mxfp8_gemm input scales (mat1Scale, mat2Scale)
- CUTLASS fused MoE input scale factors (input_sf)
- TRT-LLM fused MoE scale factors (hidden_states_scale, weight scales)

The message distinguishes E8M0 NaN from FP8 data NaN:
  [FLASHINFER_NAN_CHECK] E8M0 NaN scale (0xFF) in "mxfp8_gemm:mat2Scale[e8m0]" ...

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant