Skip to content

Commit 8d9e5c3

Browse files
author
tibuch
authored
Merge pull request #53 from juglab/fix_numPix_computation
Fix num pix computation
2 parents fbc0432 + 8fc0930 commit 8d9e5c3

13 files changed

Lines changed: 1425 additions & 252 deletions

File tree

examples/2D/denoising2D_BSD68/BSD68_reproducibility.ipynb

Lines changed: 925 additions & 0 deletions
Large diffs are not rendered by default.

examples/2D/denoising2D_RGB/01_training.ipynb

Lines changed: 99 additions & 41 deletions
Large diffs are not rendered by default.

examples/2D/denoising2D_RGB/02_prediction.ipynb

Lines changed: 19 additions & 10 deletions
Large diffs are not rendered by default.

examples/2D/denoising2D_SEM/01_training.ipynb

Lines changed: 84 additions & 26 deletions
Large diffs are not rendered by default.

examples/2D/denoising2D_SEM/02_prediction.ipynb

Lines changed: 21 additions & 5 deletions
Large diffs are not rendered by default.

examples/3D/01_training.ipynb

Lines changed: 112 additions & 37 deletions
Large diffs are not rendered by default.

examples/3D/02_prediction.ipynb

Lines changed: 15 additions & 5 deletions
Large diffs are not rendered by default.

n2v/internals/N2V_DataWrapper.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class N2V_DataWrapper(Sequence):
2424
The manipulator used for the pixel replacement.
2525
"""
2626

27-
def __init__(self, X, Y, batch_size, num_pix=1, shape=(64, 64),
27+
def __init__(self, X, Y, batch_size, perc_pix=0.198, shape=(64, 64),
2828
value_manipulation=None):
2929
self.X, self.Y = X, Y
3030
self.batch_size = batch_size
@@ -35,93 +35,94 @@ def __init__(self, X, Y, batch_size, num_pix=1, shape=(64, 64),
3535
self.dims = len(shape)
3636
self.n_chan = X.shape[-1]
3737

38+
num_pix = int(np.product(shape)/100.0 * perc_pix)
39+
assert num_pix >= 1, "Number of blind-spot pixels is below one. At least {}% of pixels should be replaced.".format(100.0/np.product(shape))
40+
print("{} blind-spots will be generated per training patch of size {}.".format(num_pix, shape))
41+
3842
if self.dims == 2:
3943
self.patch_sampler = self.__subpatch_sampling2D__
40-
self.box_size = np.round(np.sqrt(shape[0] * shape[1] / num_pix)).astype(np.int)
44+
self.box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
4145
self.get_stratified_coords = self.__get_stratified_coords2D__
4246
self.rand_float = self.__rand_float_coords2D__(self.box_size)
43-
self.X_Batches = np.zeros([X.shape[0], shape[0], shape[1], X.shape[3]])
44-
self.Y_Batches = np.zeros([Y.shape[0], shape[0], shape[1], Y.shape[3]])
4547
elif self.dims == 3:
4648
self.patch_sampler = self.__subpatch_sampling3D__
47-
self.box_size = np.round(np.power(shape[0] * shape[1] * shape[2] / num_pix, 1/3.0)).astype(np.int)
49+
self.box_size = np.round(np.sqrt(100 / perc_pix)).astype(np.int)
4850
self.get_stratified_coords = self.__get_stratified_coords3D__
4951
self.rand_float = self.__rand_float_coords3D__(self.box_size)
50-
self.X_Batches = np.zeros([X.shape[0], shape[0], shape[1], shape[2], X.shape[4]])
51-
self.Y_Batches = np.zeros([Y.shape[0], shape[0], shape[1], shape[2], Y.shape[4]])
5252
else:
5353
raise Exception('Dimensionality not supported.')
5454

55+
self.X_Batches = np.zeros((self.X.shape[0], *self.shape, self.n_chan), dtype=np.float32)
56+
self.Y_Batches = np.zeros((self.Y.shape[0], *self.shape, 2*self.n_chan), dtype=np.float32)
57+
5558
def __len__(self):
5659
return int(np.ceil(len(self.X) / float(self.batch_size)))
5760

5861
def on_epoch_end(self):
5962
self.perm = np.random.permutation(len(self.X))
63+
self.X_Batches *= 0
64+
self.Y_Batches *= 0
6065

6166
def __getitem__(self, i):
6267
idx = slice(i * self.batch_size, (i + 1) * self.batch_size)
6368
idx = self.perm[idx]
64-
self.patch_sampler(self.X, self.Y, self.X_Batches, self.Y_Batches, idx, self.range, self.shape)
69+
self.patch_sampler(self.X, self.X_Batches, indices=idx, range=self.range, shape=self.shape)
6570

66-
for j in idx:
67-
for c in range(self.n_chan):
71+
for c in range(self.n_chan):
72+
for j in idx:
6873
coords = self.get_stratified_coords(self.rand_float, box_size=self.box_size,
69-
shape=np.array(self.X_Batches.shape)[1:-1])
70-
71-
y_val = []
72-
x_val = []
73-
for k in range(len(coords)):
74-
y_val.append(np.copy(self.Y_Batches[(j, *coords[k], ..., c)]))
75-
x_val.append(self.value_manipulation(self.X_Batches[j, ..., c][...,np.newaxis], coords[k], self.dims))
76-
77-
self.Y_Batches[j,...,c] *= 0
78-
self.Y_Batches[j,...,self.n_chan+c] *= 0
74+
shape=self.shape)
7975

80-
for k in range(len(coords)):
81-
self.Y_Batches[(j, *coords[k], c)] = y_val[k]
82-
self.Y_Batches[(j, *coords[k], self.n_chan+c)] = 1
83-
self.X_Batches[(j, *coords[k], c)] = x_val[k]
76+
indexing = (j,) + coords + (c,)
77+
indexing_mask = (j,) + coords + (c + self.n_chan, )
78+
y_val = self.X_Batches[indexing]
79+
x_val = self.value_manipulation(self.X_Batches[j, ..., c], coords, self.dims)
8480

81+
self.Y_Batches[indexing] = y_val
82+
self.Y_Batches[indexing_mask] = 1
83+
self.X_Batches[indexing] = x_val
8584

8685
return self.X_Batches[idx], self.Y_Batches[idx]
8786

8887
@staticmethod
89-
def __subpatch_sampling2D__(X, Y, X_Batches, Y_Batches, indices, range, shape):
88+
def __subpatch_sampling2D__(X, X_Batches, indices, range, shape):
9089
for j in indices:
9190
y_start = np.random.randint(0, range[0] + 1)
9291
x_start = np.random.randint(0, range[1] + 1)
93-
X_Batches[j] = X[j, y_start:y_start + shape[0], x_start:x_start + shape[1]]
94-
Y_Batches[j] = Y[j, y_start:y_start + shape[0], x_start:x_start + shape[1]]
92+
X_Batches[j] = np.copy(X[j, y_start:y_start + shape[0], x_start:x_start + shape[1]])
9593

9694
@staticmethod
97-
def __subpatch_sampling3D__(X, Y, X_Batches, Y_Batches, indices, range, shape):
95+
def __subpatch_sampling3D__(X, X_Batches, indices, range, shape):
9896
for j in indices:
9997
z_start = np.random.randint(0, range[0] + 1)
10098
y_start = np.random.randint(0, range[1] + 1)
10199
x_start = np.random.randint(0, range[2] + 1)
102-
X_Batches[j] = X[j, z_start:z_start + shape[0], y_start:y_start + shape[1], x_start:x_start + shape[2]]
103-
Y_Batches[j] = Y[j, z_start:z_start + shape[0], y_start:y_start + shape[1], x_start:x_start + shape[2]]
100+
X_Batches[j] = np.copy(X[j, z_start:z_start + shape[0], y_start:y_start + shape[1], x_start:x_start + shape[2]])
104101

105102
@staticmethod
106103
def __get_stratified_coords2D__(coord_gen, box_size, shape):
107-
coords = []
108104
box_count_y = int(np.ceil(shape[0] / box_size))
109105
box_count_x = int(np.ceil(shape[1] / box_size))
106+
x_coords = []
107+
y_coords = []
110108
for i in range(box_count_y):
111109
for j in range(box_count_x):
112110
y, x = next(coord_gen)
113111
y = int(i * box_size + y)
114112
x = int(j * box_size + x)
115113
if (y < shape[0] and x < shape[1]):
116-
coords.append((y, x))
117-
return coords
114+
y_coords.append(y)
115+
x_coords.append(x)
116+
return (y_coords, x_coords)
118117

119118
@staticmethod
120119
def __get_stratified_coords3D__(coord_gen, box_size, shape):
121-
coords = []
122120
box_count_z = int(np.ceil(shape[0] / box_size))
123121
box_count_y = int(np.ceil(shape[1] / box_size))
124122
box_count_x = int(np.ceil(shape[2] / box_size))
123+
x_coords = []
124+
y_coords = []
125+
z_coords = []
125126
for i in range(box_count_z):
126127
for j in range(box_count_y):
127128
for k in range(box_count_x):
@@ -130,8 +131,10 @@ def __get_stratified_coords3D__(coord_gen, box_size, shape):
130131
y = int(j * box_size + y)
131132
x = int(k * box_size + x)
132133
if (z < shape[0] and y < shape[1] and x < shape[2]):
133-
coords.append((z, y, x))
134-
return coords
134+
z_coords.append(z)
135+
y_coords.append(y)
136+
x_coords.append(x)
137+
return (z_coords, y_coords, x_coords)
135138

136139
@staticmethod
137140
def __rand_float_coords2D__(boxsize):

n2v/models/n2v_standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,14 @@ def train(self, X, validation_X, epochs=None, steps_per_epoch=None):
208208
# Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with
209209
# a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating.
210210
training_data = N2V_DataWrapper(X, np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')),
211-
self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix),
211+
self.config.train_batch_size, self.config.n2v_perc_pix,
212212
self.config.n2v_patch_shape, manipulator)
213213

214214
# validation_Y is also validation_X plus a concatenated masking channel.
215215
# To speed things up, we precompute the masking vo the validation data.
216216
validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C'))
217217
n2v_utils.manipulate_val_data(validation_X, validation_Y,
218-
num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix),
218+
perc_pix=self.config.n2v_perc_pix,
219219
shape=val_patch_shape,
220220
value_manipulation=manipulator)
221221

n2v/utils/n2v_utils.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ def get_subpatch(patch, coord, local_sub_patch_radius):
77
start = np.maximum(0, np.array(coord) - local_sub_patch_radius)
88
end = start + local_sub_patch_radius*2 + 1
99

10-
start = np.append(start, 0)
11-
end = np.append(end, patch.shape[-1])
12-
1310
shift = np.minimum(0, patch.shape - end)
1411

1512
start += shift
@@ -37,66 +34,78 @@ def normal_int(mean, sigma, w):
3734

3835

3936
def pm_normal_withoutCP(local_sub_patch_radius):
40-
def normal_withoutCP(patch, coord, dims):
41-
rand_coords = random_neighbor(patch.shape, coord)
42-
return patch[tuple(rand_coords)]
37+
def normal_withoutCP(patch, coords, dims):
38+
vals = []
39+
for coord in zip(*coords):
40+
rand_coords = random_neighbor(patch.shape, coord)
41+
vals.append(patch[tuple(rand_coords)])
42+
return vals
4343
return normal_withoutCP
4444

4545

4646
def pm_uniform_withCP(local_sub_patch_radius):
47-
def random_neighbor_withCP_uniform(patch, coord, dims):
48-
sub_patch = get_subpatch(patch, coord,local_sub_patch_radius)
49-
rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]]
50-
return sub_patch[tuple(rand_coords)]
47+
def random_neighbor_withCP_uniform(patch, coords, dims):
48+
vals = []
49+
for coord in zip(*coords):
50+
sub_patch = get_subpatch(patch, coord,local_sub_patch_radius)
51+
rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]]
52+
vals.append(sub_patch[tuple(rand_coords)])
53+
return vals
5154
return random_neighbor_withCP_uniform
5255

5356

5457
def pm_normal_additive(pixel_gauss_sigma):
55-
def pixel_gauss(patch, coord, dims):
56-
return np.random.normal(patch[tuple(coord)], pixel_gauss_sigma)
58+
def pixel_gauss(patch, coords, dims):
59+
vals = []
60+
for coord in zip(*coords):
61+
vals.append(np.random.normal(patch[tuple(coord)], pixel_gauss_sigma))
62+
return vals
5763
return pixel_gauss
5864

5965

6066
def pm_normal_fitted(local_sub_patch_radius):
61-
def local_gaussian(patch, coord, dims):
62-
sub_patch = get_subpatch(patch, coord, local_sub_patch_radius)
63-
axis = tuple(range(dims))
64-
return np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis))
67+
def local_gaussian(patch, coords, dims):
68+
vals = []
69+
for coord in zip(*coords):
70+
sub_patch = get_subpatch(patch, coord, local_sub_patch_radius)
71+
axis = tuple(range(dims))
72+
vals.append(np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis)))
73+
return vals
6574
return local_gaussian
6675

6776

6877
def pm_identity(local_sub_patch_radius):
69-
def identity(patch, coord, dims):
70-
return patch[tuple(coord)]
78+
def identity(patch, coords, dims):
79+
vals = []
80+
for coord in zip(*coords):
81+
vals.append(patch[coord])
82+
return vals
7183
return identity
7284

7385

74-
def manipulate_val_data(X_val, Y_val, num_pix=64, shape=(64, 64), value_manipulation=pm_uniform_withCP(5)):
86+
def manipulate_val_data(X_val, Y_val, perc_pix=0.198, shape=(64, 64), value_manipulation=pm_uniform_withCP(5)):
7587
dims = len(shape)
7688
if dims == 2:
77-
box_size = np.round(np.sqrt(shape[0] * shape[1] / num_pix)).astype(np.int)
89+
box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
7890
get_stratified_coords = dw.__get_stratified_coords2D__
7991
rand_float = dw.__rand_float_coords2D__(box_size)
8092
elif dims == 3:
81-
box_size = np.round(np.power(shape[0] * shape[1] * shape[2] / num_pix, 1 / 3.0)).astype(np.int)
93+
box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
8294
get_stratified_coords = dw.__get_stratified_coords3D__
8395
rand_float = dw.__rand_float_coords3D__(box_size)
8496

8597
n_chan = X_val.shape[-1]
8698

99+
Y_val *= 0
87100
for j in tqdm(range(X_val.shape[0]), desc='Preparing validation data: '):
88101
coords = get_stratified_coords(rand_float, box_size=box_size,
89102
shape=np.array(X_val.shape)[1:-1])
90-
y_val = []
91-
x_val = []
92-
for k in range(len(coords)):
93-
y_val.append(np.copy(Y_val[(j, *coords[k], ...)]))
94-
x_val.append(value_manipulation(X_val[j, ...], coords[k], dims))
95-
96-
Y_val[j] *= 0
97-
98-
for k in range(len(coords)):
99-
for c in range(n_chan):
100-
Y_val[(j, *coords[k], c)] = y_val[k][c]
101-
Y_val[(j, *coords[k], n_chan+c)] = 1
102-
X_val[(j, *coords[k], c)] = x_val[k][c]
103+
for c in range(n_chan):
104+
indexing = (j,) + coords + (c,)
105+
indexing_mask = (j,) + coords + (c + n_chan,)
106+
y_val = X_val[indexing]
107+
x_val = value_manipulation(X_val[j, ..., c], coords, dims)
108+
109+
Y_val[indexing] = y_val
110+
Y_val[indexing_mask] = 1
111+
X_val[indexing] = x_val

0 commit comments

Comments
 (0)