Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 88 additions & 210 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import os
import re
from dataclasses import dataclass, field
from typing import Any, cast, Optional, Union
from typing import Any, Optional, Union

import torch
from aiter import QuantType
from aiter.utility.dtypes import d_dtypes
from atom.quant_spec import (
LayerQuantConfig,
ParsedQuantConfig,
get_quant_parser,
)
from atom.utils import envs, get_open_port
from atom.utils.distributed.utils import stateless_init_torch_distributed_process_group
from torch.distributed import ProcessGroup, ReduceOp
Expand Down Expand Up @@ -251,70 +255,94 @@ def set_splitting_ops_for_v1(self):
]


class LayerQuantConfig(dict):
def __init__(
self,
quant_type=QuantType.No,
quant_dtype=torch.bfloat16,
is_dynamic=True,
quant_method="",
):
"""
Core components of layer_quant
"""
super().__init__()
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16
self["is_dynamic"] = is_dynamic
self["quant_method"] = quant_method
class QuantizationConfig:
"""Model-wide quantization configuration.

API:
- ``get_layer_quant_config(prefix)`` -> :class:`LayerQuantConfig`
- ``global_quant_config`` property -> :class:`LayerQuantConfig`
- ``quant_type``, ``quant_dtype``, ``is_dynamic`` convenience properties
"""

class QuantizationConfig:
def __init__(self, config: PretrainedConfig = None):
if config is None:
self.torch_dtype = torch.bfloat16
self.hf_quant_config = None
self.global_quant_config = LayerQuantConfig()
self.layer_quant_config = {}
self.exclude_layers = []
self._parsed = ParsedQuantConfig()
self.exclude_layers: list[str] = []
self.quant_method = ""
return

self.torch_dtype = getattr(config, "torch_dtype", torch.bfloat16)
self.hf_quant_config = getattr(config, "quantization_config", None)
self.global_quant_config = None
self.layer_quant_config = {}
self.exclude_layers = []

if self.hf_quant_config is None:
self.global_quant_config = LayerQuantConfig(
quant_type=QuantType.No, quant_dtype=self.torch_dtype
self._parsed = ParsedQuantConfig(
global_spec=LayerQuantConfig(
quant_type=QuantType.No, quant_dtype=self.torch_dtype
)
)
self.quant_method = ""
return

self.quant_method = self.hf_quant_config.get("quant_method", "")
if self.quant_method == "quark":
layer_quant_config_dict = cast(
dict[str, Any], self.hf_quant_config.get("layer_quant_config", {})
)
for layer_name, layer_cfg in layer_quant_config_dict.items():
self.layer_quant_config[layer_name] = self.parse_quark_config_dict(
layer_cfg
)

global_quant_config_dict = cast(
dict[str, Any], self.hf_quant_config.get("global_quant_config", {})
)
self.global_quant_config = self.parse_quark_config_dict(
global_quant_config_dict
)
# Use the parser registry to build a structured ParsedQuantConfig
parser = get_quant_parser(self.quant_method)
self._parsed = parser.parse(self.hf_quant_config)
self.exclude_layers = list(self._parsed.exclude_layers)

self.exclude_layers = cast(
list[str], self.hf_quant_config.get("exclude", [])
)
else:
self.parse_other_config()
# -- typed API (preferred) ----------------------------------------------

@property
def global_quant_config(self) -> LayerQuantConfig:
"""The default quantization spec for all layers."""
return self._parsed.global_spec

def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig:
"""Return the :class:`LayerQuantConfig` for *layer_name*.

Resolution order:
1. Check exclude list -> ``LayerQuantConfig.no_quant()``.
2. Exact match in ``parsed.layer_specs``.
3. fnmatch-style pattern match in ``parsed.layer_pattern_specs``.
4. Fall back to ``global_quant_config``.
"""
# 1. Exclude list
if self._is_excluded(layer_name):
return LayerQuantConfig(quant_dtype=self.torch_dtype)

# 2. Exact match
if layer_name in self._parsed.layer_specs:
return self._parsed.layer_specs[layer_name]

# 3. Pattern match
for pattern, spec in self._parsed.layer_pattern_specs:
if "*" not in pattern:
if layer_name in pattern:
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In get_layer_quant_config, the non-glob pattern branch uses if layer_name in pattern:. This is reversed for the intended semantics (exact match or pattern being a substring/prefix of the layer name) and can both fail to match valid overrides and accidentally match short layer_name values (e.g., "l" matches "lm_head"). Replace this with an explicit equality check (layer_name == pattern) or a clear prefix/substring rule (pattern in layer_name) consistent with the documented resolution order.

Suggested change
if layer_name in pattern:
if pattern in layer_name:

Copilot uses AI. Check for mistakes.
return spec
elif fnmatch.fnmatch(layer_name, pattern):
return spec

# 4. Global default
return self._parsed.global_spec

# -- convenience properties (delegate to global_quant_config) -------------

@property
def quant_type(self) -> QuantType:
return self._parsed.global_spec.quant_type

@property
def quant_dtype(self) -> torch.dtype:
return self._parsed.global_spec.quant_dtype

@property
def is_dynamic(self) -> bool:
return self._parsed.global_spec.is_dynamic

# -- other methods ------------------------------------------------------

def compute_hash(self) -> str:
"""
Expand All @@ -329,191 +357,41 @@ def compute_hash(self) -> str:
the final hidden states.
"""
factors: list[Any] = []
factors.append(self.global_quant_config)
factors.append(self.layer_quant_config)
factors.append(self._parsed.global_spec)
factors.append(self._parsed.layer_pattern_specs)
factors.append(self.exclude_layers)
hash_value = hashlib.sha256(str(factors).encode()).hexdigest()
return hash_value

def get_name(self):
"""
Returns the quantization method name.
"""
"""Returns the quantization method name."""
return self.quant_method

def parse_quark_config_dict(self, config: dict) -> LayerQuantConfig:
quant_type = None
quant_dtype = None
is_dynamic = True
weight_config = cast(dict[str, Any], config.get("weight", {}))
input_config = cast(dict[str, Any], config.get("input_tensors", {}))
weight_qscheme = cast(str, weight_config.get("qscheme", ""))
weight_dtype = weight_config.get("dtype", "")

# quant_type
if weight_qscheme == "per_channel":
quant_type = QuantType.per_Token
elif weight_qscheme == "per_tensor":
quant_type = QuantType.per_Tensor
elif weight_qscheme == "per_group":
# Currently, quark only supports group_size=32
quant_type = QuantType.per_1x32
else:
quant_type = QuantType.No

# quant_dtype
dtype = weight_dtype.split("_")[0]
if dtype.endswith("4"):
dtype += "x2"
quant_dtype = d_dtypes[dtype]

# is_dynamic
if input_config:
# input_dtype = input_config.get("dtype")
# input_qscheme = cast(str, input_config.get("qscheme"))
is_dynamic = cast(bool, input_config.get("is_dynamic", True))
return LayerQuantConfig(
quant_type=quant_type,
quant_dtype=quant_dtype,
is_dynamic=is_dynamic,
quant_method="quark",
)
# -- internal helpers ---------------------------------------------------

# TODO: For now, it's just a temporary migration.
# We should subsequently refine them in a targeted manner.
def parse_other_config(self):
RE_QUANT_BLOCKSIZE = (
r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
)
orig_quant_config = self.hf_quant_config
quant_method = self.quant_method
orig_quant_config_str = str(orig_quant_config)
if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str:
quant_type = QuantType.per_Token
elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str):
group_size = int(group_size.group(1))
assert group_size in (32, 128), f"Unsupported group size {group_size}"
if group_size == 128:
quant_type = QuantType.per_1x128
elif group_size == 32:
quant_type = QuantType.per_1x32
else:
quant_type = QuantType.per_Tensor

RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'"
quant_dtype = None
m = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
if m and m.group(1).lower() in [
"fp8",
"fp4",
"int8",
"int4",
"fp8_e4m3",
"mxfp4",
]:
dtype = m.group(1).lower().split("_")[0]
if dtype == "mxfp4":
dtype = "fp4"
if dtype.endswith("4"):
dtype += "x2"
quant_dtype = d_dtypes[dtype]
else:
bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str)
if bit_match:
bit = int(bit_match.group(1))
dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
if dtype_match:
dtype = dtype_match.group(1).lower()
dtype_prefix = "i" if dtype.startswith("int") else "fp"
else:
dtype_prefix = "i"
quant_dtype_str = (
f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2"
)
quant_dtype = d_dtypes.get(quant_dtype_str, None)
assert (
quant_dtype is not None
), f"Cannot parse quant dtype from {orig_quant_config_str}"
if quant_dtype == d_dtypes["fp4x2"]:
quant_type = QuantType.per_1x32

RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'"
if re.search(RE_STATIC_QUANT, orig_quant_config_str):
is_dynamic = False
else:
is_dynamic = True
if quant_method == "compressed-tensors":
exclude_layers_key = "ignore"
else:
logger.warning(
f"Using 'ignore' as key for exclude layers with quant_method "
f"{quant_method}, please double check the quantization config."
)
exclude_layers_key = "ignore"
exclude_layers = orig_quant_config.get(exclude_layers_key, [])

self.global_quant_config = LayerQuantConfig(
quant_type=quant_type,
quant_dtype=quant_dtype,
is_dynamic=is_dynamic,
quant_method=quant_method,
)
self.exclude_layers = exclude_layers

def should_ignore_layer_quant(self, layer_name: str) -> bool:
# TODO: solve fused_mapping case
def _is_excluded(self, layer_name: str) -> bool:
if layer_name is None or not self.exclude_layers:
return False
return any(
self.is_equal_or_regex_match(layer_name, ignore_str)
self._matches_exclude(layer_name, ignore_str)
for ignore_str in self.exclude_layers
)

def is_equal_or_regex_match(
self, layer_name: str, ignore_str: str, check_contains: bool = False
@staticmethod
def _matches_exclude(
layer_name: str, ignore_str: str, check_contains: bool = False
) -> bool:
"""Match the target string or regular expression"""
"""Match the target string or regular expression."""
if ignore_str.startswith("re:"):
# case "re:model.layers.*self_attn.*", remove the 're:' prefix
pattern = ignore_str[3:]
if re.search(pattern, layer_name):
return True
# case exclude_layer like "model.layers.0.self_attn.q_a_proj" (dpsk-attn)
# a common prefix for linear layers in attn like "model.layers.0.self_attn"
elif check_contains:
return layer_name.lower() in ignore_str.lower()
elif ignore_str == layer_name:
return True
return False

def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig:
if self.should_ignore_layer_quant(layer_name=layer_name):
# return unquantized config
return LayerQuantConfig(quant_dtype=self.torch_dtype)
# layer quant config
layer_quant_config = None
if self.layer_quant_config:

def _matches_pattern(layer_name, pattern):
if "*" not in pattern:
return layer_name in pattern
return fnmatch.fnmatch(layer_name, pattern)

for name_pattern, config in self.layer_quant_config.items():
if _matches_pattern(layer_name, name_pattern):
layer_quant_config = config

layer_quant_config = (
self.global_quant_config
if layer_quant_config is None
else layer_quant_config
)
# TODO: if use_aiter, we can customize the quantization format here, such as dpsk
# For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs,
# For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs

return layer_quant_config

def remap_layer_name(
self, hf_config: PretrainedConfig, packed_modules_mapping: dict | None = None
):
Expand Down Expand Up @@ -556,11 +434,11 @@ def _remap_layer_name(name: str) -> list[str]:
return [name.replace(packed_key, packed_remap_part, 1)]
return [name]

new_layer_quant_config = {}
for layer_name, layer_qconfig in self.layer_quant_config.items():
for remapped in _remap_layer_name(layer_name):
new_layer_quant_config[remapped] = layer_qconfig
self.layer_quant_config = new_layer_quant_config
new_pattern_specs = []
for pattern, spec in self._parsed.layer_pattern_specs:
for remapped in _remap_layer_name(pattern):
new_pattern_specs.append((remapped, spec))
self._parsed.layer_pattern_specs = new_pattern_specs

new_exclude = []
for name in self.exclude_layers:
Expand Down
Loading
Loading