-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsplit_train_test.py
More file actions
84 lines (65 loc) · 2.4 KB
/
split_train_test.py
File metadata and controls
84 lines (65 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Split the dataset into train and test set.
Routine Listings
----------------
get_params()
Get the DVC stage parameters.
split(seed, test_ratio, input_path, train, test)
Split dataset into train and test set.
"""
import sys
import dask
import dask.distributed
import pandas as pd
from sklearn.model_selection import train_test_split
import conf
def get_params():
"""Get the DVC stage parameters."""
return {
'test_ratio': 0.33,
'seed': 42}
@dask.delayed
def split(seed, test_ratio, input_path, train, test):
"""Split dataset into train and test set."""
def sub_df_by_ids(df, ids):
df_train_order = pd.DataFrame(data={'id': ids})
return df.merge(df_train_order, on='id')
def train_test_split_df(df, ids, test_ratio, seed):
train_ids, test_ids = train_test_split(
ids, test_size=test_ratio, random_state=seed)
return sub_df_by_ids(df, train_ids), sub_df_by_ids(df, test_ids)
df = pd.read_csv(
input_path,
encoding='utf-8',
header=None,
delimiter='\t',
names=['id', 'label', 'text']
)
df_positive = df[df['label'] == 1]
df_negative = df[df['label'] == 0]
sys.stderr.write('Positive size {}, negative size {}\n'.format(
df_positive.shape[0],
df_negative.shape[0]
))
df_pos_train, df_pos_test = train_test_split_df(
df, df_positive.id, test_ratio, seed)
df_neg_train, df_neg_test = train_test_split_df(
df, df_negative.id, test_ratio, seed)
df_train = pd.concat([df_pos_train, df_neg_train])
df_test = pd.concat([df_pos_test, df_neg_test])
df_train.to_csv(train, sep='\t', header=False, index=False)
df_test.to_csv(test, sep='\t', header=False, index=False)
if __name__ == '__main__':
client = dask.distributed.Client('localhost:8786')
INPUT_DATASET_TSV_PATH = conf.data_dir/'xml_to_tsv'/'Posts.tsv'
dvc_stage_name = __file__.strip('.py')
print(f'dvc_stage_name: {dvc_stage_name}')
STAGE_OUTPUT_PATH = conf.data_dir/dvc_stage_name
conf.remote_mkdir(STAGE_OUTPUT_PATH).compute()
OUTPUT_TRAIN_TSV_PATH = STAGE_OUTPUT_PATH/'Posts-train.tsv'
OUTPUT_TEST_TSV_PATH = STAGE_OUTPUT_PATH/'Posts-test.tsv'
config = get_params()
TEST_RATIO = config['test_ratio']
SEED = config['seed']
split(SEED, TEST_RATIO, INPUT_DATASET_TSV_PATH,
OUTPUT_TRAIN_TSV_PATH, OUTPUT_TEST_TSV_PATH).compute()