Skip to content

[JAX] MXFP8 Grouped Quantize V2#2760

Draft
jberchtold-nvidia wants to merge 17 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/grouped-mxfp8-quantize
Draft

[JAX] MXFP8 Grouped Quantize V2#2760
jberchtold-nvidia wants to merge 17 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/grouped-mxfp8-quantize

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 16 commits March 9, 2026 15:42
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 13, 2026 15:43
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 13, 2026

Greptile Summary

This PR refactors the JAX grouped GEMM stack to replace the old flat group_sizes/scalar M/N/K/is_grouped_dense_wgrad interface with a richer N-D first_dims/last_dims per-tensor descriptor model, and adds a MXFP8 CUDA-graph-safe quantization path (te_grouped_quantize_v2_ffi). The new design lets each operand carry its own ragged-dimension bookkeeping, removes hardcoded shape math from C++ FFI handlers, and adds an SM100+ guard so the V2 grouped-GEMM kernel is only used on Blackwell+ hardware.

Key changes:

  • New GroupedNoScaleTensor pytree class for unquantized grouped operands; GroupedScaledTensor1x renamed group_sizesfirst_dims/last_dims
  • grouped_gemm() signature changed: group_sizes positional arg removed; callers must now wrap raw arrays in GroupedNoScaleTensor
  • C++ V2 FFI handler rebuilt around N-D buffer dimensions and a partitioned int64_workspace — each ragged-dim buffer gets its own non-aliasing slot
  • GroupedGemmFFI (V1) derives m/n/k from buffer dims + lhs/rhs_axis_boundary instead of passed-in scalars
  • ScaledTensorFactory updated throughout to thread first_dims/last_dims in place of group_sizes

Issues found:

  • In grouped_gemm(), when lhs is GroupedNoScaleTensor and rhs is GroupedScaledTensor1x, scaling_mode remains NO_SCALING with no error raised, which would produce silently wrong results for that combination
  • The bias shape validation was changed from raise ValueError to assert, which can be disabled with python -O
  • ScaledTensorFactory.create_1x's new condition routes any call with original_shape (and the default group_axis=0) into the GroupedScaledTensor1x path, which is broader than intended and could silently affect future callers

Confidence Score: 3/5

  • Needs fixes before merge: a silent scaling-mode mismatch and a regressed assert-based validation are correctness risks.
  • The core refactor is logically sound and the C++ changes are well-structured, but two issues in gemm.py lower confidence: (1) the mixed lhs/rhs type path leaves scaling_mode as NO_SCALING when rhs carries real scales, which would silently produce wrong output if that combination is ever reached; (2) the bias shape check was downgraded from raise ValueError to assert, making it bypassable in optimised mode. The overly broad ScaledTensorFactory.create_1x condition adds latent risk. None of the existing callers exercise the broken mixed-type path today, but the lack of a guard or clear error means it can break silently in future work.
  • transformer_engine/jax/cpp_extensions/gemm.py (scaling_mode propagation and assert regression), transformer_engine/jax/quantize/tensor.py (ScaledTensorFactory condition)

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Major refactor of GroupedGemmPrimitive and grouped_gemm(): replaces scalar M/N/K attrs and group_sizes with N-D lhs/rhs/out_first_dims/last_dims buffers; removes is_grouped_dense_wgrad flag; adds SM100+ guard for V2 FFI path. Two issues found: (1) assert used instead of raise ValueError for bias shape check (regression from prior code); (2) scaling_mode is left as NO_SCALING when lhs is GroupedNoScaleTensor and rhs is GroupedScaledTensor1x, with no type error raised to guard against the invalid combination.
transformer_engine/jax/quantize/tensor.py Renames group_sizes → first_dims/last_dims in GroupedScaledTensor1x; adds new GroupedNoScaleTensor pytree class; updates ScaledTensorFactory.create_1x/create_2x/create. The create_1x condition is now overly broad: any call passing original_shape (with default group_axis=0) silently creates GroupedScaledTensor1x, which is a potential footgun for future callers.
transformer_engine/jax/csrc/extensions/gemm.cpp Replaces old GroupedGemmV2FFI (with hardcoded m/n/k and is_grouped_dense_wgrad) with a cleaner N-D version using first_dims/last_dims buffers; adds a new make_grouped_tensor overload that derives tensor shapes from XLA buffer dimensions with proper int64 workspace partitioning. Logic is consistent with the Python changes.
transformer_engine/jax/cpp_extensions/quantization.py Adds MXFP8 V2 quantization path (te_grouped_quantize_v2_ffi) with separate abstract/lowering branches; returns int64_workspace (not amax) as the 5th output for MXFP8, with a dummy zero amax substituted at the outer interface for backward compatibility. Intentional design, but the dummy amax substitution is an implicit contract that should be kept in sync with callers.
transformer_engine/jax/dense.py Updates _grouped_dense_fwd_rule and _grouped_dense_bwd_rule to wrap raw arrays with GroupedNoScaleTensor when is_noop_quantizer_set=True; removes group_sizes positional arg from grouped_gemm calls; correctly wraps wgrad_x_T, wgrad_grad, dgrad_grad and dgrad_kernel_T with appropriate first_dims.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["grouped_gemm(lhs, rhs, ...)"] --> B{lhs type?}
    B -->|GroupedNoScaleTensor| C["scaling_mode = NO_SCALING\nlhs_first/last_dims from lhs"]
    B -->|GroupedScaledTensor1x| D["scaling_mode = lhs.scaling_mode\nlhs_first/last_dims from lhs"]
    B -->|other| E["raise TypeError"]

    C --> F{rhs type?}
    D --> F

    F -->|GroupedNoScaleTensor| G["rhs_first/last_dims from rhs"]
    F -->|GroupedScaledTensor1x| H["rhs_first/last_dims from rhs\nCheck scaling_mode iff lhs also scaled"]
    F -->|other| I["raise TypeError"]

    G --> J["Infer out_first/last_dims"]
    H --> J

    J --> K{quantizer_set?}
    K -->|noop| L["Use raw data as-is"]
    K -->|FP8/MXFP8| M["grouped_quantize lhs & rhs"]

    L --> N["Compute lhs/rhs_axis_boundary"]
    M --> N

    N --> O{_can_use_v2_grouped_gemm?}
    O -->|NO_SCALING + BF16 + SM100+| P["GroupedGemmV2FFI\n(CUDA-graph safe)\nalpha/beta buffers"]
    O -->|otherwise| Q["GroupedGemmFFI\n(legacy, FP8/MXFP8 support)\ngroup_offset buffer"]

    P --> R["Output tensor"]
    Q --> R

    style C fill:#f9f,stroke:#333
    style H fill:#f9f,stroke:#333
Loading

Last reviewed commit: 269a518

Comment on lines +2208 to 2211
num_gemms,
N_dim,
), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}"
bias = jnp.empty((), jnp.float32) if bias is None else bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert replaces previous raise ValueError for bias shape check

The prior code used a raise ValueError(...) for this validation. The new assert statement can be silently disabled when Python is run in optimized mode (python -O), causing the check to be skipped entirely at runtime. This is a correctness regression from the original code.

Suggested change
num_gemms,
N_dim,
), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}"
bias = jnp.empty((), jnp.float32) if bias is None else bias
if bias.shape != (num_gemms, N_dim):
raise ValueError(
f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}"
)

Comment on lines +2049 to +2066
rhs_data = rhs.data.reshape(rhs_shape)
rhs_scale_inv = rhs.scale_inv
if lhs.scaling_mode != rhs.scaling_mode:
rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs
rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs
if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode:
raise ValueError(
f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode},"
f" rhs.scaling_mode={rhs.scaling_mode}"
)
scaling_mode = lhs.scaling_mode
if isinstance(lhs, GroupedScaledTensor1x):
scaling_mode = lhs.scaling_mode
else:
raise TypeError("Unsupported lhs type object!")
raise TypeError(
f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}"
)

# Infer output dims from which operand has the ragged non-contracting dim.
if rhs_first_dims.size > 0 or rhs_last_dims.size > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaling_mode not propagated when lhs is unscaled and rhs is scaled

When lhs is GroupedNoScaleTensor, scaling_mode is set to ScalingMode.NO_SCALING (line 2023). The rhs branch below only updates scaling_mode via if isinstance(lhs, GroupedScaledTensor1x): scaling_mode = lhs.scaling_mode, which is False in the mixed case. As a result, if rhs is a GroupedScaledTensor1x (with real scale_inv values), the operation is dispatched to C++ as NO_SCALING despite the rhs carrying scale information — this would produce silently incorrect output.

The complementary scaling-mode mismatch check (if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode) also only fires for the lhs-scaled case, so no error is raised for the mixed combination either.

If the intent is that GroupedNoScaleTensor + GroupedScaledTensor1x is simply an unsupported combination, an explicit TypeError should be raised:

elif isinstance(rhs, GroupedScaledTensor1x):
    ...
    if isinstance(lhs, GroupedNoScaleTensor):
        raise TypeError(
            "Mixed lhs/rhs types are not supported: "
            "lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x."
        )

Comment on lines +695 to +697
if (
first_dims is not None
or last_dims is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overly broad condition silently creates GroupedScaledTensor1x when original_shape is passed

Since group_axis defaults to 0 (never None), the sub-condition original_shape is not None and group_axis is not None evaluates to True for every caller that supplies an original_shape. This means a future caller that passes original_shape without intending to create a grouped tensor will silently receive a GroupedScaledTensor1x instead of a plain ScaledTensor1x, with first_dims=None, last_dims=None, and num_groups inferred from original_shape[group_axis].

The equivalent old guard was if group_sizes is not None, which was an explicit signal. The new guard conflates "caller wants a grouped tensor" with "caller happened to pass original_shape". Consider tightening the condition back to the explicit signal:

if first_dims is not None or last_dims is not None:

…and move any logic that genuinely needs original_shape for a non-grouped tensor (e.g. shape normalisation) outside the if block, or add a separate explicit is_grouped parameter to avoid the ambiguity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant