-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathWizHarvester.py
More file actions
181 lines (138 loc) · 5.21 KB
/
Copy pathWizHarvester.py
File metadata and controls
181 lines (138 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Wizard to download VOS Datasets
import os
import yaml
import questionary
import gdown
import requests
import tarfile
import zipfile
import concurrent.futures
# ----
# MAIN
def main():
YAML_CONFIG = "supported_datasets.yaml"
gdrive_global_addr, datasets_info = prep_dict_config(YAML_CONFIG)
parent_folder, selected_datasets, selected_splits, proceed = run_wizard(
datasets_info
)
if not proceed:
return
os.makedirs(parent_folder, exist_ok=True)
workers = min(len(selected_datasets), os.cpu_count())
print(f"Using {workers} workers")
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
futures = []
for name in selected_datasets:
# Run multiple workers asynchronously
future = executor.submit(
process_single_dataset,
parent_folder,
name,
selected_splits,
gdrive_global_addr,
datasets_info[name],
)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
try:
# raise exceptions caught during the worker process
result = future.result()
print(result)
except Exception as e:
print(f"Failed with error: {e}")
print("\nData Harvestation completed")
# ---------
# FUNCTIONS
def process_single_dataset(
parent_folder, name, selected_splits, gdrive_global_addr, datasets_info
):
print(f"\n $$ Starting worker for dataset {name}")
downloaded_files = download_dataset(
parent_folder, name, selected_splits, gdrive_global_addr, datasets_info
)
if downloaded_files:
extract_compressed_content(downloaded_files, selected_splits)
return f"Successfully Downloaded and Extracted {name}"
else:
return f"Error with {name}"
def prep_dict_config(yaml_config):
with open(yaml_config, "r") as file:
yaml_config_dict = yaml.safe_load(file)
gdrive_global_addr = yaml_config_dict.pop("GDrive_addr", "")
datasets_info = yaml_config_dict
return gdrive_global_addr, datasets_info
def run_wizard(available_datasets):
parent_folder = questionary.path(
"Parent folder to download the datasets (Use 'Tab'):", only_directories=True
).ask()
if not parent_folder:
print("Operation cancelled")
return None, None, None, False
selected_datasets = questionary.checkbox(
"Select datasets to download ('Space' to select & 'Enter' to confirm):",
choices=list(available_datasets.keys()),
).ask()
if not selected_datasets:
print("Exit, no datasets selected")
return None, None, None, False
print(f"\n Datasets selected: {selected_datasets}")
selected_splits = questionary.checkbox(
"UNselect splits to download ('Space' to UNselect & 'Enter' to confirm):",
choices=[
questionary.Choice("train", checked=True),
questionary.Choice("valid", checked=True),
questionary.Choice("test", checked=True),
],
).ask()
if not selected_splits:
print("Exit, no split selected")
return None, None, None, False
return parent_folder, selected_datasets, selected_splits, True
def filter_info_based_on_splits_selected(info, splits):
if len(splits) == 3:
return info
if len(info) <= 1:
return info
# negative filtering (download all expect the unselected part)
updated_info = {}
unslected_splits = list(set(["train", "valid", "test"]) - set(splits))
for label, url_link in info.items():
if not any(s in label for s in unslected_splits):
updated_info[label] = url_link
info = updated_info
return info
def download_dataset(parent_folder, name, splits, gdrive_global_addr, info):
target_dir = os.path.join(parent_folder, name)
os.makedirs(target_dir, exist_ok=True)
downloaded_files = []
info = filter_info_based_on_splits_selected(info, splits)
# download content
for label, url_link in info.items():
if is_gdrive_link(url_link):
url_link = gdrive_global_addr + url_link
print(f"Starting download of {label} for {name}")
output_path = os.path.join(target_dir, label)
gdown.download(url_link, output_path, quiet=False, fuzzy=True)
downloaded_files.append((label, output_path))
return downloaded_files
def is_gdrive_link(url):
if url.startswith("https://"):
return False
return True
def extract_compressed_content(downloaded_files, selected_splits):
for label, file_path in downloaded_files:
target_dir = os.path.dirname(file_path)
if tarfile.is_tarfile(file_path):
with tarfile.open(file_path) as tar:
print(f"Starting extraction of {label}")
tar.extractall(path=target_dir)
os.remove(file_path)
elif zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zip_ref:
print(f"Starting extraction of {label}")
zip_ref.extractall(path=target_dir)
os.remove(file_path)
# ---------
# EXECUTION
if __name__ == "__main__":
main()