Skip to content
This repository was archived by the owner on Mar 17, 2021. It is now read-only.

Commit 72e0907

Browse files
committed
upgrade contrib/segmentation_bf_aug
1 parent 1f2ffa2 commit 72e0907

File tree

1 file changed

+32
-254
lines changed

1 file changed

+32
-254
lines changed
Lines changed: 32 additions & 254 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,58 @@
11
import tensorflow as tf
22

3-
from niftynet.application.base_application import BaseApplication
4-
from niftynet.engine.application_factory import \
5-
ApplicationNetFactory, InitializerFactory, OptimiserFactory
6-
from niftynet.engine.application_variables import \
7-
CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES
8-
from niftynet.engine.sampler_grid import GridSampler
9-
from niftynet.engine.sampler_resize import ResizeSampler
10-
from niftynet.engine.sampler_uniform import UniformSampler
11-
from niftynet.engine.sampler_weighted import WeightedSampler
12-
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
13-
from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator
3+
from niftynet.application.segmentation_application import \
4+
SegmentationApplication, SUPPORTED_INPUT
145
from niftynet.io.image_reader import ImageReader
156
from niftynet.layer.binary_masking import BinaryMaskingLayer
167
from niftynet.layer.discrete_label_normalisation import \
178
DiscreteLabelNormalisationLayer
189
from niftynet.layer.histogram_normalisation import \
1910
HistogramNormalisationLayer
20-
from niftynet.layer.loss_segmentation import LossFunction
2111
from niftynet.layer.mean_variance_normalisation import \
2212
MeanVarNormalisationLayer
2313
from niftynet.layer.pad import PadLayer
24-
from niftynet.layer.post_processing import PostProcessingLayer
14+
from niftynet.layer.rand_bias_field import RandomBiasFieldLayer
2515
from niftynet.layer.rand_flip import RandomFlipLayer
2616
from niftynet.layer.rand_rotation import RandomRotationLayer
2717
from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer
28-
from niftynet.layer.rand_bias_field import RandomBiasFieldLayer
29-
30-
SUPPORTED_INPUT = set(['image', 'label', 'weight', 'sampler'])
3118

3219

33-
class SegmentationApplicationBFAug(BaseApplication):
20+
class SegmentationApplicationBFAug(SegmentationApplication):
3421
REQUIRED_CONFIG_SECTION = "SEGMENTATION"
3522

3623
def __init__(self, net_param, action_param, is_training):
37-
super(SegmentationApplicationBFAug, self).__init__()
24+
SegmentationApplication.__init__(
25+
self, net_param, action_param, is_training)
3826
tf.logging.info('starting segmentation application')
39-
self.is_training = is_training
40-
41-
self.net_param = net_param
42-
self.action_param = action_param
43-
44-
self.data_param = None
45-
self.segmentation_param = None
46-
self.SUPPORTED_SAMPLING = {
47-
'uniform': (self.initialise_uniform_sampler,
48-
self.initialise_grid_sampler,
49-
self.initialise_grid_aggregator),
50-
'weighted': (self.initialise_weighted_sampler,
51-
self.initialise_grid_sampler,
52-
self.initialise_grid_aggregator),
53-
'resize': (self.initialise_resize_sampler,
54-
self.initialise_resize_sampler,
55-
self.initialise_resize_aggregator),
56-
}
5727

5828
def initialise_dataset_loader(
5929
self, data_param=None, task_param=None, data_partitioner=None):
6030

6131
self.data_param = data_param
6232
self.segmentation_param = task_param
6333

64-
# read each line of csv files into an instance of Subject
34+
# initialise input image readers
6535
if self.is_training:
66-
file_lists = []
67-
if self.action_param.validation_every_n > 0:
68-
file_lists.append(data_partitioner.train_files)
69-
file_lists.append(data_partitioner.validation_files)
70-
else:
71-
file_lists.append(data_partitioner.train_files)
72-
73-
self.readers = []
74-
for file_list in file_lists:
75-
reader = ImageReader(SUPPORTED_INPUT)
76-
reader.initialise(data_param, task_param, file_list)
77-
self.readers.append(reader)
78-
79-
else: # in the inference process use image input only
80-
inference_reader = ImageReader(['image'])
81-
file_list = data_partitioner.inference_files
82-
inference_reader.initialise(data_param, task_param, file_list)
83-
self.readers = [inference_reader]
36+
reader_names = ('image', 'label', 'weight', 'sampler')
37+
elif self.is_inference:
38+
# in the inference process use `image` input only
39+
reader_names = ('image',)
40+
elif self.is_evaluation:
41+
reader_names = ('image', 'label', 'inferred')
42+
else:
43+
tf.logging.fatal(
44+
'Action `%s` not supported. Expected one of %s',
45+
self.action, self.SUPPORTED_PHASES)
46+
raise ValueError
47+
try:
48+
reader_phase = self.action_param.dataset_to_infer
49+
except AttributeError:
50+
reader_phase = None
51+
file_lists = data_partitioner.get_file_lists_by(
52+
phase=reader_phase, action=self.action)
53+
self.readers = [
54+
ImageReader(reader_names).initialise(
55+
data_param, task_param, file_list) for file_list in file_lists]
8456

8557
foreground_masking_layer = None
8658
if self.net_param.normalise_foreground_only:
@@ -148,209 +120,15 @@ def initialise_dataset_loader(
148120
self.action_param.bias_field_range)
149121
augmentation_layers.append(bias_field_layer)
150122

151-
152123
volume_padding_layer = []
153124
if self.net_param.volume_padding_size:
154125
volume_padding_layer.append(PadLayer(
155126
image_name=SUPPORTED_INPUT,
156127
border=self.net_param.volume_padding_size))
157128

158-
for reader in self.readers:
159-
reader.add_preprocessing_layers(
160-
volume_padding_layer +
161-
normalisation_layers +
162-
augmentation_layers)
163-
164-
def initialise_uniform_sampler(self):
165-
self.sampler = [[UniformSampler(
166-
reader=reader,
167-
data_param=self.data_param,
168-
batch_size=self.net_param.batch_size,
169-
windows_per_image=self.action_param.sample_per_volume,
170-
queue_length=self.net_param.queue_length) for reader in
171-
self.readers]]
172-
173-
def initialise_weighted_sampler(self):
174-
self.sampler = [[WeightedSampler(
175-
reader=reader,
176-
data_param=self.data_param,
177-
batch_size=self.net_param.batch_size,
178-
windows_per_image=self.action_param.sample_per_volume,
179-
queue_length=self.net_param.queue_length) for reader in
180-
self.readers]]
181-
182-
def initialise_resize_sampler(self):
183-
self.sampler = [[ResizeSampler(
184-
reader=reader,
185-
data_param=self.data_param,
186-
batch_size=self.net_param.batch_size,
187-
shuffle_buffer=self.is_training,
188-
queue_length=self.net_param.queue_length) for reader in
189-
self.readers]]
129+
self.readers[0].add_preprocessing_layers(
130+
volume_padding_layer + normalisation_layers + augmentation_layers)
190131

191-
def initialise_grid_sampler(self):
192-
self.sampler = [[GridSampler(
193-
reader=reader,
194-
data_param=self.data_param,
195-
batch_size=self.net_param.batch_size,
196-
spatial_window_size=self.action_param.spatial_window_size,
197-
window_border=self.action_param.border,
198-
queue_length=self.net_param.queue_length) for reader in
199-
self.readers]]
200-
201-
def initialise_grid_aggregator(self):
202-
self.output_decoder = GridSamplesAggregator(
203-
image_reader=self.readers[0],
204-
output_path=self.action_param.save_seg_dir,
205-
window_border=self.action_param.border,
206-
interp_order=self.action_param.output_interp_order)
207-
208-
def initialise_resize_aggregator(self):
209-
self.output_decoder = ResizeSamplesAggregator(
210-
image_reader=self.readers[0],
211-
output_path=self.action_param.save_seg_dir,
212-
window_border=self.action_param.border,
213-
interp_order=self.action_param.output_interp_order)
214-
215-
def initialise_sampler(self):
216-
if self.is_training:
217-
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
218-
else:
219-
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()
220-
221-
def initialise_network(self):
222-
w_regularizer = None
223-
b_regularizer = None
224-
reg_type = self.net_param.reg_type.lower()
225-
decay = self.net_param.decay
226-
if reg_type == 'l2' and decay > 0:
227-
from tensorflow.contrib.layers.python.layers import regularizers
228-
w_regularizer = regularizers.l2_regularizer(decay)
229-
b_regularizer = regularizers.l2_regularizer(decay)
230-
elif reg_type == 'l1' and decay > 0:
231-
from tensorflow.contrib.layers.python.layers import regularizers
232-
w_regularizer = regularizers.l1_regularizer(decay)
233-
b_regularizer = regularizers.l1_regularizer(decay)
234-
235-
self.net = ApplicationNetFactory.create(self.net_param.name)(
236-
num_classes=self.segmentation_param.num_classes,
237-
w_initializer=InitializerFactory.get_initializer(
238-
name=self.net_param.weight_initializer),
239-
b_initializer=InitializerFactory.get_initializer(
240-
name=self.net_param.bias_initializer),
241-
w_regularizer=w_regularizer,
242-
b_regularizer=b_regularizer,
243-
acti_func=self.net_param.activation_function)
244-
245-
def connect_data_and_network(self,
246-
outputs_collector=None,
247-
gradients_collector=None):
248-
# def data_net(for_training):
249-
# with tf.name_scope('train' if for_training else 'validation'):
250-
# sampler = self.get_sampler()[0][0 if for_training else -1]
251-
# data_dict = sampler.pop_batch_op()
252-
# image = tf.cast(data_dict['image'], tf.float32)
253-
# return data_dict, self.net(image, is_training=for_training)
254-
255-
def switch_sampler(for_training):
256-
with tf.name_scope('train' if for_training else 'validation'):
257-
sampler = self.get_sampler()[0][0 if for_training else -1]
258-
return sampler.pop_batch_op()
259-
260-
if self.is_training:
261-
# if self.action_param.validation_every_n > 0:
262-
# data_dict, net_out = tf.cond(tf.logical_not(self.is_validation),
263-
# lambda: data_net(True),
264-
# lambda: data_net(False))
265-
# else:
266-
# data_dict, net_out = data_net(True)
267-
if self.action_param.validation_every_n > 0:
268-
data_dict = tf.cond(tf.logical_not(self.is_validation),
269-
lambda: switch_sampler(for_training=True),
270-
lambda: switch_sampler(for_training=False))
271-
else:
272-
data_dict = switch_sampler(for_training=True)
273-
image = tf.cast(data_dict['image'], tf.float32)
274-
net_out = self.net(image, is_training=self.is_training)
275-
276-
with tf.name_scope('Optimiser'):
277-
optimiser_class = OptimiserFactory.create(
278-
name=self.action_param.optimiser)
279-
self.optimiser = optimiser_class.get_instance(
280-
learning_rate=self.action_param.lr)
281-
loss_func = LossFunction(
282-
n_class=self.segmentation_param.num_classes,
283-
loss_type=self.action_param.loss_type)
284-
data_loss = loss_func(
285-
prediction=net_out,
286-
ground_truth=data_dict.get('label', None),
287-
weight_map=data_dict.get('weight', None))
288-
reg_losses = tf.get_collection(
289-
tf.GraphKeys.REGULARIZATION_LOSSES)
290-
if self.net_param.decay > 0.0 and reg_losses:
291-
reg_loss = tf.reduce_mean(
292-
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
293-
loss = data_loss + reg_loss
294-
else:
295-
loss = data_loss
296-
grads = self.optimiser.compute_gradients(loss)
297-
# collecting gradients variables
298-
gradients_collector.add_to_collection([grads])
299-
# collecting output variables
300-
outputs_collector.add_to_collection(
301-
var=data_loss, name='loss',
302-
average_over_devices=False, collection=CONSOLE)
303-
outputs_collector.add_to_collection(
304-
var=data_loss, name='loss',
305-
average_over_devices=True, summary_type='scalar',
306-
collection=TF_SUMMARIES)
307-
308-
# outputs_collector.add_to_collection(
309-
# var=image*180.0, name='image',
310-
# average_over_devices=False, summary_type='image3_sagittal',
311-
# collection=TF_SUMMARIES)
312-
313-
# outputs_collector.add_to_collection(
314-
# var=image, name='image',
315-
# average_over_devices=False,
316-
# collection=NETWORK_OUTPUT)
317-
318-
# outputs_collector.add_to_collection(
319-
# var=tf.reduce_mean(image), name='mean_image',
320-
# average_over_devices=False, summary_type='scalar',
321-
# collection=CONSOLE)
322-
else:
323-
# converting logits into final output for
324-
# classification probabilities or argmax classification labels
325-
data_dict = switch_sampler(for_training=False)
326-
image = tf.cast(data_dict['image'], tf.float32)
327-
net_out = self.net(image, is_training=self.is_training)
328-
329-
output_prob = self.segmentation_param.output_prob
330-
num_classes = self.segmentation_param.num_classes
331-
if output_prob and num_classes > 1:
332-
post_process_layer = PostProcessingLayer(
333-
'SOFTMAX', num_classes=num_classes)
334-
elif not output_prob and num_classes > 1:
335-
post_process_layer = PostProcessingLayer(
336-
'ARGMAX', num_classes=num_classes)
337-
else:
338-
post_process_layer = PostProcessingLayer(
339-
'IDENTITY', num_classes=num_classes)
340-
net_out = post_process_layer(net_out)
341-
342-
outputs_collector.add_to_collection(
343-
var=net_out, name='window',
344-
average_over_devices=False, collection=NETWORK_OUTPUT)
345-
outputs_collector.add_to_collection(
346-
var=data_dict['image_location'], name='location',
347-
average_over_devices=False, collection=NETWORK_OUTPUT)
348-
init_aggregator = \
349-
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
350-
init_aggregator()
351-
352-
def interpret_output(self, batch_output):
353-
if not self.is_training:
354-
return self.output_decoder.decode_batch(
355-
batch_output['window'], batch_output['location'])
356-
return True
132+
for reader in self.readers[1:]:
133+
reader.add_preprocessing_layers(
134+
volume_padding_layer + normalisation_layers)

0 commit comments

Comments
 (0)