Skip to content
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#############################################################################
# Distributed W8A8 quantization example with activation observer sync.
# run this with `torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py`
# or change nproc_per_node to your desired configuration
#############################################################################

import torch
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.datasets.utils import get_rank_partition
from llmcompressor.modifiers.quantization import QuantizationModifier

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048

###### DDP MODEL LOAD CHANGE #####
init_dist()
with load_offloaded_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype="auto", device_map="auto_offload"
)
##################################

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

###### DDP DATA LOAD CHANGE #####
ds = load_dataset(
DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
)
##################################

ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# QuantizationModifier automatically detects torch.distributed and
# all-reduces activation observer statistics at layer boundaries
recipe = [
QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

print("Saving...")
SAVE_DIR = (
MODEL_ID.rstrip("/").split("/")[-1]
+ "-W8A8-DDP"
+ str(torch.distributed.get_world_size())
)
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

torch.distributed.destroy_process_group()
11 changes: 10 additions & 1 deletion src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class QuantizationModifier(Modifier, QuantizationMixin):
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.

In DDP mode, activation observer statistics are all-reduced across ranks at
sequential layer boundaries so all ranks share identical quantization parameters.

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
Expand Down Expand Up @@ -65,14 +68,16 @@ def on_initialize(self, state: State, **kwargs) -> bool:

def on_start(self, state: State, event: Event, **kwargs):
"""
Begin calibrating activations and weights. Calibrate weights only once on start
Begin calibrating activations and weights. Calibrate weights only once
on start. Each rank calibrates weights independently.
"""
self.started_ = True
QuantizationMixin.start_calibration(self, state.model)

named_modules = list(
match_named_modules(state.model, self.resolved_targets, self.ignore)
)

# TODO: this step can be combined with update_weight_zp_scale
# once update_fused_layer_weight_global_scales is removed
# and not required by vLLM
Expand All @@ -95,7 +100,11 @@ def on_event(self, state: State, event: Event, **kwargs):
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need to add these to GPTQ and AWQ, right?


if event.type_ == EventType.CALIBRATION_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)
if not self.ended_:
self.on_end(state, None)

Expand Down
51 changes: 49 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
IMPL_ATTR,
KV_CACHE_ATTR,
)
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.quantization import (
DynamicType,
QuantizationArgs,
Expand All @@ -18,7 +19,7 @@
is_preset_scheme,
preset_name_to_scheme,
)
from compressed_tensors.utils import match_named_modules
from compressed_tensors.utils import match_named_modules, update_offload_parameter
from pydantic import Field, PrivateAttr, field_validator
from torch.utils.hooks import RemovableHandle

Expand All @@ -37,7 +38,11 @@
validate_group_size_divisibility,
)
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
from llmcompressor.utils import (
targets_embeddings,
untie_word_embeddings,
wait_for_comms,
)

__all__ = ["QuantizationMixin"]

Expand Down Expand Up @@ -257,6 +262,48 @@ def end_calibration(self, model: torch.nn.Module):

model.apply(enable_quantization) # keep quantization enabled

def sync_activation_observers(self, model: torch.nn.Module):
"""
All-reduce activation observer min/max values across DDP ranks,
then recompute scale/zp from the global statistics. No-op when
not distributed.

:param model: model containing quantized modules
"""
if not is_distributed():
return

pending_comms = []
modules_to_update = []

for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
for base_name in ("input", "output", "q", "k", "v"):
observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
continue
pending_comms.extend(observer.synchronize())
modules_to_update.append((module, base_name, observer))

wait_for_comms(pending_comms)

# recompute qparams from synchronized statistics
for module, base_name, observer in modules_to_update:
# recompute global scale if using TENSOR_GROUP strategy
global_scale = observer.recompute_global_scale()
if global_scale is not None:
update_offload_parameter(
module, f"{base_name}_global_scale", global_scale
)

result = observer.recompute_qparams()
if result is not None:
scale, zero_point = result
update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(
module, f"{base_name}_zero_point", zero_point
)

def has_config(self) -> bool:
"""
Determine if the user has specified a quantization config on this modifier
Expand Down
60 changes: 59 additions & 1 deletion src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from weakref import ref

import torch
Expand All @@ -8,6 +8,7 @@
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.utils import align_module_device
from torch import distributed as dist

from llmcompressor.observers.helpers import flatten_for_calibration

Expand Down Expand Up @@ -133,6 +134,63 @@ def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]:
with align_module_device(module):
return getattr(module, f"{self.base_name}_{name}", None)

def synchronize(self) -> List[dist.Work]:
"""All-reduce accumulated min/max statistics across DDP ranks.

Issues async all-reduce operations on any accumulated state
(``past_min_vals``, ``past_max_vals``, ``past_global_min_vals``,
``past_global_max_vals``). Memoryless observers return an empty list.

:return: list of async communication handles
"""
comms = []
for attr, op in [
("past_min_vals", dist.ReduceOp.MIN),
("past_max_vals", dist.ReduceOp.MAX),
("past_global_min_vals", dist.ReduceOp.MIN),
("past_global_max_vals", dist.ReduceOp.MAX),
]:
Comment on lines +147 to +152
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that I think this approach is more elegant than reimplementing for each subclass

Copy link
Collaborator

@HDCharles HDCharles Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont we need to do this for every subclass? this strategy only makes sense for an single observer: static_minmax

memoryless_minmax and memoryless_mse it also works i guess, because they don't have any of those values

minmax (moving average) and mse (moving average) - it makes no sense, you probably need to average across ranks, though it wouldn't be hard in theory to do:

$$f(x_0,...,x_{n-1}) = w*\sum_{i=0}^{n-1} x_i(1-w)^{n-1-i}$$ (rank 0 avg)
$$f(x_n,...,x_{2n-1}=w*\sum_{i=n}^{2n-1} x_i(1-w)^{2n-1-i}$$ (rank 1 avg)
$$f(x_0,...,x_{2n-1}) = w*\sum_{i=0}^{2n-1} x_i(1-w)^{2n-1-i}$$ (alldata avg)
$$= (1-w)^{n} w \sum_{i=0}^{n-1} x_i(1-w)^{n-1-i} + w\sum_{i=n}^{2n-1} x_i(1-w)^{2n-1-i} $$
$$= f(x_0,...,x_{n-1})(1-w)^n + f(x_n,...,x_{2n-1})$$ accumulate in terms of rank averages

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's just average for now

val = getattr(self, attr, None)
if val is not None:
comms.append(dist.all_reduce(val, op=op, async_op=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we need the fp8 trick here from GPTQ base.py

return comms

def recompute_global_scale(self) -> Optional[torch.Tensor]:
"""Recompute global scale from accumulated global min/max state.

Used after :meth:`synchronize` to update the global scale from
globally reduced statistics. Returns ``None`` for memoryless observers.

:return: global scale tensor or ``None``
"""
global_min = getattr(self, "past_global_min_vals", None)
global_max = getattr(self, "past_global_max_vals", None)
if global_min is None or global_max is None:
return None
return generate_gparam(global_min, global_max)

def recompute_qparams(self) -> Optional[ScaleZpTuple]:
"""Recompute scale and zero_point from accumulated min/max state.

Used after :meth:`synchronize` to update quantization parameters from
globally reduced statistics. Returns ``None`` for memoryless observers.

:return: (scale, zero_point) tuple or ``None``
"""
min_vals = getattr(self, "past_min_vals", None)
max_vals = getattr(self, "past_max_vals", None)
if min_vals is None or max_vals is None:
return None

global_scale = self._get_module_param("global_scale")
self._check_has_global_scale(global_scale)
return calculate_qparams(
min_vals=min_vals,
max_vals=max_vals,
quantization_args=self.args,
global_scale=global_scale,
)

def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
if (
self.args.strategy == QuantizationStrategy.TENSOR_GROUP
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Run with: torchrun --nproc_per_node=2 -m pytest <this_file> -v
"""

import os

import pytest
import torch
import torch.distributed as dist
from compressed_tensors.quantization import QuantizationArgs

from llmcompressor.observers.min_max import StaticMinMaxObserver
from llmcompressor.utils.dist import wait_for_comms
from tests.testing_utils import requires_gpu

# initialize process group when running under torchrun
if (
os.environ.get("RANK") is not None
and torch.cuda.is_available()
and not dist.is_initialized()
):
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))


def _skip_if_not_distributed():
if not (dist.is_available() and dist.is_initialized()):
pytest.skip("Requires torchrun --nproc_per_node=2")


@pytest.mark.multi_gpu
@requires_gpu(2)
def test_observer_synchronize_reduces_min_max():
_skip_if_not_distributed()
rank = dist.get_rank()

args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor")
observer = StaticMinMaxObserver(base_name="input", args=args)

# each rank has different local statistics
observer.past_min_vals = (
torch.tensor([1.0, 3.0], device="cuda")
if rank == 0
else torch.tensor([2.0, 1.0], device="cuda")
)
observer.past_max_vals = (
torch.tensor([10.0, 20.0], device="cuda")
if rank == 0
else torch.tensor([15.0, 10.0], device="cuda")
)

comms = observer.synchronize()
wait_for_comms(comms)

# after sync, min should be element-wise minimum, max element-wise maximum
assert torch.equal(
observer.past_min_vals, torch.tensor([1.0, 1.0], device="cuda")
)
assert torch.equal(
observer.past_max_vals, torch.tensor([15.0, 20.0], device="cuda")
)


@pytest.mark.multi_gpu
@requires_gpu(2)
def test_synced_qparams_are_identical_across_ranks():
_skip_if_not_distributed()
rank = dist.get_rank()

args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor")
observer = StaticMinMaxObserver(base_name="input", args=args)

observer.past_min_vals = (
torch.tensor([-2.0], device="cuda")
if rank == 0
else torch.tensor([-5.0], device="cuda")
)
observer.past_max_vals = (
torch.tensor([3.0], device="cuda")
if rank == 0
else torch.tensor([1.0], device="cuda")
)

comms = observer.synchronize()
wait_for_comms(comms)

result = observer.recompute_qparams()
assert result is not None
scale, _ = result

gathered = [torch.zeros_like(scale) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, scale)
assert torch.equal(gathered[0], gathered[1])
Loading
Loading