-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsplit_data.py
55 lines (43 loc) · 1.24 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
'''
Split a data set into two randomly, line by line. Used to create a hold out data set from a training set for cross-validation
Usage: split.py <input whole training set> <output split training set> <output hold out test set> <headers? (Y/N)> [% to holdout] [SEED]
Ex.: python split_data.py Data/text_body_trn.svm Data/text_body_trn_split.svm Data/text_body_trn_holdout.svm N .1 888
'''
__author__ = 'Bryan Gregory'
__date__ = '09-05-2013'
import csv
import sys
import random
print __doc__
input_file = sys.argv[1]
output_file1 = sys.argv[2]
output_file2 = sys.argv[3]
headers_fg = sys.argv[4]
i = open( input_file )
o1 = open( output_file1, 'wb' )
o2 = open( output_file2, 'wb' )
if (headers_fg == 'Y'):
headers = i.next()
o1.write( headers )
o2.write( headers )
try:
P = 1 - float(sys.argv[5])
except IndexError:
P = 0.15
try:
seed = sys.argv[6]
except IndexError:
seed = None
print "Splitting %s percent of data" % (P*100)
if seed:
random.seed(seed)
counter = 0
for line in i:
r = random.random()
if r < P:
o2.write( line )
else:
o1.write( line )
counter += 1
print "Split complete. Total rows processed: {:,}".format(counter)