|
1 | 1 | import tensorflow as tf |
2 | 2 |
|
3 | | -from niftynet.application.base_application import BaseApplication |
4 | | -from niftynet.engine.application_factory import \ |
5 | | - ApplicationNetFactory, InitializerFactory, OptimiserFactory |
6 | | -from niftynet.engine.application_variables import \ |
7 | | - CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES |
8 | | -from niftynet.engine.sampler_grid import GridSampler |
9 | | -from niftynet.engine.sampler_resize import ResizeSampler |
10 | | -from niftynet.engine.sampler_uniform import UniformSampler |
11 | | -from niftynet.engine.sampler_weighted import WeightedSampler |
12 | | -from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator |
13 | | -from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator |
| 3 | +from niftynet.application.segmentation_application import \ |
| 4 | + SegmentationApplication, SUPPORTED_INPUT |
14 | 5 | from niftynet.io.image_reader import ImageReader |
15 | 6 | from niftynet.layer.binary_masking import BinaryMaskingLayer |
16 | 7 | from niftynet.layer.discrete_label_normalisation import \ |
17 | 8 | DiscreteLabelNormalisationLayer |
18 | 9 | from niftynet.layer.histogram_normalisation import \ |
19 | 10 | HistogramNormalisationLayer |
20 | | -from niftynet.layer.loss_segmentation import LossFunction |
21 | 11 | from niftynet.layer.mean_variance_normalisation import \ |
22 | 12 | MeanVarNormalisationLayer |
23 | 13 | from niftynet.layer.pad import PadLayer |
24 | | -from niftynet.layer.post_processing import PostProcessingLayer |
| 14 | +from niftynet.layer.rand_bias_field import RandomBiasFieldLayer |
25 | 15 | from niftynet.layer.rand_flip import RandomFlipLayer |
26 | 16 | from niftynet.layer.rand_rotation import RandomRotationLayer |
27 | 17 | from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer |
28 | | -from niftynet.layer.rand_bias_field import RandomBiasFieldLayer |
29 | | - |
30 | | -SUPPORTED_INPUT = set(['image', 'label', 'weight', 'sampler']) |
31 | 18 |
|
32 | 19 |
|
33 | | -class SegmentationApplicationBFAug(BaseApplication): |
| 20 | +class SegmentationApplicationBFAug(SegmentationApplication): |
34 | 21 | REQUIRED_CONFIG_SECTION = "SEGMENTATION" |
35 | 22 |
|
36 | 23 | def __init__(self, net_param, action_param, is_training): |
37 | | - super(SegmentationApplicationBFAug, self).__init__() |
| 24 | + SegmentationApplication.__init__( |
| 25 | + self, net_param, action_param, is_training) |
38 | 26 | tf.logging.info('starting segmentation application') |
39 | | - self.is_training = is_training |
40 | | - |
41 | | - self.net_param = net_param |
42 | | - self.action_param = action_param |
43 | | - |
44 | | - self.data_param = None |
45 | | - self.segmentation_param = None |
46 | | - self.SUPPORTED_SAMPLING = { |
47 | | - 'uniform': (self.initialise_uniform_sampler, |
48 | | - self.initialise_grid_sampler, |
49 | | - self.initialise_grid_aggregator), |
50 | | - 'weighted': (self.initialise_weighted_sampler, |
51 | | - self.initialise_grid_sampler, |
52 | | - self.initialise_grid_aggregator), |
53 | | - 'resize': (self.initialise_resize_sampler, |
54 | | - self.initialise_resize_sampler, |
55 | | - self.initialise_resize_aggregator), |
56 | | - } |
57 | 27 |
|
58 | 28 | def initialise_dataset_loader( |
59 | 29 | self, data_param=None, task_param=None, data_partitioner=None): |
60 | 30 |
|
61 | 31 | self.data_param = data_param |
62 | 32 | self.segmentation_param = task_param |
63 | 33 |
|
64 | | - # read each line of csv files into an instance of Subject |
| 34 | + # initialise input image readers |
65 | 35 | if self.is_training: |
66 | | - file_lists = [] |
67 | | - if self.action_param.validation_every_n > 0: |
68 | | - file_lists.append(data_partitioner.train_files) |
69 | | - file_lists.append(data_partitioner.validation_files) |
70 | | - else: |
71 | | - file_lists.append(data_partitioner.train_files) |
72 | | - |
73 | | - self.readers = [] |
74 | | - for file_list in file_lists: |
75 | | - reader = ImageReader(SUPPORTED_INPUT) |
76 | | - reader.initialise(data_param, task_param, file_list) |
77 | | - self.readers.append(reader) |
78 | | - |
79 | | - else: # in the inference process use image input only |
80 | | - inference_reader = ImageReader(['image']) |
81 | | - file_list = data_partitioner.inference_files |
82 | | - inference_reader.initialise(data_param, task_param, file_list) |
83 | | - self.readers = [inference_reader] |
| 36 | + reader_names = ('image', 'label', 'weight', 'sampler') |
| 37 | + elif self.is_inference: |
| 38 | + # in the inference process use `image` input only |
| 39 | + reader_names = ('image',) |
| 40 | + elif self.is_evaluation: |
| 41 | + reader_names = ('image', 'label', 'inferred') |
| 42 | + else: |
| 43 | + tf.logging.fatal( |
| 44 | + 'Action `%s` not supported. Expected one of %s', |
| 45 | + self.action, self.SUPPORTED_PHASES) |
| 46 | + raise ValueError |
| 47 | + try: |
| 48 | + reader_phase = self.action_param.dataset_to_infer |
| 49 | + except AttributeError: |
| 50 | + reader_phase = None |
| 51 | + file_lists = data_partitioner.get_file_lists_by( |
| 52 | + phase=reader_phase, action=self.action) |
| 53 | + self.readers = [ |
| 54 | + ImageReader(reader_names).initialise( |
| 55 | + data_param, task_param, file_list) for file_list in file_lists] |
84 | 56 |
|
85 | 57 | foreground_masking_layer = None |
86 | 58 | if self.net_param.normalise_foreground_only: |
@@ -148,209 +120,15 @@ def initialise_dataset_loader( |
148 | 120 | self.action_param.bias_field_range) |
149 | 121 | augmentation_layers.append(bias_field_layer) |
150 | 122 |
|
151 | | - |
152 | 123 | volume_padding_layer = [] |
153 | 124 | if self.net_param.volume_padding_size: |
154 | 125 | volume_padding_layer.append(PadLayer( |
155 | 126 | image_name=SUPPORTED_INPUT, |
156 | 127 | border=self.net_param.volume_padding_size)) |
157 | 128 |
|
158 | | - for reader in self.readers: |
159 | | - reader.add_preprocessing_layers( |
160 | | - volume_padding_layer + |
161 | | - normalisation_layers + |
162 | | - augmentation_layers) |
163 | | - |
164 | | - def initialise_uniform_sampler(self): |
165 | | - self.sampler = [[UniformSampler( |
166 | | - reader=reader, |
167 | | - data_param=self.data_param, |
168 | | - batch_size=self.net_param.batch_size, |
169 | | - windows_per_image=self.action_param.sample_per_volume, |
170 | | - queue_length=self.net_param.queue_length) for reader in |
171 | | - self.readers]] |
172 | | - |
173 | | - def initialise_weighted_sampler(self): |
174 | | - self.sampler = [[WeightedSampler( |
175 | | - reader=reader, |
176 | | - data_param=self.data_param, |
177 | | - batch_size=self.net_param.batch_size, |
178 | | - windows_per_image=self.action_param.sample_per_volume, |
179 | | - queue_length=self.net_param.queue_length) for reader in |
180 | | - self.readers]] |
181 | | - |
182 | | - def initialise_resize_sampler(self): |
183 | | - self.sampler = [[ResizeSampler( |
184 | | - reader=reader, |
185 | | - data_param=self.data_param, |
186 | | - batch_size=self.net_param.batch_size, |
187 | | - shuffle_buffer=self.is_training, |
188 | | - queue_length=self.net_param.queue_length) for reader in |
189 | | - self.readers]] |
| 129 | + self.readers[0].add_preprocessing_layers( |
| 130 | + volume_padding_layer + normalisation_layers + augmentation_layers) |
190 | 131 |
|
191 | | - def initialise_grid_sampler(self): |
192 | | - self.sampler = [[GridSampler( |
193 | | - reader=reader, |
194 | | - data_param=self.data_param, |
195 | | - batch_size=self.net_param.batch_size, |
196 | | - spatial_window_size=self.action_param.spatial_window_size, |
197 | | - window_border=self.action_param.border, |
198 | | - queue_length=self.net_param.queue_length) for reader in |
199 | | - self.readers]] |
200 | | - |
201 | | - def initialise_grid_aggregator(self): |
202 | | - self.output_decoder = GridSamplesAggregator( |
203 | | - image_reader=self.readers[0], |
204 | | - output_path=self.action_param.save_seg_dir, |
205 | | - window_border=self.action_param.border, |
206 | | - interp_order=self.action_param.output_interp_order) |
207 | | - |
208 | | - def initialise_resize_aggregator(self): |
209 | | - self.output_decoder = ResizeSamplesAggregator( |
210 | | - image_reader=self.readers[0], |
211 | | - output_path=self.action_param.save_seg_dir, |
212 | | - window_border=self.action_param.border, |
213 | | - interp_order=self.action_param.output_interp_order) |
214 | | - |
215 | | - def initialise_sampler(self): |
216 | | - if self.is_training: |
217 | | - self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() |
218 | | - else: |
219 | | - self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() |
220 | | - |
221 | | - def initialise_network(self): |
222 | | - w_regularizer = None |
223 | | - b_regularizer = None |
224 | | - reg_type = self.net_param.reg_type.lower() |
225 | | - decay = self.net_param.decay |
226 | | - if reg_type == 'l2' and decay > 0: |
227 | | - from tensorflow.contrib.layers.python.layers import regularizers |
228 | | - w_regularizer = regularizers.l2_regularizer(decay) |
229 | | - b_regularizer = regularizers.l2_regularizer(decay) |
230 | | - elif reg_type == 'l1' and decay > 0: |
231 | | - from tensorflow.contrib.layers.python.layers import regularizers |
232 | | - w_regularizer = regularizers.l1_regularizer(decay) |
233 | | - b_regularizer = regularizers.l1_regularizer(decay) |
234 | | - |
235 | | - self.net = ApplicationNetFactory.create(self.net_param.name)( |
236 | | - num_classes=self.segmentation_param.num_classes, |
237 | | - w_initializer=InitializerFactory.get_initializer( |
238 | | - name=self.net_param.weight_initializer), |
239 | | - b_initializer=InitializerFactory.get_initializer( |
240 | | - name=self.net_param.bias_initializer), |
241 | | - w_regularizer=w_regularizer, |
242 | | - b_regularizer=b_regularizer, |
243 | | - acti_func=self.net_param.activation_function) |
244 | | - |
245 | | - def connect_data_and_network(self, |
246 | | - outputs_collector=None, |
247 | | - gradients_collector=None): |
248 | | - # def data_net(for_training): |
249 | | - # with tf.name_scope('train' if for_training else 'validation'): |
250 | | - # sampler = self.get_sampler()[0][0 if for_training else -1] |
251 | | - # data_dict = sampler.pop_batch_op() |
252 | | - # image = tf.cast(data_dict['image'], tf.float32) |
253 | | - # return data_dict, self.net(image, is_training=for_training) |
254 | | - |
255 | | - def switch_sampler(for_training): |
256 | | - with tf.name_scope('train' if for_training else 'validation'): |
257 | | - sampler = self.get_sampler()[0][0 if for_training else -1] |
258 | | - return sampler.pop_batch_op() |
259 | | - |
260 | | - if self.is_training: |
261 | | - # if self.action_param.validation_every_n > 0: |
262 | | - # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), |
263 | | - # lambda: data_net(True), |
264 | | - # lambda: data_net(False)) |
265 | | - # else: |
266 | | - # data_dict, net_out = data_net(True) |
267 | | - if self.action_param.validation_every_n > 0: |
268 | | - data_dict = tf.cond(tf.logical_not(self.is_validation), |
269 | | - lambda: switch_sampler(for_training=True), |
270 | | - lambda: switch_sampler(for_training=False)) |
271 | | - else: |
272 | | - data_dict = switch_sampler(for_training=True) |
273 | | - image = tf.cast(data_dict['image'], tf.float32) |
274 | | - net_out = self.net(image, is_training=self.is_training) |
275 | | - |
276 | | - with tf.name_scope('Optimiser'): |
277 | | - optimiser_class = OptimiserFactory.create( |
278 | | - name=self.action_param.optimiser) |
279 | | - self.optimiser = optimiser_class.get_instance( |
280 | | - learning_rate=self.action_param.lr) |
281 | | - loss_func = LossFunction( |
282 | | - n_class=self.segmentation_param.num_classes, |
283 | | - loss_type=self.action_param.loss_type) |
284 | | - data_loss = loss_func( |
285 | | - prediction=net_out, |
286 | | - ground_truth=data_dict.get('label', None), |
287 | | - weight_map=data_dict.get('weight', None)) |
288 | | - reg_losses = tf.get_collection( |
289 | | - tf.GraphKeys.REGULARIZATION_LOSSES) |
290 | | - if self.net_param.decay > 0.0 and reg_losses: |
291 | | - reg_loss = tf.reduce_mean( |
292 | | - [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) |
293 | | - loss = data_loss + reg_loss |
294 | | - else: |
295 | | - loss = data_loss |
296 | | - grads = self.optimiser.compute_gradients(loss) |
297 | | - # collecting gradients variables |
298 | | - gradients_collector.add_to_collection([grads]) |
299 | | - # collecting output variables |
300 | | - outputs_collector.add_to_collection( |
301 | | - var=data_loss, name='loss', |
302 | | - average_over_devices=False, collection=CONSOLE) |
303 | | - outputs_collector.add_to_collection( |
304 | | - var=data_loss, name='loss', |
305 | | - average_over_devices=True, summary_type='scalar', |
306 | | - collection=TF_SUMMARIES) |
307 | | - |
308 | | - # outputs_collector.add_to_collection( |
309 | | - # var=image*180.0, name='image', |
310 | | - # average_over_devices=False, summary_type='image3_sagittal', |
311 | | - # collection=TF_SUMMARIES) |
312 | | - |
313 | | - # outputs_collector.add_to_collection( |
314 | | - # var=image, name='image', |
315 | | - # average_over_devices=False, |
316 | | - # collection=NETWORK_OUTPUT) |
317 | | - |
318 | | - # outputs_collector.add_to_collection( |
319 | | - # var=tf.reduce_mean(image), name='mean_image', |
320 | | - # average_over_devices=False, summary_type='scalar', |
321 | | - # collection=CONSOLE) |
322 | | - else: |
323 | | - # converting logits into final output for |
324 | | - # classification probabilities or argmax classification labels |
325 | | - data_dict = switch_sampler(for_training=False) |
326 | | - image = tf.cast(data_dict['image'], tf.float32) |
327 | | - net_out = self.net(image, is_training=self.is_training) |
328 | | - |
329 | | - output_prob = self.segmentation_param.output_prob |
330 | | - num_classes = self.segmentation_param.num_classes |
331 | | - if output_prob and num_classes > 1: |
332 | | - post_process_layer = PostProcessingLayer( |
333 | | - 'SOFTMAX', num_classes=num_classes) |
334 | | - elif not output_prob and num_classes > 1: |
335 | | - post_process_layer = PostProcessingLayer( |
336 | | - 'ARGMAX', num_classes=num_classes) |
337 | | - else: |
338 | | - post_process_layer = PostProcessingLayer( |
339 | | - 'IDENTITY', num_classes=num_classes) |
340 | | - net_out = post_process_layer(net_out) |
341 | | - |
342 | | - outputs_collector.add_to_collection( |
343 | | - var=net_out, name='window', |
344 | | - average_over_devices=False, collection=NETWORK_OUTPUT) |
345 | | - outputs_collector.add_to_collection( |
346 | | - var=data_dict['image_location'], name='location', |
347 | | - average_over_devices=False, collection=NETWORK_OUTPUT) |
348 | | - init_aggregator = \ |
349 | | - self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] |
350 | | - init_aggregator() |
351 | | - |
352 | | - def interpret_output(self, batch_output): |
353 | | - if not self.is_training: |
354 | | - return self.output_decoder.decode_batch( |
355 | | - batch_output['window'], batch_output['location']) |
356 | | - return True |
| 132 | + for reader in self.readers[1:]: |
| 133 | + reader.add_preprocessing_layers( |
| 134 | + volume_padding_layer + normalisation_layers) |
0 commit comments