-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmerge.py
More file actions
87 lines (69 loc) · 3.42 KB
/
merge.py
File metadata and controls
87 lines (69 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import json
import pandas as pd
from tqdm import tqdm
from glob import glob
from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse
def main(result_base_path, data_base_path, output_dir):
groups = sorted([
dir for dir in os.listdir(result_base_path) if os.path.isdir(os.path.join(result_base_path, dir))
])
first_dataset = list(glob(f"{result_base_path}/*/*"))[0]
detectors = [
dir for dir in os.listdir(first_dataset) if os.path.isdir(os.path.join(first_dataset, dir))
]
example_detector = detectors[0]
# MAIN FUNCTION TO PARALLELIZE
def process_repetition(dataset, group):
with open(f"{data_base_path}/{group}/{dataset}/{dataset}_types.json", "r") as f:
types = json.load(f)
result = {name: [] for name in detectors + ["dataset", "label", "type", "column", "index"]}
label_path = f"{data_base_path}/{group}/{dataset}/{dataset}_labels.csv"
labels = pd.read_csv(label_path)
for detector in detectors:
try:
score_path = os.path.join(result_base_path, group, dataset, detector, f"scores.csv")
scores = pd.read_csv(score_path)
except FileNotFoundError:
scores = pd.DataFrame(0.0, index=labels.index, columns=labels.columns)
for col in labels.columns:
result[detector].extend(scores[col].tolist())
if detector == example_detector:
result["dataset"].extend([dataset] * len(scores))
result["label"].extend(labels[col].tolist())
result["type"].extend([types[col]] * len(scores))
result["column"].extend([col] * len(scores))
result["index"].extend(list(range(len(scores))))
return result
# PREPARE TASKS
tasks = []
for group in groups:
datasets = sorted([
dir for dir in os.listdir(f"{result_base_path}/{group}") if os.path.isdir(os.path.join(f"{result_base_path}/{group}", dir))
])
for dataset in datasets:
tasks.append((dataset, group))
# RUN TASKS IN PARALLEL
final_data = {name: [] for name in detectors + ["dataset", "label", "type", "column", "index"]}
with ThreadPoolExecutor(max_workers=32) as executor:
futures = [executor.submit(process_repetition, *task) for task in tasks]
for future in tqdm(futures, total=len(futures), desc="Processing"):
result = future.result()
for key in final_data:
final_data[key].extend(result[key])
# WRITE FINAL DATAFRAMES
os.makedirs(output_dir, exist_ok=True)
for column in final_data.keys():
print(f"Writing column: {column}")
pd.DataFrame({column: final_data[column]}).to_csv(os.path.join(output_dir, f"{column}.csv"), index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge polluted detection results")
parser.add_argument("--result_base_path", required=True, help="Path to results directory")
parser.add_argument("--data_base_path", required=True, help="Path to data directory")
parser.add_argument("--output_dir", required=True, help="Path to output directory")
args = parser.parse_args()
result_base_path = args.result_base_path
data_base_path = args.data_base_path
output_dir = args.output_dir
main(result_base_path, data_base_path, output_dir)