-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_augmentation.py
102 lines (73 loc) · 2.57 KB
/
data_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# coding: utf-8
# In[ ]:
import os
import numpy as np
from scipy import misc
from shutil import copyfile
train_list = os.listdir('stage1_train')
train_list = [x for x in train_list if not x.startswith('.')]
# In[ ]:
#generate cumulative masks
for ii,f in enumerate(train_list):
small_masks = os.listdir('stage1_train/'+f+'/masks')
og_image = os.listdir('stage1_train/'+f+'/images')
big_mask = 'placeholder'
for s in small_masks:
a=misc.imread('stage1_train/'+f+'/masks/'+s)
if big_mask == 'placeholder':
big_mask = a
else:
big_mask = big_mask + a
misc.imsave('train/y/'+'mask_'+str(ii)+'.png',big_mask)
copyfile('stage1_train/'+f+'/images/'+og_image[0],'train/x/'+'img_'+str(ii)+'.png')
# In[ ]:
#extract 100 random patches from each image
from sklearn.feature_extraction import image
np.random.seed(0)
og_x_list = os.listdir('train/x/')
og_x_list = [ x for x in og_x_list if not x.endswith('.DS_Store') ]
for x in og_x_list:
img_x = misc.imread('train/x/'+x)
img_y = misc.imread('train/y/'+'mask_'+x[4:])
img = np.dstack((img_x,img_y))
patches = image.extract_patches_2d(img, (100+np.random.randint(low=0,high=100), 100+np.random.randint(low=0,high=100)),100,0) #50% of possible patches, random seed = 0
dims = np.shape(patches)
for i in range(dims[0]):
mask = patches[i,:,:,-1]
misc.imsave('train/y/'+'mask_'+x[4:-4]+'_'+str(i)+'.png',mask)
patch = patches[i,:,:,0:dims[-1]-1]
misc.imsave('train/x/'+'img_'+x[4:-4]+'_'+str(i)+'.png',patch)
# In[ ]:
#rotate and flip
xlist = os.listdir('./train/x/')
ylist = os.listdir('./train/y/')
for f in xlist:
im = misc.imread('./train/x/'+f)
im1 = np.rot90(im)
im2 = np.rot90(im1)
im3 = np.rot90(im2)
f=f[:-4]
misc.imsave('./train/x/'+f+'_r1.png',im1)
misc.imsave('./train/x/'+f+'_r2.png',im2)
misc.imsave('./train/x/'+f+'_r3.png',im3)
for f in ylist:
im = misc.imread('./train/y/'+f)
im1 = np.rot90(im)
im2 = np.rot90(im1)
im3 = np.rot90(im2)
f=f[:-4]
misc.imsave('./train/y/'+f+'_r1.png',im1)
misc.imsave('./train/y/'+f+'_r2.png',im2)
misc.imsave('./train/y/'+f+'_r3.png',im3)
xlist = os.listdir('./train/x/')
ylist = os.listdir('./train/y/')
for f in xlist:
im = misc.imread('./train/x/'+f)
im1 = np.flip_ud(im)
f=f[:-4]
misc.imsave('./train/x/'+f+'_flip.png',im1)
for f in ylist:
im = misc.imread('./train/y/'+f)
im1 = np.flip_ud(im)
f=f[:-4]
misc.imsave('./train/y/'+f+'_flip.png',im1)