Skip to content

Commit 7ebb4cd

Browse files
feat: add GenEval benchmark (#507)
* feat: add benchmark support to PrunaDataModule and implement PartiPrompts benchmark - Introduced `from_benchmark` method in `PrunaDataModule` to create instances from benchmark classes. - Added `Benchmark`, `BenchmarkEntry`, and `BenchmarkRegistry` classes for managing benchmarks. - Implemented `PartiPrompts` benchmark for text-to-image generation with various categories and challenges. - Created utility function `benchmark_to_datasets` to convert benchmarks into datasets compatible with `PrunaDataModule`. - Added integration tests for benchmark functionality and data module interactions. * refactor: simplify benchmark system, extend PartiPrompts with subset filtering - Remove heavy benchmark abstraction (Benchmark class, registry, adapter, 24 subclasses) - Extend setup_parti_prompts_dataset with category and num_samples params - Add BenchmarkInfo dataclass for metadata (metrics, description, subsets) - Switch PartiPrompts to prompt_with_auxiliaries_collate to preserve Category/Challenge - Merge tests into test_datamodule.py Reduces 964 lines to 128 lines (87% reduction) Co-authored-by: Cursor <cursoragent@cursor.com> * feat: add GenEval benchmark Add GenEval benchmark for fine-grained compositional evaluation of text-to-image models. Fetches prompts from GitHub and generates questions. - Add setup_geneval_dataset with 6 subcategories - Categories: single_object, two_object, counting, colors, position, color_attr - Generates evaluation questions from metadata - Register in base_datasets with prompt_with_auxiliaries_collate - Add BenchmarkInfo with metrics: ["qa_accuracy"] - Add tests Co-authored-by: Cursor <cursoragent@cursor.com> * fix: add Numpydoc parameter docs for BenchmarkInfo Document all dataclass fields per Numpydoc PR01 with summary on new line per GL01. Co-authored-by: Cursor <cursoragent@cursor.com> * feat: add benchmark discovery functions and expand benchmark registry - Add list_benchmarks() to filter benchmarks by task type - Add get_benchmark_info() to retrieve benchmark metadata - Add COCO, ImageNet, WikiText to benchmark_info registry - Fix metric names to match MetricRegistry (clip_score, clipiqa) Co-authored-by: Cursor <cursoragent@cursor.com> * fix: properly check position value before generating question Use None default and check both pos existence and non-empty first element to avoid malformed questions. Co-authored-by: Cursor <cursoragent@cursor.com> * chore: apply ruff format to data module, add lint-before-push script Made-with: Cursor * chore: fix get_literal_values_from_param docstring, add SCOPE to lint script Made-with: Cursor * chore: remove scripts/lint-before-push.sh Made-with: Cursor * chore: align metrics with Pruna, comment unsupported InferBench metrics Made-with: Cursor * fix: remove accuracy from GenEval - qa_accuracy ≠ accuracy Made-with: Cursor * chore: simplify metric comment Made-with: Cursor --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent e9bf9cb commit 7ebb4cd

File tree

4 files changed

+276
-6
lines changed

4 files changed

+276
-6
lines changed

src/pruna/data/__init__.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from dataclasses import dataclass, field
1516
from functools import partial
1617
from typing import Any, Callable, Tuple
1718

@@ -28,6 +29,7 @@
2829
from pruna.data.datasets.prompt import (
2930
setup_drawbench_dataset,
3031
setup_genai_bench_dataset,
32+
setup_geneval_dataset,
3133
setup_parti_prompts_dataset,
3234
)
3335
from pruna.data.datasets.question_answering import setup_polyglot_dataset
@@ -50,6 +52,7 @@
5052

5153
BENCHMARK_CATEGORY_CONFIG: dict[str, tuple[str, list[str]]] = {
5254
"PartiPrompts": ("Animals", ["Category", "Challenge"]),
55+
"GenEval": ("counting", ["tag"]),
5356
}
5457

5558
base_datasets: dict[str, Tuple[Callable, str, dict[str, Any]]] = {
@@ -107,6 +110,186 @@
107110
{},
108111
),
109112
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
113+
"GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}),
110114
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
111115
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
112116
}
117+
118+
119+
@dataclass
120+
class BenchmarkInfo:
121+
"""
122+
Metadata for a benchmark dataset.
123+
124+
Parameters
125+
----------
126+
name : str
127+
Internal identifier for the benchmark.
128+
display_name : str
129+
Human-readable name for display purposes.
130+
description : str
131+
Description of what the benchmark evaluates.
132+
metrics : list[str]
133+
List of metric names used for evaluation.
134+
task_type : str
135+
Type of task the benchmark evaluates (e.g., 'text_to_image').
136+
subsets : list[str]
137+
Optional list of benchmark subset names.
138+
"""
139+
140+
name: str
141+
display_name: str
142+
description: str
143+
metrics: list[str]
144+
task_type: str
145+
subsets: list[str] = field(default_factory=list)
146+
147+
148+
benchmark_info: dict[str, BenchmarkInfo] = {
149+
"PartiPrompts": BenchmarkInfo(
150+
name="parti_prompts",
151+
display_name="Parti Prompts",
152+
description=(
153+
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
154+
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
155+
"across different domains and difficulty levels."
156+
),
157+
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
158+
task_type="text_to_image",
159+
subsets=[
160+
"Abstract",
161+
"Animals",
162+
"Artifacts",
163+
"Arts",
164+
"Food & Beverage",
165+
"Illustrations",
166+
"Indoor Scenes",
167+
"Outdoor Scenes",
168+
"People",
169+
"Produce & Plants",
170+
"Vehicles",
171+
"World Knowledge",
172+
"Basic",
173+
"Complex",
174+
"Fine-grained Detail",
175+
"Imagination",
176+
"Linguistic Structures",
177+
"Perspective",
178+
"Properties & Positioning",
179+
"Quantity",
180+
"Simple Detail",
181+
"Style & Format",
182+
"Writing & Symbols",
183+
],
184+
),
185+
"DrawBench": BenchmarkInfo(
186+
name="drawbench",
187+
display_name="DrawBench",
188+
description="A comprehensive benchmark for evaluating text-to-image generation models.",
189+
metrics=[
190+
"clip_score",
191+
"clipiqa",
192+
"sharpness",
193+
# "image_reward" not supported in Pruna
194+
],
195+
task_type="text_to_image",
196+
),
197+
"GenAIBench": BenchmarkInfo(
198+
name="genai_bench",
199+
display_name="GenAI Bench",
200+
description="A benchmark for evaluating generative AI models.",
201+
metrics=[
202+
"clip_score",
203+
"clipiqa",
204+
"sharpness",
205+
# "vqa" not supported in Pruna
206+
],
207+
task_type="text_to_image",
208+
),
209+
"VBench": BenchmarkInfo(
210+
name="vbench",
211+
display_name="VBench",
212+
description="A benchmark for evaluating video generation models.",
213+
metrics=["clip_score"],
214+
task_type="text_to_video",
215+
),
216+
"GenEval": BenchmarkInfo(
217+
name="geneval",
218+
display_name="GenEval",
219+
description=(
220+
"Fine-grained compositional evaluation across object co-occurrence, positioning, "
221+
"counting, and color binding to identify specific failure modes in text-to-image alignment."
222+
),
223+
metrics=[
224+
# "qa_accuracy" not supported in Pruna
225+
],
226+
task_type="text_to_image",
227+
subsets=["single_object", "two_object", "counting", "colors", "position", "color_attr"],
228+
),
229+
"COCO": BenchmarkInfo(
230+
name="coco",
231+
display_name="COCO",
232+
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
233+
metrics=["fid", "clip_score", "clipiqa"],
234+
task_type="text_to_image",
235+
),
236+
"ImageNet": BenchmarkInfo(
237+
name="imagenet",
238+
display_name="ImageNet",
239+
description="Large-scale image classification benchmark with 1000 classes.",
240+
metrics=["accuracy"],
241+
task_type="image_classification",
242+
),
243+
"WikiText": BenchmarkInfo(
244+
name="wikitext",
245+
display_name="WikiText",
246+
description="Language modeling benchmark based on Wikipedia articles.",
247+
metrics=["perplexity"],
248+
task_type="text_generation",
249+
),
250+
}
251+
252+
253+
def list_benchmarks(task_type: str | None = None) -> list[str]:
254+
"""
255+
List available benchmark names.
256+
257+
Parameters
258+
----------
259+
task_type : str | None
260+
Filter by task type (e.g., 'text_to_image', 'text_to_video').
261+
If None, returns all benchmarks.
262+
263+
Returns
264+
-------
265+
list[str]
266+
List of benchmark names.
267+
"""
268+
if task_type is None:
269+
return list(benchmark_info.keys())
270+
return [name for name, info in benchmark_info.items() if info.task_type == task_type]
271+
272+
273+
def get_benchmark_info(name: str) -> BenchmarkInfo:
274+
"""
275+
Get benchmark metadata by name.
276+
277+
Parameters
278+
----------
279+
name : str
280+
The benchmark name.
281+
282+
Returns
283+
-------
284+
BenchmarkInfo
285+
The benchmark metadata.
286+
287+
Raises
288+
------
289+
KeyError
290+
If benchmark name is not found.
291+
"""
292+
if name not in benchmark_info:
293+
available = ", ".join(benchmark_info.keys())
294+
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
295+
return benchmark_info[name]

src/pruna/data/datasets/prompt.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pruna.data.utils import _prepare_test_only_prompt_dataset, define_sample_size_for_dataset
2020
from pruna.logging.logger import pruna_logger
2121

22+
GenEvalCategory = Literal["single_object", "two_object", "counting", "colors", "position", "color_attr"]
23+
2224
PartiCategory = Literal[
2325
"Abstract",
2426
"Animals",
@@ -110,6 +112,91 @@ def setup_parti_prompts_dataset(
110112
return _prepare_test_only_prompt_dataset(ds, seed, "PartiPrompts")
111113

112114

115+
def _generate_geneval_question(entry: dict) -> list[str]:
116+
"""Generate evaluation questions from GenEval metadata."""
117+
tag = entry.get("tag", "")
118+
include = entry.get("include", [])
119+
questions = []
120+
121+
for obj in include:
122+
cls = obj.get("class", "")
123+
if "color" in obj:
124+
questions.append(f"Does the image contain a {obj['color']} {cls}?")
125+
elif "count" in obj:
126+
questions.append(f"Does the image contain exactly {obj['count']} {cls}(s)?")
127+
else:
128+
questions.append(f"Does the image contain a {cls}?")
129+
130+
if tag == "position" and len(include) >= 2:
131+
a_cls = include[0].get("class", "")
132+
b_cls = include[1].get("class", "")
133+
pos = include[1].get("position")
134+
if pos and pos[0]:
135+
questions.append(f"Is the {b_cls} {pos[0]} the {a_cls}?")
136+
137+
return questions
138+
139+
140+
def setup_geneval_dataset(
141+
seed: int,
142+
fraction: float = 1.0,
143+
train_sample_size: int | None = None,
144+
test_sample_size: int | None = None,
145+
category: GenEvalCategory | list[GenEvalCategory] | None = None,
146+
) -> Tuple[Dataset, Dataset, Dataset]:
147+
"""
148+
Setup the GenEval benchmark dataset.
149+
150+
License: MIT
151+
152+
Parameters
153+
----------
154+
seed : int
155+
The seed to use.
156+
fraction : float
157+
The fraction of the dataset to use.
158+
train_sample_size : int | None
159+
Unused; train/val are dummy.
160+
test_sample_size : int | None
161+
The sample size to use for the test dataset.
162+
category : GenEvalCategory | list[GenEvalCategory] | None
163+
Filter by category. Available: single_object, two_object, counting, colors, position, color_attr.
164+
165+
Returns
166+
-------
167+
Tuple[Dataset, Dataset, Dataset]
168+
The GenEval dataset (dummy train, dummy val, test).
169+
"""
170+
import json
171+
172+
import requests
173+
174+
url = "https://raw.githubusercontent.com/djghosh13/geneval/d927da8e42fde2b1b5cd743da4df5ff83c1654ff/prompts/evaluation_metadata.jsonl"
175+
response = requests.get(url)
176+
data = [json.loads(line) for line in response.text.splitlines()]
177+
178+
if category is not None:
179+
categories = [category] if not isinstance(category, list) else category
180+
data = [entry for entry in data if entry.get("tag") in categories]
181+
182+
records = []
183+
for entry in data:
184+
questions = _generate_geneval_question(entry)
185+
records.append(
186+
{
187+
"text": entry["prompt"],
188+
"tag": entry.get("tag", ""),
189+
"questions": questions,
190+
"include": entry.get("include", []),
191+
}
192+
)
193+
194+
ds = Dataset.from_list(records)
195+
test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size)
196+
ds = ds.select(range(min(test_sample_size, len(ds))))
197+
return _prepare_test_only_prompt_dataset(ds, seed, "GenEval")
198+
199+
113200
def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
114201
"""
115202
Setup the GenAI Bench dataset.

src/pruna/data/datasets/text_generation.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def setup_wikitext_dataset() -> Tuple[Dataset, Dataset, Dataset]:
3333
The WikiText dataset.
3434
"""
3535
train_dataset, val_dataset, test_dataset = load_dataset(
36-
path="mikasenghaas/wikitext-2",
37-
split=["train", "validation", "test"]
36+
path="mikasenghaas/wikitext-2", split=["train", "validation", "test"]
3837
)
3938
return train_dataset, val_dataset, test_dataset # type: ignore[return-value]
4039

@@ -57,15 +56,15 @@ def setup_wikitext_tiny_dataset(seed: int = 42, num_rows: int = 960) -> Tuple[Da
5756
Tuple[Dataset, Dataset, Dataset]
5857
The TinyWikiText dataset split .8/.1/.1 into train/val/test subsets, respectively.
5958
"""
60-
assert 10 <= num_rows < 1000, 'the total number of rows, r, for the tiny wikitext dataset must be 10 <= r < 1000'
59+
assert 10 <= num_rows < 1000, "the total number of rows, r, for the tiny wikitext dataset must be 10 <= r < 1000"
6160

6261
# load the 'mikasenghaas/wikitext-2' dataset with a total of 21,580 rows using the setup_wikitext_dataset() function
6362
train_ds, val_ds, test_ds = setup_wikitext_dataset()
6463

6564
# assert the wikitext dataset train/val/test splits each have enough rows for reducing to .8/.1/.1, respectively
66-
assert train_ds.num_rows >= int(num_rows * 0.8), f'wikitext cannot be reduced to {num_rows} rows, train too small'
67-
assert val_ds.num_rows >= int(num_rows * 0.1), f'wikitext cannot be reduced to {num_rows} rows, val too small'
68-
assert test_ds.num_rows >= int(num_rows * 0.1), f'wikitext cannot be reduced to {num_rows} rows, test too small'
65+
assert train_ds.num_rows >= int(num_rows * 0.8), f"wikitext cannot be reduced to {num_rows} rows, train too small"
66+
assert val_ds.num_rows >= int(num_rows * 0.1), f"wikitext cannot be reduced to {num_rows} rows, val too small"
67+
assert test_ds.num_rows >= int(num_rows * 0.1), f"wikitext cannot be reduced to {num_rows} rows, test too small"
6968

7069
# randomly select from the wikitext dataset a total number of rows below 1000 split .8/.1/.1 between train/val/test
7170
train_dataset_tiny = train_ds.shuffle(seed=seed).select(range(int(num_rows * 0.8)))

tests/data/test_datamodule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
4545
pytest.param("GenAIBench", dict(), marks=pytest.mark.slow),
4646
pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
4747
pytest.param("VBench", dict(), marks=pytest.mark.slow),
48+
pytest.param("GenEval", dict(), marks=pytest.mark.slow),
4849
],
4950
)
5051
def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None:

0 commit comments

Comments
 (0)