Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from auto_round.data_type.fp8 import float8_e4m3fn_ste
from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.logger import logger


# taken from
Expand Down Expand Up @@ -206,6 +205,20 @@ def ref_fp4_quant(x, global_scale, block_size=16, v=0, max_scale=1.0):
return (cast_to_fp4(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale


def ref_fp4_quant_v3(x, global_scale, block_size=16, v=0, max_scale=1.0):
assert (not isinstance(global_scale, torch.Tensor)) or global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
if isinstance(max_scale, torch.Tensor):
max_scale = max_scale.unsqueeze(dim=-1).to(x.device)
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0] * max_scale
scale = global_scale * (vec_max.to(torch.bfloat16) * get_reciprocal(FLOAT4_E2M1_MAX))
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
scaled_x = x.to(torch.float32) * output_scale + v
clipped_x = torch.clamp(scaled_x, -6.0, 6.0)
return (cast_to_fp4(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale


@register_dtype("fp4_v2_with_global_scale")
def fp4_v2_with_global_scale(tensor, bits=4, group_size=16, v=0, tensor_max=None, max_scale=1.0, **kwargs):
assert group_size == 32 or group_size == 16
Expand Down Expand Up @@ -235,6 +248,16 @@ def fp4_v2(tensor, bits=4, group_size=32, v=0, max_scale=1.0, **kwargs):
return qdq_res.to(orig_dtype), scale, None


@register_dtype("fp4_v3")
def fp4_v3(tensor, bits=4, group_size=32, v=0, max_scale=1.0, **kwargs):
orig_dtype = tensor.dtype
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
global_scale = 1.0
Comment on lines 250 to +255
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR title/description are currently the default template and don’t explain what fp4_v3 is intended to change vs fp4_v2 (e.g., why scale uses bf16 and why UE5M3 clipping/casting is removed). Please update the PR description to document the motivation and expected usage so reviewers can validate correctness.

Copilot uses AI. Check for mistakes.
qdq_res, scale = ref_fp4_quant_v3(tensor, global_scale, group_size, v, max_scale)
Comment on lines +251 to +256
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp4_v3 doesn’t validate group_size, while fp4_v2/fp4_v2_with_global_scale explicitly restrict it to 16 or 32. If fp4_v3 has the same constraints, add the same assertion (or otherwise handle/describe supported values) to avoid silently producing unexpected scaling for unsupported group sizes.

Copilot uses AI. Check for mistakes.
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
Comment on lines 250 to +257
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New registered dtype fp4_v3/ref_fp4_quant_v3 adds a new quantization path but there are no corresponding unit tests covering it. Please add a small CPU test that exercises get_quant_func('fp4_v3', ...) (or QUANT_FUNC_WITH_DTYPE['fp4_v3']) and verifies output shape matches input, scale shape matches the number of groups, and results stay within the intended FP4 range after quant-dequant.

Copilot uses AI. Check for mistakes.
return qdq_res.to(orig_dtype), scale, None


if __name__ == "__main__":
data = torch.tensor([0.0, 0.25, 0.4, 0.75, 1.25, 1.4, 1.75, 2.5, 2.9, 3.5, 5.0, 5.1, 6.0, 6.2, 8.9])
data1 = cast_to_fp4(data)
Expand Down