|
| 1 | +import re |
| 2 | +from typing import Any, Callable, Dict, List, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from datasets import Dataset |
| 6 | +from loguru import logger |
| 7 | +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler |
| 8 | +from transformers.data import default_data_collator |
| 9 | + |
| 10 | +from llmcompressor.args import DatasetArguments |
| 11 | +from llmcompressor.transformers.finetune.data import TextGenerationDataset |
| 12 | +from llmcompressor.typing import Processor |
| 13 | + |
| 14 | + |
| 15 | +def get_processed_dataset( |
| 16 | + dataset_args: DatasetArguments, |
| 17 | + processor: Processor, |
| 18 | + do_oneshot: bool = False, |
| 19 | + do_train: bool = True, |
| 20 | +) -> Optional[Dict[str, Dataset]]: |
| 21 | + """ |
| 22 | + Loads datasets for each flow based on dataset_args, stores a Dataset for each |
| 23 | + enabled flow in datasets |
| 24 | + :param dataset_args: DatasetArguments that contain dataset loading and |
| 25 | + processing params |
| 26 | + :param processor: processor or tokenizer to use for dataset tokenization |
| 27 | + :param do_oneshot: True for oneshot pathway |
| 28 | + :param do_train: True for train pathway |
| 29 | + :return: A dataset corresponding to either train or calibration (oneshot) |
| 30 | + """ |
| 31 | + if dataset_args.dataset is None: |
| 32 | + logger.warning( |
| 33 | + "Running oneshot without calibration data. This is expected for " |
| 34 | + "weight-only and dynamic quantization" |
| 35 | + ) |
| 36 | + return |
| 37 | + |
| 38 | + splits = dataset_args.splits |
| 39 | + tokenized_datasets = {} |
| 40 | + |
| 41 | + def _get_split_name(inp_str): |
| 42 | + # strip out split name, for ex train[60%:] -> train |
| 43 | + match = re.match(r"(\w*)\[.*\]", inp_str) |
| 44 | + if match is not None: |
| 45 | + return match.group(1) |
| 46 | + return inp_str |
| 47 | + |
| 48 | + if splits is None: |
| 49 | + splits = {"all": None} |
| 50 | + elif isinstance(splits, str): |
| 51 | + splits = {_get_split_name(splits): splits} |
| 52 | + elif isinstance(splits, List): |
| 53 | + splits = {_get_split_name(s): s for s in splits} |
| 54 | + |
| 55 | + # default to custom dataset if dataset provided isn't a string |
| 56 | + registry_id = ( |
| 57 | + dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom" |
| 58 | + ) |
| 59 | + for split_name, split_str in splits.items(): |
| 60 | + dataset = dataset_args.dataset |
| 61 | + if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: |
| 62 | + # dataset is already tokenized |
| 63 | + tokenized_datasets[split_name] = dataset |
| 64 | + else: |
| 65 | + # dataset needs to be tokenized |
| 66 | + dataset_manager = TextGenerationDataset.load_from_registry( |
| 67 | + registry_id, |
| 68 | + dataset_args=dataset_args, |
| 69 | + split=split_str, |
| 70 | + processor=processor, |
| 71 | + ) |
| 72 | + tokenized_datasets[split_name] = dataset_manager(add_labels=do_train) |
| 73 | + |
| 74 | + return make_dataset_splits( |
| 75 | + tokenized_datasets, |
| 76 | + do_oneshot=do_oneshot, |
| 77 | + do_train=do_train, |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +def get_calibration_dataloader( |
| 82 | + dataset_args: DatasetArguments, |
| 83 | + processor: Processor, |
| 84 | +) -> torch.utils.data.DataLoader: |
| 85 | + """ |
| 86 | + Get the dataloader used for oneshot calibration. |
| 87 | + :param dataset_args: DatasetArguments that contains the dataset parameters. |
| 88 | + :param processor: Processor or the tokenizer of the model. |
| 89 | + :return: PyTorch dataloader object that contains the calibration dataset. |
| 90 | + """ |
| 91 | + if dataset_args.dataset is None: |
| 92 | + # weight-only quantization or dynamic quantization |
| 93 | + return |
| 94 | + |
| 95 | + datasets = get_processed_dataset( |
| 96 | + dataset_args=dataset_args, |
| 97 | + processor=processor, |
| 98 | + do_oneshot=True, |
| 99 | + do_train=False, |
| 100 | + ) |
| 101 | + |
| 102 | + calibration_dataset = datasets.get("calibration") |
| 103 | + |
| 104 | + return format_calibration_data( |
| 105 | + tokenized_dataset=calibration_dataset, |
| 106 | + num_calibration_samples=dataset_args.num_calibration_samples, |
| 107 | + do_shuffle=dataset_args.shuffle_calibration_samples, |
| 108 | + collate_fn=dataset_args.data_collator, |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | +def format_calibration_data( |
| 113 | + tokenized_dataset: Dataset, |
| 114 | + num_calibration_samples: Optional[int] = None, |
| 115 | + do_shuffle: bool = True, |
| 116 | + collate_fn: Callable = default_data_collator, |
| 117 | +) -> List[torch.Tensor]: |
| 118 | + """ |
| 119 | + Creates a dataloader out of the calibration dataset split, trimming it to |
| 120 | + the desired number of calibration samples |
| 121 | + :param tokenized_dataset: dataset to convert to dataloader |
| 122 | + :param num_calibration_samples: number of data samples to convert |
| 123 | + :param do_shuffle: whether to shuffle the dataset before selecting calibration |
| 124 | + samples, true by default |
| 125 | + :param collate_fn: optional custom collate function, or use default |
| 126 | + :return: list of trimmed calibration data tensors |
| 127 | + """ |
| 128 | + safe_calibration_samples = len(tokenized_dataset) |
| 129 | + if num_calibration_samples is not None: |
| 130 | + safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) |
| 131 | + if safe_calibration_samples != num_calibration_samples: |
| 132 | + logger.warn( |
| 133 | + f"Requested {num_calibration_samples} calibration samples but " |
| 134 | + f"the provided dataset only has {safe_calibration_samples}. " |
| 135 | + ) |
| 136 | + |
| 137 | + if do_shuffle: |
| 138 | + tokenized_dataset = tokenized_dataset.shuffle() |
| 139 | + tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) |
| 140 | + |
| 141 | + dataloader_params = { |
| 142 | + "batch_size": 1, |
| 143 | + "sampler": RandomSampler(tokenized_calibration) |
| 144 | + if do_shuffle |
| 145 | + else SequentialSampler(tokenized_calibration), |
| 146 | + "collate_fn": collate_fn, |
| 147 | + "pin_memory": True, |
| 148 | + } |
| 149 | + |
| 150 | + calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) |
| 151 | + |
| 152 | + return calibration_dataloader |
| 153 | + |
| 154 | + |
| 155 | +def make_dataset_splits( |
| 156 | + tokenized_datasets: Dict[str, Any], |
| 157 | + do_oneshot: bool = True, |
| 158 | + do_train: bool = False, |
| 159 | +) -> Dict[str, Dataset]: |
| 160 | + """ |
| 161 | + Restructures the datasets dictionary based on what tasks will be run |
| 162 | + train |
| 163 | + :param tokenized_datasets: dictionary of processed datasets |
| 164 | + :param do_oneshot: Whether to store the calibration dataset |
| 165 | + :return: A dataset corresponding to either train or calibration (oneshot) |
| 166 | + """ |
| 167 | + |
| 168 | + # handles case where all splits are contained in a single dataset |
| 169 | + if "all" in tokenized_datasets and len(tokenized_datasets) == 1: |
| 170 | + tokenized_datasets = tokenized_datasets.get("all") |
| 171 | + if isinstance(tokenized_datasets, Dataset): |
| 172 | + tokenized_datasets = {"train": tokenized_datasets} |
| 173 | + |
| 174 | + train_split = calib_split = None |
| 175 | + |
| 176 | + if do_train: |
| 177 | + if "train" not in tokenized_datasets: |
| 178 | + raise ValueError("--do_train requires a train dataset") |
| 179 | + train_split = tokenized_datasets["train"] |
| 180 | + if do_oneshot: |
| 181 | + calib_split = tokenized_datasets.get("calibration") |
| 182 | + if calib_split is None: |
| 183 | + if "train" not in tokenized_datasets: |
| 184 | + raise ValueError("--do_oneshot requires a calibration dataset") |
| 185 | + calib_split = tokenized_datasets["train"] |
| 186 | + |
| 187 | + split_datasets = { |
| 188 | + "train": train_split, |
| 189 | + "calibration": calib_split, |
| 190 | + } |
| 191 | + return split_datasets |
0 commit comments