forked from nunchaku-ai/nunchaku
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_metrics.py
More file actions
73 lines (62 loc) · 2.7 KB
/
Copy pathget_metrics.py
File metadata and controls
73 lines (62 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import json
import os
from data import get_dataset
from metrics.fid import compute_fid
from metrics.image_reward import compute_image_reward
from metrics.multimodal import compute_image_multimodal_metrics
from metrics.similarity import compute_image_similarity_metrics
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("input_roots", type=str, nargs="*")
parser.add_argument("-o", "--output-path", type=str, default="metrics.json", help="Image output path")
parser.add_argument(
"--max-dataset-size",
type=int,
default=1024,
help="Maximum number of images to compute metrics for each dataset",
)
args = parser.parse_args()
return args
def main():
args = get_args()
assert len(args.input_roots) > 0
assert len(args.input_roots) <= 2
image_root1 = args.input_roots[0]
if len(args.input_roots) > 1:
image_root2 = args.input_roots[1]
else:
image_root2 = None
results = {}
dataset_names = sorted(os.listdir(image_root1))
for dataset_name in dataset_names:
if image_root2 is not None and dataset_name not in os.listdir(image_root2):
continue
print("Results for dataset:", dataset_name)
results[dataset_name] = {}
dataset = get_dataset(name=dataset_name, return_gt=True, max_dataset_size=args.max_dataset_size)
fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
results[dataset_name]["fid"] = fid
print("FID:", fid)
multimodal_metrics = compute_image_multimodal_metrics(
ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name)
)
results[dataset_name].update(multimodal_metrics)
for k, v in multimodal_metrics.items():
print(f"{k}:", v)
image_reward = compute_image_reward(ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
results[dataset_name].update(image_reward)
for k, v in image_reward.items():
print(f"{k}:", v)
if image_root2 is not None and os.path.exists(os.path.join(image_root2, dataset_name)):
similarity_results = compute_image_similarity_metrics(
os.path.join(image_root1, dataset_name), os.path.join(image_root2, dataset_name)
)
results[dataset_name].update(similarity_results)
for k, v in similarity_results.items():
print(f"{k}:", v)
os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2, sort_keys=True)
if __name__ == "__main__":
main()