Skip to content

[Enhancement]: Refactor the NVFP4 weight global scale calculation to be block-wise #1331

@yiliu30

Description

@yiliu30

Feature Description

Current behavior:

if is_nv_fp(self.data_type):
from auto_round.data_type.nvfp import calculate_gparam
from auto_round.data_type.utils import update_fused_layer_global_scales
pbar = tqdm(all_to_quantized_module_names)
for name in pbar:
pbar.set_description(f"Calculate weight global scale: {name}")
m = get_module(self.model, name)
if is_fp8_linear(m):
m = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device)
set_module(self.model, name, m)
weight_global_scale = calculate_gparam(m.weight, self.group_size)
setattr(m, "weight_global_scale", weight_global_scale)

This loop needs to go through the entire model, which makes it impossible to delay materializing the blocks. #1276

Motivation and Use Case

Expect behavior:

def calulate_and_set_global_scale(block) -> None:
    ...

for block in block_list:
    calulate_and_set_global_scale(block)

Alternatives Considered

No response

Definition of Done

  • Refactor this into a function that accepts either a block or a model and computes the scale.
  • Block-wise nvfp4 + delay materializing can be in another PR @yiliu30

Additional Context

No response

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions