Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 140 additions & 16 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,58 @@
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme

# ============================================================================
# FP8 Dequantization Registry
# ============================================================================
# Registry mapping FP8 layer class names to their dequantization handlers.
# This ensures detection and dequantization logic stay in sync.
FP8_DEQUANT_REGISTRY = {}


def register_fp8_layer(layer_name: str):
"""Register a dequantization handler for an FP8 layer type.

Args:
layer_name: The class name of the FP8 layer type (e.g., "FP8Linear", "CompressedLinear")

Returns:
Decorator function that registers the handler.

Example:
@register_fp8_layer("FP8Linear")
def dequant_fp8_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
# Dequantization logic
return dequantized_weight
"""

def decorator(fn):
FP8_DEQUANT_REGISTRY[layer_name] = fn
return fn

return decorator


def dequant_fp8_layer(layer, dtype=torch.bfloat16, device: str = "cpu"):
"""Dequantize an FP8 layer using the registry.

Args:
layer: The FP8 layer to dequantize.
dtype: Target dtype for dequantized weights.
device: Target device for dequantization.

Returns:
Dequantized weight tensor.

Raises:
NotImplementedError: If the layer type is not registered.
"""
name = layer.__class__.__name__
if name not in FP8_DEQUANT_REGISTRY:
raise NotImplementedError(
f"Unsupported FP8 layer type: {name}. " f"Supported types: {list(FP8_DEQUANT_REGISTRY.keys())}"
)
return FP8_DEQUANT_REGISTRY[name](layer, dtype=dtype, device=device)


def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None:
"""This function is recommended to be used instead of module.weight = None.
Expand Down Expand Up @@ -628,16 +680,35 @@ def is_fp8_model(model: torch.nn.Module) -> bool:


def is_fp8_linear(module: torch.nn.Module) -> bool:
"""Check if a module is an FP8 linear layer.

Detection follows this priority:
1. Explicit `is_fp8_linear` attribute (highest priority)
2. Registry lookup (for registered FP8 layer types)
3. Fallback dtype check (for backward compatibility with unregistered FP8 layers)

Args:
module: The module to check.

Returns:
True if the module is an FP8 linear layer, False otherwise.
"""
# First check if explicitly marked
if hasattr(module, "is_fp8_linear"):
return module.is_fp8_linear
if not (type(module) == torch.nn.Linear or module.__class__.__name__ == "FP8Linear"):
return False
if module.weight is None:
return False
if str(module.weight.dtype).startswith("torch.float8"):

# Check registry for supported FP8 layer types
layer_name = module.__class__.__name__
if layer_name in FP8_DEQUANT_REGISTRY:
return True
else:
return False

# Fallback: Check for FP8 dtype (for torch.nn.Linear with FP8 weights)
# This maintains backward compatibility for layers not yet in registry
if type(module) == torch.nn.Linear and module.weight is not None:
if str(module.weight.dtype).startswith("torch.float8"):
return True

return False


def get_block_names(model, quant_vision=False):
Expand Down Expand Up @@ -1000,6 +1071,33 @@ def dequant_block_fp8_weight(
return dequant_weight


# Register FP8 layer dequantization handlers
# Note: Handlers are registered here (after dequant_block_fp8_weight is defined)
# to ensure all dependencies are available.
@register_fp8_layer("CompressedLinear")
def dequant_compressed_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
"""Dequantize CompressedLinear layer using compressor."""
layer = layer.to(device)
return layer.compressor.decompress_module(layer)


@register_fp8_layer("FP8Linear")
def dequant_fp8_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
"""Dequantize FP8Linear layer using block-based dequantization."""
layer = layer.to(device)
weight_scale = layer.weight_scale if hasattr(layer, "weight_scale") else layer.weight_scale_inv
data_type = getattr(layer, "data_type", None)
# Pass data_type if dequant_block_fp8_weight supports it
# Check if function accepts data_type parameter
import inspect

sig = inspect.signature(dequant_block_fp8_weight)
if "data_type" in sig.parameters:
return dequant_block_fp8_weight(layer.weight, weight_scale, layer.block_size, data_type=data_type)
else:
return dequant_block_fp8_weight(layer.weight, weight_scale, layer.block_size)


def check_to_quantized(config):
"""Checks if the configuration is valid for quantization.

Expand Down Expand Up @@ -1058,42 +1156,68 @@ def check_seqlen_compatible(input_seqlen, tokenizer=None, model=None):


def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
""" """
"""Convert an FP8 layer to a standard Linear layer.

Uses the FP8 dequantization registry to handle different FP8 layer types.
Preserves quantization scheme attributes and other metadata.

Args:
layer: The FP8 layer to convert.
dtype: Target dtype for the converted layer.
device: Target device for the conversion.

Returns:
A new torch.nn.Linear layer with dequantized weights.

Raises:
NotImplementedError: If the layer type is not registered in the FP8 registry.
"""
from auto_round.schemes import QuantizationScheme

new_layer = torch.nn.Linear(layer.in_features, layer.out_features, bias=layer.bias is not None, dtype=dtype)
if layer.bias is not None:
new_layer.bias.data.copy_(layer.bias.data.to(dtype=dtype))

# Copy quantization scheme attributes
scheme_keys = (f.name for f in fields(QuantizationScheme))
keys = tuple(scheme_keys) + ("global_name", "scale_dtype")
for key in keys:
setattr(new_layer, key, getattr(layer, key, None))

# Handle Gaudi2 device compatibility
from auto_round.utils.device import is_gaudi2

if is_gaudi2():
device = "cpu"
layer = layer.to(device)
if layer.__class__.__name__ == "CompressedLinear":
dq_weight = layer.compressor.decompress_module(layer)
else:
weight_scale = layer.weight_scale if hasattr(layer, "weight_scale") else layer.weight_scale_inv
data_type = getattr(layer, "data_type", None)
dq_weight = dequant_block_fp8_weight(layer.weight, weight_scale, layer.block_size, data_type=data_type)

# Use registry-based dequantization
dq_weight = dequant_fp8_layer(layer, dtype=dtype, device=device)
new_layer.weight.data.copy_(dq_weight.to(dtype=dtype))

return new_layer


def convert_fp8_module_to_16b(model, dtype=torch.bfloat16, device: str = "cpu"):
"""
Convert a model with FP8 quantized layers to a model with 16-bit linear layers.
This is useful for compatibility with other frameworks or for further processing.

Uses `is_fp8_linear` for detection, which supports all registered FP8 layer types
via the FP8 dequantization registry.

Args:
model: The model with FP8 layers to convert.
dtype: Target dtype for converted layers.
device: Target device for conversion.

Returns:
The model with FP8 layers converted to 16-bit Linear layers.
"""
from auto_round.utils.device import clear_memory

cnt = 0
for n, m in model.named_modules():
if m.__class__.__name__ == "FP8Linear":
if is_fp8_linear(m): # Use registry-based detection
new_module = convert_fp8_layer_to_linear(m, dtype=dtype, device=device)
set_module(model, n, new_module)
cnt += 1
Expand Down
Loading