Skip to content

Commit a6d45ba

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/remove-double-init
2 parents 606b93b + 2a59554 commit a6d45ba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+928
-836
lines changed

README.md

+33-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,32 @@
2323
* SmoothQuant
2424
* SparseGPT
2525

26+
### When to Use Which Optimization
27+
28+
#### PTQ
29+
PTQ is performed to reduce the precision of quantizable weights (e.g., linear layers) to a lower bit-width. Supported formats are:
30+
31+
##### [W4A16](./examples/quantization_w4a16/README.md)
32+
- Uses GPTQ to compress weights to 4 bits. Requires calibration dataset.
33+
- Useful speed ups in low QPS regimes with more weight compression.
34+
- Recommended for any GPUs types.
35+
##### [W8A8-INT8](./examples/quantization_w8a8_int8/README.md)
36+
- Uses channel-wise quantization to compress weights to 8 bits using GPTQ, and uses dynamic per-token quantization to compress activations to 8 bits. Requires calibration dataset for weight quantization. Activation quantization is carried out during inference on vLLM.
37+
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
38+
- Recommended for NVIDIA GPUs with compute capability <8.9 (Ampere, Turing, Volta, Pascal, or older).
39+
##### [W8A8-FP8](./examples/quantization_w8a8_fp8/README.md)
40+
- Uses channel-wise quantization to compress weights to 8 bits, and uses dynamic per-token quantization to compress activations to 8 bits. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM.
41+
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
42+
- Recommended for NVIDIA GPUs with compute capability >8.9 (Hopper and Ada Lovelace).
43+
44+
#### Sparsification
45+
Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include:
46+
47+
##### [2:4-Sparsity with FP8 Weight, FP8 Input Activation](./examples/sparse_2of4_quantization_fp8/README.md)
48+
- Uses (1) semi-structured sparsity (SparseGPT), where, for every four contiguous weights in a tensor, two are set to zero. (2) Uses channel-wise quantization to compress weights to 8 bits and dynamic per-token quantization to compress activations to 8 bits.
49+
- Useful for better inference than W8A8-fp8, with almost no drop in its evaluation score [blog](https://neuralmagic.com/blog/24-sparse-llama-fp8-sota-performance-for-nvidia-hopper-gpus/). Note: Small models may experience accuracy drops when the remaining non-zero weights are insufficient to recapitulate the original distribution.
50+
- Recommended for compute capability >8.9 (Hopper and Ada Lovelace).
51+
2652

2753
## Installation
2854

@@ -35,16 +61,16 @@ pip install llmcompressor
3561
### End-to-End Examples
3662

3763
Applying quantization with `llmcompressor`:
38-
* [Activation quantization to `int8`](examples/quantization_w8a8_int8)
39-
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8)
40-
* [Weight only quantization to `int4`](examples/quantization_w4a16)
41-
* [Quantizing MoE LLMs](examples/quantizing_moe)
42-
* [Quantizing Vision-Language Models](examples/multimodal_vision)
43-
* [Quantizing Audio-Language Models](examples/multimodal_audio)
64+
* [Activation quantization to `int8`](examples/quantization_w8a8_int8/README.md)
65+
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8/README.md)
66+
* [Weight only quantization to `int4`](examples/quantization_w4a16/README.md)
67+
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
68+
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)
69+
* [Quantizing Audio-Language Models](examples/multimodal_audio/README.md)
4470

4571
### User Guides
4672
Deep dives into advanced usage of `llmcompressor`:
47-
* [Quantizing with large models with the help of `accelerate`](examples/big_models_with_accelerate)
73+
* [Quantizing with large models with the help of `accelerate`](examples/big_models_with_accelerate/README.md)
4874

4975

5076
## Quick Tour

examples/trl_mixin/ex_trl_distillation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
max_seq_length = 512
2020

2121
# Load gsm8k using SparseML dataset tools
22-
data_args = DatasetArguments(
22+
dataset_args = DatasetArguments(
2323
dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length
2424
)
2525
dataset_manager = TextGenerationDataset.load_from_registry(
26-
data_args.dataset,
27-
data_args=data_args,
26+
dataset_args.dataset,
27+
dataset_args=dataset_args,
2828
split="train",
2929
processor=tokenizer,
3030
)
@@ -69,7 +69,7 @@
6969
train_dataset=train_dataset,
7070
data_collator=data_collator,
7171
trl_sft_config_args=trl_sft_config_args,
72-
data_args=data_args,
72+
dataset_args=dataset_args,
7373
model_args=model_args,
7474
)
7575
trainer.train()

src/llmcompressor/args/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .model_arguments import ModelArguments
55
from .recipe_arguments import RecipeArguments
66
from .training_arguments import TrainingArguments
7+
from .utils import parse_args

src/llmcompressor/args/utils.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Tuple
2+
3+
from loguru import logger
4+
from transformers import HfArgumentParser
5+
6+
from llmcompressor.args import (
7+
DatasetArguments,
8+
ModelArguments,
9+
RecipeArguments,
10+
TrainingArguments,
11+
)
12+
from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args
13+
14+
15+
def parse_args(
16+
include_training_args: bool = False, **kwargs
17+
) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments, str]:
18+
"""
19+
Keyword arguments passed in from `oneshot` or `train` will
20+
separate the arguments into the following:
21+
22+
* ModelArguments in
23+
src/llmcompressor/args/model_args.py
24+
* DatasetArguments in
25+
src/llmcompressor/args/dataset_args.py
26+
* RecipeArguments in
27+
src/llmcompressor/args/recipe_args.py
28+
* TrainingArguments in
29+
src/llmcompressor/args/training_args.py
30+
31+
ModelArguments, DatasetArguments, and RecipeArguments are used for both
32+
`oneshot` and `train`. TrainingArguments is only used for `train`.
33+
34+
"""
35+
36+
# pop output_dir, used as an attr in TrainingArguments, where oneshot is not used
37+
output_dir = kwargs.pop("output_dir", None)
38+
39+
parser_args = (ModelArguments, DatasetArguments, RecipeArguments)
40+
if include_training_args:
41+
parser_args += (TrainingArguments,)
42+
43+
parser = HfArgumentParser(parser_args)
44+
parsed_args = parser.parse_dict(kwargs)
45+
46+
training_args = None
47+
if include_training_args:
48+
model_args, dataset_args, recipe_args, training_args = parsed_args
49+
if output_dir is not None:
50+
training_args.output_dir = output_dir
51+
else:
52+
model_args, dataset_args, recipe_args = parsed_args
53+
54+
if recipe_args.recipe_args is not None:
55+
if not isinstance(recipe_args.recipe_args, dict):
56+
arg_dict = {}
57+
for recipe_arg in recipe_args.recipe_args:
58+
key, value = recipe_arg.split("=")
59+
arg_dict[key] = value
60+
recipe_args.recipe_args = arg_dict
61+
62+
# raise depreciation warnings
63+
if dataset_args.remove_columns is not None:
64+
logger.warn(
65+
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
66+
"columns which are invalid inputs the tokenizer will be removed",
67+
DeprecationWarning,
68+
)
69+
70+
# silently assign tokenizer to processor
71+
resolve_processor_from_model_args(model_args)
72+
73+
return model_args, dataset_args, recipe_args, training_args, output_dir
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# flake8: noqa
2+
3+
from .utils import (
4+
format_calibration_data,
5+
get_calibration_dataloader,
6+
get_processed_dataset,
7+
make_dataset_splits,
8+
)

src/llmcompressor/datasets/utils.py

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import re
2+
from typing import Any, Callable, Dict, List, Optional
3+
4+
import torch
5+
from datasets import Dataset
6+
from loguru import logger
7+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
8+
from transformers.data import default_data_collator
9+
10+
from llmcompressor.args import DatasetArguments
11+
from llmcompressor.transformers.finetune.data import TextGenerationDataset
12+
from llmcompressor.typing import Processor
13+
14+
15+
def get_processed_dataset(
16+
dataset_args: DatasetArguments,
17+
processor: Processor,
18+
do_oneshot: bool = False,
19+
do_train: bool = True,
20+
) -> Optional[Dict[str, Dataset]]:
21+
"""
22+
Loads datasets for each flow based on dataset_args, stores a Dataset for each
23+
enabled flow in datasets
24+
:param dataset_args: DatasetArguments that contain dataset loading and
25+
processing params
26+
:param processor: processor or tokenizer to use for dataset tokenization
27+
:param do_oneshot: True for oneshot pathway
28+
:param do_train: True for train pathway
29+
:return: A dataset corresponding to either train or calibration (oneshot)
30+
"""
31+
if dataset_args.dataset is None:
32+
logger.warning(
33+
"Running oneshot without calibration data. This is expected for "
34+
"weight-only and dynamic quantization"
35+
)
36+
return
37+
38+
splits = dataset_args.splits
39+
tokenized_datasets = {}
40+
41+
def _get_split_name(inp_str):
42+
# strip out split name, for ex train[60%:] -> train
43+
match = re.match(r"(\w*)\[.*\]", inp_str)
44+
if match is not None:
45+
return match.group(1)
46+
return inp_str
47+
48+
if splits is None:
49+
splits = {"all": None}
50+
elif isinstance(splits, str):
51+
splits = {_get_split_name(splits): splits}
52+
elif isinstance(splits, List):
53+
splits = {_get_split_name(s): s for s in splits}
54+
55+
# default to custom dataset if dataset provided isn't a string
56+
registry_id = (
57+
dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
58+
)
59+
for split_name, split_str in splits.items():
60+
dataset = dataset_args.dataset
61+
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
62+
# dataset is already tokenized
63+
tokenized_datasets[split_name] = dataset
64+
else:
65+
# dataset needs to be tokenized
66+
dataset_manager = TextGenerationDataset.load_from_registry(
67+
registry_id,
68+
dataset_args=dataset_args,
69+
split=split_str,
70+
processor=processor,
71+
)
72+
tokenized_datasets[split_name] = dataset_manager(add_labels=do_train)
73+
74+
return make_dataset_splits(
75+
tokenized_datasets,
76+
do_oneshot=do_oneshot,
77+
do_train=do_train,
78+
)
79+
80+
81+
def get_calibration_dataloader(
82+
dataset_args: DatasetArguments,
83+
processor: Processor,
84+
) -> torch.utils.data.DataLoader:
85+
"""
86+
Get the dataloader used for oneshot calibration.
87+
:param dataset_args: DatasetArguments that contains the dataset parameters.
88+
:param processor: Processor or the tokenizer of the model.
89+
:return: PyTorch dataloader object that contains the calibration dataset.
90+
"""
91+
if dataset_args.dataset is None:
92+
# weight-only quantization or dynamic quantization
93+
return
94+
95+
datasets = get_processed_dataset(
96+
dataset_args=dataset_args,
97+
processor=processor,
98+
do_oneshot=True,
99+
do_train=False,
100+
)
101+
102+
calibration_dataset = datasets.get("calibration")
103+
104+
return format_calibration_data(
105+
tokenized_dataset=calibration_dataset,
106+
num_calibration_samples=dataset_args.num_calibration_samples,
107+
do_shuffle=dataset_args.shuffle_calibration_samples,
108+
collate_fn=dataset_args.data_collator,
109+
)
110+
111+
112+
def format_calibration_data(
113+
tokenized_dataset: Dataset,
114+
num_calibration_samples: Optional[int] = None,
115+
do_shuffle: bool = True,
116+
collate_fn: Callable = default_data_collator,
117+
) -> List[torch.Tensor]:
118+
"""
119+
Creates a dataloader out of the calibration dataset split, trimming it to
120+
the desired number of calibration samples
121+
:param tokenized_dataset: dataset to convert to dataloader
122+
:param num_calibration_samples: number of data samples to convert
123+
:param do_shuffle: whether to shuffle the dataset before selecting calibration
124+
samples, true by default
125+
:param collate_fn: optional custom collate function, or use default
126+
:return: list of trimmed calibration data tensors
127+
"""
128+
safe_calibration_samples = len(tokenized_dataset)
129+
if num_calibration_samples is not None:
130+
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
131+
if safe_calibration_samples != num_calibration_samples:
132+
logger.warn(
133+
f"Requested {num_calibration_samples} calibration samples but "
134+
f"the provided dataset only has {safe_calibration_samples}. "
135+
)
136+
137+
if do_shuffle:
138+
tokenized_dataset = tokenized_dataset.shuffle()
139+
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))
140+
141+
dataloader_params = {
142+
"batch_size": 1,
143+
"sampler": RandomSampler(tokenized_calibration)
144+
if do_shuffle
145+
else SequentialSampler(tokenized_calibration),
146+
"collate_fn": collate_fn,
147+
"pin_memory": True,
148+
}
149+
150+
calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params)
151+
152+
return calibration_dataloader
153+
154+
155+
def make_dataset_splits(
156+
tokenized_datasets: Dict[str, Any],
157+
do_oneshot: bool = True,
158+
do_train: bool = False,
159+
) -> Dict[str, Dataset]:
160+
"""
161+
Restructures the datasets dictionary based on what tasks will be run
162+
train
163+
:param tokenized_datasets: dictionary of processed datasets
164+
:param do_oneshot: Whether to store the calibration dataset
165+
:return: A dataset corresponding to either train or calibration (oneshot)
166+
"""
167+
168+
# handles case where all splits are contained in a single dataset
169+
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
170+
tokenized_datasets = tokenized_datasets.get("all")
171+
if isinstance(tokenized_datasets, Dataset):
172+
tokenized_datasets = {"train": tokenized_datasets}
173+
174+
train_split = calib_split = None
175+
176+
if do_train:
177+
if "train" not in tokenized_datasets:
178+
raise ValueError("--do_train requires a train dataset")
179+
train_split = tokenized_datasets["train"]
180+
if do_oneshot:
181+
calib_split = tokenized_datasets.get("calibration")
182+
if calib_split is None:
183+
if "train" not in tokenized_datasets:
184+
raise ValueError("--do_oneshot requires a calibration dataset")
185+
calib_split = tokenized_datasets["train"]
186+
187+
split_datasets = {
188+
"train": train_split,
189+
"calibration": calib_split,
190+
}
191+
return split_datasets
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# flake8: noqa
22
from .oneshot import Oneshot, oneshot
3+
from .utils import post_process, pre_process

0 commit comments

Comments
 (0)