-
Notifications
You must be signed in to change notification settings - Fork 438
[Distributed] Extend QuantizationModifier to support distributed activation calibration #2391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
f60200a
c4d630d
89d1ade
ac0cc2a
76cf40f
0f3e1f9
9975edc
3320812
87f4b0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Select model and load it. | ||
| MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Initialize distributed. | ||
| # Usage: torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py | ||
| dist.init_process_group(backend="nccl") | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
| torch.cuda.set_device(rank) | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if rank == 0: | ||
| print(f"Running distributed quantization with {world_size} GPUs") | ||
|
|
||
| # Load model to CPU for sequential onloading. | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, | ||
| dtype="auto", | ||
| device_map=None, | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| # Load and partition dataset across ranks. | ||
| # Each rank loads a disjoint slice of the calibration data. | ||
| samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size | ||
| start = samples_per_rank * rank | ||
| end = start + samples_per_rank | ||
|
|
||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[{start}:{end}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # QuantizationModifier automatically detects torch.distributed and: | ||
| # * partitions weight calibration across ranks | ||
| # * all-reduces activation observer statistics at layer boundaries | ||
| recipe = [ | ||
| QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), | ||
| ] | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=samples_per_rank, | ||
| ) | ||
|
|
||
| # Save to disk compressed (rank 0 only). | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-distributed" | ||
| if rank == 0: | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| print(f"Model saved to {SAVE_DIR}") | ||
|
|
||
| dist.destroy_process_group() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,21 @@ | |
| from llmcompressor.core import Event, EventType, State | ||
| from llmcompressor.modifiers import Modifier | ||
| from llmcompressor.modifiers.quantization.calibration import ( | ||
| recompute_qparams_from_observer, | ||
| update_weight_global_scale, | ||
| update_weight_zp_scale, | ||
| ) | ||
| from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin | ||
| from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales | ||
| from llmcompressor.utils.distributed import ( | ||
| all_reduce_max, | ||
| all_reduce_min, | ||
| broadcast_module_parameter, | ||
| build_module_to_rank_map, | ||
| get_rank, | ||
| is_distributed, | ||
| partition_modules_by_weight_size, | ||
| ) | ||
|
|
||
| __all__ = ["QuantizationModifier"] | ||
|
|
||
|
|
@@ -20,6 +30,9 @@ class QuantizationModifier(Modifier, QuantizationMixin): | |
| the specified module(s) forward pass will emulate quantized execution and the | ||
| modifier will be enabled until training is completed. | ||
|
|
||
| In DDP mode, weight calibration is partitioned across ranks and activation | ||
| observer statistics are all-reduced at sequential layer boundaries. | ||
|
|
||
| :param config_groups: dictionary specifying quantization schemes to apply to target | ||
| modules. Modules not matching a scheme target will NOT be quantized. | ||
| :param targets: list of layer names to quantize if a scheme is provided. Defaults | ||
|
|
@@ -65,14 +78,23 @@ def on_initialize(self, state: State, **kwargs) -> bool: | |
|
|
||
| def on_start(self, state: State, event: Event, **kwargs): | ||
| """ | ||
| Begin calibrating activations and weights. Calibrate weights only once on start | ||
| Begin calibrating activations and weights. Calibrate weights only once | ||
| on start. In DDP mode, weight calibration is partitioned across ranks. | ||
| """ | ||
| self.started_ = True | ||
| QuantizationMixin.start_calibration(self, state.model) | ||
|
|
||
| named_modules = list( | ||
| match_named_modules(state.model, self.resolved_targets, self.ignore) | ||
| ) | ||
|
|
||
| if is_distributed(): | ||
| self._calibrate_weights_distributed(state.model, named_modules) | ||
| else: | ||
| self._calibrate_weights_single(state.model, named_modules) | ||
|
|
||
| def _calibrate_weights_single(self, model, named_modules): | ||
| """Original single-process weight calibration.""" | ||
| # TODO: this step can be combined with update_weight_zp_scale | ||
| # once update_fused_layer_weight_global_scales is removed | ||
| # and not required by vLLM | ||
|
|
@@ -84,21 +106,107 @@ def on_start(self, state: State, event: Event, **kwargs): | |
| # on targeted modules, we need to run on all modules. | ||
| # Because this call is idempotent, setting all global_scales to the | ||
| # min value, it is ok to run potentially multiple times for all modules | ||
| for module in state.model.modules(): | ||
| for module in model.modules(): | ||
| update_fused_layer_weight_global_scales(module) | ||
|
|
||
| for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"): | ||
| update_weight_zp_scale(module) | ||
|
|
||
| def _calibrate_weights_distributed(self, model, named_modules): | ||
| """ | ||
| DDP-partitioned weight calibration. Each rank calibrates a subset of | ||
| modules and broadcasts results to all ranks. | ||
| """ | ||
| module_to_rank = build_module_to_rank_map(named_modules) | ||
| my_modules = partition_modules_by_weight_size(named_modules) | ||
| rank = get_rank() | ||
|
|
||
| # compute global_scale for assigned modules only | ||
| for _, module in tqdm.tqdm( | ||
| my_modules, desc=f"[Rank {rank}] Updating global scales" | ||
| ): | ||
| update_weight_global_scale(module) | ||
|
|
||
| # broadcast global_scales so all ranks can run the fuse step | ||
| for _, module in named_modules: | ||
| src_rank = module_to_rank[module] | ||
| broadcast_module_parameter(module, "weight_global_scale", src_rank) | ||
|
|
||
| # fuse global_scales (all ranks, idempotent) | ||
| for module in model.modules(): | ||
| update_fused_layer_weight_global_scales(module) | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # compute scale/zp for assigned modules only | ||
| for _, module in tqdm.tqdm( | ||
| my_modules, desc=f"[Rank {rank}] Calibrating weights" | ||
| ): | ||
| update_weight_zp_scale(module) | ||
|
|
||
| # broadcast scale/zp to all ranks | ||
| for _, module in named_modules: | ||
| src_rank = module_to_rank[module] | ||
| broadcast_module_parameter(module, "weight_scale", src_rank) | ||
| if hasattr(module, "weight_zero_point"): | ||
| broadcast_module_parameter(module, "weight_zero_point", src_rank) | ||
|
|
||
| def on_event(self, state: State, event: Event, **kwargs): | ||
| if event.type_ == EventType.CALIBRATION_EPOCH_START: | ||
| if not self.started_: | ||
| self.on_start(state, None) | ||
|
|
||
| if event.type_ == EventType.SEQUENTIAL_EPOCH_END: | ||
| self._sync_activation_observers(state.model) | ||
|
|
||
| if event.type_ == EventType.CALIBRATION_EPOCH_END: | ||
| self._sync_activation_observers(state.model) | ||
| if not self.ended_: | ||
| self.on_end(state, None) | ||
|
|
||
| def _sync_activation_observers(self, model): | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| All-reduce activation observer min/max values across DDP ranks, | ||
| then recompute scale/zp from the global statistics. | ||
| No-op if not distributed. | ||
| """ | ||
| if not is_distributed(): | ||
| return | ||
|
|
||
| for _, module in match_named_modules(model, self.resolved_targets, self.ignore): | ||
| for base_name in ("input", "output", "q", "k", "v"): | ||
| observer = getattr(module, f"{base_name}_observer", None) | ||
| if observer is None: | ||
| continue | ||
|
|
||
| # all-reduce accumulated min/max across ranks | ||
| if ( | ||
| hasattr(observer, "past_min_vals") | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| and observer.past_min_vals is not None | ||
| ): | ||
| observer.past_min_vals = all_reduce_min(observer.past_min_vals) | ||
| if ( | ||
| hasattr(observer, "past_max_vals") | ||
| and observer.past_max_vals is not None | ||
| ): | ||
| observer.past_max_vals = all_reduce_max(observer.past_max_vals) | ||
|
|
||
| # all-reduce global min/max (TENSOR_GROUP strategy) | ||
| if ( | ||
| hasattr(observer, "past_global_min_vals") | ||
| and observer.past_global_min_vals is not None | ||
| ): | ||
| observer.past_global_min_vals = all_reduce_min( | ||
| observer.past_global_min_vals | ||
| ) | ||
| if ( | ||
| hasattr(observer, "past_global_max_vals") | ||
| and observer.past_global_max_vals is not None | ||
| ): | ||
| observer.past_global_max_vals = all_reduce_max( | ||
| observer.past_global_max_vals | ||
| ) | ||
|
|
||
| recompute_qparams_from_observer(module, base_name) | ||
|
||
|
|
||
| def on_end(self, state: State, event: Event, **kwargs): | ||
| """ | ||
| Finish calibrating by removing observers and calibration hooks | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,4 @@ | |
| from .dev import * | ||
| from .helpers import * | ||
| from .dist import * | ||
| from .distributed import * | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting the CUDA device using the global
rankwill fail in multi-node setups where the rank exceeds the number of GPUs per node. It is standard practice to use the local rank for device assignment.