[JAX] MXFP8 Grouped Quantize V2#2760
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR refactors the JAX grouped GEMM stack to replace the old flat Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["grouped_gemm(lhs, rhs, ...)"] --> B{lhs type?}
B -->|GroupedNoScaleTensor| C["scaling_mode = NO_SCALING\nlhs_first/last_dims from lhs"]
B -->|GroupedScaledTensor1x| D["scaling_mode = lhs.scaling_mode\nlhs_first/last_dims from lhs"]
B -->|other| E["raise TypeError"]
C --> F{rhs type?}
D --> F
F -->|GroupedNoScaleTensor| G["rhs_first/last_dims from rhs"]
F -->|GroupedScaledTensor1x| H["rhs_first/last_dims from rhs\nCheck scaling_mode iff lhs also scaled"]
F -->|other| I["raise TypeError"]
G --> J["Infer out_first/last_dims"]
H --> J
J --> K{quantizer_set?}
K -->|noop| L["Use raw data as-is"]
K -->|FP8/MXFP8| M["grouped_quantize lhs & rhs"]
L --> N["Compute lhs/rhs_axis_boundary"]
M --> N
N --> O{_can_use_v2_grouped_gemm?}
O -->|NO_SCALING + BF16 + SM100+| P["GroupedGemmV2FFI\n(CUDA-graph safe)\nalpha/beta buffers"]
O -->|otherwise| Q["GroupedGemmFFI\n(legacy, FP8/MXFP8 support)\ngroup_offset buffer"]
P --> R["Output tensor"]
Q --> R
style C fill:#f9f,stroke:#333
style H fill:#f9f,stroke:#333
Last reviewed commit: 269a518 |
| num_gemms, | ||
| N_dim, | ||
| ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" | ||
| bias = jnp.empty((), jnp.float32) if bias is None else bias |
There was a problem hiding this comment.
assert replaces previous raise ValueError for bias shape check
The prior code used a raise ValueError(...) for this validation. The new assert statement can be silently disabled when Python is run in optimized mode (python -O), causing the check to be skipped entirely at runtime. This is a correctness regression from the original code.
| num_gemms, | |
| N_dim, | |
| ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" | |
| bias = jnp.empty((), jnp.float32) if bias is None else bias | |
| if bias.shape != (num_gemms, N_dim): | |
| raise ValueError( | |
| f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" | |
| ) |
| rhs_data = rhs.data.reshape(rhs_shape) | ||
| rhs_scale_inv = rhs.scale_inv | ||
| if lhs.scaling_mode != rhs.scaling_mode: | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: | ||
| raise ValueError( | ||
| f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," | ||
| f" rhs.scaling_mode={rhs.scaling_mode}" | ||
| ) | ||
| scaling_mode = lhs.scaling_mode | ||
| if isinstance(lhs, GroupedScaledTensor1x): | ||
| scaling_mode = lhs.scaling_mode | ||
| else: | ||
| raise TypeError("Unsupported lhs type object!") | ||
| raise TypeError( | ||
| f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" | ||
| ) | ||
|
|
||
| # Infer output dims from which operand has the ragged non-contracting dim. | ||
| if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: |
There was a problem hiding this comment.
scaling_mode not propagated when lhs is unscaled and rhs is scaled
When lhs is GroupedNoScaleTensor, scaling_mode is set to ScalingMode.NO_SCALING (line 2023). The rhs branch below only updates scaling_mode via if isinstance(lhs, GroupedScaledTensor1x): scaling_mode = lhs.scaling_mode, which is False in the mixed case. As a result, if rhs is a GroupedScaledTensor1x (with real scale_inv values), the operation is dispatched to C++ as NO_SCALING despite the rhs carrying scale information — this would produce silently incorrect output.
The complementary scaling-mode mismatch check (if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode) also only fires for the lhs-scaled case, so no error is raised for the mixed combination either.
If the intent is that GroupedNoScaleTensor + GroupedScaledTensor1x is simply an unsupported combination, an explicit TypeError should be raised:
elif isinstance(rhs, GroupedScaledTensor1x):
...
if isinstance(lhs, GroupedNoScaleTensor):
raise TypeError(
"Mixed lhs/rhs types are not supported: "
"lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x."
)| if ( | ||
| first_dims is not None | ||
| or last_dims is not None |
There was a problem hiding this comment.
Overly broad condition silently creates GroupedScaledTensor1x when original_shape is passed
Since group_axis defaults to 0 (never None), the sub-condition original_shape is not None and group_axis is not None evaluates to True for every caller that supplies an original_shape. This means a future caller that passes original_shape without intending to create a grouped tensor will silently receive a GroupedScaledTensor1x instead of a plain ScaledTensor1x, with first_dims=None, last_dims=None, and num_groups inferred from original_shape[group_axis].
The equivalent old guard was if group_sizes is not None, which was an explicit signal. The new guard conflates "caller wants a grouped tensor" with "caller happened to pass original_shape". Consider tightening the condition back to the explicit signal:
if first_dims is not None or last_dims is not None:…and move any logic that genuinely needs original_shape for a non-grouped tensor (e.g. shape normalisation) outside the if block, or add a separate explicit is_grouped parameter to avoid the ambiguity.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: