diff --git a/atom/config.py b/atom/config.py index 1d433987..f75a1b35 100644 --- a/atom/config.py +++ b/atom/config.py @@ -8,11 +8,15 @@ import os import re from dataclasses import dataclass, field -from typing import Any, cast, Optional, Union +from typing import Any, Optional, Union import torch from aiter import QuantType -from aiter.utility.dtypes import d_dtypes +from atom.quant_spec import ( + LayerQuantConfig, + ParsedQuantConfig, + get_quant_parser, +) from atom.utils import envs, get_open_port from atom.utils.distributed.utils import stateless_init_torch_distributed_process_group from torch.distributed import ProcessGroup, ReduceOp @@ -251,70 +255,94 @@ def set_splitting_ops_for_v1(self): ] -class LayerQuantConfig(dict): - def __init__( - self, - quant_type=QuantType.No, - quant_dtype=torch.bfloat16, - is_dynamic=True, - quant_method="", - ): - """ - Core components of layer_quant - """ - super().__init__() - self["quant_type"] = quant_type if quant_type is not None else QuantType.No - self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16 - self["is_dynamic"] = is_dynamic - self["quant_method"] = quant_method +class QuantizationConfig: + """Model-wide quantization configuration. + API: + - ``get_layer_quant_config(prefix)`` -> :class:`LayerQuantConfig` + - ``global_quant_config`` property -> :class:`LayerQuantConfig` + - ``quant_type``, ``quant_dtype``, ``is_dynamic`` convenience properties + """ -class QuantizationConfig: def __init__(self, config: PretrainedConfig = None): if config is None: self.torch_dtype = torch.bfloat16 self.hf_quant_config = None - self.global_quant_config = LayerQuantConfig() - self.layer_quant_config = {} - self.exclude_layers = [] + self._parsed = ParsedQuantConfig() + self.exclude_layers: list[str] = [] self.quant_method = "" return self.torch_dtype = getattr(config, "torch_dtype", torch.bfloat16) self.hf_quant_config = getattr(config, "quantization_config", None) - self.global_quant_config = None - self.layer_quant_config = {} self.exclude_layers = [] if self.hf_quant_config is None: - self.global_quant_config = LayerQuantConfig( - quant_type=QuantType.No, quant_dtype=self.torch_dtype + self._parsed = ParsedQuantConfig( + global_spec=LayerQuantConfig( + quant_type=QuantType.No, quant_dtype=self.torch_dtype + ) ) self.quant_method = "" return self.quant_method = self.hf_quant_config.get("quant_method", "") - if self.quant_method == "quark": - layer_quant_config_dict = cast( - dict[str, Any], self.hf_quant_config.get("layer_quant_config", {}) - ) - for layer_name, layer_cfg in layer_quant_config_dict.items(): - self.layer_quant_config[layer_name] = self.parse_quark_config_dict( - layer_cfg - ) - global_quant_config_dict = cast( - dict[str, Any], self.hf_quant_config.get("global_quant_config", {}) - ) - self.global_quant_config = self.parse_quark_config_dict( - global_quant_config_dict - ) + # Use the parser registry to build a structured ParsedQuantConfig + parser = get_quant_parser(self.quant_method) + self._parsed = parser.parse(self.hf_quant_config) + self.exclude_layers = list(self._parsed.exclude_layers) - self.exclude_layers = cast( - list[str], self.hf_quant_config.get("exclude", []) - ) - else: - self.parse_other_config() + # -- typed API (preferred) ---------------------------------------------- + + @property + def global_quant_config(self) -> LayerQuantConfig: + """The default quantization spec for all layers.""" + return self._parsed.global_spec + + def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig: + """Return the :class:`LayerQuantConfig` for *layer_name*. + + Resolution order: + 1. Check exclude list -> ``LayerQuantConfig.no_quant()``. + 2. Exact match in ``parsed.layer_specs``. + 3. fnmatch-style pattern match in ``parsed.layer_pattern_specs``. + 4. Fall back to ``global_quant_config``. + """ + # 1. Exclude list + if self._is_excluded(layer_name): + return LayerQuantConfig(quant_dtype=self.torch_dtype) + + # 2. Exact match + if layer_name in self._parsed.layer_specs: + return self._parsed.layer_specs[layer_name] + + # 3. Pattern match + for pattern, spec in self._parsed.layer_pattern_specs: + if "*" not in pattern: + if layer_name in pattern: + return spec + elif fnmatch.fnmatch(layer_name, pattern): + return spec + + # 4. Global default + return self._parsed.global_spec + + # -- convenience properties (delegate to global_quant_config) ------------- + + @property + def quant_type(self) -> QuantType: + return self._parsed.global_spec.quant_type + + @property + def quant_dtype(self) -> torch.dtype: + return self._parsed.global_spec.quant_dtype + + @property + def is_dynamic(self) -> bool: + return self._parsed.global_spec.is_dynamic + + # -- other methods ------------------------------------------------------ def compute_hash(self) -> str: """ @@ -329,191 +357,41 @@ def compute_hash(self) -> str: the final hidden states. """ factors: list[Any] = [] - factors.append(self.global_quant_config) - factors.append(self.layer_quant_config) + factors.append(self._parsed.global_spec) + factors.append(self._parsed.layer_pattern_specs) factors.append(self.exclude_layers) hash_value = hashlib.sha256(str(factors).encode()).hexdigest() return hash_value def get_name(self): - """ - Returns the quantization method name. - """ + """Returns the quantization method name.""" return self.quant_method - def parse_quark_config_dict(self, config: dict) -> LayerQuantConfig: - quant_type = None - quant_dtype = None - is_dynamic = True - weight_config = cast(dict[str, Any], config.get("weight", {})) - input_config = cast(dict[str, Any], config.get("input_tensors", {})) - weight_qscheme = cast(str, weight_config.get("qscheme", "")) - weight_dtype = weight_config.get("dtype", "") - - # quant_type - if weight_qscheme == "per_channel": - quant_type = QuantType.per_Token - elif weight_qscheme == "per_tensor": - quant_type = QuantType.per_Tensor - elif weight_qscheme == "per_group": - # Currently, quark only supports group_size=32 - quant_type = QuantType.per_1x32 - else: - quant_type = QuantType.No - - # quant_dtype - dtype = weight_dtype.split("_")[0] - if dtype.endswith("4"): - dtype += "x2" - quant_dtype = d_dtypes[dtype] - - # is_dynamic - if input_config: - # input_dtype = input_config.get("dtype") - # input_qscheme = cast(str, input_config.get("qscheme")) - is_dynamic = cast(bool, input_config.get("is_dynamic", True)) - return LayerQuantConfig( - quant_type=quant_type, - quant_dtype=quant_dtype, - is_dynamic=is_dynamic, - quant_method="quark", - ) + # -- internal helpers --------------------------------------------------- - # TODO: For now, it's just a temporary migration. - # We should subsequently refine them in a targeted manner. - def parse_other_config(self): - RE_QUANT_BLOCKSIZE = ( - r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+)," - ) - orig_quant_config = self.hf_quant_config - quant_method = self.quant_method - orig_quant_config_str = str(orig_quant_config) - if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str: - quant_type = QuantType.per_Token - elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str): - group_size = int(group_size.group(1)) - assert group_size in (32, 128), f"Unsupported group size {group_size}" - if group_size == 128: - quant_type = QuantType.per_1x128 - elif group_size == 32: - quant_type = QuantType.per_1x32 - else: - quant_type = QuantType.per_Tensor - - RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'" - quant_dtype = None - m = re.search(RE_QUANT_DTYPE, orig_quant_config_str) - if m and m.group(1).lower() in [ - "fp8", - "fp4", - "int8", - "int4", - "fp8_e4m3", - "mxfp4", - ]: - dtype = m.group(1).lower().split("_")[0] - if dtype == "mxfp4": - dtype = "fp4" - if dtype.endswith("4"): - dtype += "x2" - quant_dtype = d_dtypes[dtype] - else: - bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str) - if bit_match: - bit = int(bit_match.group(1)) - dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str) - if dtype_match: - dtype = dtype_match.group(1).lower() - dtype_prefix = "i" if dtype.startswith("int") else "fp" - else: - dtype_prefix = "i" - quant_dtype_str = ( - f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2" - ) - quant_dtype = d_dtypes.get(quant_dtype_str, None) - assert ( - quant_dtype is not None - ), f"Cannot parse quant dtype from {orig_quant_config_str}" - if quant_dtype == d_dtypes["fp4x2"]: - quant_type = QuantType.per_1x32 - - RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'" - if re.search(RE_STATIC_QUANT, orig_quant_config_str): - is_dynamic = False - else: - is_dynamic = True - if quant_method == "compressed-tensors": - exclude_layers_key = "ignore" - else: - logger.warning( - f"Using 'ignore' as key for exclude layers with quant_method " - f"{quant_method}, please double check the quantization config." - ) - exclude_layers_key = "ignore" - exclude_layers = orig_quant_config.get(exclude_layers_key, []) - - self.global_quant_config = LayerQuantConfig( - quant_type=quant_type, - quant_dtype=quant_dtype, - is_dynamic=is_dynamic, - quant_method=quant_method, - ) - self.exclude_layers = exclude_layers - - def should_ignore_layer_quant(self, layer_name: str) -> bool: - # TODO: solve fused_mapping case + def _is_excluded(self, layer_name: str) -> bool: if layer_name is None or not self.exclude_layers: return False return any( - self.is_equal_or_regex_match(layer_name, ignore_str) + self._matches_exclude(layer_name, ignore_str) for ignore_str in self.exclude_layers ) - def is_equal_or_regex_match( - self, layer_name: str, ignore_str: str, check_contains: bool = False + @staticmethod + def _matches_exclude( + layer_name: str, ignore_str: str, check_contains: bool = False ) -> bool: - """Match the target string or regular expression""" + """Match the target string or regular expression.""" if ignore_str.startswith("re:"): - # case "re:model.layers.*self_attn.*", remove the 're:' prefix pattern = ignore_str[3:] if re.search(pattern, layer_name): return True - # case exclude_layer like "model.layers.0.self_attn.q_a_proj" (dpsk-attn) - # a common prefix for linear layers in attn like "model.layers.0.self_attn" elif check_contains: return layer_name.lower() in ignore_str.lower() elif ignore_str == layer_name: return True return False - def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig: - if self.should_ignore_layer_quant(layer_name=layer_name): - # return unquantized config - return LayerQuantConfig(quant_dtype=self.torch_dtype) - # layer quant config - layer_quant_config = None - if self.layer_quant_config: - - def _matches_pattern(layer_name, pattern): - if "*" not in pattern: - return layer_name in pattern - return fnmatch.fnmatch(layer_name, pattern) - - for name_pattern, config in self.layer_quant_config.items(): - if _matches_pattern(layer_name, name_pattern): - layer_quant_config = config - - layer_quant_config = ( - self.global_quant_config - if layer_quant_config is None - else layer_quant_config - ) - # TODO: if use_aiter, we can customize the quantization format here, such as dpsk - # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, - # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs - - return layer_quant_config - def remap_layer_name( self, hf_config: PretrainedConfig, packed_modules_mapping: dict | None = None ): @@ -556,11 +434,11 @@ def _remap_layer_name(name: str) -> list[str]: return [name.replace(packed_key, packed_remap_part, 1)] return [name] - new_layer_quant_config = {} - for layer_name, layer_qconfig in self.layer_quant_config.items(): - for remapped in _remap_layer_name(layer_name): - new_layer_quant_config[remapped] = layer_qconfig - self.layer_quant_config = new_layer_quant_config + new_pattern_specs = [] + for pattern, spec in self._parsed.layer_pattern_specs: + for remapped in _remap_layer_name(pattern): + new_pattern_specs.append((remapped, spec)) + self._parsed.layer_pattern_specs = new_pattern_specs new_exclude = [] for name in self.exclude_layers: diff --git a/atom/model_ops/activation.py b/atom/model_ops/activation.py index 02522eaa..4ef9dff8 100644 --- a/atom/model_ops/activation.py +++ b/atom/model_ops/activation.py @@ -6,7 +6,8 @@ from torch import nn import torch.nn.functional as F from aiter import silu_and_mul -from atom.config import QuantizationConfig, LayerQuantConfig +from atom.config import QuantizationConfig +from atom.quant_spec import LayerQuantConfig from aiter.jit.utils.torch_guard import torch_compile_guard from aiter import ( @@ -59,17 +60,18 @@ def __init__( self, fused_quant: bool = False, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.fused_quant = fused_quant layer_quant_config = ( LayerQuantConfig() if quant_config is None - else quant_config.global_quant_config + else quant_config.get_layer_quant_config(prefix) ) - quant_type = layer_quant_config["quant_type"] - params_dtype = layer_quant_config["quant_dtype"] + quant_type = layer_quant_config.quant_type + params_dtype = layer_quant_config.quant_dtype self.quant_type = quant_type self.params_dtype = params_dtype diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 7f556fcc..ab4ad2f9 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -8,7 +8,8 @@ has_torch_function_unary, handle_torch_function, ) -from atom.config import QuantizationConfig, LayerQuantConfig +from atom.config import QuantizationConfig +from atom.quant_spec import LayerQuantConfig from atom.utils.decorators import mark_trace from torch import nn from aiter import ( @@ -178,6 +179,7 @@ def __init__( fused_allreduce: bool = False, fused_quant: bool = False, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.dim = dim @@ -191,10 +193,10 @@ def __init__( layer_quant_config = ( LayerQuantConfig() if quant_config is None - else quant_config.global_quant_config + else quant_config.get_layer_quant_config(prefix) ) - quant_type = layer_quant_config["quant_type"] - params_dtype = layer_quant_config["quant_dtype"] + quant_type = layer_quant_config.quant_type + params_dtype = layer_quant_config.quant_dtype self.quant_type = quant_type self.params_dtype = params_dtype diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 3cfbbb39..0907cc3f 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -21,7 +21,8 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils -from atom.config import QuantizationConfig, get_current_atom_config, LayerQuantConfig +from atom.config import QuantizationConfig, get_current_atom_config +from atom.quant_spec import LayerQuantConfig from atom.model_ops.utils import ( normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, @@ -212,8 +213,8 @@ def __init__( if quant_config is not None else LayerQuantConfig() ) - quant_type = layer_quant_config["quant_type"] - params_dtype = layer_quant_config["quant_dtype"] + quant_type = layer_quant_config.quant_type + params_dtype = layer_quant_config.quant_dtype self.source_quant_dtype = source_quant_dtype self.layer_quant_config = layer_quant_config super().__init__() @@ -269,7 +270,7 @@ def __init__( torch.empty(len(self.output_partition_sizes), 1, dtype=dtypes.fp32), requires_grad=False, ) - if not layer_quant_config["is_dynamic"]: + if not layer_quant_config.is_dynamic: self.input_scale = nn.Parameter( torch.empty( len(self.output_partition_sizes), 1, dtype=dtypes.fp32 diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 0745df34..6e7bbf59 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -18,8 +18,8 @@ Config, QuantizationConfig, get_current_atom_config, - LayerQuantConfig, ) +from atom.quant_spec import LayerQuantConfig from atom.model_loader.weight_utils import set_weight_attrs from atom.model_ops.base_config import QuantizeMethodBase from atom.model_ops.fused_moe.config import ( @@ -633,9 +633,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - self.quant_type = self.quant_config["quant_type"] - self.quant_dtype = self.quant_config["quant_dtype"] - self.quant_method = self.quant_config["quant_method"] + self.quant_type = quant_config.quant_type + self.quant_dtype = quant_config.quant_dtype + self.quant_method = quant_config.quant_method or "" self.block_quant = ( self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 @@ -1029,8 +1029,8 @@ class CompressedTensorsFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - self.quant_type = quant_config["quant_type"] - self.quant_dtype = quant_config["quant_dtype"] + self.quant_type = quant_config.quant_type + self.quant_dtype = quant_config.quant_dtype # Check if we need to normalize e4m3fn to e4m3fnuz (AMD GPUs) self.need_normalize_e4m3fn_to_e4m3fnuz = ( @@ -1047,7 +1047,7 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): self.per_channel = self.quant_type == QuantType.per_Token # Check if static input scales (activation quantization) - self.static_input_scales = not quant_config.get("is_dynamic", True) + self.static_input_scales = not quant_config.is_dynamic # Block sizes for block quantization if self.block_quant: @@ -1431,14 +1431,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): the model weights are loaded. Args: - quant_config: The quantization config. + quant_config: The quantization config (LayerQuantConfig). """ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - self.quant_type = self.quant_config["quant_type"] - self.quant_dtype = self.quant_config["quant_dtype"] + self.quant_type = quant_config.quant_type + self.quant_dtype = quant_config.quant_dtype self.block_quant = ( self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 @@ -1551,7 +1551,7 @@ def create_weights( ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - assert self.quant_config["is_dynamic"] + assert self.quant_config.is_dynamic else: # Per-tensor w13_weight_scale = torch.nn.Parameter( @@ -1568,7 +1568,7 @@ def create_weights( # INPUT_SCALES # Per-channel uses dynamic per-token activation, no static input scales. - if self.channel_quant or self.quant_config["is_dynamic"]: + if self.channel_quant or self.quant_config.is_dynamic: layer.w13_input_scale = None layer.w2_input_scale = None else: @@ -1610,7 +1610,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: self._process_tensor_quant(layer) def _process_block_quant(self, layer: nn.Module) -> None: - assert self.quant_config["is_dynamic"] + assert self.quant_config.is_dynamic self._normalize_weights_and_scales(layer) if not self.need_normalize_e4m3fn_to_e4m3fnuz: @@ -1647,7 +1647,7 @@ def _process_channel_quant(self, layer: nn.Module) -> None: shuffle_weights(layer.w13_weight, layer.w2_weight) def _process_tensor_quant(self, layer: nn.Module) -> None: - if not self.quant_config["is_dynamic"]: + if not self.quant_config.is_dynamic: if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " @@ -1902,7 +1902,7 @@ def __init__( quant_config.get_layer_quant_config(prefix) if quant_config else None ) self.params_dtype = ( - layer_quant_config["quant_dtype"] + layer_quant_config.quant_dtype if layer_quant_config else torch.get_default_dtype() ) @@ -2024,24 +2024,26 @@ def __init__( # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. - quant_method_str = layer_quant_config.get("quant_method", None) - if layer_quant_config["quant_type"] == QuantType.No: + quant_method_str = ( + layer_quant_config.quant_method if layer_quant_config else None + ) + if layer_quant_config.quant_type == QuantType.No: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( moe ) elif ( quant_method_str == "compressed-tensors" - and layer_quant_config["quant_dtype"] == dtypes.fp8 + and layer_quant_config.quant_dtype == dtypes.fp8 ): # Use CompressedTensorsFp8MoEMethod for compressed-tensors format self.quant_method = CompressedTensorsFp8MoEMethod(layer_quant_config, moe) - elif layer_quant_config["quant_dtype"] == dtypes.fp8: + elif layer_quant_config.quant_dtype == dtypes.fp8: self.quant_method = Fp8MoEMethod(layer_quant_config, moe) - elif layer_quant_config["quant_dtype"] == dtypes.fp4x2: + elif layer_quant_config.quant_dtype == dtypes.fp4x2: self.quant_method = Mxfp4MoEMethod(layer_quant_config, moe) else: raise ValueError( - f"Unsupported quant dtype: {layer_quant_config['quant_dtype']}" + f"Unsupported quant dtype: {layer_quant_config.quant_dtype}" ) assert self.quant_method is not None @@ -2340,7 +2342,7 @@ def weight_loader( shard_id: str = "", expert_id: int = 0, ) -> None: - if self.layer_quant_config["quant_dtype"] == dtypes.fp4x2 and weight_name == "": + if self.layer_quant_config.quant_dtype == dtypes.fp4x2 and weight_name == "": self.mxf4_merged_weight_loader(param, loaded_weight) return @@ -2418,7 +2420,7 @@ def weight_loader( # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case - quant_method = self.layer_quant_config["quant_type"] + quant_method = self.layer_quant_config.quant_type if quant_method == QuantType.per_Token: self._load_per_channel_weight_scale( shard_id=shard_id, diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 4ea30083..e0b1b0fd 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -57,7 +57,7 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None: quant_config = atom_config.quant_config layer_quant_config = quant_config.get_layer_quant_config(prefix) - if layer_quant_config["quant_dtype"] == dtypes.fp4x2: + if layer_quant_config.quant_dtype == dtypes.fp4x2: quant_config = QuantizationConfig() self.mtp_block = DeepseekV2DecoderLayer( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f05..39691add 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1127,7 +1127,10 @@ def __init__( ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.weights_proj = ReplicatedLinear( - hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj" + hidden_size, + self.n_head, + quant_config=quant_config, + prefix=f"{prefix}.weights_proj", ) self.softmax_scale = self.head_dim**-0.5 @@ -1251,7 +1254,7 @@ def __init__( ) layer_quant_dtype = quant_config.get_layer_quant_config( f"{prefix}.{q_a_proj_name}" - )["quant_dtype"] + ).quant_dtype if layer_quant_dtype == dtypes.fp4x2: if not use_triton_gemm(): source_quant_dtype = None @@ -1570,7 +1573,7 @@ def __init__( self.quant_dtype = ( None if quant_config is None - else quant_config.global_quant_config["quant_dtype"] + else quant_config.get_layer_quant_config(prefix).quant_dtype ) self.fuse_input_norm_quant = False self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index d689048e..3b81f3ff 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -111,7 +111,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.num_attention_heads, total_num_kv_heads=self.num_key_value_heads, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.qkv_proj", bias=True, ) @@ -119,7 +119,7 @@ def __init__( self.o_proj = RowParallelLinear( input_size=self.num_attention_heads * self.head_dim, output_size=self.hidden_size, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.o_proj", bias=True, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, diff --git a/atom/models/llama.py b/atom/models/llama.py index 893fdf79..349ebc56 100644 --- a/atom/models/llama.py +++ b/atom/models/llama.py @@ -99,7 +99,7 @@ def __init__( self.act_fn = SiluAndMul( fused_quant=self.fused_act_quant, quant_config=quant_config ) - self.quant_type = quant_config.global_quant_config["quant_type"] + self.quant_type = quant_config.get_layer_quant_config(prefix).quant_type def forward(self, x, x_scale: Optional[torch.Tensor] = None): x = self.gate_up_proj(x, x_scale=x_scale) @@ -271,7 +271,7 @@ def __init__( ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT ) - self.quant_type = quant_config.global_quant_config["quant_type"] + self.quant_type = quant_config.get_layer_quant_config(prefix).quant_type self.self_attn = LlamaAttention( config=config, diff --git a/atom/models/qwen3_next.py b/atom/models/qwen3_next.py index 568d0b2e..c755e3ce 100644 --- a/atom/models/qwen3_next.py +++ b/atom/models/qwen3_next.py @@ -537,6 +537,7 @@ def __init__( input_size=self.conv_kernel_size, output_size=self.conv_dim, bias=False, + quant_config=quant_config, prefix=f"{prefix}.conv1d", ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) diff --git a/atom/quant_spec.py b/atom/quant_spec.py new file mode 100644 index 00000000..07baec55 --- /dev/null +++ b/atom/quant_spec.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +"""Typed quantization specification and parser registry. + +This module introduces: +- :class:`LayerQuantConfig` — a frozen dataclass for type-safe, immutable + layer quant descriptions. +- :class:`ParsedQuantConfig` — structured output of parsing ``quantization_config`` + from a HuggingFace ``PretrainedConfig``. +- A parser registry (:func:`register_quant_parser`, :func:`get_quant_parser`) so + new quantizer back-ends (Quark, compressed-tensors, …) can each provide their + own parsing logic without bloating ``config.py``. +""" + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +import torch +from aiter import QuantType +from aiter.utility.dtypes import d_dtypes + +# ────────────────────────────────────────────────────────────────────── +# Typed layer-level spec +# ────────────────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class LayerQuantConfig: + """Immutable description of how a single layer (or default) is quantized.""" + + quant_type: QuantType = QuantType.No + quant_dtype: Any = torch.bfloat16 # torch.dtype (use Any for forward compat) + is_dynamic: bool = True + quant_method: str | None = None + + @property + def is_quantized(self) -> bool: + return self.quant_type != QuantType.No + + @classmethod + def no_quant(cls, dtype: Any = torch.bfloat16) -> LayerQuantConfig: + """Convenience: unquantized spec with a given storage dtype.""" + return cls(quant_type=QuantType.No, quant_dtype=dtype) + + +# ────────────────────────────────────────────────────────────────────── +# Structured parsed config +# ────────────────────────────────────────────────────────────────────── + + +@dataclass +class ParsedQuantConfig: + """Result of parsing a ``quantization_config`` dict.""" + + global_spec: LayerQuantConfig = field(default_factory=LayerQuantConfig) + layer_specs: dict[str, LayerQuantConfig] = field(default_factory=dict) + # Pattern specs as list of (pattern, spec) tuples to preserve order + layer_pattern_specs: list[tuple[str, LayerQuantConfig]] = field( + default_factory=list + ) + exclude_layers: list[str] = field(default_factory=list) + + +# ────────────────────────────────────────────────────────────────────── +# Parser registry +# ────────────────────────────────────────────────────────────────────── + +_PARSER_REGISTRY: dict[str, type[QuantConfigParser]] = {} + + +class QuantConfigParser(ABC): + """Base class for quantization config parsers.""" + + @abstractmethod + def parse(self, hf_quant_config: dict) -> ParsedQuantConfig: + """Parse a ``quantization_config`` dict into :class:`ParsedQuantConfig`.""" + ... + + +def register_quant_parser(name: str): + """Decorator: register a parser class under *name*.""" + + def wrapper(cls: type[QuantConfigParser]): + _PARSER_REGISTRY[name] = cls + return cls + + return wrapper + + +def get_quant_parser(method_name: str) -> QuantConfigParser: + """Return an instance of the parser for *method_name*. + + Falls back to the ``_generic`` parser if no specific one is registered. + """ + cls = _PARSER_REGISTRY.get(method_name) or _PARSER_REGISTRY.get("_generic") + if cls is None: + raise ValueError( + f"No quant config parser registered for {method_name!r} " + f"and no _generic fallback available." + ) + return cls() + + +# ────────────────────────────────────────────────────────────────────── +# Built-in parsers +# ────────────────────────────────────────────────────────────────────── + + +# -- helpers ---------------------------------------------------------------- + +_QSCHEME_TO_QUANT_TYPE: dict[str, QuantType] = { + "per_channel": QuantType.per_Token, + "per_tensor": QuantType.per_Tensor, + "per_group": QuantType.per_1x32, + "per_block": QuantType.per_1x128, +} + + +def _parse_quant_type(qscheme: str | None) -> QuantType: + if qscheme is None: + return QuantType.No + return _QSCHEME_TO_QUANT_TYPE.get(qscheme, QuantType.No) + + +def _parse_quant_dtype(dtype_str: str | None) -> Any: + if dtype_str is None: + return torch.bfloat16 + # Normalise e.g. "fp8_e4m3" -> "fp8", "fp4_e2m1" -> "fp4" + key = re.sub(r"_e\d+m\d+.*", "", dtype_str) + # Direct lookup + result = d_dtypes.get(key) + if result is not None: + return result + # Try common suffixed variants: fp4 -> fp4x2, int4 -> int4x2, etc. + for suffix in ("x2", "x4"): + result = d_dtypes.get(key + suffix) + if result is not None: + return result + return torch.bfloat16 + + +def _parse_is_dynamic(input_tensors: dict | None) -> bool: + if input_tensors is None: + return True + return input_tensors.get("is_dynamic", True) + + +def _build_quark_layer_spec(layer_dict: dict) -> LayerQuantConfig: + """Build a :class:`LayerQuantConfig` from a single Quark per-layer dict.""" + weight = layer_dict.get("weight", {}) or {} + return LayerQuantConfig( + quant_type=_parse_quant_type(weight.get("qscheme")), + quant_dtype=_parse_quant_dtype(weight.get("dtype")), + is_dynamic=_parse_is_dynamic(layer_dict.get("input_tensors")), + quant_method="quark", + ) + + +# -- Quark ------------------------------------------------------------------ + + +@register_quant_parser("quark") +class QuarkParser(QuantConfigParser): + """Parser for Quark-style ``quantization_config``.""" + + def parse(self, hf_quant_config: dict) -> ParsedQuantConfig: + global_dict = hf_quant_config.get("global_quant_config") or {} + layer_dict = hf_quant_config.get("layer_quant_config") or {} + exclude = list(hf_quant_config.get("exclude") or []) + + global_spec = ( + _build_quark_layer_spec(global_dict) if global_dict else LayerQuantConfig() + ) + + pattern_specs: list[tuple[str, LayerQuantConfig]] = [] + for pattern, cfg in layer_dict.items(): + pattern_specs.append((pattern, _build_quark_layer_spec(cfg))) + + return ParsedQuantConfig( + global_spec=global_spec, + layer_pattern_specs=pattern_specs, + exclude_layers=exclude, + ) + + +# -- Generic (compressed-tensors, GPTQ, AWQ, …) ---------------------------- + + +@register_quant_parser("_generic") +class GenericParser(QuantConfigParser): + """Fallback parser that uses heuristics for compressed-tensors, etc.""" + + # Regex patterns for identifying quantization types from config keys/values + _DTYPE_PATTERNS = { + r"fp8|float8": "fp8", + r"fp4|float4|mxfp4": "fp4x2", + r"int8|w8a8": "int8", + r"int4|w4a16|gptq|awq": "int4x2", + } + + _QTYPE_PATTERNS = { + r"block|per_block|blockwise|1x128": QuantType.per_1x128, + r"per_channel|channel|per_token|token": QuantType.per_Token, + r"per_tensor|tensor": QuantType.per_Tensor, + r"per_group|group": QuantType.per_1x32, + } + + def parse(self, hf_quant_config: dict) -> ParsedQuantConfig: + quant_method = hf_quant_config.get("quant_method", "") + config_str = str(hf_quant_config).lower() + + quant_dtype = self._infer_dtype(hf_quant_config, config_str) + quant_type = self._infer_qtype(hf_quant_config, config_str) + is_dynamic = hf_quant_config.get("is_dynamic", True) + # Each quantizer uses a different key for excluded layers: + # Quark -> "exclude", compressed-tensors -> "ignore", + # gpt-oss/HF transformers -> "modules_to_not_convert" + exclude = list( + hf_quant_config.get("ignore") + or hf_quant_config.get("modules_to_not_convert") + or hf_quant_config.get("exclude") + or [] + ) + + global_spec = LayerQuantConfig( + quant_type=quant_type, + quant_dtype=quant_dtype, + is_dynamic=is_dynamic, + quant_method=quant_method or None, + ) + + return ParsedQuantConfig(global_spec=global_spec, exclude_layers=exclude) + + def _infer_dtype(self, cfg: dict, config_str: str) -> Any: + # Check explicit fields first + for key in ("weight_dtype", "activation_dtype", "dtype"): + val = cfg.get(key) + if val and isinstance(val, str): + parsed = _parse_quant_dtype(val) + if parsed != torch.bfloat16: + return parsed + # Check compressed-tensors config_groups (type + num_bits encoding) + config_groups = cfg.get("config_groups") + if isinstance(config_groups, dict): + for group in config_groups.values(): + if not isinstance(group, dict): + continue + weights = group.get("weights") or {} + wtype = weights.get("type", "") + num_bits = weights.get("num_bits") + if wtype == "float" and num_bits == 8: + return d_dtypes.get("fp8", torch.bfloat16) + if wtype == "float" and num_bits == 4: + return d_dtypes.get("fp4x2", torch.bfloat16) + if wtype == "int" and num_bits == 8: + return d_dtypes.get("i8", torch.bfloat16) + # Fall back to regex heuristics + for pattern, dtype_key in self._DTYPE_PATTERNS.items(): + if re.search(pattern, config_str): + return d_dtypes.get(dtype_key, torch.bfloat16) + return torch.bfloat16 + + def _infer_qtype(self, cfg: dict, config_str: str) -> QuantType: + # Check explicit fields + for key in ("quant_type", "quantization_type", "scheme"): + val = cfg.get(key) + if val and isinstance(val, str): + for pattern, qtype in self._QTYPE_PATTERNS.items(): + if re.search(pattern, val.lower()): + return qtype + # Check compressed-tensors config_groups for weight strategy + config_groups = cfg.get("config_groups") + if isinstance(config_groups, dict): + for group in config_groups.values(): + if not isinstance(group, dict): + continue + weights = group.get("weights") or {} + strategy = weights.get("strategy", "") + if strategy: + mapped = _QSCHEME_TO_QUANT_TYPE.get(strategy) + if mapped is None: + mapped = _QSCHEME_TO_QUANT_TYPE.get(f"per_{strategy}") + if mapped is not None: + return mapped + # Fall back to regex heuristics on full config string + for pattern, qtype in self._QTYPE_PATTERNS.items(): + if re.search(pattern, config_str): + return qtype + return QuantType.No diff --git a/docs/architecture_guide.md b/docs/architecture_guide.md index a026fc1a..e51064db 100644 --- a/docs/architecture_guide.md +++ b/docs/architecture_guide.md @@ -237,4 +237,4 @@ WAITING ──(scheduled for prefill)──► RUNNING ──(stop condition met | `atom/model_engine/request.py` | `RequestOutput` dataclass for streaming callbacks | | `atom/model_engine/async_proc.py` | `AsyncIOProcManager` and `AsyncIOProc` for spawning and managing ModelRunner subprocesses | | `atom/utils/forward_context.py` | `ForwardContext`, `Context`, `DPMetadata`, `SpecDecodeMetadata`, `AttentionMetaData` dataclasses and global accessors | -| `atom/config.py` | `Config` master configuration, `ParallelConfig`, `CompilationConfig`, `LayerQuantConfig`, `QuantizationConfig`, `SpeculativeConfig`, `KVCacheTensor` | +| `atom/config.py` | `Config` master configuration, `ParallelConfig`, `CompilationConfig`, `QuantizationConfig`, `SpeculativeConfig`, `KVCacheTensor` | diff --git a/docs/configuration_guide.md b/docs/configuration_guide.md index d0b90e2a..78dc7156 100644 --- a/docs/configuration_guide.md +++ b/docs/configuration_guide.md @@ -16,7 +16,7 @@ controls ATOM's runtime behaviour. | `CompilationLevel` | Integer constants for the four compilation levels | | `CUDAGraphMode` | Enum controlling how CUDA graphs are captured (none / piecewise / full / hybrid) | | `QuantizationConfig` | Layer-wise quantization orchestrator: global config, per-layer overrides, exclude lists, layer name remapping | -| `LayerQuantConfig` | Per-layer quantization parameters: quant type, dtype, dynamic flag, method | +| `LayerQuantConfig` | Per-layer quantization spec (frozen dataclass): quant type, dtype, dynamic flag, method | | `ParallelConfig` | Data-parallel size, rank, master IP/port | | `SpeculativeConfig` | Speculative decoding method, draft model, number of speculative tokens | | `KVCacheConfig` / `KVCacheTensor` | Per-layer KV cache tensor descriptors (k/v caches and scales) | @@ -122,16 +122,16 @@ Helper methods on `CUDAGraphMode`: ## 3. Quantization Configuration (`QuantizationConfig` & `LayerQuantConfig`) -Defined in `atom/config.py`. The quantization system uses two classes: +Defined in `atom/config.py` and `atom/quant_spec.py`. The quantization system uses two classes: -- **`QuantizationConfig`** -- the top-level orchestrator that holds a global config, per-layer overrides, and exclusion lists. It is **not** a `dict` subclass. -- **`LayerQuantConfig(dict)`** -- a `dict` subclass that stores the concrete quantization parameters for a single layer (or as the global default). +- **`QuantizationConfig`** -- the top-level orchestrator that holds a global config, per-layer overrides, and exclusion lists. +- **`LayerQuantConfig`** -- a frozen dataclass (defined in `atom/quant_spec.py`) that stores the concrete quantization parameters for a single layer or as the global default. Typed, immutable, with attribute access (e.g., `spec.quant_type`). ### 3.1 `LayerQuantConfig` Fields -`LayerQuantConfig` extends `dict`. Fields are stored and accessed as dictionary keys (e.g., `cfg["quant_type"]`). +`LayerQuantConfig` is a frozen dataclass. Fields are accessed as typed attributes (e.g., `spec.quant_type`). -| Key | Type | Default | Description | +| Field | Type | Default | Description | |---|---|---|---| | `quant_type` | `QuantType` | `QuantType.No` | Quantization granularity (see below) | | `quant_dtype` | `torch.dtype` | `torch.bfloat16` | Data type for quantized weights | @@ -144,8 +144,8 @@ Defined in `atom/config.py`. The quantization system uses two classes: |---|---|---| | `torch_dtype` | `torch.dtype` | The model's default dtype (from `hf_config.torch_dtype`) | | `hf_quant_config` | `dict \| None` | Raw `quantization_config` dict from HuggingFace config | -| `global_quant_config` | `LayerQuantConfig` | Default quantization config applied to all layers | -| `layer_quant_config` | `dict[str, LayerQuantConfig]` | Per-layer overrides keyed by layer name pattern (supports fnmatch globs like `"*.mlp.*"`) | +| `global_quant_config` | `LayerQuantConfig` | Default quantization spec applied to all layers | +| `_parsed.layer_pattern_specs` | `list[tuple[str, LayerQuantConfig]]` | Per-layer overrides keyed by layer name pattern (supports fnmatch globs like `"*.mlp.*"`) | | `exclude_layers` | `list[str]` | Layer names excluded from quantization (supports exact match and `"re:"` regex prefix) | | `quant_method` | `str` | Top-level quantization method name (e.g., `"quark"`, `"compressed-tensors"`) | @@ -154,11 +154,10 @@ Key methods: | Method | Description | |---|---| | `get_name()` | Returns the quantization method name | -| `get_layer_quant_config(layer_name)` | Returns the `LayerQuantConfig` for a layer: checks exclusions first, then per-layer overrides, then falls back to global config | -| `should_ignore_layer_quant(layer_name)` | Returns `True` if the layer is in the exclusion list | +| `get_layer_quant_config(layer_name)` | Returns the `LayerQuantConfig` for a layer: checks exclusions first, then per-layer overrides, then falls back to global spec | | `remap_layer_name(hf_config, packed_modules_mapping)` | Remaps layer names for packed/fused modules (e.g., `q_a_proj` → `fused_qkv_a_proj` for DeepSeek) | | `compute_hash()` | Returns a SHA-256 hash of the quantization config for cache invalidation | -| `parse_quark_config_dict(config)` | Parses a quark-format config dict into a `LayerQuantConfig` | + ### 3.3 `QuantType` Values (from AITER) @@ -188,7 +187,7 @@ parameters: **For quark models** (`quant_method == "quark"`): -1. Parses `global_quant_config` dict via `parse_quark_config_dict()` to produce the global `LayerQuantConfig`. +1. Parses `global_quant_config` dict via `QuarkParser` to produce the global `LayerQuantConfig`. 2. Parses each entry in `layer_quant_config` dict to produce per-layer overrides. 3. Reads the `"exclude"` list for excluded layers. 4. Within each config dict, `weight.qscheme` determines `quant_type` (`"per_channel"` → `per_Token`, `"per_tensor"` → `per_Tensor`, `"per_group"` → `per_1x32`), and `weight.dtype` determines `quant_dtype`. @@ -388,7 +387,7 @@ Need maximum decode throughput? | File | Description | |---|---| -| `atom/config.py` | `Config`, `CompilationConfig`, `CompilationLevel`, `CUDAGraphMode`, `LayerQuantConfig`, `QuantizationConfig`, `ParallelConfig`, `SpeculativeConfig`, `KVCacheTensor`, `KVCacheConfig`, `get_hf_config` | +| `atom/config.py` | `Config`, `CompilationConfig`, `CompilationLevel`, `CUDAGraphMode`, `QuantizationConfig`, `ParallelConfig`, `SpeculativeConfig`, `KVCacheTensor`, `KVCacheConfig`, `get_hf_config` | | `atom/utils/envs.py` | All `ATOM_*` environment variable definitions with lazy evaluation | | `atom/model_engine/arg_utils.py` | `EngineArgs` dataclass and CLI argument parser | | `atom/sampling_params.py` | `SamplingParams` dataclass | diff --git a/tests/test_quant_config.py b/tests/test_quant_config.py index c507fc66..d8aa85b6 100644 --- a/tests/test_quant_config.py +++ b/tests/test_quant_config.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: MIT -# Tests for LayerQuantConfig and QuantizationConfig refactoring (atom/config.py). +# Tests for LayerQuantConfig, QuantizationConfig, and the +# parser registry (atom/config.py + atom/quant_spec.py). # # Covers: per-layer quant config dispatch, quark config parsing, -# layer name matching (exact / regex / fnmatch), and packed-module remapping. +# layer name matching (exact / regex / fnmatch), packed-module remapping, +# typed LayerQuantConfig API, and backward compatibility. # # atom.config depends on torch, aiter, and transformers. We load the source -# file under temporary sys.modules mocks so the tests run in any environment. +# files under temporary sys.modules mocks so the tests run in any environment. import contextlib import enum @@ -20,9 +22,9 @@ ATOM_ROOT = str(Path(__file__).resolve().parent.parent) -# --------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Mock primitives -# --------------------------------------------------------------------------- +# ------------------------------------------------------------------------- class QuantType(enum.IntEnum): @@ -60,9 +62,9 @@ def get_config_dict(model): return {}, {} -# --------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Module loader — patch sys.modules only while exec-ing config.py -# --------------------------------------------------------------------------- +# ------------------------------------------------------------------------- @contextlib.contextmanager @@ -123,43 +125,159 @@ def _temporary_mocks(): sys.modules[name] = orig -def _load_config(): - path = os.path.join(ATOM_ROOT, "atom", "config.py") - spec = importlib.util.spec_from_file_location("_atom_config_test", path) +def _load_module(filename: str, module_name: str): + path = os.path.join(ATOM_ROOT, "atom", filename) + spec = importlib.util.spec_from_file_location(module_name, path) mod = importlib.util.module_from_spec(spec) + # Register before exec so @dataclass etc. can resolve the module + sys.modules[module_name] = mod with _temporary_mocks(): spec.loader.exec_module(mod) return mod -_m = _load_config() -LayerQuantConfig = _m.LayerQuantConfig +# Load quant_spec first, then inject it so config.py can import it. +_qs = _load_module("quant_spec.py", "atom.quant_spec") +sys.modules["atom.quant_spec"] = _qs + +_m = _load_module("config.py", "_atom_config_test") + QuantizationConfig = _m.QuantizationConfig +LayerQuantConfig = _qs.LayerQuantConfig +ParsedQuantConfig = _qs.ParsedQuantConfig +QuarkParser = _qs.QuarkParser +GenericParser = _qs.GenericParser +get_quant_parser = _qs.get_quant_parser -# =========================================================================== -# Tests -# =========================================================================== + +# ========================================================================= +# Tests — LayerQuantConfig +# ========================================================================= class TestLayerQuantConfig: def test_defaults(self): - cfg = LayerQuantConfig() - assert cfg["quant_type"] == QuantType.No - assert cfg["quant_dtype"] == BF16 - assert cfg["is_dynamic"] is True - assert cfg["quant_method"] == "" - - def test_custom_values_and_dict_interface(self): - cfg = LayerQuantConfig( - quant_type=QuantType.per_Token, - quant_dtype=FP8, - is_dynamic=False, - quant_method="quark", + spec = LayerQuantConfig() + assert spec.quant_type == QuantType.No + assert spec.quant_dtype == BF16 + assert spec.is_dynamic is True + assert spec.quant_method is None + assert spec.is_quantized is False + + def test_no_quant_factory(self): + spec = LayerQuantConfig.no_quant(FP8) + assert spec.quant_type == QuantType.No + assert spec.quant_dtype == FP8 + assert spec.is_quantized is False + + def test_is_quantized(self): + spec = LayerQuantConfig(quant_type=QuantType.per_Token, quant_dtype=FP8) + assert spec.is_quantized is True + + def test_frozen(self): + spec = LayerQuantConfig() + with pytest.raises(AttributeError): + spec.quant_type = QuantType.per_Token # type: ignore[misc] + + +# ========================================================================= +# Tests — Parser Registry +# ========================================================================= + + +class TestParserRegistry: + def test_quark_registered(self): + parser = get_quant_parser("quark") + assert isinstance(parser, QuarkParser) + + def test_generic_fallback(self): + parser = get_quant_parser("compressed-tensors") + assert isinstance(parser, GenericParser) + + def test_unknown_falls_to_generic(self): + parser = get_quant_parser("some_unknown_method") + assert isinstance(parser, GenericParser) + + +# ========================================================================= +# Tests — QuarkParser +# ========================================================================= + + +class TestQuarkParser: + def test_per_channel_fp8(self): + parser = QuarkParser() + result = parser.parse( + { + "quant_method": "quark", + "global_quant_config": { + "weight": {"qscheme": "per_channel", "dtype": "fp8_e4m3"}, + "input_tensors": {"is_dynamic": True}, + }, + } + ) + assert result.global_spec.quant_type == QuantType.per_Token + assert result.global_spec.quant_dtype == FP8 + assert result.global_spec.is_dynamic is True + + def test_per_group_fp4(self): + parser = QuarkParser() + result = parser.parse( + { + "quant_method": "quark", + "global_quant_config": { + "weight": {"qscheme": "per_group", "dtype": "fp4_e2m1"}, + "input_tensors": {"is_dynamic": False}, + }, + } + ) + assert result.global_spec.quant_type == QuantType.per_1x32 + assert result.global_spec.quant_dtype == FP4X2 + assert result.global_spec.is_dynamic is False + + def test_no_input_tensors_defaults_dynamic(self): + parser = QuarkParser() + result = parser.parse( + { + "quant_method": "quark", + "global_quant_config": { + "weight": {"qscheme": "per_tensor", "dtype": "int8"}, + "input_tensors": None, + }, + } ) - assert isinstance(cfg, dict) - assert cfg["quant_type"] == QuantType.per_Token - assert cfg["quant_dtype"] == FP8 - assert cfg["is_dynamic"] is False + assert result.global_spec.quant_type == QuantType.per_Tensor + assert result.global_spec.is_dynamic is True + + def test_layer_config_parsed(self): + parser = QuarkParser() + result = parser.parse( + { + "quant_method": "quark", + "global_quant_config": { + "weight": {"qscheme": "per_channel", "dtype": "fp8_e4m3"}, + "input_tensors": {"is_dynamic": True}, + }, + "layer_quant_config": { + "*.mlp.*": { + "weight": {"qscheme": "per_group", "dtype": "fp4_e2m1"}, + "input_tensors": {"is_dynamic": False}, + }, + }, + "exclude": ["lm_head"], + } + ) + assert len(result.layer_pattern_specs) == 1 + pattern, spec = result.layer_pattern_specs[0] + assert pattern == "*.mlp.*" + assert spec.quant_type == QuantType.per_1x32 + assert spec.quant_dtype == FP4X2 + assert result.exclude_layers == ["lm_head"] + + +# ========================================================================= +# Tests — QuantizationConfig init +# ========================================================================= class TestQuantizationConfigInit: @@ -167,15 +285,15 @@ def test_none_config(self): qcfg = QuantizationConfig(config=None) assert qcfg.quant_method == "" assert qcfg.exclude_layers == [] - assert qcfg.layer_quant_config == {} - assert qcfg.global_quant_config["quant_type"] == QuantType.No + assert qcfg.global_quant_config.quant_type == QuantType.No + assert qcfg.global_quant_config.is_quantized is False def test_config_without_quantization(self): hf = FakeHFConfig(torch_dtype=BF16) qcfg = QuantizationConfig(hf) assert qcfg.quant_method == "" - assert qcfg.global_quant_config["quant_type"] == QuantType.No - assert qcfg.global_quant_config["quant_dtype"] == BF16 + assert qcfg.global_quant_config.quant_type == QuantType.No + assert qcfg.global_quant_config.quant_dtype == BF16 def test_quark_config_parses_global_and_layer(self): hf = FakeHFConfig( @@ -197,115 +315,122 @@ def test_quark_config_parses_global_and_layer(self): ) qcfg = QuantizationConfig(hf) assert qcfg.quant_method == "quark" - assert qcfg.global_quant_config["quant_type"] == QuantType.per_Token - assert qcfg.global_quant_config["quant_dtype"] == FP8 - - assert "*.mlp.*" in qcfg.layer_quant_config - mlp = qcfg.layer_quant_config["*.mlp.*"] - assert mlp["quant_type"] == QuantType.per_1x32 - assert mlp["quant_dtype"] == FP4X2 - assert mlp["is_dynamic"] is False + assert qcfg.global_quant_config.quant_type == QuantType.per_Token + assert qcfg.global_quant_config.quant_dtype == FP8 + # layer pattern specs + assert len(qcfg._parsed.layer_pattern_specs) == 1 + mlp_pattern, mlp_spec = qcfg._parsed.layer_pattern_specs[0] + assert mlp_pattern == "*.mlp.*" + assert mlp_spec.quant_type == QuantType.per_1x32 + assert mlp_spec.quant_dtype == FP4X2 + assert mlp_spec.is_dynamic is False assert qcfg.exclude_layers == ["lm_head"] -class TestParseQuarkConfigDict: - @pytest.fixture - def qcfg(self): - return QuantizationConfig(config=None) +# ========================================================================= +# Tests — get_layer_quant_config resolution +# ========================================================================= - def test_per_channel_fp8(self, qcfg): - result = qcfg.parse_quark_config_dict( - { - "weight": {"qscheme": "per_channel", "dtype": "fp8_e4m3"}, - "input_tensors": {"is_dynamic": True}, - } + +class TestGetLayerQuantConfig: + def test_falls_back_to_global(self): + qcfg = QuantizationConfig(config=None) + qcfg._parsed = ParsedQuantConfig( + global_spec=LayerQuantConfig( + quant_type=QuantType.per_Token, quant_dtype=FP8 + ), ) - assert result["quant_type"] == QuantType.per_Token - assert result["quant_dtype"] == FP8 - assert result["is_dynamic"] is True - assert result["quant_method"] == "quark" + result = qcfg.get_layer_quant_config("model.layers.0.self_attn.q_proj") + assert result.quant_type == QuantType.per_Token + assert result.quant_dtype == FP8 - def test_per_group_fp4(self, qcfg): - result = qcfg.parse_quark_config_dict( - { - "weight": {"qscheme": "per_group", "dtype": "fp4_e2m1"}, - "input_tensors": {"is_dynamic": False}, - } + def test_layer_specific_overrides_global(self): + qcfg = QuantizationConfig(config=None) + qcfg._parsed = ParsedQuantConfig( + global_spec=LayerQuantConfig(quant_dtype=FP8), + layer_pattern_specs=[ + ( + "*.mlp.*", + LayerQuantConfig(quant_type=QuantType.per_1x32, quant_dtype=FP4X2), + ), + ], ) - assert result["quant_type"] == QuantType.per_1x32 - assert result["quant_dtype"] == FP4X2 - assert result["is_dynamic"] is False + result = qcfg.get_layer_quant_config("model.layers.0.mlp.gate_proj") + assert result.quant_dtype == FP4X2 + assert result.quant_type == QuantType.per_1x32 - def test_no_input_tensors_defaults_dynamic(self, qcfg): - """When input_tensors is absent/None, is_dynamic keeps its True default.""" - result = qcfg.parse_quark_config_dict( - { - "weight": {"qscheme": "per_tensor", "dtype": "int8"}, - "input_tensors": None, - } + def test_excluded_layer_returns_unquantized(self): + qcfg = QuantizationConfig(config=None) + qcfg.torch_dtype = BF16 + qcfg._parsed = ParsedQuantConfig( + global_spec=LayerQuantConfig( + quant_type=QuantType.per_Token, quant_dtype=FP8 + ), ) - assert result["quant_type"] == QuantType.per_Tensor - assert result["is_dynamic"] is True + qcfg.exclude_layers = ["lm_head"] + + result = qcfg.get_layer_quant_config("lm_head") + assert result.quant_type == QuantType.No + assert result.quant_dtype == BF16 + +# ========================================================================= +# Tests — Exclude layer matching +# ========================================================================= -class TestShouldIgnoreLayerQuant: + +class TestExcludeMatching: def _make(self, exclude_layers): qcfg = QuantizationConfig(config=None) qcfg.exclude_layers = exclude_layers return qcfg def test_empty_exclude(self): - assert self._make([]).should_ignore_layer_quant("any_layer") is False + qcfg = self._make([]) + assert not qcfg._is_excluded("any_layer") def test_none_layer_name(self): - assert self._make(["lm_head"]).should_ignore_layer_quant(None) is False + qcfg = self._make(["lm_head"]) + assert not qcfg._is_excluded(None) def test_exact_match(self): - assert self._make(["lm_head"]).should_ignore_layer_quant("lm_head") is True + qcfg = self._make(["lm_head"]) + assert qcfg._is_excluded("lm_head") def test_regex_match(self): qcfg = self._make(["re:model\\.layers\\..*shared_expert.*"]) - assert ( - qcfg.should_ignore_layer_quant("model.layers.3.shared_expert.gate_proj") - is True - ) + assert qcfg._is_excluded("model.layers.3.shared_expert.gate_proj") def test_no_match(self): - assert ( - self._make(["lm_head"]).should_ignore_layer_quant("self_attn.q_proj") - is False - ) - + qcfg = self._make(["lm_head"]) + assert not qcfg._is_excluded("self_attn.q_proj") -class TestIsEqualOrRegexMatch: - @pytest.fixture - def qcfg(self): - return QuantizationConfig(config=None) - def test_exact(self, qcfg): - assert qcfg.is_equal_or_regex_match("lm_head", "lm_head") is True - assert qcfg.is_equal_or_regex_match("lm_head", "other") is False +class TestMatchesExclude: + def test_exact(self): + assert QuantizationConfig._matches_exclude("lm_head", "lm_head") is True + assert QuantizationConfig._matches_exclude("lm_head", "other") is False - def test_regex(self, qcfg): + def test_regex(self): assert ( - qcfg.is_equal_or_regex_match( + QuantizationConfig._matches_exclude( "model.layers.5.self_attn.q_proj", "re:model\\.layers\\..*self_attn.*", ) is True ) assert ( - qcfg.is_equal_or_regex_match( + QuantizationConfig._matches_exclude( "model.layers.5.mlp.gate_proj", "re:model\\.layers\\..*self_attn.*", ) is False ) - def test_contains_mode(self, qcfg): + def test_contains_mode(self): assert ( - qcfg.is_equal_or_regex_match( + QuantizationConfig._matches_exclude( "self_attn", "model.layers.0.self_attn.q_a_proj", check_contains=True, @@ -313,81 +438,62 @@ def test_contains_mode(self, qcfg): is True ) assert ( - qcfg.is_equal_or_regex_match("mlp", "self_attn.q_proj", check_contains=True) + QuantizationConfig._matches_exclude( + "mlp", "self_attn.q_proj", check_contains=True + ) is False ) -class TestGetLayerQuantConfig: - def test_falls_back_to_global(self): - qcfg = QuantizationConfig(config=None) - qcfg.global_quant_config = LayerQuantConfig( - quant_type=QuantType.per_Token, quant_dtype=FP8 - ) - result = qcfg.get_layer_quant_config("model.layers.0.self_attn.q_proj") - assert result["quant_type"] == QuantType.per_Token - assert result["quant_dtype"] == FP8 - - def test_layer_specific_overrides_global(self): - qcfg = QuantizationConfig(config=None) - qcfg.global_quant_config = LayerQuantConfig(quant_dtype=FP8) - mlp_cfg = LayerQuantConfig(quant_type=QuantType.per_1x32, quant_dtype=FP4X2) - qcfg.layer_quant_config = {"*.mlp.*": mlp_cfg} - - result = qcfg.get_layer_quant_config("model.layers.0.mlp.gate_proj") - assert result["quant_dtype"] == FP4X2 - assert result["quant_type"] == QuantType.per_1x32 - - def test_excluded_layer_returns_unquantized(self): - qcfg = QuantizationConfig(config=None) - qcfg.torch_dtype = BF16 - qcfg.global_quant_config = LayerQuantConfig( - quant_type=QuantType.per_Token, quant_dtype=FP8 - ) - qcfg.exclude_layers = ["lm_head"] - - result = qcfg.get_layer_quant_config("lm_head") - assert result["quant_type"] == QuantType.No - assert result["quant_dtype"] == BF16 +# ========================================================================= +# Tests — remap_layer_name +# ========================================================================= class TestRemapLayerName: + @staticmethod + def _pattern_dict(qcfg): + """Helper: return pattern->spec dict from _parsed.layer_pattern_specs.""" + return dict(qcfg._parsed.layer_pattern_specs) + def test_deepseek_v3_with_q_lora_rank(self): """Individual proj names -> fused names for deepseek_v3.""" qcfg = QuantizationConfig(config=None) - qcfg.layer_quant_config = { - "*.q_a_proj": LayerQuantConfig(quant_type=QuantType.per_Token), - "*.gate_proj": LayerQuantConfig(quant_type=QuantType.per_1x32), - } + qcfg._parsed.layer_pattern_specs = [ + ("*.q_a_proj", LayerQuantConfig(quant_type=QuantType.per_Token)), + ("*.gate_proj", LayerQuantConfig(quant_type=QuantType.per_1x32)), + ] qcfg.exclude_layers = ["model.layers.0.q_a_proj"] hf = FakeHFConfig(model_type="deepseek_v3", q_lora_rank=512) qcfg.remap_layer_name(hf) - assert "*.fused_qkv_a_proj" in qcfg.layer_quant_config - assert "*.gate_up_proj" in qcfg.layer_quant_config - assert "*.q_a_proj" not in qcfg.layer_quant_config + pats = self._pattern_dict(qcfg) + assert "*.fused_qkv_a_proj" in pats + assert "*.gate_up_proj" in pats + assert "*.q_a_proj" not in pats assert "model.layers.0.fused_qkv_a_proj" in qcfg.exclude_layers def test_qwen3_moe_splits_fused(self): """Fused gate_up_proj -> [gate_proj, up_proj] for qwen3_moe.""" qcfg = QuantizationConfig(config=None) - qcfg.layer_quant_config = { - "*.gate_up_proj": LayerQuantConfig(quant_type=QuantType.per_Token), - } + qcfg._parsed.layer_pattern_specs = [ + ("*.gate_up_proj", LayerQuantConfig(quant_type=QuantType.per_Token)), + ] qcfg.exclude_layers = [] hf = FakeHFConfig(model_type="qwen3_moe", mlp_only_layers=[1]) qcfg.remap_layer_name(hf, packed_modules_mapping={}) - assert "*.gate_proj" in qcfg.layer_quant_config - assert "*.up_proj" in qcfg.layer_quant_config - assert "*.gate_up_proj" not in qcfg.layer_quant_config + pats = self._pattern_dict(qcfg) + assert "*.gate_proj" in pats + assert "*.up_proj" in pats + assert "*.gate_up_proj" not in pats def test_exclude_layers_deduplication(self): """gate_proj and up_proj both map to gate_up_proj -- only one remains.""" qcfg = QuantizationConfig(config=None) - qcfg.layer_quant_config = {} + qcfg._parsed.layer_pattern_specs = [] qcfg.exclude_layers = [ "model.layers.0.gate_proj", "model.layers.0.up_proj", @@ -410,8 +516,10 @@ def test_hash_is_deterministic(self): def test_different_configs_produce_different_hashes(self): qcfg1 = QuantizationConfig(config=None) qcfg2 = QuantizationConfig(config=None) - qcfg2.global_quant_config = LayerQuantConfig( - quant_type=QuantType.per_Token, quant_dtype=FP8 + qcfg2._parsed = ParsedQuantConfig( + global_spec=LayerQuantConfig( + quant_type=QuantType.per_Token, quant_dtype=FP8 + ), ) assert qcfg1.compute_hash() != qcfg2.compute_hash() @@ -421,10 +529,33 @@ def test_exclude_layers_affect_hash(self): qcfg2.exclude_layers = ["lm_head"] assert qcfg1.compute_hash() != qcfg2.compute_hash() - def test_layer_quant_config_affects_hash(self): + def test_layer_pattern_specs_affect_hash(self): qcfg1 = QuantizationConfig(config=None) qcfg2 = QuantizationConfig(config=None) - qcfg2.layer_quant_config = { - "*.mlp.*": LayerQuantConfig(quant_type=QuantType.per_1x32) - } + qcfg2._parsed.layer_pattern_specs = [ + ("*.mlp.*", LayerQuantConfig(quant_type=QuantType.per_1x32)), + ] assert qcfg1.compute_hash() != qcfg2.compute_hash() + + +# ========================================================================= +# Tests — Convenience properties +# ========================================================================= + + +class TestConvenienceProperties: + def test_quant_type_property(self): + hf = FakeHFConfig( + torch_dtype=BF16, + quantization_config={ + "quant_method": "quark", + "global_quant_config": { + "weight": {"qscheme": "per_channel", "dtype": "fp8_e4m3"}, + "input_tensors": {"is_dynamic": True}, + }, + }, + ) + qcfg = QuantizationConfig(hf) + assert qcfg.quant_type == QuantType.per_Token + assert qcfg.quant_dtype == FP8 + assert qcfg.is_dynamic is True