-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathsplit_json.py
More file actions
38 lines (28 loc) · 1.08 KB
/
split_json.py
File metadata and controls
38 lines (28 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Split JSON test data by GPU count for multi-GPU parallel evaluation."""
import json
import os
import fire
def split(input_path: str, output_path: str, cuda_list: str):
"""Split JSON test data.
Args:
input_path: Input JSON file path
output_path: Output directory
cuda_list: GPU list, e.g. "0,1,2,3" or "0 1 2 3"
"""
if isinstance(cuda_list, str):
cuda_list = [x.strip() for x in cuda_list.replace(",", " ").split() if x.strip()]
with open(input_path, encoding="utf-8") as f:
data = json.load(f)
os.makedirs(output_path, exist_ok=True)
n = len(data)
num_parts = len(cuda_list)
for i, gpu_id in enumerate(cuda_list):
start = i * n // num_parts
end = (i + 1) * n // num_parts
part = data[start:end]
out_file = os.path.join(output_path, f"{gpu_id}.json")
with open(out_file, "w", encoding="utf-8") as f:
json.dump(part, f, indent=4, ensure_ascii=False)
print(f"GPU {gpu_id}: {len(part)} samples -> {out_file}")
if __name__ == "__main__":
fire.Fire(split)