Skip to content

Conversation

@lkk12014402
Copy link
Contributor

@lkk12014402 lkk12014402 commented Jan 27, 2026

Description

  1. support hadamard transform for mxfp4

original linear:

$$ y = Wx $$

transform matrix $$H$$(Hadamard should $$H^\top H = I$$,and $$H^{-1}=H^\top$$):

$$ y = W x = (W H^\top) (H x) $$

define:

  • $$W' = W H^\top$$rotated weight
  • $$x' = H x$$rotated activation

then:

$$ y = W' x' $$

  1. support do evaluation with huggingface/transformers

with huggingface/transformers

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "./mxfp4_transformed_model"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
model.to("cuda")
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

with vllm

lm_eval --model hf    --model_args pretrained=./mxfp4_transformed_model    --tasks gsm8k     --batch_size 8

Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
Copilot AI review requested due to automatic review settings January 27, 2026 05:20
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for Hadamard transform for mxfp4 quantization with RTN or AutoRound methods. The Hadamard transform is an orthogonal transformation that rotates weights and activations to improve quantization quality by decorrelating features.

Changes:

  • Introduced a transform system supporting identity and Hadamard transforms
  • Integrated transform configuration throughout the quantization pipeline
  • Implemented custom Triton kernel for efficient mxfp4 quantization with Hadamard transform

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
auto_round/transforms/transforms.py New module defining transform classes (Identity, Hadamard) and factory function
auto_round/wrapper.py Added transform support to WrapperLinear and WrapperWALayer with transform matrix persistence
auto_round/schemes.py Extended QuantizationScheme to include transform_config field
auto_round/inference/convert_model.py Propagated transform_config from quantization config to layer config
auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py Registered forward_hadamard_matrix buffer during layer packing
auto_round/experimental/triton/mxfp4.py New Triton kernel for mxfp4 forward pass with Hadamard transform
auto_round/experimental/qmodules/mx.py Integrated Triton mxfp4 kernel for inference when transform is enabled
auto_round/compressors/base.py Added transform_config parameter and propagated to RTN quantization
auto_round/autoround.py Added transform_config parameter to AutoRound interface

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def forward(self, x: torch.Tensor):
return x

def remove_parametrizations(self) -> None:
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scale calculation 1 / math.sqrt(self.group_size) is duplicated in both __init__ (line 29) and get_transform_matrix. Consider using self.scale in get_transform_matrix to avoid duplication and ensure consistency.

Suggested change
def remove_parametrizations(self) -> None:
return hadamard_transform(torch.eye(self.group_size, device=device, dtype=dtype), scale=self.scale)

Copilot uses AI. Check for mistakes.
Comment on lines +163 to +164

# Create output tensors on CUDA
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using lambda for the grid calculation can cause issues with serialization and debugging. Consider using a regular function definition instead.

Suggested change
# Create output tensors on CUDA
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

Copilot uses AI. Check for mistakes.
minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`.
low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False.
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'matirx' to 'matrix'.

Suggested change
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.

Copilot uses AI. Check for mistakes.
disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation
with lower accuracy. Defaults to None.
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'matirx' to 'matrix'.

Suggested change
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.

Copilot uses AI. Check for mistakes.
@lkk12014402 lkk12014402 requested review from WeiweiZhang1, wenhuach21 and yiliu30 and removed request for WeiweiZhang1, wenhuach21 and yiliu30 January 27, 2026 05:24

orig_shape = input.shape
x_flat = input.contiguous().flatten(end_dim=-2)
qdq_input, _ = mxfp4_forward_kernel_wrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can the code ensure that the transformation is equivalent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls see the transformation equation in the comments. the activation is rotated and the packed weight is also rotated. so the linear is equivalent

Copy link
Contributor

@wenhuach21 wenhuach21 Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks. Should qkv or moe share a same H on activations, as they are fused in vllms

Copy link
Contributor Author

@lkk12014402 lkk12014402 Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only support one same H for all layers. we can support specific transform matrixs for each layer when we support other transform functions in this file [auto_round/transforms/transforms.py](https://github.com/intel/auto-round/pull/1349/files#diff-c384e42585ac4e70e9217f23c870c46b8b66097f7d6810bea55779a55e38fe73)

@lkk12014402
Copy link
Contributor Author

lkk12014402 commented Jan 27, 2026

accuracy test on llama3.1-8b-instruct

BF16

Quantization GSM8k (flexible-extract/strict-match)
N/A 0.7794/0.7096

MXFP4

Quantization GSM8k (flexible-extract/strict-match)
RTN 0.5724/0.5595
RTN + HAD (GS32) 0.5914/0.5898
AutoRound 0.6399/0.6391
AutoRound + HAD (GS32) 0.6626/0.6603

minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`.
low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False.
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please coordinate with Heng and Weiwei regarding the API (@n1ck-guo @WeiweiZhang1).
We should avoid fusing multiple algorithms into a single implementation, as we plan to support more algorithms in the future. We can take inspiration from the design used in LLMC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion. we can refer LLMC way like this https://github.com/vllm-project/llm-compressor/blob/main/examples/transform/quip_example.py#L26 to support more rotation methods

disable_opt_rtn: bool | None = None,
seed: int = 42,
low_cpu_mem_usage: bool = True,
transform_config: dict = {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s better not to name it transform_config, as it may be confusing with Transformers.

],
key=[],
)
@triton.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume xpu does not support this, but it's not a big issue for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpu support this kernel. but I haven't tested the performance

key=[],
)
@triton.jit
def mxfp4_forward_kernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add the source if the code is copied from another repo.
better add their license at the beginning of this file

@wenhuach21
Copy link
Contributor

@WeiweiZhang1 Please arrange an accuracy test to evaluate the benefits if the PR runs correctly. There’s no need to wait for the code to be fully refined.

@lkk12014402
Copy link
Contributor Author

@WeiweiZhang1 Please arrange an accuracy test to evaluate the benefits if the PR runs correctly. There’s no need to wait for the code to be fully refined.

@WeiweiZhang1 you can tell me how to do this. I will test more models.

@wenhuach21
Copy link
Contributor

Thanks a lot for the PR! It would be better to make it compatible with other schemes in the future.

@lkk12014402
Copy link
Contributor Author

lkk12014402 commented Jan 27, 2026

Thanks a lot for the PR! It would be better to make it compatible with other schemes in the future.

Thank you. I will make a plan for other schemes and other transform matrixs (methods)

@WeiweiZhang1
Copy link
Contributor

@WeiweiZhang1 Please arrange an accuracy test to evaluate the benefits if the PR runs correctly. There’s no need to wait for the code to be fully refined.

Given that we previously evaluated MXFP4 SignRoundV2, I think we should follow the same scope as before: test commonly used models in the 8B–70B range and compare results on around 10 mainstream lm_eval tasks, right? paper details link
image

previous env info:
torch==2.8.0, transformers=4.57.1, lm_eval=0.4.9.1
device info:
8B: A100, 32B-70B: B200

shell cmd example:
CUDA_VISIBLE_DEVICES=0 python3 -m auto_round \ --model_name $dir/$model \ --scheme "mxfp4" \ --format fake \ --tasks "lambada_openai,hellaswag,winogrande,piqa,mmlu,truthfulqa_mc1,openbookqa,boolq,arc_easy,arc_challenge,gsm8k" \ --enable_alg_ext \ --enable_torch_compile \ --enable_deterministic_algorithms \ --eval_task_by_task \ --eval_bs 32

Thanks, @lkk12014402 ! Could you share your BKC and let me know which kind of benchmark you'd like me to run?

@lkk12014402
Copy link
Contributor Author

@WeiweiZhang1 Please arrange an accuracy test to evaluate the benefits if the PR runs correctly. There’s no need to wait for the code to be fully refined.

Given that we previously evaluated MXFP4 SignRoundV2, I think we should follow the same scope as before: test commonly used models in the 8B–70B range and compare results on around 10 mainstream lm_eval tasks, right? paper details link image

previous env info: torch==2.8.0, transformers=4.57.1, lm_eval=0.4.9.1 device info: 8B: A100, 32B-70B: B200

shell cmd example: CUDA_VISIBLE_DEVICES=0 python3 -m auto_round \ --model_name $dir/$model \ --scheme "mxfp4" \ --format fake \ --tasks "lambada_openai,hellaswag,winogrande,piqa,mmlu,truthfulqa_mc1,openbookqa,boolq,arc_easy,arc_challenge,gsm8k" \ --enable_alg_ext \ --enable_torch_compile \ --enable_deterministic_algorithms \ --eval_task_by_task \ --eval_bs 32

Thanks, @lkk12014402 ! Could you share your BKC and let me know which kind of benchmark you'd like me to run?

no problem. Let me do the test. Thank you @WeiweiZhang1

@lkk12014402 lkk12014402 added this to the 0.10.0 milestone Jan 27, 2026
disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation
with lower accuracy. Defaults to None.
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mark it as experimental feature, and clarify the limitation.


def forward(self, x):
act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None
if self.enable_transform:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: using a forward hook might make the implementation clearer from a readability perspective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so too. Let me look into refining the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A cleaner API would look like: model = Hardarmd_transform(model), and then pass the transformed model to AutoRound. Please sync with Heng on this or other possibilities.

@wenhuach21
Copy link
Contributor

@WeiweiZhang1 Please arrange an accuracy test to evaluate the benefits if the PR runs correctly. There’s no need to wait for the code to be fully refined.

Given that we previously evaluated MXFP4 SignRoundV2, I think we should follow the same scope as before: test commonly used models in the 8B–70B range and compare results on around 10 mainstream lm_eval tasks, right? paper details link image
previous env info: torch==2.8.0, transformers=4.57.1, lm_eval=0.4.9.1 device info: 8B: A100, 32B-70B: B200
shell cmd example: CUDA_VISIBLE_DEVICES=0 python3 -m auto_round \ --model_name $dir/$model \ --scheme "mxfp4" \ --format fake \ --tasks "lambada_openai,hellaswag,winogrande,piqa,mmlu,truthfulqa_mc1,openbookqa,boolq,arc_easy,arc_challenge,gsm8k" \ --enable_alg_ext \ --enable_torch_compile \ --enable_deterministic_algorithms \ --eval_task_by_task \ --eval_bs 32
Thanks, @lkk12014402 ! Could you share your BKC and let me know which kind of benchmark you'd like me to run?

no problem. Let me do the test. Thank you @WeiweiZhang1

if you are using transformers to evaluate the accuracy, please follow the notes here for llama https://github.com/intel/auto-round/blob/main/docs/alg_202508.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants