@@ -24,7 +24,7 @@ class N2V_DataWrapper(Sequence):
2424 The manipulator used for the pixel replacement.
2525 """
2626
27- def __init__ (self , X , Y , batch_size , num_pix = 1 , shape = (64 , 64 ),
27+ def __init__ (self , X , Y , batch_size , perc_pix = 0.198 , shape = (64 , 64 ),
2828 value_manipulation = None ):
2929 self .X , self .Y = X , Y
3030 self .batch_size = batch_size
@@ -35,93 +35,94 @@ def __init__(self, X, Y, batch_size, num_pix=1, shape=(64, 64),
3535 self .dims = len (shape )
3636 self .n_chan = X .shape [- 1 ]
3737
38+ num_pix = int (np .product (shape )/ 100.0 * perc_pix )
39+ assert num_pix >= 1 , "Number of blind-spot pixels is below one. At least {}% of pixels should be replaced." .format (100.0 / np .product (shape ))
40+ print ("{} blind-spots will be generated per training patch of size {}." .format (num_pix , shape ))
41+
3842 if self .dims == 2 :
3943 self .patch_sampler = self .__subpatch_sampling2D__
40- self .box_size = np .round (np .sqrt (shape [ 0 ] * shape [ 1 ] / num_pix )).astype (np .int )
44+ self .box_size = np .round (np .sqrt (100 / perc_pix )).astype (np .int )
4145 self .get_stratified_coords = self .__get_stratified_coords2D__
4246 self .rand_float = self .__rand_float_coords2D__ (self .box_size )
43- self .X_Batches = np .zeros ([X .shape [0 ], shape [0 ], shape [1 ], X .shape [3 ]])
44- self .Y_Batches = np .zeros ([Y .shape [0 ], shape [0 ], shape [1 ], Y .shape [3 ]])
4547 elif self .dims == 3 :
4648 self .patch_sampler = self .__subpatch_sampling3D__
47- self .box_size = np .round (np .power ( shape [ 0 ] * shape [ 1 ] * shape [ 2 ] / num_pix , 1 / 3.0 )).astype (np .int )
49+ self .box_size = np .round (np .sqrt ( 100 / perc_pix )).astype (np .int )
4850 self .get_stratified_coords = self .__get_stratified_coords3D__
4951 self .rand_float = self .__rand_float_coords3D__ (self .box_size )
50- self .X_Batches = np .zeros ([X .shape [0 ], shape [0 ], shape [1 ], shape [2 ], X .shape [4 ]])
51- self .Y_Batches = np .zeros ([Y .shape [0 ], shape [0 ], shape [1 ], shape [2 ], Y .shape [4 ]])
5252 else :
5353 raise Exception ('Dimensionality not supported.' )
5454
55+ self .X_Batches = np .zeros ((self .X .shape [0 ], * self .shape , self .n_chan ), dtype = np .float32 )
56+ self .Y_Batches = np .zeros ((self .Y .shape [0 ], * self .shape , 2 * self .n_chan ), dtype = np .float32 )
57+
5558 def __len__ (self ):
5659 return int (np .ceil (len (self .X ) / float (self .batch_size )))
5760
5861 def on_epoch_end (self ):
5962 self .perm = np .random .permutation (len (self .X ))
63+ self .X_Batches *= 0
64+ self .Y_Batches *= 0
6065
6166 def __getitem__ (self , i ):
6267 idx = slice (i * self .batch_size , (i + 1 ) * self .batch_size )
6368 idx = self .perm [idx ]
64- self .patch_sampler (self .X , self .Y , self . X_Batches , self . Y_Batches , idx , self .range , self .shape )
69+ self .patch_sampler (self .X , self .X_Batches , indices = idx , range = self .range , shape = self .shape )
6570
66- for j in idx :
67- for c in range ( self . n_chan ) :
71+ for c in range ( self . n_chan ) :
72+ for j in idx :
6873 coords = self .get_stratified_coords (self .rand_float , box_size = self .box_size ,
69- shape = np .array (self .X_Batches .shape )[1 :- 1 ])
70-
71- y_val = []
72- x_val = []
73- for k in range (len (coords )):
74- y_val .append (np .copy (self .Y_Batches [(j , * coords [k ], ..., c )]))
75- x_val .append (self .value_manipulation (self .X_Batches [j , ..., c ][...,np .newaxis ], coords [k ], self .dims ))
76-
77- self .Y_Batches [j ,...,c ] *= 0
78- self .Y_Batches [j ,...,self .n_chan + c ] *= 0
74+ shape = self .shape )
7975
80- for k in range ( len ( coords )):
81- self . Y_Batches [ (j , * coords [ k ], c )] = y_val [ k ]
82- self .Y_Batches [( j , * coords [ k ], self . n_chan + c )] = 1
83- self .X_Batches [( j , * coords [ k ] , c )] = x_val [ k ]
76+ indexing = ( j ,) + coords + ( c ,)
77+ indexing_mask = (j ,) + coords + ( c + self . n_chan , )
78+ y_val = self .X_Batches [ indexing ]
79+ x_val = self .value_manipulation ( self . X_Batches [j , ... , c ], coords , self . dims )
8480
81+ self .Y_Batches [indexing ] = y_val
82+ self .Y_Batches [indexing_mask ] = 1
83+ self .X_Batches [indexing ] = x_val
8584
8685 return self .X_Batches [idx ], self .Y_Batches [idx ]
8786
8887 @staticmethod
89- def __subpatch_sampling2D__ (X , Y , X_Batches , Y_Batches , indices , range , shape ):
88+ def __subpatch_sampling2D__ (X , X_Batches , indices , range , shape ):
9089 for j in indices :
9190 y_start = np .random .randint (0 , range [0 ] + 1 )
9291 x_start = np .random .randint (0 , range [1 ] + 1 )
93- X_Batches [j ] = X [j , y_start :y_start + shape [0 ], x_start :x_start + shape [1 ]]
94- Y_Batches [j ] = Y [j , y_start :y_start + shape [0 ], x_start :x_start + shape [1 ]]
92+ X_Batches [j ] = np .copy (X [j , y_start :y_start + shape [0 ], x_start :x_start + shape [1 ]])
9593
9694 @staticmethod
97- def __subpatch_sampling3D__ (X , Y , X_Batches , Y_Batches , indices , range , shape ):
95+ def __subpatch_sampling3D__ (X , X_Batches , indices , range , shape ):
9896 for j in indices :
9997 z_start = np .random .randint (0 , range [0 ] + 1 )
10098 y_start = np .random .randint (0 , range [1 ] + 1 )
10199 x_start = np .random .randint (0 , range [2 ] + 1 )
102- X_Batches [j ] = X [j , z_start :z_start + shape [0 ], y_start :y_start + shape [1 ], x_start :x_start + shape [2 ]]
103- Y_Batches [j ] = Y [j , z_start :z_start + shape [0 ], y_start :y_start + shape [1 ], x_start :x_start + shape [2 ]]
100+ X_Batches [j ] = np .copy (X [j , z_start :z_start + shape [0 ], y_start :y_start + shape [1 ], x_start :x_start + shape [2 ]])
104101
105102 @staticmethod
106103 def __get_stratified_coords2D__ (coord_gen , box_size , shape ):
107- coords = []
108104 box_count_y = int (np .ceil (shape [0 ] / box_size ))
109105 box_count_x = int (np .ceil (shape [1 ] / box_size ))
106+ x_coords = []
107+ y_coords = []
110108 for i in range (box_count_y ):
111109 for j in range (box_count_x ):
112110 y , x = next (coord_gen )
113111 y = int (i * box_size + y )
114112 x = int (j * box_size + x )
115113 if (y < shape [0 ] and x < shape [1 ]):
116- coords .append ((y , x ))
117- return coords
114+ y_coords .append (y )
115+ x_coords .append (x )
116+ return (y_coords , x_coords )
118117
119118 @staticmethod
120119 def __get_stratified_coords3D__ (coord_gen , box_size , shape ):
121- coords = []
122120 box_count_z = int (np .ceil (shape [0 ] / box_size ))
123121 box_count_y = int (np .ceil (shape [1 ] / box_size ))
124122 box_count_x = int (np .ceil (shape [2 ] / box_size ))
123+ x_coords = []
124+ y_coords = []
125+ z_coords = []
125126 for i in range (box_count_z ):
126127 for j in range (box_count_y ):
127128 for k in range (box_count_x ):
@@ -130,8 +131,10 @@ def __get_stratified_coords3D__(coord_gen, box_size, shape):
130131 y = int (j * box_size + y )
131132 x = int (k * box_size + x )
132133 if (z < shape [0 ] and y < shape [1 ] and x < shape [2 ]):
133- coords .append ((z , y , x ))
134- return coords
134+ z_coords .append (z )
135+ y_coords .append (y )
136+ x_coords .append (x )
137+ return (z_coords , y_coords , x_coords )
135138
136139 @staticmethod
137140 def __rand_float_coords2D__ (boxsize ):
0 commit comments