-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_perturbations.py
More file actions
150 lines (118 loc) · 5.87 KB
/
Copy pathgenerate_perturbations.py
File metadata and controls
150 lines (118 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import json
import logging
from pathlib import Path
from gemini_client import get_client, load_image, download_image, apply_image_edit
from config import PERTURBATION_CATEGORIES, MODEL_CONFIG
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
def build_edit_prompt(original_prompt, instruction):
return (
f"Apply the following edit to the image.\n\n"
f"Edit:\n{instruction}\n\n"
f"Do not change anything else in the image.\n\n"
f"Original description:\n{original_prompt}"
)
def get_gold_image(inst_id, gold_image_dir, instances_dir):
"""Load gold image from local dir or download from instance URL."""
local_path = Path(gold_image_dir) / f"{inst_id}.png"
if local_path.exists():
return load_image(str(local_path))
# Fallback: try to find URL in instance JSON
inst_file = Path(instances_dir) / f"{inst_id}.json"
if inst_file.exists():
with open(inst_file) as f:
inst = json.load(f)
image_url = inst.get("image_url") or inst.get("data", {}).get("img_url")
if image_url:
return download_image(image_url)
return None
def main():
parser = argparse.ArgumentParser(description="Apply perturbations to generate perturbed images")
parser.add_argument("--edit_dir", required=True,
help="Directory with edit instruction JSONs (category/subcategory/id.json)")
parser.add_argument("--gold_image_dir", required=True, help="Directory with gold images")
parser.add_argument("--instances_dir", required=True, help="Directory with original instance JSONs")
parser.add_argument("--output_dir", required=True, help="Directory to save perturbed images and final JSONs")
parser.add_argument("--category", default=None,
choices=list(PERTURBATION_CATEGORIES.keys()))
parser.add_argument("--subcategory", default=None)
parser.add_argument("--model", default=MODEL_CONFIG["image_generation_model"])
args = parser.parse_args()
client = get_client()
edit_dir = Path(args.edit_dir)
output_dir = Path(args.output_dir)
total_success = 0
total_failed = 0
for cat_dir in sorted(edit_dir.iterdir()):
if not cat_dir.is_dir():
continue
if args.category and cat_dir.name != args.category:
continue
for subcat_dir in sorted(cat_dir.iterdir()):
if not subcat_dir.is_dir():
continue
if args.subcategory and subcat_dir.name != args.subcategory:
continue
cat_key = cat_dir.name
subcat_key = subcat_dir.name
img_out = output_dir / "images" / cat_key / subcat_key
json_out = output_dir / "json" / cat_key / subcat_key
img_out.mkdir(parents=True, exist_ok=True)
json_out.mkdir(parents=True, exist_ok=True)
edit_files = sorted(subcat_dir.glob("*.json"))
logger.info("Processing %s/%s: %d items", cat_key, subcat_key, len(edit_files))
for i, edit_file in enumerate(edit_files, 1):
with open(edit_file) as f:
edit_data = json.load(f)
inst_id = edit_data["id"]
prompt = edit_data["data"]["prompt"]
instruction = edit_data.get("edit_instruction", "")
if not instruction:
logger.warning("No edit instruction for %s, skipping", inst_id)
continue
perturbed_path = img_out / f"{inst_id}_perturbed.png"
if perturbed_path.exists():
total_success += 1
continue
logger.info("[%d/%d] %s/%s: %s", i, len(edit_files),
cat_key, subcat_key, inst_id)
try:
image = get_gold_image(inst_id, args.gold_image_dir, args.instances_dir)
if image is None:
logger.error("No gold image for %s", inst_id)
total_failed += 1
continue
edit_prompt = build_edit_prompt(prompt, instruction)
result = apply_image_edit(client, image, edit_prompt, args.model)
if result:
result.save(str(perturbed_path))
# Save final output JSON
gold_image_path = f"images/gold/{inst_id}.png"
perturbed_image_path = f"images/{cat_key}/{subcat_key}/{inst_id}_perturbed.png"
final_json = {
"id": inst_id,
"data": {
"prompt": prompt,
"gold_image": gold_image_path,
"perturbed_image": perturbed_image_path,
},
"type": edit_data.get("type", ""),
"difficulty": edit_data.get("difficulty", ""),
"category": cat_key,
"subcategory": subcat_key,
"edit_instruction": instruction,
}
with open(json_out / f"{inst_id}.json", "w") as f:
json.dump(final_json, f, indent=4)
total_success += 1
logger.info("Saved %s", perturbed_path.name)
else:
logger.error("Failed to generate perturbed image for %s", inst_id)
total_failed += 1
except Exception as e:
logger.error("Error for %s: %s", inst_id, e)
total_failed += 1
logger.info("Completed: %d succeeded, %d failed", total_success, total_failed)
if __name__ == "__main__":
main()