Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"OpenImage": (
setup_open_image_dataset,
"image_generation_collate",
{"img_size": 1024},
{"img_size": 1024, "column_map": {"image": "image_quality_dev", "text": "quality_prompt"}},
),
"CIFAR10": (
setup_cifar10_dataset,
Expand Down
69 changes: 52 additions & 17 deletions src/pruna/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ def image_format_to_transforms(output_format: str, img_size: int) -> transforms.
raise ValueError(f"Invalid output format: {output_format}")


def _resolve_column(column_map: dict[str, str] | None, canonical: str) -> str:
"""Resolve a canonical column name using an optional mapping."""
if column_map is None:
return canonical
return column_map.get(canonical, canonical)


def image_generation_collate(
data: Any, img_size: int, output_format: str = "int"
data: Any, img_size: int, output_format: str = "int", column_map: dict[str, str] | None = None
) -> Tuple[List[str], Union[Float[torch.Tensor, ImageShape], Int[torch.Tensor, ImageShape]]]:
"""
Custom collation function for text-to-image generation datasets.
Expand All @@ -77,28 +84,34 @@ def image_generation_collate(
The output format, in ["int", "float", "normalized"].
With "int", output tensors have integer values between 0 and 255. With "float", they have float values
between 0 and 1. With "normalized", they have float values between -1 and 1.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``image``, ``text``) to actual dataset column names.
For example, ``{"image": "chosen", "text": "prompt"}`` maps the canonical ``image`` key to
a ``chosen`` column and ``text`` to ``prompt`` in the dataset.

Returns
-------
Tuple[torch.Tensor, Any]
The collated data with size img_size and normalized to [0, 1].
"""
transformations = image_format_to_transforms(output_format, img_size)
image_col = _resolve_column(column_map, "image")
text_col = _resolve_column(column_map, "text")
images, texts = [], []

for item in data:
image = item["image"]
image = item[image_col]
if image.mode != "RGB":
image = image.convert("RGB")
image_tensor = transformations(image)
images.append(image_tensor)
texts.append(item["text"])
texts.append(item[text_col])

images_tensor = torch.stack(images)
return texts, images_tensor


def prompt_collate(data: Any) -> Tuple[List[str], None]:
def prompt_collate(data: Any, column_map: dict[str, str] | None = None) -> Tuple[List[str], None]:
"""
Custom collation function for prompt datasets.

Expand All @@ -108,17 +121,21 @@ def prompt_collate(data: Any) -> Tuple[List[str], None]:
----------
data : Any
The data to collate.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``text``) to actual dataset column names.

Returns
-------
Tuple[List[str], None]
The collated data.
"""
return [item["text"] for item in data], None
text_col = _resolve_column(column_map, "text")
return [item[text_col] for item in data], None


def prompt_with_auxiliaries_collate(
data: Any,
column_map: dict[str, str] | None = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
"""
Custom collation function for prompt datasets with auxiliaries.
Expand All @@ -131,20 +148,23 @@ def prompt_with_auxiliaries_collate(
----------
data : Any
The data to collate.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``text``) to actual dataset column names.

Returns
-------
Tuple[List[str], Any]
The collated data.
"""
text_col = _resolve_column(column_map, "text")
# The text column has the prompt.
prompt_list = [item["text"] for item in data]
prompt_list = [item[text_col] for item in data]
# All the other columns that might include category, scene information, etc.
auxiliary_list = [{k: v for k, v in row.items() if k != "text"} for row in data]
auxiliary_list = [{k: v for k, v in row.items() if k != text_col} for row in data]
return prompt_list, auxiliary_list


def audio_collate(data: Any) -> Tuple[List[str], List[str]]:
def audio_collate(data: Any, column_map: dict[str, str] | None = None) -> Tuple[List[str], List[str]]:
"""
Custom collation function for audio datasets.

Expand All @@ -155,17 +175,21 @@ def audio_collate(data: Any) -> Tuple[List[str], List[str]]:
----------
data : Any
The data to collate.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``audio``, ``sentence``) to actual dataset column names.

Returns
-------
List[str]
The collated data.
"""
return [item["audio"]["path"] for item in data], [item["sentence"] for item in data]
audio_col = _resolve_column(column_map, "audio")
sentence_col = _resolve_column(column_map, "sentence")
return [item[audio_col]["path"] for item in data], [item[sentence_col] for item in data]


def image_classification_collate(
data: Any, img_size: int, output_format: str = "int"
data: Any, img_size: int, output_format: str = "int", column_map: dict[str, str] | None = None
) -> Tuple[Float[torch.Tensor, ImageShape], Int[torch.Tensor, LabelShape]]:
"""
Custom collation function for image classification datasets.
Expand All @@ -180,29 +204,33 @@ def image_classification_collate(
The output format, in ["int", "float", "normalized"].
With "int", output tensors have integer values between 0 and 255. With "float", they have float values
between 0 and 1. With "normalized", they have float values between -1 and 1.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``image``, ``label``) to actual dataset column names.

Returns
-------
Tuple[torch.Tensor, torch.Tensor]
The collated data.
"""
transformations = image_format_to_transforms(output_format, img_size)
image_col = _resolve_column(column_map, "image")
label_col = _resolve_column(column_map, "label")
images, labels = [], []

for item in data:
image = item["image"]
image = item[image_col]
if image.mode != "RGB":
image = image.convert("RGB")
image_tensor = transformations(image)
images.append(image_tensor)
labels.append(item["label"])
labels.append(item[label_col])

images_tensor = torch.stack(images).float()
return images_tensor, torch.tensor(labels)


def text_generation_collate(
data: Any, max_seq_len: int | None, tokenizer: AutoTokenizer
data: Any, max_seq_len: int | None, tokenizer: AutoTokenizer, column_map: dict[str, str] | None = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Custom collation function for text generation datasets.
Expand All @@ -217,17 +245,20 @@ def text_generation_collate(
The maximum sequence length.
tokenizer : AutoTokenizer
The tokenizer to use.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``text``) to actual dataset column names.

Returns
-------
Tuple[torch.Tensor, torch.Tensor]
The collated data.
"""
text_col = _resolve_column(column_map, "text")
input_ids = []
for sample in data:
input_ids.append(
tokenizer(
sample["text"],
sample[text_col],
max_length=max_seq_len,
truncation=True,
padding="max_length" if max_seq_len else False,
Expand All @@ -238,7 +269,7 @@ def text_generation_collate(


def question_answering_collate(
data: Any, max_seq_len: int, tokenizer: AutoTokenizer
data: Any, max_seq_len: int, tokenizer: AutoTokenizer, column_map: dict[str, str] | None = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Custom collation function for question answering datasets.
Expand All @@ -253,25 +284,29 @@ def question_answering_collate(
The maximum sequence length.
tokenizer : AutoTokenizer
The tokenizer to use.
column_map : dict[str, str] | None
Optional mapping from canonical column names (``question``, ``answer``) to actual dataset column names.

Returns
-------
Tuple[torch.Tensor, torch.Tensor]
The collated data.
"""
question_col = _resolve_column(column_map, "question")
answer_col = _resolve_column(column_map, "answer")
questions, answers = [], []
for sample in data:
questions.append(
tokenizer(
sample["question"],
sample[question_col],
max_length=max_seq_len,
truncation=True,
padding="max_length",
)["input_ids"]
)
answers.append(
tokenizer(
sample["answer"],
sample[answer_col],
max_length=max_seq_len,
truncation=True,
padding="max_length",
Expand Down
2 changes: 0 additions & 2 deletions src/pruna/data/datasets/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def setup_open_image_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
The OpenImage dataset.
"""
dataset = load_dataset("data-is-better-together/open-image-preferences-v1")["cleaned"] # type: ignore[index]
dataset = dataset.rename_column("image_quality_dev", "image")
dataset = dataset.rename_column("quality_prompt", "text")
return split_train_into_train_val_test(dataset, seed)


Expand Down
Loading