-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsplit_data_top.py
More file actions
31 lines (24 loc) · 1013 Bytes
/
split_data_top.py
File metadata and controls
31 lines (24 loc) · 1013 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pandas as pd
import os
from tqdm import *
def main():
root_csv = "/media/ngxbac/Bac/competition/kaggle/competition_data/quickdraw/data/csv/train_simplified/"
split_csv = "/media/ngxbac/Bac/competition/kaggle/competition_data/quickdraw/data/all_data/"
files = os.listdir(root_csv)
for file in tqdm(files):
# print(file)
file_path = os.path.join(root_csv, file)
df = pd.read_csv(file_path, usecols=["key_id", "drawing", "recognized", "countrycode"])
nrows = df.shape[0]
# Spend 5000 rows for validation
n_train = nrows - 5000
df_train = df.head(n_train)
df_valid = df.tail(5000)
path = os.path.join(split_csv, "train")
os.makedirs(path, exist_ok=True)
df_train.to_csv(os.path.join(path, file), index=False)
path = os.path.join(split_csv, "valid")
os.makedirs(path, exist_ok=True)
df_valid.to_csv(os.path.join(path, file), index=False)
if __name__ == '__main__':
main()