Skip to content
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.distributed as dist
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048

# Initialize distributed.
# Usage: torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting the CUDA device using the global rank will fail in multi-node setups where the rank exceeds the number of GPUs per node. It is standard practice to use the local rank for device assignment.

Suggested change
torch.cuda.set_device(rank)
torch.cuda.set_device(rank % torch.cuda.device_count())


if rank == 0:
print(f"Running distributed quantization with {world_size} GPUs")

# Load model to CPU for sequential onloading.
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map=None,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load and partition dataset across ranks.
# Each rank loads a disjoint slice of the calibration data.
samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size
start = samples_per_rank * rank
end = start + samples_per_rank

ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[{start}:{end}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# QuantizationModifier automatically detects torch.distributed and:
# * partitions weight calibration across ranks
# * all-reduces activation observer statistics at layer boundaries
recipe = [
QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=samples_per_rank,
)

# Save to disk compressed (rank 0 only).
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-distributed"
if rank == 0:
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
print(f"Model saved to {SAVE_DIR}")

dist.destroy_process_group()
36 changes: 36 additions & 0 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"calibrate_query_hook",
"calibrate_key_hook",
"calibrate_value_hook",
"recompute_qparams_from_observer",
]


Expand Down Expand Up @@ -235,6 +236,41 @@ def calibrate_value_hook(module: Module, value_states: torch.Tensor):
calibrate_activations(module, value_states, base_name="v")


def recompute_qparams_from_observer(module: Module, base_name: str):
"""
Recompute scale and zero_point from an observer's accumulated
past_min_vals/past_max_vals. Used after DDP all-reduce to update
qparams from synchronized statistics.

:param module: module with quantization parameters
:param base_name: "input", "output", "q", "k", or "v"
"""
from compressed_tensors.quantization.utils import calculate_qparams

observer: Observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
return

min_vals = getattr(observer, "past_min_vals", None)
max_vals = getattr(observer, "past_max_vals", None)

if min_vals is None or max_vals is None:
return

global_scale = getattr(module, f"{base_name}_global_scale", None)

scale, zero_point = calculate_qparams(
min_vals=min_vals,
max_vals=max_vals,
quantization_args=observer.args,
global_scale=global_scale,
)

update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)


def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
Expand Down
112 changes: 110 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.quantization.calibration import (
recompute_qparams_from_observer,
update_weight_global_scale,
update_weight_zp_scale,
)
from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
from llmcompressor.utils.distributed import (
all_reduce_max,
all_reduce_min,
broadcast_module_parameter,
build_module_to_rank_map,
get_rank,
is_distributed,
partition_modules_by_weight_size,
)

__all__ = ["QuantizationModifier"]

Expand All @@ -20,6 +30,9 @@ class QuantizationModifier(Modifier, QuantizationMixin):
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.

In DDP mode, weight calibration is partitioned across ranks and activation
observer statistics are all-reduced at sequential layer boundaries.

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
Expand Down Expand Up @@ -65,14 +78,23 @@ def on_initialize(self, state: State, **kwargs) -> bool:

def on_start(self, state: State, event: Event, **kwargs):
"""
Begin calibrating activations and weights. Calibrate weights only once on start
Begin calibrating activations and weights. Calibrate weights only once
on start. In DDP mode, weight calibration is partitioned across ranks.
"""
self.started_ = True
QuantizationMixin.start_calibration(self, state.model)

named_modules = list(
match_named_modules(state.model, self.resolved_targets, self.ignore)
)

if is_distributed():
self._calibrate_weights_distributed(state.model, named_modules)
else:
self._calibrate_weights_single(state.model, named_modules)

def _calibrate_weights_single(self, model, named_modules):
"""Original single-process weight calibration."""
# TODO: this step can be combined with update_weight_zp_scale
# once update_fused_layer_weight_global_scales is removed
# and not required by vLLM
Expand All @@ -84,21 +106,107 @@ def on_start(self, state: State, event: Event, **kwargs):
# on targeted modules, we need to run on all modules.
# Because this call is idempotent, setting all global_scales to the
# min value, it is ok to run potentially multiple times for all modules
for module in state.model.modules():
for module in model.modules():
update_fused_layer_weight_global_scales(module)

for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
update_weight_zp_scale(module)

def _calibrate_weights_distributed(self, model, named_modules):
"""
DDP-partitioned weight calibration. Each rank calibrates a subset of
modules and broadcasts results to all ranks.
"""
module_to_rank = build_module_to_rank_map(named_modules)
my_modules = partition_modules_by_weight_size(named_modules)
rank = get_rank()

# compute global_scale for assigned modules only
for _, module in tqdm.tqdm(
my_modules, desc=f"[Rank {rank}] Updating global scales"
):
update_weight_global_scale(module)

# broadcast global_scales so all ranks can run the fuse step
for _, module in named_modules:
src_rank = module_to_rank[module]
broadcast_module_parameter(module, "weight_global_scale", src_rank)

# fuse global_scales (all ranks, idempotent)
for module in model.modules():
update_fused_layer_weight_global_scales(module)

# compute scale/zp for assigned modules only
for _, module in tqdm.tqdm(
my_modules, desc=f"[Rank {rank}] Calibrating weights"
):
update_weight_zp_scale(module)

# broadcast scale/zp to all ranks
for _, module in named_modules:
src_rank = module_to_rank[module]
broadcast_module_parameter(module, "weight_scale", src_rank)
if hasattr(module, "weight_zero_point"):
broadcast_module_parameter(module, "weight_zero_point", src_rank)

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self._sync_activation_observers(state.model)

if event.type_ == EventType.CALIBRATION_EPOCH_END:
self._sync_activation_observers(state.model)
if not self.ended_:
self.on_end(state, None)

def _sync_activation_observers(self, model):
"""
All-reduce activation observer min/max values across DDP ranks,
then recompute scale/zp from the global statistics.
No-op if not distributed.
"""
if not is_distributed():
return

for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
for base_name in ("input", "output", "q", "k", "v"):
observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
continue

# all-reduce accumulated min/max across ranks
if (
hasattr(observer, "past_min_vals")
and observer.past_min_vals is not None
):
observer.past_min_vals = all_reduce_min(observer.past_min_vals)
if (
hasattr(observer, "past_max_vals")
and observer.past_max_vals is not None
):
observer.past_max_vals = all_reduce_max(observer.past_max_vals)

# all-reduce global min/max (TENSOR_GROUP strategy)
if (
hasattr(observer, "past_global_min_vals")
and observer.past_global_min_vals is not None
):
observer.past_global_min_vals = all_reduce_min(
observer.past_global_min_vals
)
if (
hasattr(observer, "past_global_max_vals")
and observer.past_global_max_vals is not None
):
observer.past_global_max_vals = all_reduce_max(
observer.past_global_max_vals
)

recompute_qparams_from_observer(module, base_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _sync_activation_observers performs multiple all_reduce operations per module inside a nested loop. For a typical transformer model, this can result in hundreds or even thousands of small collective communication calls. In distributed settings, the latency overhead of many small calls is much higher than a single large call.

Consider aggregating all tensors that need reduction into a single list, concatenating them into one or two large buffers (e.g., one for MIN and one for MAX), performing a single all_reduce on each buffer, and then unpacking the results back into the observers. This will significantly improve performance on high-latency networks.


def on_end(self, state: State, event: Event, **kwargs):
"""
Finish calibrating by removing observers and calibration hooks
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .dev import *
from .helpers import *
from .dist import *
from .distributed import *
Loading
Loading