-
Notifications
You must be signed in to change notification settings - Fork 75
support hadamard transform for mxfp4 with rtn or autoround method. #1349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for Hadamard transform for mxfp4 quantization with RTN or AutoRound methods. The Hadamard transform is an orthogonal transformation that rotates weights and activations to improve quantization quality by decorrelating features.
Changes:
- Introduced a transform system supporting identity and Hadamard transforms
- Integrated transform configuration throughout the quantization pipeline
- Implemented custom Triton kernel for efficient mxfp4 quantization with Hadamard transform
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| auto_round/transforms/transforms.py | New module defining transform classes (Identity, Hadamard) and factory function |
| auto_round/wrapper.py | Added transform support to WrapperLinear and WrapperWALayer with transform matrix persistence |
| auto_round/schemes.py | Extended QuantizationScheme to include transform_config field |
| auto_round/inference/convert_model.py | Propagated transform_config from quantization config to layer config |
| auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py | Registered forward_hadamard_matrix buffer during layer packing |
| auto_round/experimental/triton/mxfp4.py | New Triton kernel for mxfp4 forward pass with Hadamard transform |
| auto_round/experimental/qmodules/mx.py | Integrated Triton mxfp4 kernel for inference when transform is enabled |
| auto_round/compressors/base.py | Added transform_config parameter and propagated to RTN quantization |
| auto_round/autoround.py | Added transform_config parameter to AutoRound interface |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def forward(self, x: torch.Tensor): | ||
| return x | ||
|
|
||
| def remove_parametrizations(self) -> None: |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scale calculation 1 / math.sqrt(self.group_size) is duplicated in both __init__ (line 29) and get_transform_matrix. Consider using self.scale in get_transform_matrix to avoid duplication and ensure consistency.
| def remove_parametrizations(self) -> None: | |
| return hadamard_transform(torch.eye(self.group_size, device=device, dtype=dtype), scale=self.scale) |
|
|
||
| # Create output tensors on CUDA |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using lambda for the grid calculation can cause issues with serialization and debugging. Consider using a regular function definition instead.
| # Create output tensors on CUDA | |
| def grid(meta): | |
| return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`. | ||
| low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False. | ||
| low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. | ||
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'matirx' to 'matrix'.
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. | |
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. |
| disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation | ||
| with lower accuracy. Defaults to None. | ||
| low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. | ||
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'matirx' to 'matrix'.
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. | |
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. |
|
|
||
| orig_shape = input.shape | ||
| x_flat = input.contiguous().flatten(end_dim=-2) | ||
| qdq_input, _ = mxfp4_forward_kernel_wrapper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can the code ensure that the transformation is equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls see the transformation equation in the comments. the activation is rotated and the packed weight is also rotated. so the linear is equivalent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, thanks. Should qkv or moe share a same H on activations, as they are fused in vllms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only support one same H for all layers. we can support specific transform matrixs for each layer when we support other transform functions in this file [auto_round/transforms/transforms.py](https://github.com/intel/auto-round/pull/1349/files#diff-c384e42585ac4e70e9217f23c870c46b8b66097f7d6810bea55779a55e38fe73)
accuracy test on llama3.1-8b-instructBF16
MXFP4
|
| minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`. | ||
| low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False. | ||
| low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. | ||
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please coordinate with Heng and Weiwei regarding the API (@n1ck-guo @WeiweiZhang1).
We should avoid fusing multiple algorithms into a single implementation, as we plan to support more algorithms in the future. We can take inspiration from the design used in LLMC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good suggestion. we can refer LLMC way like this https://github.com/vllm-project/llm-compressor/blob/main/examples/transform/quip_example.py#L26 to support more rotation methods
| disable_opt_rtn: bool | None = None, | ||
| seed: int = 42, | ||
| low_cpu_mem_usage: bool = True, | ||
| transform_config: dict = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s better not to name it transform_config, as it may be confusing with Transformers.
| ], | ||
| key=[], | ||
| ) | ||
| @triton.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume xpu does not support this, but it's not a big issue for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xpu support this kernel. but I haven't tested the performance
| key=[], | ||
| ) | ||
| @triton.jit | ||
| def mxfp4_forward_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add the source if the code is copied from another repo.
better add their license at the beginning of this file
|
@WeiweiZhang1 Please arrange an accuracy test to evaluate the benefits if the PR runs correctly. There’s no need to wait for the code to be fully refined. |
@WeiweiZhang1 you can tell me how to do this. I will test more models. |
|
Thanks a lot for the PR! It would be better to make it compatible with other schemes in the future. |
Thank you. I will make a plan for other schemes and other transform matrixs (methods) |
Given that we previously evaluated MXFP4 SignRoundV2, I think we should follow the same scope as before: test commonly used models in the 8B–70B range and compare results on around 10 mainstream lm_eval tasks, right? paper details link previous env info: shell cmd example: Thanks, @lkk12014402 ! Could you share your BKC and let me know which kind of benchmark you'd like me to run? |
no problem. Let me do the test. Thank you @WeiweiZhang1 |
| disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation | ||
| with lower accuracy. Defaults to None. | ||
| low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. | ||
| transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please mark it as experimental feature, and clarify the limitation.
|
|
||
| def forward(self, x): | ||
| act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None | ||
| if self.enable_transform: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: using a forward hook might make the implementation clearer from a readability perspective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so too. Let me look into refining the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A cleaner API would look like: model = Hardarmd_transform(model), and then pass the transformed model to AutoRound. Please sync with Heng on this or other possibilities.
if you are using transformers to evaluate the accuracy, please follow the notes here for llama https://github.com/intel/auto-round/blob/main/docs/alg_202508.md |


Description
original linear:
transform matrix$$H$$ (Hadamard should $$H^\top H = I$$ ,and $$H^{-1}=H^\top$$ ):
define:
then:
with huggingface/transformers
with vllm