-
Notifications
You must be signed in to change notification settings - Fork 288
Expand file tree
/
Copy pathgenerate_validation_json.py
More file actions
63 lines (50 loc) · 2.13 KB
/
generate_validation_json.py
File metadata and controls
63 lines (50 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import json
import os
import random
def generate_merged_validation_json(input_dir, output_file):
# read in video2caption.json
with open(os.path.join(input_dir, "video2caption_replace.json"), "r") as f:
video2caption = json.load(f)
# count how many elements are in the list
num_elements = len(video2caption)
print(f"Number of elements in video2caption.json: {num_elements}")
# randomly sample 64 elements from the list
sampled_elements = random.sample(video2caption, 64)
# Transform sampled elements into validation.json format
validation_data = []
for element in sampled_elements:
assert element.get("cap") is not None, f"Caption is None for element: {element}"
validation_entry = {
"caption": element["cap"],
"video_path": element.get("path", ""),
"num_inference_steps": 40,
"height": 480,
"width": 832,
"num_frames": 77
}
validation_data.append(validation_entry)
# Create the final validation structure
validation_json = {
"data": validation_data
}
# Write the validation JSON to the output file
with open(output_file, "w") as f:
json.dump(validation_json, f, indent=2)
print(f"Generated validation JSON with {len(validation_data)} entries and saved to {output_file}")
def main():
parser = argparse.ArgumentParser()
# dataset_type: "mixkit"
parser.add_argument("--dataset_type", choices=["merged"], required=True)
parser.add_argument("--input_dir", type=str, required=True)
parser.add_argument("--output_file", type=str, required=True)
parser.add_argument("--num_elements", type=int, default=64)
parser.add_argument("--num_frames", type=int, default=77)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=832)
parser.add_argument("--num_inference_steps", type=int, default=40)
args = parser.parse_args()
if args.dataset_type == "merged":
generate_merged_validation_json(args.input_dir, args.output_file)
if __name__ == "__main__":
main()