-
Notifications
You must be signed in to change notification settings - Fork 490
Expand file tree
/
Copy pathutils.py
More file actions
417 lines (342 loc) · 15 KB
/
utils.py
File metadata and controls
417 lines (342 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
"""
Dataset utility functions for LLM compression workflows.
Provides helper functions for loading, processing, and formatting datasets used
in model compression pipelines. Handles dataset splitting, tokenization,
calibration data preparation, and dataloader creation for both training and
one-shot calibration workflows.
"""
import math
from collections.abc import Iterator, Sized
from typing import Any, Callable, Optional
import torch
from datasets import Dataset
from loguru import logger
from torch import distributed as dist
from torch.utils.data import DataLoader, RandomSampler, Sampler
from transformers.data import DataCollatorWithPadding, default_data_collator
from llmcompressor.args import DatasetArguments
from llmcompressor.transformers.data import TextGenerationDataset
from llmcompressor.typing import Processor
BS_WARNING_THRESHOLD = 16
def get_processed_dataset(
dataset_args: DatasetArguments,
processor: Processor | None = None,
) -> Dataset | None:
"""
Loads dataset based on dataset_args.
:param dataset_args: DatasetArguments that contain dataset loading and
processing params
:param processor: processor or tokenizer to use for dataset tokenization
:return: A Dataset corresponding to the single split for calibration
"""
if dataset_args.dataset is None:
logger.warning(
"Running oneshot without calibration data. This is expected for "
"weight-only and dynamic quantization"
)
return
splits = dataset_args.splits
match splits:
case None:
split_str = None
case str():
split_str = splits
case dict():
if "calibration" in splits:
split_str = splits["calibration"]
if len(splits) > 1:
ignored_keys = set(splits.keys()) - {"calibration"}
logger.warning(
f"Ignoring extra keys in splits: {list(ignored_keys)}. "
"Only the 'calibration' split is used."
)
else:
raise ValueError(
"Passing `splits` as a dict is only supported when it contains a "
"`'calibration'` key during the deprecation period. "
"Please pass a split string instead."
)
logger.warning(
"Passing `splits` as a dictionary is deprecated. "
f"Extracted split string: '{split_str}'. "
"Please pass `splits` as a string instead."
)
case list():
split_str = splits[0] if len(splits) > 0 else None
logger.warning(
"Passing `splits` as a list is deprecated. "
f"Using first element: '{split_str}'. "
"Please pass `splits` as a string instead."
)
case _:
raise ValueError(
f"Invalid splits type: {type(splits)}. Expected a split string "
"or the deprecated `{'calibration': ...}` form."
)
# default to custom dataset if dataset provided isn't a string
registry_id = (
dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
)
dataset = dataset_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
return dataset
else:
# dataset needs to be tokenized
dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
dataset_args=dataset_args,
split=split_str,
processor=processor,
)
dataset = dataset_manager()
# If no split was specified, a DatasetDict format is typically returned.
# Fallback to the 'train' split for backward compatibility.
if not isinstance(dataset, Dataset):
if "train" in dataset:
logger.warning(
"No split was specified, but a multi-split dataset was loaded. "
"Falling back to the 'train' split for calibration."
)
dataset = dataset["train"]
else:
raise ValueError(
"No split specified and 'train' split not found in dataset. "
"Please specify `splits` explicitly."
)
return dataset
def get_calibration_dataloader(
dataset_args: DatasetArguments,
processor: Processor,
) -> DataLoader | None:
"""
Get the dataloader used for oneshot calibration.
If dataset_args.dataset is already a PyTorch DataLoader,
it is returned directly, bypassing dataset loading and tokenization.
:param dataset_args: DatasetArguments that contains the dataset parameters.
:param processor: Processor or the tokenizer of the model.
:return: PyTorch dataloader object that contains the calibration
dataset, or None for data-free flows.
"""
if dataset_args.dataset is None:
# weight-only quantization or dynamic quantization
return None
if isinstance(dataset_args.dataset, DataLoader):
return dataset_args.dataset
calibration_dataset = get_processed_dataset(
dataset_args=dataset_args,
processor=processor,
)
if calibration_dataset is None:
return None
return format_calibration_data(dataset_args, calibration_dataset, processor)
def format_calibration_data(
args: DatasetArguments,
tokenized_dataset: Dataset,
processor: Processor,
) -> DataLoader:
# Pin memory only when using workers (saves RAM for low-memory users when
# num_workers=0; when num_workers>0, pin_memory speeds CPU->GPU transfer)
num_workers = args.dataloader_num_workers
pin_memory = torch.accelerator.is_available() and num_workers > 0
# persistent_workers avoids worker respawn between epochs (only when
# num_workers > 0). prefetch_factor is left at DataLoader default (2).
kwargs: dict[str, Any] = {}
if num_workers > 0:
kwargs["persistent_workers"] = True
return DataLoader(
tokenized_dataset,
batch_size=args.batch_size,
sampler=_make_sampler(args, tokenized_dataset),
collate_fn=_make_collate_fn(args, processor),
pin_memory=pin_memory,
num_workers=num_workers,
**kwargs,
)
def _make_collate_fn(args: DatasetArguments, processor: Processor) -> Callable:
if isinstance(args.data_collator, Callable):
return args.data_collator
if args.data_collator == "truncation":
if args.batch_size > BS_WARNING_THRESHOLD:
logger.warning(
f"Calibrating with batch sizes greater than {BS_WARNING_THRESHOLD} and "
"`data_collator='truncation'` can lead to significant portions of the "
"calibration dataset being deleted via truncation. Please consider "
"reducing the calibration batch size or using filtering the dataset "
"to use more uniform sequence lengths"
)
return data_collator_with_truncation
elif args.data_collator == "padding":
if args.batch_size > BS_WARNING_THRESHOLD:
logger.warning(
f"Calibrating with batch sizes greater than {BS_WARNING_THRESHOLD} and "
"`data_collator='padding'` can lead to excess token used for padding, "
"which slows down calibration time and calibrates on padding tokens not"
" seen at runtime. Please consider reducing the calibration batch size "
"or using filtering the dataset to use more uniform sequence lengths"
)
tokenizer = getattr(processor, "tokenizer", processor)
if tokenizer.pad_token is None or tokenizer.pad_token_id < 0:
logger.debug("Could not find padding token. Setting PAD token to EOS token")
tokenizer.pad_token = tokenizer.eos_token
return DataCollatorWithPadding(tokenizer)
else:
raise ValueError(f"Unknown data collator {args.data_collator}")
def _is_dist_and_same_ds(dataset: Dataset) -> bool:
if not dist.is_initialized():
return False
assert len(dataset) > 0, (
"Dataset must have at least one sample on each"
f"device but got None for rank={dist.get_rank()}"
)
# use _fingerprint if it exists, otherwise hash the first sample.
# This isn't perfect but should work in most cases
local_hash = getattr(dataset, "_fingerprint", str(abs(hash(str(dataset[0])))))
all_hashes = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(all_hashes, local_hash)
return all(local_hash == other_hash for other_hash in all_hashes)
def _get_partition_start_end(
num_samples: int, index: int, num_partitions: int
) -> tuple[int, int]:
# num_samples / num_partitions is average samples per partition
# we multiply this number with the partition indices to get partition bounds
# note that final partition has index+1 == num_partitions so it will
# always get all samples
start = math.floor(num_samples * (index / num_partitions))
end = math.floor(num_samples * ((index + 1) / num_partitions))
return start, end
def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler:
num_samples = args.num_calibration_samples
shuffle = args.shuffle_calibration_samples
batch_size = args.batch_size
# detect whether we're in a distributed setting
# but all ranks have the same dataset.
if _is_dist_and_same_ds(dataset):
logger.info(
"Detected distributed setting with identical datasets across ranks. "
"partitioning dataset across ranks."
)
num_samples = num_samples or len(dataset)
start, end = _get_partition_start_end(
num_samples, dist.get_rank(), dist.get_world_size()
)
dataset = dataset.select(range(start, end))
if num_samples is not None and num_samples > len(dataset):
logger.warning(
f"Requested {num_samples} samples but the provided dataset only has "
f"{len(dataset)} samples."
)
num_samples = len(dataset)
if shuffle:
if batch_size > 1:
logger.warning(
"Shuffling the dataset can lead to unoptimal batching for sequence "
"lengths non-uniform sizes. When collating with truncation, this will "
"delete a large number of tokens. When collating with padding, this "
"will add a large number of padding tokens.\n\nPlease consider calling "
"`oneshot` with `batch_size=1`"
)
return RandomSampler(dataset, num_samples=num_samples)
else:
return LengthAwareSampler(
dataset, num_samples=num_samples, batch_size=batch_size
)
def data_collator_with_truncation(
features: list[dict[str, Any]], return_tensors: str = "pt"
) -> dict[str, Any]:
for key in ("input_ids", "labels", "attention_mask", "loss_mask"):
if any(key not in feature for feature in features):
continue
min_len = min(len(feature[key]) for feature in features)
for feature in features:
feature[key] = feature[key][:min_len]
return default_data_collator(features, return_tensors)
class LengthAwareSampler(Sampler[int]):
"""
Sample data in order of descending sequence length. Relies on `input_ids` or
`decoder_input_ids` column existing in dataset
:param data_source: dataset containing a `input_ids` or `decoder_input_ids` column
:param num_samples: Maximum number of samples to sample. Shorted sequence lengths
are truncated first
"""
data_source: Sized
replacement: bool
def __init__(
self,
data_source: Dataset,
num_samples: Optional[int] = None,
batch_size: int = 1,
) -> None:
self.data_source = data_source
self._num_samples = num_samples or len(data_source)
self.batch_size = batch_size
if "input_ids" in data_source.column_names:
feature_name = "input_ids"
elif "decoder_input_ids" in data_source.column_names:
feature_name = "decoder_input_ids"
else:
logger.warning(f"Could not find input ids in {data_source.column_names}")
self.order = range(len(data_source))
return
lengths = [len(sample) for sample in data_source[feature_name]]
self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist()
self._calculate_and_log_batch_stats(lengths)
def _calculate_and_log_batch_stats(self, lengths: list[int]):
if self.batch_size == 1:
return
logger.debug(
"LengthAwareSampler: Calculating batch statistics for "
f"{self.num_samples} samples with batch size {self.batch_size}"
)
sorted_lengths = [lengths[i] for i in self.order][: self.num_samples]
total_tokens_removed = 0
total_tokens_added = 0
for i in range(0, self.num_samples, self.batch_size):
batch_lengths = sorted_lengths[i : i + self.batch_size]
if not batch_lengths:
continue
shortest_in_batch = min(batch_lengths)
longest_in_batch = max(batch_lengths)
tokens_removed = sum(lgth - shortest_in_batch for lgth in batch_lengths)
tokens_added = sum(longest_in_batch - lgth for lgth in batch_lengths)
total_tokens_removed += tokens_removed
total_tokens_added += tokens_added
if total_tokens_removed > 0 or total_tokens_added > 0:
logger.debug(
f"LengthAwareSampler: Total token overhead - "
f"removed (truncation): {total_tokens_removed}, "
f"added (padding): {total_tokens_added}"
)
@property
def num_samples(self) -> int:
return self._num_samples
def __iter__(self) -> Iterator[int]:
return iter(self.order[: self._num_samples])
def __len__(self) -> int:
return self._num_samples
def get_rank_partition(split: str, num_samples: int) -> str:
"""
Utility for splitting data in a distributed setting
:param split: the split string to partition, e.g. "train"
:param num_samples: the total number of samples in the dataset to partition
:return: a partitioned split string
Usage example:
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
ds = load_dataset(
DATASET_ID,
split=split,
)
for S samples and D devices, when S is not perfectly divisible by D,
we give each device at least S//D samples and distribute
the remaining samples as evenly as possible across all devices
"""
assert (
"[" not in split
), "Split string should not already contain partitioning brackets"
start, end = _get_partition_start_end(
num_samples, dist.get_rank(), dist.get_world_size()
)
return f"{split}[{start}:{end}]"