Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
40c6112
feat: add PartiPrompts benchmark and extend data/benchmark system
davidberenstein1957 Mar 5, 2026
02fdf76
refactor: format import statements in utils.py for improved readability
davidberenstein1957 Mar 5, 2026
f69faab
chore: update pyproject.toml and refactor function signatures for con…
davidberenstein1957 Mar 5, 2026
68fe5db
chore: update ruff ignore rules and enhance documentation in evaluati…
davidberenstein1957 Mar 5, 2026
eda9c5b
refactor: update benchmark descriptions and metrics in BenchmarkRegistry
davidberenstein1957 Mar 5, 2026
0d3a4b3
refactor: update stratify_dataset function for improved sampling flex…
davidberenstein1957 Mar 19, 2026
4f76ced
refactor: remove seed parameter from stratify_dataset calls in datase…
davidberenstein1957 Mar 19, 2026
d47643d
refactor: streamline import statements and update docstring in evalua…
davidberenstein1957 Mar 19, 2026
289571a
chore: update ruff version and enhance docstring in stratify_dataset
davidberenstein1957 Mar 19, 2026
12b58b0
feat: introduce Benchmark and BenchmarkRegistry classes for evaluatio…
davidberenstein1957 Mar 19, 2026
67fa068
chore: revert changes chore: downgrade ruff version and refine docstr…
davidberenstein1957 Mar 19, 2026
2100e42
refactor: improve parameter handling and method naming in evaluation …
davidberenstein1957 Mar 19, 2026
347134b
refactor: enhance docstring formatting and parameter descriptions in …
davidberenstein1957 Mar 19, 2026
789d02f
refactor: improve docstring clarity in MetricRegistry and stratify_da…
davidberenstein1957 Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
setup_mnist_dataset,
)
from pruna.data.datasets.prompt import (
setup_dpg_dataset,
setup_drawbench_dataset,
setup_gedit_dataset,
setup_genai_bench_dataset,
setup_geneval_dataset,
setup_hps_dataset,
setup_imgedit_dataset,
setup_long_text_bench_dataset,
setup_oneig_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -97,8 +104,19 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (
setup_parti_prompts_dataset,
"prompt_with_auxiliaries_collate",
{},
),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}),
"HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}),
"ImgEdit": (setup_imgedit_dataset, "prompt_with_auxiliaries_collate", {}),
"LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}),
"GEditBench": (setup_gedit_dataset, "prompt_with_auxiliaries_collate", {}),
"OneIG": (setup_oneig_dataset, "prompt_with_auxiliaries_collate", {}),
"DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}
21 changes: 6 additions & 15 deletions src/pruna/data/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from datasets import Dataset, load_dataset

from pruna.data.utils import (
define_sample_size_for_dataset,
split_train_into_train_val,
split_val_into_val_test,
stratify_dataset,
Expand Down Expand Up @@ -54,11 +53,8 @@ def setup_mnist_dataset(
train_ds = cast(Dataset, train_ds)
test_ds = cast(Dataset, test_ds)

train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)

train_ds = stratify_dataset(train_ds, train_sample_size, seed)
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
train_ds = stratify_dataset(train_ds, sample_size=train_sample_size, fraction=fraction, seed=seed)
test_ds = stratify_dataset(test_ds, sample_size=test_sample_size, fraction=fraction, seed=seed)

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
val_ds, test_ds = split_val_into_val_test(val_ds, seed)
Expand Down Expand Up @@ -93,11 +89,9 @@ def setup_imagenet_dataset(
train_ds, val = load_dataset("zh-plus/tiny-imagenet", split=["train", "valid"]) # type: ignore[misc]
train_ds = cast(Dataset, train_ds)
val = cast(Dataset, val)
train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
train_ds = stratify_dataset(train_ds, train_sample_size, seed)
train_ds = stratify_dataset(train_ds, sample_size=train_sample_size, fraction=fraction, seed=seed)
val_ds, test_ds = split_val_into_val_test(val, seed)
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
test_ds = stratify_dataset(test_ds, sample_size=test_sample_size, fraction=fraction, seed=seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]


Expand Down Expand Up @@ -136,11 +130,8 @@ def setup_cifar10_dataset(
train_ds = train_ds.rename_column("img", "image")
test_ds = test_ds.rename_column("img", "image")

train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)

train_ds = stratify_dataset(train_ds, train_sample_size, seed)
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
train_ds = stratify_dataset(train_ds, sample_size=train_sample_size, fraction=fraction, seed=seed)
test_ds = stratify_dataset(test_ds, sample_size=test_sample_size, fraction=fraction, seed=seed)

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]
Loading