Skip to content

Commit edaaf32

Browse files
fix: address PR #502 review comments
- Remove shuffle from test-only prompt datasets (test sets should not be shuffled) - Remove duplicate Benchmark/benchmark_info from data/__init__.py (already in evaluation/benchmarks) - Use explicit if/else for category is None in setup_parti_prompts_dataset - Add status_code check to ImgEdit for consistency with OneIG - Put load_dataset on one line in _load_oneig_generic - Derive alignment_cats from ONEIG_DATASET_CATEGORIES Made-with: Cursor
1 parent 09b353b commit edaaf32

File tree

3 files changed

+23
-221
lines changed

3 files changed

+23
-221
lines changed

src/pruna/data/__init__.py

Lines changed: 0 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dataclasses import dataclass
1615
from functools import partial
1716
from typing import Any, Callable, Tuple
1817

@@ -121,207 +120,3 @@
121120
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
122121
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
123122
}
124-
125-
126-
@dataclass
127-
class Benchmark:
128-
"""
129-
Metadata for a benchmark dataset.
130-
131-
Parameters
132-
----------
133-
name : str
134-
Internal identifier for the benchmark.
135-
display_name : str
136-
Human-readable name for display purposes.
137-
description : str
138-
Description of what the benchmark evaluates.
139-
metrics : list[str]
140-
List of metric names used for evaluation.
141-
task_type : str
142-
Type of task the benchmark evaluates (e.g., 'text_to_image').
143-
"""
144-
145-
name: str
146-
display_name: str
147-
description: str
148-
metrics: list[str]
149-
task_type: str
150-
151-
152-
benchmark_info: dict[str, Benchmark] = {
153-
"PartiPrompts": Benchmark(
154-
name="parti_prompts",
155-
display_name="Parti Prompts",
156-
description=(
157-
"Holistic benchmark from Google Research with over 1,600 English prompts across 12 categories "
158-
"and 11 challenge aspects. Evaluates text-to-image models on abstract thinking, world knowledge, "
159-
"perspectives, and symbol rendering from basic to complex compositions."
160-
),
161-
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
162-
task_type="text_to_image",
163-
),
164-
"DrawBench": Benchmark(
165-
name="drawbench",
166-
display_name="DrawBench",
167-
description=(
168-
"Comprehensive benchmark from the Imagen team for rigorous evaluation of text-to-image models. "
169-
"Enables side-by-side comparison on sample quality and image-text alignment with human raters."
170-
),
171-
metrics=[
172-
"clip_score",
173-
"clipiqa",
174-
"sharpness",
175-
# "image_reward" not supported in Pruna
176-
],
177-
task_type="text_to_image",
178-
),
179-
"GenAIBench": Benchmark(
180-
name="genai_bench",
181-
display_name="GenAI Bench",
182-
description=(
183-
"1,600 prompts from professional designers for compositional text-to-visual generation. "
184-
"Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning "
185-
"(counting, comparison, logic/negation) with over 24k human ratings."
186-
),
187-
metrics=[
188-
"clip_score",
189-
"clipiqa",
190-
"sharpness",
191-
# "vqa" not supported in Pruna
192-
],
193-
task_type="text_to_image",
194-
),
195-
"VBench": Benchmark(
196-
name="vbench",
197-
display_name="VBench",
198-
description=(
199-
"Comprehensive benchmark suite for video generative models. Decomposes video quality into "
200-
"16 disentangled dimensions: temporal flickering, motion smoothness, subject consistency, "
201-
"spatial relationship, color, aesthetic quality, and more."
202-
),
203-
metrics=["clip_score"],
204-
task_type="text_to_video",
205-
),
206-
"GenEval": Benchmark(
207-
name="geneval",
208-
display_name="GenEval",
209-
description=(
210-
"Object-focused framework (NeurIPS 2023) for fine-grained text-to-image alignment. "
211-
"Evaluates compositional properties: object co-occurrence, position, count, and color binding "
212-
"via instance-level analysis rather than distribution-level metrics."
213-
),
214-
metrics=[
215-
# "qa_accuracy" not supported in Pruna
216-
],
217-
task_type="text_to_image",
218-
),
219-
"HPS": Benchmark(
220-
name="hps",
221-
display_name="HPS",
222-
description=(
223-
"Human Preference Score v2: large-scale benchmark with 798k human preference choices on "
224-
"433k image pairs. CLIP fine-tuned on HPD v2 to predict human preferences and align "
225-
"evaluation with actual human judgment across diverse generative outputs."
226-
),
227-
metrics=[
228-
# "hps" not supported in Pruna
229-
],
230-
task_type="text_to_image",
231-
),
232-
"LongTextBench": Benchmark(
233-
name="long_text_bench",
234-
display_name="Long Text Bench",
235-
description=(
236-
"DetailMaster benchmark with prompts averaging 284.89 tokens. Evaluates four dimensions: "
237-
"character attributes, structured locations, scene attributes, and spatial relationships "
238-
"to test compositional reasoning under long prompt complexity."
239-
),
240-
metrics=[
241-
# "text_score" not supported in Pruna
242-
],
243-
task_type="text_to_image",
244-
),
245-
"ImgEdit": Benchmark(
246-
name="imgedit",
247-
display_name="ImgEdit",
248-
description=(
249-
"Unified image editing benchmark (PKU-YuanGroup) with 8 edit types: replace, add, remove, "
250-
"adjust, extract, style, background, compose. Evaluates instruction adherence, editing "
251-
"quality, and detail preservation."
252-
),
253-
metrics=[
254-
# "img_edit_score" not supported in Pruna
255-
],
256-
task_type="image_edit",
257-
),
258-
"GEditBench": Benchmark(
259-
name="gedit_bench",
260-
display_name="GEdit Bench",
261-
description=(
262-
"StepFun benchmark grounded in real-world user instructions. 11 task types including "
263-
"background_change, subject_add/remove/replace, style_change, and tone_transfer for "
264-
"practical evaluation of image editing capabilities."
265-
),
266-
metrics=[
267-
# "viescore" not supported in Pruna
268-
],
269-
task_type="image_edit",
270-
),
271-
"OneIG": Benchmark(
272-
name="oneig",
273-
display_name="OneIG",
274-
description=(
275-
"Omni-dimensional benchmark (NeurIPS 2025) for nuanced image generation evaluation. "
276-
"Six categories: Text_Rendering, Anime_Stylization, Portrait, General_Object, "
277-
"Knowledge_Reasoning, Multilingualism. Addresses text rendering precision and prompt-image alignment."
278-
),
279-
metrics=[
280-
# "alignment_score", "text_score" not supported in Pruna
281-
],
282-
task_type="text_to_image",
283-
),
284-
"DPG": Benchmark(
285-
name="dpg",
286-
display_name="DPG",
287-
description=(
288-
"Dense Prompt Graph benchmark from ELLA/Tencent. ~1,000 complex prompts testing "
289-
"entity, attribute, relation, and global aspects. Evaluates models on dense prompt "
290-
"following with multiple objects and varied attributes."
291-
),
292-
metrics=[
293-
# "qa_accuracy" not supported in Pruna
294-
],
295-
task_type="text_to_image",
296-
),
297-
"COCO": Benchmark(
298-
name="coco",
299-
display_name="COCO",
300-
description=(
301-
"Microsoft COCO dataset for image generation evaluation. Real image-caption pairs "
302-
"enabling FID and alignment metrics on distribution-level and instance-level quality."
303-
),
304-
metrics=["fid", "clip_score", "clipiqa"],
305-
task_type="text_to_image",
306-
),
307-
"ImageNet": Benchmark(
308-
name="imagenet",
309-
display_name="ImageNet",
310-
description=(
311-
"Large-scale image classification benchmark with 1,000 classes. Standard evaluation "
312-
"for vision model accuracy on object recognition."
313-
),
314-
metrics=["accuracy"],
315-
task_type="image_classification",
316-
),
317-
"WikiText": Benchmark(
318-
name="wikitext",
319-
display_name="WikiText",
320-
description=(
321-
"Language modeling benchmark based on Wikipedia articles. Standard evaluation "
322-
"for text generation quality via perplexity."
323-
),
324-
metrics=["perplexity"],
325-
task_type="text_generation",
326-
),
327-
}

src/pruna/data/datasets/prompt.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
from typing import Literal, Tuple, get_args
1617

1718
from datasets import Dataset, load_dataset
@@ -122,6 +123,10 @@
122123
]
123124
DPGCategory = Literal["entity", "attribute", "relation", "global", "other"]
124125

126+
ONEIG_DATASET_CATEGORIES = frozenset(
127+
{"Anime_Stylization", "General_Object", "Knowledge_Reasoning", "Multilingualism", "Portrait", "Text_Rendering"}
128+
)
129+
125130

126131
def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
127132
"""
@@ -177,7 +182,9 @@ def setup_parti_prompts_dataset(
177182
"""
178183
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]
179184

180-
if category is not None:
185+
if category is None:
186+
pass
187+
else:
181188
categories = [category] if not isinstance(category, list) else category
182189
ds = ds.filter(lambda x: x["Category"] in categories or x["Challenge"] in categories)
183190

@@ -419,8 +426,16 @@ def setup_imgedit_dataset(
419426
instructions_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/b3eb8e74d7cd1fd0ce5341eaf9254744a8ab4c0b/Benchmark/Basic/basic_edit.json"
420427
judge_prompts_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/c14480ac5e7b622e08cd8c46f96624a48eb9ab46/Benchmark/Basic/prompts.json"
421428

422-
instructions = json.loads(requests.get(instructions_url).text)
423-
judge_prompts = json.loads(requests.get(judge_prompts_url).text)
429+
resp_inst = requests.get(instructions_url)
430+
resp_judge = requests.get(judge_prompts_url)
431+
instructions: dict = {}
432+
if resp_inst.status_code == 200:
433+
with contextlib.suppress(json.JSONDecodeError):
434+
instructions = json.loads(resp_inst.text)
435+
judge_prompts: dict = {}
436+
if resp_judge.status_code == 200:
437+
with contextlib.suppress(json.JSONDecodeError):
438+
judge_prompts = json.loads(resp_judge.text)
424439

425440
categories = [category] if category is not None and not isinstance(category, list) else category
426441
records = []
@@ -495,7 +510,7 @@ def _load_oneig_alignment(seed: int, category: str | None = None, class_filter:
495510
except json.JSONDecodeError:
496511
pass
497512

498-
alignment_cats = {"Anime_Stylization", "Portrait", "General_Object"}
513+
alignment_cats = ONEIG_DATASET_CATEGORIES - {"Knowledge_Reasoning", "Multilingualism", "Text_Rendering"}
499514
records = []
500515
for row in ds:
501516
row_id = row.get("id", "")
@@ -525,21 +540,14 @@ def _load_oneig_alignment(seed: int, category: str | None = None, class_filter:
525540
return Dataset.from_list(records).shuffle(seed=seed)
526541

527542

528-
ONEIG_DATASET_CATEGORIES = frozenset(
529-
{"Anime_Stylization", "General_Object", "Knowledge_Reasoning", "Multilingualism", "Portrait", "Text_Rendering"}
530-
)
531-
532-
533543
def _load_oneig_generic(
534544
seed: int,
535545
category_filter: str | None = None,
536546
class_filter: str | None = None,
537547
config: str = "OneIG-Bench",
538548
) -> Dataset:
539549
"""Load OneIG data for Knowledge_Reasoning, Multilingualism, or any category without alignment questions."""
540-
ds = load_dataset("OneIG-Bench/OneIG-Bench", config)[ # type: ignore[index]
541-
"train"
542-
]
550+
ds = load_dataset("OneIG-Bench/OneIG-Bench", config)["train"] # type: ignore[index]
543551

544552
records = []
545553
for row in ds:

src/pruna/data/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,16 @@ def _prepare_test_only_prompt_dataset(
268268
dataset_name: str,
269269
) -> Tuple[Dataset, Dataset, Dataset]:
270270
"""
271-
Shared tail for test-only prompt datasets: shuffle, return dummy train/val + test.
271+
Shared tail for test-only prompt datasets: return dummy train/val + test.
272272
273-
All benchmark datasets use this.
273+
All benchmark datasets use this. Test datasets are not shuffled.
274274
275275
Parameters
276276
----------
277277
ds : Dataset
278278
The dataset to prepare.
279279
seed : int
280-
The seed for shuffling.
280+
Unused; kept for API compatibility.
281281
dataset_name : str
282282
Name for logging.
283283
@@ -286,7 +286,6 @@ def _prepare_test_only_prompt_dataset(
286286
Tuple[Dataset, Dataset, Dataset]
287287
Dummy train, dummy val, and test datasets.
288288
"""
289-
ds = ds.shuffle(seed=seed)
290289
pruna_logger.info(f"{dataset_name} is a test-only dataset. Do not use it for training or validation.")
291290
return ds.select([0]), ds.select([0]), ds
292291

0 commit comments

Comments
 (0)