Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/disk_offloading/kimi_k2_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
model=model,
processor=tokenizer,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
splits=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
Expand Down
2 changes: 1 addition & 1 deletion examples/disk_offloading/qwen3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
oneshot(
model=model,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
splits=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
Expand Down
2 changes: 1 addition & 1 deletion examples/imatrix/llama3_imatrix_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
oneshot(
model=model,
dataset=DATASET_ID,
splits={"calibration": "train[:5%]"},
splits="train[:5%]",
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
splits=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/mistral3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def data_collator(features):
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
splits=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/mllama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
splits=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/pixtral_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def data_collator(features):
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
splits=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
data_collator=data_collator,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
Expand Down
9 changes: 8 additions & 1 deletion src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,14 @@ class DatasetArguments(CustomDatasetArguments):
)
splits: None | str | list[str] | dict[str, str] = field(
default=None,
metadata={"help": "Optional percentages of each split to download"},
metadata={
"help": (
"Optional dataset split selector. Passing a string like 'train' or "
"'train[:50%]' is strongly recommended. Legacy dict input is "
"deprecated and only supported for calibration compatibility "
"(for example: {'calibration': 'train[:50%]'})."
)
},
Comment thread
arpitkh101 marked this conversation as resolved.
Comment thread
arpitkh101 marked this conversation as resolved.
)
num_calibration_samples: int | None = field(
default=512,
Expand Down
1 change: 0 additions & 1 deletion src/llmcompressor/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@
format_calibration_data,
get_calibration_dataloader,
get_processed_dataset,
make_dataset_splits,
get_rank_partition,
)
155 changes: 70 additions & 85 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import math
import re
from collections.abc import Iterator, Sized
from typing import Any, Callable, Optional

Expand All @@ -29,18 +28,13 @@
def get_processed_dataset(
dataset_args: DatasetArguments,
processor: Processor | None = None,
do_oneshot: bool = False,
do_train: bool = True,
) -> dict[str, Dataset] | None:
) -> Dataset | None:
"""
Comment thread
arpitkh101 marked this conversation as resolved.
Loads datasets for each flow based on dataset_args, stores a Dataset for each
enabled flow in datasets
Loads dataset based on dataset_args.
:param dataset_args: DatasetArguments that contain dataset loading and
processing params
:param processor: processor or tokenizer to use for dataset tokenization
:param do_oneshot: True for oneshot pathway
:param do_train: True for train pathway
:return: A dataset corresponding to either train or calibration (oneshot)
:return: A Dataset corresponding to the single split for calibration
"""
if dataset_args.dataset is None:
logger.warning(
Expand All @@ -50,51 +44,81 @@ def get_processed_dataset(
return

splits = dataset_args.splits
tokenized_datasets = {}

def _get_split_name(inp_str):
# strip out split name, for ex train[60%:] -> train
split_name_match = re.match(r"(\w*)\[.*\]", inp_str)
if split_name_match is not None:
return split_name_match.group(1)
return inp_str

match splits:
case None:
splits = {"all": None}
split_str = None
case str():
splits = {_get_split_name(splits): splits}
case list():
Comment thread
arpitkh101 marked this conversation as resolved.
splits = {_get_split_name(s): s for s in splits}
split_str = splits
Comment thread
arpitkh101 marked this conversation as resolved.
case dict():
pass
if "calibration" in splits:
split_str = splits["calibration"]
if len(splits) > 1:
ignored_keys = set(splits.keys()) - {"calibration"}
logger.warning(
f"Ignoring extra keys in splits: {list(ignored_keys)}. "
"Only the 'calibration' split is used."
)
else:
raise ValueError(
"Passing `splits` as a dict is only supported when it contains a "
"`'calibration'` key during the deprecation period. "
"Please pass a split string instead."
)

logger.warning(
"Passing `splits` as a dictionary is deprecated. "
f"Extracted split string: '{split_str}'. "
"Please pass `splits` as a string instead."
)
case list():
split_str = splits[0] if len(splits) > 0 else None
logger.warning(
"Passing `splits` as a list is deprecated. "
f"Using first element: '{split_str}'. "
"Please pass `splits` as a string instead."
)
case _:
raise ValueError(f"Invalid splits type: {type(splits)}")
raise ValueError(
f"Invalid splits type: {type(splits)}. Expected a split string "
"or the deprecated `{'calibration': ...}` form."
)

# default to custom dataset if dataset provided isn't a string
registry_id = (
dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
)
for split_name, split_str in splits.items():
dataset = dataset_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
# dataset needs to be tokenized
dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
dataset_args=dataset_args,
split=split_str,
processor=processor,
)
tokenized_datasets[split_name] = dataset_manager(add_labels=do_train)

return make_dataset_splits(
tokenized_datasets,
do_oneshot=do_oneshot,
do_train=do_train,
)
dataset = dataset_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
return dataset
else:
# dataset needs to be tokenized
dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
dataset_args=dataset_args,
split=split_str,
processor=processor,
)
dataset = dataset_manager()

# If no split was specified, a DatasetDict format is typically returned.
# Fallback to the 'train' split for backward compatibility.
if not isinstance(dataset, Dataset):
if "train" in dataset:
logger.warning(
"No split was specified, but a multi-split dataset was loaded. "
"Falling back to the 'train' split for calibration."
)
dataset = dataset["train"]
else:
raise ValueError(
"No split specified and 'train' split not found in dataset. "
"Please specify `splits` explicitly."
)

return dataset


def get_calibration_dataloader(
Expand All @@ -119,13 +143,13 @@ def get_calibration_dataloader(
if isinstance(dataset_args.dataset, DataLoader):
return dataset_args.dataset

datasets = get_processed_dataset(
calibration_dataset = get_processed_dataset(
dataset_args=dataset_args,
processor=processor,
do_oneshot=True,
do_train=False,
)
calibration_dataset = datasets.get("calibration")

if calibration_dataset is None:
return None

return format_calibration_data(dataset_args, calibration_dataset, processor)

Expand Down Expand Up @@ -155,45 +179,6 @@ def format_calibration_data(
)


def make_dataset_splits(
tokenized_datasets: dict[str, Any],
do_oneshot: bool = True,
do_train: bool = False,
) -> dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
train
:param tokenized_datasets: dictionary of processed datasets
:param do_oneshot: Whether to store the calibration dataset
:return: A dataset corresponding to either train or calibration (oneshot)
"""

# handles case where all splits are contained in a single dataset
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
tokenized_datasets = tokenized_datasets.get("all")
if isinstance(tokenized_datasets, Dataset):
tokenized_datasets = {"train": tokenized_datasets}

train_split = calib_split = None

if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_split = tokenized_datasets["train"]
if do_oneshot:
calib_split = tokenized_datasets.get("calibration")
if calib_split is None:
if "train" not in tokenized_datasets:
raise ValueError("--do_oneshot requires a calibration dataset")
calib_split = tokenized_datasets["train"]

split_datasets = {
"train": train_split,
"calibration": calib_split,
}
return split_datasets


def _make_collate_fn(args: DatasetArguments, processor: Processor) -> Callable:
if isinstance(args.data_collator, Callable):
return args.data_collator
Expand Down
8 changes: 4 additions & 4 deletions src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def trace(
dataset = TextGenerationDataset.load_from_registry(
dataset_args.dataset,
dataset_args=dataset_args,
split=dataset_args.splits["calibration"],
split=dataset_args.splits,
processor=processor,
)(add_labels=False)
sample = next(iter(dataset))
Expand Down Expand Up @@ -121,17 +121,17 @@ def get_dataset_kwargs(modality: str, ignore: list[str]) -> dict[str, str]:
dataset_kwargs = {
"text": {
"dataset": "ultrachat-200k",
"splits": {"calibration": "test_sft[:1]"},
"splits": "test_sft[:1]",
"max_seq_length": 4096,
},
"vision": {
"dataset": "flickr",
"splits": {"calibration": "test[:1]"},
"splits": "test[:1]",
"max_seq_length": 4096,
},
"audio": {
"dataset": "peoples_speech",
"splits": {"calibration": "test[:1]"},
"splits": "test[:1]",
"max_seq_length": 4096,
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_pipeline_produces_quantized_model(self):
oneshot(
model=model,
dataset=DATASET,
splits={"calibration": "train[:5%]"},
splits="train[:5%]",
recipe=recipe,
num_calibration_samples=NUM_CALIB_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_pipeline_produces_quantized_model(self):
oneshot(
model=model_no_gatherer,
dataset=DATASET,
splits={"calibration": "train[:5%]"},
splits="train[:5%]",
recipe=recipe_no_gatherer,
num_calibration_samples=NUM_CALIB_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_gatherer_without_observer_no_crash(self):
oneshot(
model=model,
dataset=DATASET,
splits={"calibration": "train[:5%]"},
splits="train[:5%]",
recipe=recipe,
num_calibration_samples=NUM_CALIB_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_pipeline_with_regex_targets(self):
oneshot(
model=model,
dataset=DATASET,
splits={"calibration": "train[:5%]"},
splits="train[:5%]",
recipe=recipe,
num_calibration_samples=NUM_CALIB_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_observer_without_gatherer_fallback(self):
oneshot(
model=model,
dataset=DATASET,
splits={"calibration": "train[:5%]"},
splits="train[:5%]",
recipe=recipe,
num_calibration_samples=NUM_CALIB_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_smoothquant_e2e():
oneshot(
model=model,
dataset="open_platypus",
splits={"calibration": "train[:10%]"},
splits="train[:10%]",
recipe=SmoothQuantModifier(smoothing_strength=0.5),
num_calibration_samples=4,
max_seq_length=128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_quant_model_compressed(tmp_path):
model_path = "nm-testing/tinysmokellama-3.2"
dataset = "open_platypus"
num_calibration_samples = 16
splits = {"calibration": f"train[:{num_calibration_samples}]"}
splits = f"train[:{num_calibration_samples}]"

# create a compressed
model = oneshot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def setup_model_and_config(request, tmpdir_factory):
num_calibration_samples=num_calibration_samples,
recipe=config["new_recipe"],
pad_to_max_length=pad_to_max_length,
splits={"calibration": f"train_gen[:{num_calibration_samples}]"},
splits=f"train_gen[:{num_calibration_samples}]",
save_compressed=False,
)

Expand Down
Loading