diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 3a811868..6b3c898f 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -26,8 +26,15 @@ setup_mnist_dataset, ) from pruna.data.datasets.prompt import ( + setup_dpg_dataset, setup_drawbench_dataset, + setup_gedit_dataset, setup_genai_bench_dataset, + setup_geneval_dataset, + setup_hps_dataset, + setup_imgedit_dataset, + setup_long_text_bench_dataset, + setup_oneig_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -97,8 +104,19 @@ {"img_size": 224}, ), "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), - "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), + "PartiPrompts": ( + setup_parti_prompts_dataset, + "prompt_with_auxiliaries_collate", + {}, + ), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}), + "HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}), + "ImgEdit": (setup_imgedit_dataset, "prompt_with_auxiliaries_collate", {}), + "LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}), + "GEditBench": (setup_gedit_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIG": (setup_oneig_dataset, "prompt_with_auxiliaries_collate", {}), + "DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } diff --git a/src/pruna/data/datasets/image.py b/src/pruna/data/datasets/image.py index 0fd2fff4..679bd4ab 100644 --- a/src/pruna/data/datasets/image.py +++ b/src/pruna/data/datasets/image.py @@ -19,7 +19,6 @@ from datasets import Dataset, load_dataset from pruna.data.utils import ( - define_sample_size_for_dataset, split_train_into_train_val, split_val_into_val_test, stratify_dataset, @@ -54,11 +53,8 @@ def setup_mnist_dataset( train_ds = cast(Dataset, train_ds) test_ds = cast(Dataset, test_ds) - train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size) - test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size) - - train_ds = stratify_dataset(train_ds, train_sample_size, seed) - test_ds = stratify_dataset(test_ds, test_sample_size, seed) + train_ds = stratify_dataset(train_ds, sample_size=train_sample_size, fraction=fraction, seed=seed) + test_ds = stratify_dataset(test_ds, sample_size=test_sample_size, fraction=fraction, seed=seed) train_ds, val_ds = split_train_into_train_val(train_ds, seed) val_ds, test_ds = split_val_into_val_test(val_ds, seed) @@ -93,11 +89,9 @@ def setup_imagenet_dataset( train_ds, val = load_dataset("zh-plus/tiny-imagenet", split=["train", "valid"]) # type: ignore[misc] train_ds = cast(Dataset, train_ds) val = cast(Dataset, val) - train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size) - train_ds = stratify_dataset(train_ds, train_sample_size, seed) + train_ds = stratify_dataset(train_ds, sample_size=train_sample_size, fraction=fraction, seed=seed) val_ds, test_ds = split_val_into_val_test(val, seed) - test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size) - test_ds = stratify_dataset(test_ds, test_sample_size, seed) + test_ds = stratify_dataset(test_ds, sample_size=test_sample_size, fraction=fraction, seed=seed) return train_ds, val_ds, test_ds # type: ignore[return-value] @@ -136,11 +130,8 @@ def setup_cifar10_dataset( train_ds = train_ds.rename_column("img", "image") test_ds = test_ds.rename_column("img", "image") - train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size) - test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size) - - train_ds = stratify_dataset(train_ds, train_sample_size, seed) - test_ds = stratify_dataset(test_ds, test_sample_size, seed) + train_ds = stratify_dataset(train_ds, sample_size=train_sample_size, fraction=fraction, seed=seed) + test_ds = stratify_dataset(test_ds, sample_size=test_sample_size, fraction=fraction, seed=seed) train_ds, val_ds = split_train_into_train_val(train_ds, seed) return train_ds, val_ds, test_ds # type: ignore[return-value] diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 1f6fab71..7764d23b 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -12,24 +12,141 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Literal, Tuple, get_args from datasets import Dataset, load_dataset +from pruna.data.utils import stratify_dataset from pruna.logging.logger import pruna_logger +GenEvalCategory = Literal["single_object", "two_object", "counting", "colors", "position", "color_attr"] +HPSCategory = Literal["anime", "concept-art", "paintings", "photo"] +OneIGCategory = Literal[ + "Anime_Stylization", + "General_Object", + "Knowledge_Reasoning", + "Multilingualism", + "Portrait", + "Text_Rendering", + "3d rendering", + "Baroque", + "Celluloid", + "Chibi", + "Chinese ink painting", + "Cyberpunk", + "Ghibli", + "LEGO", + "None", + "PPT generation", + "Pixar", + "Rococo", + "Ukiyo-e", + "abstract expressionism", + "advertising imagery", + "art nouveau", + "artistic renderings", + "biology", + "blackboard text", + "chemistry", + "clay", + "comic", + "common sense", + "computer science", + "crayon", + "cubism", + "fauvism", + "floating-frame text", + "geography", + "graffiti", + "graffiti-style text", + "impasto", + "impressionism", + "line art", + "long text rendering", + "mathematics", + "menu", + "minimalism", + "natural-scene text", + "noir", + "pencil sketch", + "physics", + "pixel art", + "pointillism", + "pop art", + "poster design", + "silvertone", + "stone sculpture", + "vintage", + "vivid cold", + "vivid warm", + "watercolor", +] +PartiCategory = Literal[ + "Abstract", + "Animals", + "Artifacts", + "Arts", + "Food & Beverage", + "Illustrations", + "Indoor Scenes", + "Outdoor Scenes", + "People", + "Produce & Plants", + "Vehicles", + "World Knowledge", + "Basic", + "Complex", + "Fine-grained Detail", + "Imagination", + "Linguistic Structures", + "Perspective", + "Properties & Positioning", + "Quantity", + "Simple Detail", + "Style & Format", + "Writing & Symbols", +] +ImgEditCategory = Literal["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"] +GEditBenchCategory = Literal[ + "background_change", + "color_alter", + "material_alter", + "motion_change", + "ps_human", + "style_change", + "subject_add", + "subject_remove", + "subject_replace", + "text_change", + "tone_transfer", +] +DPGCategory = Literal["entity", "attribute", "relation", "global", "other"] -def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: + +def _to_oneig_record(row: dict, questions_by_key: dict[str, dict]) -> dict: + """Convert OneIG row to unified record format.""" + row_category = row.get("category", "") + row_class = row.get("class", "None") or "None" + qd_name = _CATEGORY_TO_QD.get(row_category, "") + lookup_key = f"{qd_name}_{row.get('id', '')}" if qd_name else "" + q_info = questions_by_key.get(lookup_key, {}) + return { + "text": row.get("prompt_en", row.get("prompt", "")), + "subset": "Text_Rendering" if row_category in ("Text_Rendering", "Text Rendering") else row_category, + "text_content": row_class if row_class != "None" else None, + "category": row_category, + "class": row_class, + "questions": q_info.get("questions", {}), + "dependencies": q_info.get("dependencies", {}), + } + + +def setup_drawbench_dataset() -> Tuple[Dataset, Dataset, Dataset]: """ Setup the DrawBench dataset. License: Apache 2.0 - Parameters - ---------- - seed : int - The seed to use. - Returns ------- Tuple[Dataset, Dataset, Dataset] @@ -41,7 +158,13 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: return ds.select([0]), ds.select([0]), ds -def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_parti_prompts_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: PartiCategory | list[PartiCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the Parti Prompts dataset. @@ -51,21 +174,177 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + The sample size to use for the train dataset (unused; train/val are dummy). + test_sample_size : int | None + The sample size to use for the test dataset. + category : PartiCategory | list[PartiCategory] | None + Filter by Category or Challenge. Returns ------- Tuple[Dataset, Dataset, Dataset] - The Parti Prompts dataset. + The Parti Prompts dataset (dummy train, dummy val, test). """ ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] + + if category is not None: + categories = [category] if not isinstance(category, list) else category + ds = ds.filter(lambda x: x["Category"] in categories or x["Challenge"] in categories) + + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) ds = ds.rename_column("Prompt", "text") pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds -def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def _generate_geneval_question(entry: dict) -> list[str]: + """Generate evaluation questions from GenEval metadata.""" + tag = entry.get("tag", "") + include = entry.get("include", []) + questions = [] + + for obj in include: + cls = obj.get("class", "") + if "color" in obj: + questions.append(f"Does the image contain a {obj['color']} {cls}?") + elif "count" in obj: + questions.append(f"Does the image contain exactly {obj['count']} {cls}(s)?") + else: + questions.append(f"Does the image contain a {cls}?") + + if tag == "position" and len(include) >= 2: + a_cls = include[0].get("class", "") + b_cls = include[1].get("class", "") + pos = include[1].get("position") + if pos and pos[0]: + questions.append(f"Is the {b_cls} {pos[0]} the {a_cls}?") + + return questions + + +def setup_geneval_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: GenEvalCategory | list[GenEvalCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: """ - Setup the GenAI Bench dataset. + Setup the GenEval benchmark dataset. + + License: MIT + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + category : GenEvalCategory | list[GenEvalCategory] | None + Filter by category. Available: single_object, two_object, counting, colors, position, color_attr. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The GenEval dataset (dummy train, dummy val, test). + """ + import json + + import requests + + url = "https://raw.githubusercontent.com/djghosh13/geneval/d927da8e42fde2b1b5cd743da4df5ff83c1654ff/prompts/evaluation_metadata.jsonl" + response = requests.get(url, timeout=30) + response.raise_for_status() + data = [json.loads(line) for line in response.text.splitlines()] + + if category is not None: + categories = [category] if not isinstance(category, list) else category + data = [entry for entry in data if entry.get("tag") in categories] + + records = [] + for entry in data: + questions = _generate_geneval_question(entry) + records.append( + { + "text": entry["prompt"], + "tag": entry.get("tag", ""), + "questions": questions, + } + ) + + ds = Dataset.from_list(records) + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + pruna_logger.info("GenEval is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_hps_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: HPSCategory | list[HPSCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the HPD (Human Preference Dataset) for the HPS (Human Preference Score) benchmark. + + License: MIT + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + category : HPSCategory | list[HPSCategory] | None + Filter by category. Available: anime, concept-art, paintings, photo. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The HPD dataset (dummy train, dummy val, test). + """ + import json + + from huggingface_hub import hf_hub_download + + categories_to_load = ( + list(get_args(HPSCategory)) if category is None else ([category] if not isinstance(category, list) else category) + ) + + all_prompts = [] + for cat in categories_to_load: + file_path = hf_hub_download("zhwang/HPDv2", f"{cat}.json", subfolder="benchmark", repo_type="dataset") + with open(file_path, "r", encoding="utf-8") as f: + prompts = json.load(f) + for prompt in prompts: + all_prompts.append({"text": prompt, "category": cat}) + + ds = Dataset.from_list(all_prompts) + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + pruna_logger.info("HPD is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_long_text_bench_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the Long Text Bench dataset. License: Apache 2.0 @@ -73,6 +352,31 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The Long Text Bench dataset (dummy train, dummy val, test). + """ + ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index] + ds = ds.rename_column("text", "text_content") + ds = ds.rename_column("prompt", "text") + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + pruna_logger.info("LongTextBench is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_genai_bench_dataset() -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the GenAI Bench dataset. + + License: Apache 2.0 Returns ------- @@ -83,3 +387,288 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ds = ds.rename_column("Prompt", "text") pruna_logger.info("GenAI-Bench is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds + + +def setup_imgedit_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: ImgEditCategory | list[ImgEditCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the ImgEdit benchmark dataset for image editing evaluation. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + category : ImgEditCategory | list[ImgEditCategory] | None + Filter by edit type. Available: replace, add, remove, adjust, extract, style, + background, compose. If None, returns all categories. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The ImgEdit dataset (dummy train, dummy val, test). + """ + import json + + import requests + + instructions_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/b3eb8e74d7cd1fd0ce5341eaf9254744a8ab4c0b/Benchmark/Basic/basic_edit.json" + judge_prompts_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/c14480ac5e7b622e08cd8c46f96624a48eb9ab46/Benchmark/Basic/prompts.json" + response_instructions = requests.get(instructions_url, timeout=30) + response_judge_prompts = requests.get(judge_prompts_url, timeout=30) + response_instructions.raise_for_status() + response_judge_prompts.raise_for_status() + instructions: dict = json.loads(response_instructions.text) + judge_prompts: dict = json.loads(response_judge_prompts.text) + + categories = [category] if category is not None and not isinstance(category, list) else category + records = [] + for _, instruction in instructions.items(): + edit_type = instruction.get("edit_type", "") + + if categories is not None and edit_type not in categories: + continue + + records.append( + { + "text": instruction.get("prompt", ""), + "category": edit_type, + "image_id": instruction.get("id", ""), + "judge_prompt": judge_prompts.get(edit_type, ""), + } + ) + + ds = Dataset.from_list(records) + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + + if len(ds) == 0: + raise ValueError(f"No samples found for category '{category}'.") + + pruna_logger.info("ImgEdit is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +_CATEGORY_TO_QD: dict[str, str] = { + "Anime_Stylization": "anime", + "Portrait": "human", + "General_Object": "object", +} + +_ONEIG_ALIGNMENT_BASE = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/41b49831e79e6dde5323618c164da1c4cf0f699d/scripts/alignment/Q_D" + + +def _fetch_oneig_alignment() -> dict[str, dict]: + """Fetch alignment questions from per-category Q_D files (InferBench-style).""" + import json + + import requests + + questions_by_key: dict[str, dict] = {} + for qd_name in ("anime", "human", "object"): + url = f"{_ONEIG_ALIGNMENT_BASE}/{qd_name}.json" + resp = requests.get(url, timeout=30) + resp.raise_for_status() + data = json.loads(resp.text) + for prompt_id, item in data.items(): + q = item.get("question", {}) + d = item.get("dependency", {}) + if isinstance(q, str): + q = json.loads(q) + if isinstance(d, str): + d = json.loads(d) + questions_by_key[f"{qd_name}_{prompt_id}"] = {"questions": q, "dependencies": d} + return questions_by_key + + +def setup_oneig_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: OneIGCategory | list[OneIGCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the OneIG benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + category : OneIGCategory | list[OneIGCategory] | None + Filter by dataset category (Anime_Stylization, Portrait, etc.) or class (fauvism, + watercolor, etc.). If None, returns all subsets. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The OneIG dataset (dummy train, dummy val, test). + """ + questions_by_key = _fetch_oneig_alignment() + + ds_raw = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench")["train"] # type: ignore[index] + records = [_to_oneig_record(dict(row), questions_by_key) for row in ds_raw] + ds = Dataset.from_list(records) + + if category is not None: + categories = [category] if not isinstance(category, list) else category + ds = ds.filter( + lambda x: x.get("category") in categories or x.get("class") in categories or x.get("subset") in categories + ) + + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + + if len(ds) == 0: + raise ValueError(f"No samples found for category '{category}'. Check that the category exists and has data.") + + pruna_logger.info("OneIG is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_gedit_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: GEditBenchCategory | list[GEditBenchCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the GEditBench dataset for image editing evaluation. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + category : GEditBenchCategory | list[GEditBenchCategory] | None + Filter by task type. Available: background_change, color_alter, material_alter, + motion_change, ps_human, style_change, subject_add, subject_remove, subject_replace, + text_change, tone_transfer. If None, returns all categories. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The GEditBench dataset (dummy train, dummy val, test). + """ + task_type_map = { + "subject_add": "subject-add", + "subject_remove": "subject-remove", + "subject_replace": "subject-replace", + } + task_type_to_category = {v: k for k, v in task_type_map.items()} + + ds = load_dataset("stepfun-ai/GEdit-Bench")["train"] # type: ignore[index] + ds = ds.filter(lambda x: x["instruction_language"] == "en") + + categories = [category] if category is not None and not isinstance(category, list) else category + if categories is not None: + hf_types = [task_type_map.get(c, c) for c in categories] + ds = ds.filter(lambda x: x["task_type"] in hf_types) + + records = [] + for row in ds: + task_type = row.get("task_type", "") + category_name = task_type_to_category.get(task_type, task_type) + records.append( + { + "text": row.get("instruction", ""), + "category": category_name, + } + ) + + ds = Dataset.from_list(records) + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + + if len(ds) == 0: + raise ValueError(f"No samples found for category '{category}'.") + + pruna_logger.info("GEditBench is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_dpg_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: DPGCategory | list[DPGCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the DPG (Dense Prompt Graph) benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + category : DPGCategory | list[DPGCategory] | None + Filter by category. Available: entity, attribute, relation, global, other. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The DPG dataset (dummy train, dummy val, test). + """ + import csv + import io + from collections import defaultdict + + import requests + + url = "https://raw.githubusercontent.com/TencentQQGYLab/ELLA/main/dpg_bench/dpg_bench.csv" + response = requests.get(url, timeout=30) + response.raise_for_status() + reader = csv.DictReader(io.StringIO(response.text)) + + categories = [category] if category is not None and not isinstance(category, list) else category + grouped: dict[tuple[str, str], list[str]] = defaultdict(list) + for row in reader: + row_category = row.get("category_broad", row.get("category", "")) + + if categories is not None and row_category not in categories: + continue + + key = (row.get("text", ""), row_category) + q = row.get("question_natural_language", "") + if q and q not in grouped[key]: + grouped[key].append(q) + + records = [{"text": text, "category": cat, "questions": qs} for (text, cat), qs in grouped.items()] + + ds = Dataset.from_list(records) + ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) + pruna_logger.info("DPG is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds diff --git a/src/pruna/data/datasets/text_generation.py b/src/pruna/data/datasets/text_generation.py index 9049dac8..69f0df8e 100644 --- a/src/pruna/data/datasets/text_generation.py +++ b/src/pruna/data/datasets/text_generation.py @@ -33,8 +33,7 @@ def setup_wikitext_dataset() -> Tuple[Dataset, Dataset, Dataset]: The WikiText dataset. """ train_dataset, val_dataset, test_dataset = load_dataset( - path="mikasenghaas/wikitext-2", - split=["train", "validation", "test"] + path="mikasenghaas/wikitext-2", split=["train", "validation", "test"] ) return train_dataset, val_dataset, test_dataset # type: ignore[return-value] diff --git a/src/pruna/data/datasets/text_to_video.py b/src/pruna/data/datasets/text_to_video.py index 7906d197..633b49c9 100644 --- a/src/pruna/data/datasets/text_to_video.py +++ b/src/pruna/data/datasets/text_to_video.py @@ -15,19 +15,44 @@ from __future__ import annotations from importlib.resources import as_file, files -from typing import List, Tuple +from typing import Literal, Tuple from datasets import Dataset, load_dataset +VBenchCategory = Literal[ + "aesthetic_quality", + "appearance_style", + "background_consistency", + "color", + "dynamic_degree", + "human_action", + "imaging_quality", + "motion_smoothness", + "multiple_objects", + "object_class", + "overall_consistency", + "scene", + "spatial_relationship", + "subject_consistency", + "temporal_flickering", + "temporal_style", +] -def setup_vbench_dataset(category: str | List[str] | None = None) -> Tuple[Dataset, Dataset, Dataset]: + +def setup_vbench_dataset( + category: VBenchCategory | list[VBenchCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the VBench dataset from the VBench full info json file. Parameters ---------- - category : str | List[str] | None - The dimension(s) of the dataset to load. + category : VBenchCategory | list[VBenchCategory] | None + Filter by dimension(s). Available: aesthetic_quality, appearance_style, + background_consistency, color, dynamic_degree, human_action, imaging_quality, + motion_smoothness, multiple_objects, object_class, overall_consistency, + scene, spatial_relationship, subject_consistency, temporal_flickering, + temporal_style. Returns ------- diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 40ed66e4..6d1eaadd 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -137,6 +137,9 @@ def from_string( dataloader_args: dict = dict(), seed: int = 42, category: str | list[str] | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, ) -> "PrunaDataModule": """ Create a PrunaDataModule from the dataset name with preimplemented dataset loading. @@ -153,9 +156,14 @@ def from_string( Any additional arguments for the dataloader. seed : int The seed to use. - category : str | list[str] | None The category of the dataset. + fraction : float + Fraction of dataset to use (when setup fn accepts it). + train_sample_size : int | None + Train sample size (when setup fn accepts it). + test_sample_size : int | None + Test sample size (when setup fn accepts it). Returns ------- @@ -174,6 +182,15 @@ def from_string( if "category" in inspect.signature(setup_fn).parameters: setup_fn = partial(setup_fn, category=category) + sampling_params = { + "fraction": fraction, + "train_sample_size": train_sample_size, + "test_sample_size": test_sample_size, + } + for param, value in sampling_params.items(): + if param in inspect.signature(setup_fn).parameters: + setup_fn = partial(setup_fn, **{param: value}) + train_ds, val_ds, test_ds = setup_fn() return cls.from_datasets( diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index 73fc35a6..2096f9e6 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -14,8 +14,18 @@ from __future__ import annotations +import inspect import random -from typing import Any, Tuple, Union +from typing import ( + Any, + Callable, + Literal, + Tuple, + Union, + get_args, + get_origin, + get_type_hints, +) import torch from datasets import Dataset, IterableDataset @@ -38,9 +48,51 @@ def __init__(self, message: str = "Tokenizer is missing. Please provide a valid super().__init__(message) -def split_train_into_train_val_test( - dataset: Dataset | IterableDataset, seed: int -) -> Tuple[Dataset, Dataset, Dataset]: +def get_literal_values_from_param(func: Callable[..., Any], param_name: str) -> list[str] | None: + """ + Extract Literal values from a function parameter's type annotation (handles Union). + + Parameters + ---------- + func : Callable[..., Any] + The function to inspect. + param_name : str + The parameter name to extract Literal values from. + + Returns + ------- + list[str] | None + List of string values if the parameter is a Literal type, None otherwise. + """ + unwrapped = getattr(func, "func", func) + sig = inspect.signature(unwrapped) + if param_name not in sig.parameters: + return None + ann = sig.parameters[param_name].annotation + if ann is inspect.Parameter.empty: + return None + if isinstance(ann, str): + try: + hints = get_type_hints(unwrapped) + ann = hints.get(param_name, ann) + except Exception: + return None + + def extract(ann: Any) -> list[str] | None: + if ann is None or ann is type(None): + return None + if get_origin(ann) is Literal: + args = get_args(ann) + return list(args) if args and all(isinstance(a, str) for a in args) else None + for arg in get_args(ann) or (): + if (r := extract(arg)) is not None: + return r + return None + + return extract(ann) + + +def split_train_into_train_val_test(dataset: Dataset | IterableDataset, seed: int) -> Tuple[Dataset, Dataset, Dataset]: """ Split the training dataset into train, validation, and test. @@ -64,9 +116,7 @@ def split_train_into_train_val_test( return train_ds, val_ds, test_ds -def split_train_into_train_val( - dataset: Dataset | IterableDataset, seed: int -) -> Tuple[Dataset, Dataset]: +def split_train_into_train_val(dataset: Dataset | IterableDataset, seed: int) -> Tuple[Dataset, Dataset]: """ Split the trainingdataset into train and validation. @@ -88,9 +138,7 @@ def split_train_into_train_val( return train_ds, val_ds -def split_val_into_val_test( - dataset: Dataset | IterableDataset, seed: int -) -> Tuple[Dataset, Dataset]: +def split_val_into_val_test(dataset: Dataset | IterableDataset, seed: int) -> Tuple[Dataset, Dataset]: """ Split the dataset into validation and test. @@ -190,58 +238,51 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list return texts -def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset: +def stratify_dataset( + dataset: Dataset, + sample_size: int | None = None, + fraction: float = 1.0, + seed: int | None = None, +) -> Dataset: """ - Stratify the dataset into a specific size. + Stratify the dataset to a specific size via optional shuffled sampling. Parameters ---------- dataset : Dataset The dataset to stratify. - sample_size : int - The size to stratify. - seed : int - The seed to use for sampling the dataset. + sample_size : int or None + Target size. If None, uses fraction or full dataset. + fraction : float + Fraction of dataset to use (0.0-1.0). Ignored if sample_size is set. + seed : int or None + Random seed for reproducible shuffled sampling. + If None, no shuffling is performed and the first target_size elements are returned. Returns ------- Dataset The stratified dataset. + + Raises + ------ + ValueError + If both fraction < 1.0 and sample_size are provided. """ + if fraction < 1.0 and sample_size is not None: + raise ValueError("Fraction and sample_size cannot be used together.") + target_size = int(len(dataset) * fraction) if fraction < 1.0 else (sample_size or len(dataset)) + dataset_length = len(dataset) - if dataset_length < sample_size: + if dataset_length < target_size: pruna_logger.warning( "Dataset length is less than the size to stratify." - f"Using the entire dataset. ({dataset_length} < {sample_size})" + f"Using the entire dataset. ({dataset_length} < {target_size})" ) return dataset indices = list(range(dataset_length)) - random.Random(seed).shuffle(indices) - selected_indices = indices[:sample_size] - dataset = dataset.select(selected_indices) - return dataset - - -def define_sample_size_for_dataset(dataset: Dataset, fraction: float, sample_size: int | None = None) -> int: - """ - Define the sample size for the dataset. - - Parameters - ---------- - dataset : Dataset - The dataset to define the sample size for. - fraction : float - The fraction of the dataset to sample. - sample_size : int | None - The sample size to use. - - Returns - ------- - int - The sample size for the dataset. - """ - if fraction < 1.0 and (sample_size is not None): - raise ValueError("Fraction and sample sizes cannot be used together.") - sample_size = int(len(dataset) * fraction) if fraction < 1.0 else sample_size or len(dataset) - return sample_size + if seed is not None: + random.Random(seed).shuffle(indices) + selected_indices = indices[:target_size] + return dataset.select(selected_indices) diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py new file mode 100644 index 00000000..e52ae463 --- /dev/null +++ b/src/pruna/evaluation/benchmarks.py @@ -0,0 +1,296 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from pruna.data import base_datasets +from pruna.data.utils import get_literal_values_from_param +from pruna.evaluation.metrics import MetricRegistry + + +@dataclass +class Benchmark: + """ + Metadata for a benchmark dataset. + + Parameters + ---------- + name : str + Human-readable name for display and base_datasets lookup. + description : str + Description of what the benchmark evaluates. + metrics : list[str] + List of metric names used for evaluation. + task_type : str + Type of task the benchmark evaluates (e.g., 'text_to_image'). + reference : str | None + URL to the canonical paper (e.g., arXiv) for this benchmark. + """ + + name: str + description: str + metrics: list[str] + task_type: str + reference: str | None = None + category: str | list[str] | None = field(default=None, init=False) + + @property + def lookup_key(self) -> str: + """Key for base_datasets lookup (name with spaces removed).""" + return self.name.replace(" ", "") + + def __post_init__(self) -> None: + """Populate category from setup function's Literal.""" + if self.lookup_key in base_datasets: + setup_fn = base_datasets[self.lookup_key][0] + literal_values = get_literal_values_from_param(setup_fn, "category") + self.category = literal_values if literal_values else None + + +class BenchmarkRegistry: + """ + Registry for benchmarks. + + Metrics per benchmark are set to those explicitly used in the reference + paper (see reference URL). All entries verified from paper evaluation + sections (ar5iv/HTML or PDF) as of verification pass: + + - Parti Prompts (2206.10789 §5.2, §5.4): human side-by-side only on P222. + - DrawBench (2205.11487 §4.3): human raters only; COCO uses FID + CLIP. + - GenAI Bench (2406.13743): VQAScore only (web/PWC; ar5iv failed). + - VBench (2311.17982): 16 dimension-specific methods; no single Pruna metric. + - COCO (2205.11487 §4.1): FID and CLIP score for fidelity and alignment. + - ImageNet (1409.0575 §4): top-1/top-5 classification accuracy. + - WikiText (1609.07843 §5): perplexity on validation/test. + - GenEval (2310.11513 §3.2): Mask2Former + CLIP color pipeline, binary score. + - HPS (2306.09341): HPS v2 scoring model (CLIP fine-tuned on HPD v2). + - ImgEdit (2505.20275 §4.2): GPT-4o 1–5 ratings and ImgEdit-Judge. + - Long Text Bench (2507.22058 §4): Text Accuracy (OCR, Qwen2.5-VL-7B). + - GEditBench (2504.17761 §4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). + - OneIG (2506.07977 §4.1): per-dimension metrics (semantic alignment, ED, etc.). + - DPG (2403.05135): DSG-style graph score, mPLUG-large adjudicator. + """ + + _registry: dict[str, Benchmark] = {} + + @classmethod + def _register(cls, benchmark: Benchmark) -> None: + missing = [m for m in benchmark.metrics if not MetricRegistry.has_metric(m)] + if missing: + raise ValueError( + f"Benchmark '{benchmark.name}' references metrics not in MetricRegistry: {missing}." + ) + if benchmark.lookup_key not in base_datasets: + available = ", ".join(base_datasets.keys()) + raise ValueError( + f"Benchmark '{benchmark.name}' (lookup key '{benchmark.lookup_key}') is not in base_datasets. " + f"Available: {available}" + ) + cls._registry[benchmark.lookup_key] = benchmark + + @classmethod + def get(cls, name: str) -> Benchmark: + """ + Get benchmark metadata by name. + + Parameters + ---------- + name : str + The benchmark name. + + Returns + ------- + Benchmark + The benchmark metadata. + + Raises + ------ + KeyError + If benchmark name is not found. + """ + key = name.replace(" ", "") + if key not in cls._registry: + raise KeyError(f"Benchmark '{name}' not found. Available: {', '.join(cls._registry)}") + return cls._registry[key] + + @classmethod + def list(cls, task_type: str | None = None) -> list[str]: + """ + List available benchmark names. + + Parameters + ---------- + task_type : str | None + Filter by task type (e.g., 'text_to_image', 'text_to_video'). + If None, returns all benchmarks. + + Returns + ------- + list[str] + List of benchmark names. + """ + if task_type is None: + return list(cls._registry) + return [key for key, b in cls._registry.items() if b.task_type == task_type] + + +for _benchmark in [ + Benchmark( + name="Parti Prompts", + description=( + "Holistic benchmark from Google Research with over 1,600 English prompts across 12 categories " + "and 11 challenge aspects. Evaluates text-to-image models on abstract thinking, world knowledge, " + "perspectives, and symbol rendering from basic to complex compositions." + ), + metrics=[], # Paper uses human evaluation only; pass explicit metrics if needed + task_type="text_to_image", + reference="https://arxiv.org/abs/2206.10789", + ), + Benchmark( + name="DrawBench", + description=( + "Comprehensive benchmark from the Imagen team for rigorous evaluation of text-to-image models. " + "Enables side-by-side comparison on sample quality and image-text alignment with human raters." + ), + metrics=[], # Paper uses human evaluation only; pass explicit metrics if needed + task_type="text_to_image", + reference="https://arxiv.org/abs/2205.11487", + ), + Benchmark( + name="GenAI Bench", + description=( + "1,600 prompts from professional designers for compositional text-to-visual generation. " + "Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning " + "(counting, comparison, logic/negation) with over 24k human ratings." + ), + metrics=[], # Paper uses VQAScore only; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2406.13743", + ), + Benchmark( + name="VBench", + description=( + "Comprehensive benchmark suite for video generative models. Decomposes video quality into " + "16 disentangled dimensions: temporal flickering, motion smoothness, subject consistency, " + "spatial relationship, color, aesthetic quality, and more." + ), + metrics=[], # Paper uses dimension-specific automated metrics; not all in Pruna + task_type="text_to_video", + reference="https://arxiv.org/abs/2311.17982", + ), + Benchmark( + name="COCO", + description=( + "MS-COCO for text-to-image evaluation (Imagen, 2205.11487). Paper reports " + "FID for fidelity and CLIP score for image-text alignment." + ), + metrics=["fid", "clip_score"], # §4.1: FID + CLIP score + task_type="text_to_image", + reference="https://arxiv.org/abs/2205.11487", + ), + Benchmark( + name="ImageNet", + description=( + "Large-scale image classification benchmark with 1,000 classes. Standard evaluation " + "for vision model accuracy on object recognition." + ), + metrics=["accuracy"], + task_type="image_classification", + reference="https://arxiv.org/abs/1409.0575", + ), + Benchmark( + name="WikiText", + description=( + "Language modeling benchmark based on Wikipedia articles. Standard evaluation " + "for text generation quality via perplexity." + ), + metrics=["perplexity"], + task_type="text_generation", + reference="https://arxiv.org/abs/1609.07843", + ), + Benchmark( + name="GenEval", + description=( + "Compositional text-to-image benchmark with 6 categories: single object, two object, " + "counting, colors, position, color attributes. Evaluates fine-grained alignment " + "between prompts and generated images via VQA-style questions." + ), + metrics=["clip_score"], # §3.2: Mask2Former; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2310.11513", + ), + Benchmark( + name="HPS", + description=( + "HPD (Human Preference Dataset) v2 for HPS (Human Preference Score) evaluation. " + "Covers anime, concept-art, paintings, and photo styles with human preference data." + ), + metrics=[], # Paper uses HPS scoring model; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2306.09341", + ), + Benchmark( + name="ImgEdit", + description=( + "Image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, " + "style, background, compose. Evaluates instruction-following for inpainting and editing." + ), + metrics=[], # Paper uses GPT-4o/ImgEdit-Judge; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2505.20275", + ), + Benchmark( + name="Long Text Bench", + description=( + "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " + "handle complex multi-clause descriptions and maintain coherence across long instructions." + ), + metrics=[], # Paper uses text_score/TIT-Score; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2507.22058", + ), + Benchmark( + name="GEditBench", + description=( + "General image editing benchmark with 11 task types: background change, color alter, " + "material alter, motion change, style change, subject add/remove/replace, text change, " + "tone transfer, and human retouching." + ), + metrics=[], # Paper uses VIEScore; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2504.17761", + ), + Benchmark( + name="OneIG", + description=( + "Omni-dimensional benchmark for text-to-image evaluation. Six dataset categories " + "(Anime_Stylization, General_Object, Knowledge_Reasoning, Multilingualism, Portrait, " + "Text_Rendering) plus fine-grained style classes. Includes alignment questions." + ), + metrics=[], # Paper uses dimension-specific metrics; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="DPG", + description=( + "Dense Prompt Graph benchmark. Evaluates entity, attribute, relation, " + "global, and other descriptive aspects with natural-language questions for alignment." + ), + metrics=[], # Paper uses custom evaluation; not in Pruna + task_type="text_to_image", + reference="https://arxiv.org/abs/2403.05135", + ), +]: + BenchmarkRegistry._register(_benchmark) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 62fa4957..5b713dea 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -77,6 +77,53 @@ def __init__( self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True + @classmethod + def from_benchmark( + cls, + benchmark_name: str, + *, + tokenizer: Any = None, + device: str | torch.device | None = None, + dataloader_args: dict[str, Any] | None = None, + **kwargs: Any, + ) -> "EvaluationAgent": + """ + Create an EvaluationAgent from a benchmark name. + + Convenience one-liner that hooks up the benchmark dataset and metrics, then runs evaluation. + + Parameters + ---------- + benchmark_name : str + Benchmark name from BenchmarkRegistry (e.g. "Parti Prompts", "DrawBench"). + tokenizer : Any, optional + Tokenizer for text-generation benchmarks. Required when benchmark is "WikiText". + device : str | torch.device | None, optional + Device for evaluation. Default is None (best available). + dataloader_args : dict | None, optional + Args passed to the dataloader (e.g. batch_size). + **kwargs : Any + Additional args for PrunaDataModule.from_string (e.g. category, fraction). + + Returns + ------- + EvaluationAgent + Agent after evaluation, with results accessible via the agent. + + Examples + -------- + >>> agent = EvaluationAgent.from_benchmark("Parti Prompts", model) + >>> agent = EvaluationAgent.from_benchmark("HPS", model, category="anime", fraction=0.1) + """ + task = Task.from_benchmark( + benchmark_name, + tokenizer=tokenizer, + device=device, + dataloader_args=dataloader_args, + **kwargs, + ) + return cls(task=task) + def evaluate(self, model: Any) -> List[MetricResult]: """ Evaluate models using different metric types. diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index f2af4ea1..a3ebe1f9 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -89,6 +89,23 @@ def decorator(wrapper_class: Callable[..., Any]) -> Callable[..., Any]: return decorator + @classmethod + def has_metric(cls, name: str) -> bool: + """ + Return True if a metric with this name is registered. + + Parameters + ---------- + name : str + Name of the metric to look up. + + Returns + ------- + bool + True if the metric is registered, False otherwise. + """ + return name in cls._registry + @classmethod def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: """ diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 62dc8495..4285cfe8 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -20,6 +20,7 @@ from pruna.data.pruna_datamodule import PrunaDataModule from pruna.engine.utils import device_to_string, find_bytes_free_per_gpu, set_to_best_available_device, split_device +from pruna.evaluation.benchmarks import BenchmarkRegistry from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.metrics.metric_stateful import StatefulMetric @@ -55,6 +56,59 @@ class Task: Default is False. """ + @classmethod + def from_benchmark( + cls, + benchmark_name: str, + tokenizer: Any = None, + device: str | torch.device | None = None, + low_memory: bool = False, + dataloader_args: dict[str, Any] | None = None, + **kwargs: Any, + ) -> "Task": + """ + Create a Task from a benchmark name in BenchmarkRegistry. + + Parameters + ---------- + benchmark_name : str + Benchmark name (e.g. "Parti Prompts", "DrawBench"). + tokenizer : Any, optional + Tokenizer for text-generation benchmarks (e.g. WikiText). Required when + benchmark_name is "WikiText". + device : str | torch.device | None, optional + Device for evaluation. + low_memory : bool, optional + If True, run stateful metrics on cpu. + dataloader_args : dict | None, optional + Args passed to the dataloader (e.g. batch_size). + **kwargs : Any + Additional args for PrunaDataModule.from_string (e.g. category, fraction). + + Returns + ------- + Task + Task configured with the benchmark's metrics and datamodule. + """ + benchmark = BenchmarkRegistry.get(benchmark_name) + if benchmark.lookup_key == "WikiText" and tokenizer is None: + raise ValueError( + "Tokenizer is required for WikiText benchmark. " + "Pass tokenizer=AutoTokenizer.from_pretrained('bert-base-uncased') or similar." + ) + datamodule = PrunaDataModule.from_string( + benchmark.lookup_key, + tokenizer=tokenizer, + dataloader_args=dataloader_args or {}, + **kwargs, + ) + return cls( + request=benchmark.metrics, + datamodule=datamodule, + device=device, + low_memory=low_memory, + ) + def __init__( self, request: str | List[str] | List[BaseMetric | StatefulMetric], diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index c50cd658..103cadfb 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,12 +1,13 @@ from typing import Any, Callable import pytest -from transformers import AutoTokenizer -from datasets import Dataset -from torch.utils.data import TensorDataset import torch +from transformers import AutoTokenizer + +from pruna.data import base_datasets from pruna.data.datasets.image import setup_imagenet_dataset from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.data.utils import get_literal_values_from_param bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") @@ -18,6 +19,19 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: next(iter(datamodule.test_dataloader())) +def _assert_at_least_one_sample(datamodule: PrunaDataModule) -> None: + """Assert train, val, and test splits each have at least one sample.""" + for name, ds in zip( + ("train", "val", "test"), + (datamodule.train_dataset, datamodule.val_dataset, datamodule.test_dataset), + ): + try: + n = len(ds) + except TypeError: + continue + assert n >= 1, f"{name} split has 0 samples" + + @pytest.mark.cpu @pytest.mark.parametrize( "dataset_name, collate_fn_args", @@ -45,18 +59,25 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), + pytest.param("GenEval", dict(), marks=pytest.mark.slow), + pytest.param("HPS", dict(), marks=pytest.mark.slow), + pytest.param("ImgEdit", dict(), marks=pytest.mark.slow), + pytest.param("LongTextBench", dict(), marks=pytest.mark.slow), + pytest.param("GEditBench", dict(), marks=pytest.mark.slow), + pytest.param("OneIG", dict(), marks=pytest.mark.slow), + pytest.param("DPG", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: """Test the datamodule from a string.""" # get tokenizer if available - tokenizer = collate_fn_args.get("tokenizer", None) + tokenizer = collate_fn_args.get("tokenizer") # get the datamodule from the string datamodule = PrunaDataModule.from_string(dataset_name, collate_fn_args=collate_fn_args, tokenizer=tokenizer) + _assert_at_least_one_sample(datamodule) datamodule.limit_datasets(10) - # iterate through the dataloaders iterate_dataloaders(datamodule) @@ -71,6 +92,7 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: str, collate_fn_args: d # get datamodule with datasets and collate function as input datasets = setup_fn(seed=123) datamodule = PrunaDataModule.from_datasets(datasets, collate_fn, collate_fn_args=collate_fn_args) + _assert_at_least_one_sample(datamodule) datamodule.limit_datasets(10) batch = next(iter(datamodule.train_dataloader())) images, labels = batch @@ -80,3 +102,61 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: str, collate_fn_args: d assert labels.dtype == torch.int64 # iterate through the dataloaders iterate_dataloaders(datamodule) + + +def _benchmarks_with_category() -> list[tuple[str, str]]: + """Benchmarks that have a category param: (dataset_name, category) for every category.""" + result = [] + for name in base_datasets: + setup_fn = base_datasets[name][0] + literal_values = get_literal_values_from_param(setup_fn, "category") + if literal_values: + for cat in literal_values: + result.append((name, cat)) + return result + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize("dataset_name, category", _benchmarks_with_category()) +def test_benchmark_category_filter(dataset_name: str, category: str) -> None: + """Test dataset loading with each category filter; dataset has at least one sample.""" + dm = PrunaDataModule.from_string(dataset_name, category=category, dataloader_args={"batch_size": 4}) + _assert_at_least_one_sample(dm) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + + def _category_in_aux(aux: dict, cat: str) -> bool: + for v in aux.values(): + if v == cat: + return True + if isinstance(v, (list, tuple)) and cat in v: + return True + return False + + assert all(_category_in_aux(aux, category) for aux in auxiliaries) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize( + "dataset_name, required_aux_key", + [ + ("LongTextBench", "text_content"), + ("OneIG", "text_content"), + ], +) +def test_prompt_benchmark_auxiliaries(dataset_name: str, required_aux_key: str) -> None: + """Test prompt-based benchmarks load with expected auxiliaries.""" + dm = PrunaDataModule.from_string(dataset_name, dataloader_args={"batch_size": 4}) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(required_aux_key in aux for aux in auxiliaries) diff --git a/tests/evaluation/test_aesthetics_laion.py b/tests/evaluation/test_aesthetics_laion.py index c0c64f17..32ce365c 100644 --- a/tests/evaluation/test_aesthetics_laion.py +++ b/tests/evaluation/test_aesthetics_laion.py @@ -1,16 +1,15 @@ -import requests -from PIL import Image from io import BytesIO -from typing import Literal, cast -import torch import numpy as np - import pytest +import requests +import torch +from PIL import Image from pruna.data.pruna_datamodule import PrunaDataModule from pruna.evaluation.metrics.aesthetic_laion import AestheticLAION + @pytest.mark.parametrize( "device, clip_model", [