-
Notifications
You must be signed in to change notification settings - Fork 432
[Distributed] Extend QuantizationModifier to support distributed activation calibration #2391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f60200a
c4d630d
89d1ade
ac0cc2a
76cf40f
0f3e1f9
9975edc
3320812
87f4b0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| ############################################################################# | ||
| # Distributed W8A8 quantization example with activation observer sync. | ||
| # run this with `torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py` | ||
| # or change nproc_per_node to your desired configuration | ||
| ############################################################################# | ||
|
|
||
| import torch | ||
| from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.datasets.utils import get_rank_partition | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
|
|
||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| ###### DDP MODEL LOAD CHANGE ##### | ||
| init_dist() | ||
| with load_offloaded_model(): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, dtype="auto", device_map="auto_offload" | ||
| ) | ||
| ################################## | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| ###### DDP DATA LOAD CHANGE ##### | ||
| ds = load_dataset( | ||
| DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) | ||
| ) | ||
| ################################## | ||
|
|
||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # QuantizationModifier automatically detects torch.distributed and | ||
| # all-reduces activation observer statistics at layer boundaries | ||
| recipe = [ | ||
| QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), | ||
| ] | ||
|
|
||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") | ||
|
|
||
| print("Saving...") | ||
| SAVE_DIR = ( | ||
| MODEL_ID.rstrip("/").split("/")[-1] | ||
| + "-W8A8-DDP" | ||
| + str(torch.distributed.get_world_size()) | ||
| ) | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
|
|
||
| torch.distributed.destroy_process_group() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| from abc import abstractmethod | ||
| from typing import Optional, Tuple | ||
| from typing import List, Optional, Tuple | ||
| from weakref import ref | ||
|
|
||
| import torch | ||
|
|
@@ -8,6 +8,7 @@ | |
| from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam | ||
| from compressed_tensors.registry.registry import RegistryMixin | ||
| from compressed_tensors.utils import align_module_device | ||
| from torch import distributed as dist | ||
|
|
||
| from llmcompressor.observers.helpers import flatten_for_calibration | ||
|
|
||
|
|
@@ -133,6 +134,63 @@ def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: | |
| with align_module_device(module): | ||
| return getattr(module, f"{self.base_name}_{name}", None) | ||
|
|
||
| def synchronize(self) -> List[dist.Work]: | ||
| """All-reduce accumulated min/max statistics across DDP ranks. | ||
|
|
||
| Issues async all-reduce operations on any accumulated state | ||
| (``past_min_vals``, ``past_max_vals``, ``past_global_min_vals``, | ||
| ``past_global_max_vals``). Memoryless observers return an empty list. | ||
|
|
||
| :return: list of async communication handles | ||
| """ | ||
| comms = [] | ||
| for attr, op in [ | ||
| ("past_min_vals", dist.ReduceOp.MIN), | ||
| ("past_max_vals", dist.ReduceOp.MAX), | ||
| ("past_global_min_vals", dist.ReduceOp.MIN), | ||
| ("past_global_max_vals", dist.ReduceOp.MAX), | ||
| ]: | ||
|
Comment on lines
+147
to
+152
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that I think this approach is more elegant than reimplementing for each subclass
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont we need to do this for every subclass? this strategy only makes sense for an single observer: static_minmax memoryless_minmax and memoryless_mse it also works i guess, because they don't have any of those values minmax (moving average) and mse (moving average) - it makes no sense, you probably need to average across ranks, though it wouldn't be hard in theory to do:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah let's just average for now |
||
| val = getattr(self, attr, None) | ||
| if val is not None: | ||
| comms.append(dist.all_reduce(val, op=op, async_op=True)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think we need the fp8 trick here from GPTQ base.py |
||
| return comms | ||
|
|
||
| def recompute_global_scale(self) -> Optional[torch.Tensor]: | ||
| """Recompute global scale from accumulated global min/max state. | ||
|
|
||
| Used after :meth:`synchronize` to update the global scale from | ||
| globally reduced statistics. Returns ``None`` for memoryless observers. | ||
|
|
||
| :return: global scale tensor or ``None`` | ||
| """ | ||
| global_min = getattr(self, "past_global_min_vals", None) | ||
| global_max = getattr(self, "past_global_max_vals", None) | ||
| if global_min is None or global_max is None: | ||
| return None | ||
| return generate_gparam(global_min, global_max) | ||
|
|
||
| def recompute_qparams(self) -> Optional[ScaleZpTuple]: | ||
| """Recompute scale and zero_point from accumulated min/max state. | ||
|
|
||
| Used after :meth:`synchronize` to update quantization parameters from | ||
| globally reduced statistics. Returns ``None`` for memoryless observers. | ||
|
|
||
| :return: (scale, zero_point) tuple or ``None`` | ||
| """ | ||
| min_vals = getattr(self, "past_min_vals", None) | ||
| max_vals = getattr(self, "past_max_vals", None) | ||
| if min_vals is None or max_vals is None: | ||
| return None | ||
|
|
||
| global_scale = self._get_module_param("global_scale") | ||
| self._check_has_global_scale(global_scale) | ||
| return calculate_qparams( | ||
| min_vals=min_vals, | ||
| max_vals=max_vals, | ||
| quantization_args=self.args, | ||
| global_scale=global_scale, | ||
| ) | ||
|
|
||
| def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): | ||
| if ( | ||
| self.args.strategy == QuantizationStrategy.TENSOR_GROUP | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| """ | ||
| Run with: torchrun --nproc_per_node=2 -m pytest <this_file> -v | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
| from compressed_tensors.quantization import QuantizationArgs | ||
|
|
||
| from llmcompressor.observers.min_max import StaticMinMaxObserver | ||
| from llmcompressor.utils.dist import wait_for_comms | ||
| from tests.testing_utils import requires_gpu | ||
|
|
||
| # initialize process group when running under torchrun | ||
| if ( | ||
| os.environ.get("RANK") is not None | ||
| and torch.cuda.is_available() | ||
| and not dist.is_initialized() | ||
| ): | ||
| dist.init_process_group(backend="nccl") | ||
| torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) | ||
|
|
||
|
|
||
| def _skip_if_not_distributed(): | ||
| if not (dist.is_available() and dist.is_initialized()): | ||
| pytest.skip("Requires torchrun --nproc_per_node=2") | ||
|
|
||
|
|
||
| @pytest.mark.multi_gpu | ||
| @requires_gpu(2) | ||
| def test_observer_synchronize_reduces_min_max(): | ||
| _skip_if_not_distributed() | ||
| rank = dist.get_rank() | ||
|
|
||
| args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor") | ||
| observer = StaticMinMaxObserver(base_name="input", args=args) | ||
|
|
||
| # each rank has different local statistics | ||
| observer.past_min_vals = ( | ||
| torch.tensor([1.0, 3.0], device="cuda") | ||
| if rank == 0 | ||
| else torch.tensor([2.0, 1.0], device="cuda") | ||
| ) | ||
| observer.past_max_vals = ( | ||
| torch.tensor([10.0, 20.0], device="cuda") | ||
| if rank == 0 | ||
| else torch.tensor([15.0, 10.0], device="cuda") | ||
| ) | ||
|
|
||
| comms = observer.synchronize() | ||
| wait_for_comms(comms) | ||
|
|
||
| # after sync, min should be element-wise minimum, max element-wise maximum | ||
| assert torch.equal( | ||
| observer.past_min_vals, torch.tensor([1.0, 1.0], device="cuda") | ||
| ) | ||
| assert torch.equal( | ||
| observer.past_max_vals, torch.tensor([15.0, 20.0], device="cuda") | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.multi_gpu | ||
| @requires_gpu(2) | ||
| def test_synced_qparams_are_identical_across_ranks(): | ||
| _skip_if_not_distributed() | ||
| rank = dist.get_rank() | ||
|
|
||
| args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor") | ||
| observer = StaticMinMaxObserver(base_name="input", args=args) | ||
|
|
||
| observer.past_min_vals = ( | ||
| torch.tensor([-2.0], device="cuda") | ||
| if rank == 0 | ||
| else torch.tensor([-5.0], device="cuda") | ||
| ) | ||
| observer.past_max_vals = ( | ||
| torch.tensor([3.0], device="cuda") | ||
| if rank == 0 | ||
| else torch.tensor([1.0], device="cuda") | ||
| ) | ||
|
|
||
| comms = observer.synchronize() | ||
| wait_for_comms(comms) | ||
|
|
||
| result = observer.recompute_qparams() | ||
| assert result is not None | ||
| scale, _ = result | ||
|
|
||
| gathered = [torch.zeros_like(scale) for _ in range(dist.get_world_size())] | ||
| dist.all_gather(gathered, scale) | ||
| assert torch.equal(gathered[0], gathered[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll need to add these to GPTQ and AWQ, right?