@@ -329,7 +329,9 @@ def apply_warp(x, warp_field, fill_mode='reflect',
329329 return x
330330
331331
332- def random_transform (x , y = None ,
332+ def random_transform (dataset ,
333+ seq ,
334+ prefix_and_fnames = None ,
333335 rotation_range = 0. ,
334336 width_shift_range = 0. ,
335337 height_shift_range = 0. ,
@@ -361,10 +363,14 @@ def random_transform(x, y=None,
361363
362364 Parameters
363365 ----------
364- x: array of floats
365- An image.
366- y: array of int
367- An array with labels.
366+ dataset: a :class:`Dataset` instance
367+ The instance of the current dataset. First step towards making
368+ this a class method.
369+ seq: a dictionary of numpy array
370+ A dictionary with at least these keys: 'data', 'labels', 'filenames',
371+ 'subset'.
372+ prefix_and_fnames: list
373+ A list of prefix and names for the current sequence
368374 rotation_range: int
369375 Degrees of rotation (0 to 180).
370376 width_shift_range: float
@@ -432,19 +438,23 @@ def random_transform(x, y=None,
432438
433439 Reference
434440 ---------
435- [1] https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
441+ [1] https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py # noqa
436442 '''
437443 # Set this to a dir, if you want to save augmented images samples
438444 save_to_dir = None
445+ nclasses = dataset .nclasses
446+ void_label = dataset .void_labels
439447
440448 if rescale :
441449 raise NotImplementedError ()
442450
443- # Do not modify the original images
444- x = x .copy ()
445- if y is not None and len (y ) > 0 :
446- y = y [..., None ] # Add extra dim to y to simplify computation
447- y = y .copy ()
451+ # Make sure we do not modify the original images
452+ seq ['data' ] = seq ['data' ].copy ()
453+ if seq ['labels' ] is not None and len (seq ['labels' ]) > 0 :
454+ seq ['labels' ] = seq ['labels' ].copy ()
455+ # Add extra dim to y to simplify computation
456+ seq ['labels' ] = seq ['labels' ][..., None ]
457+ sh = seq ['data' ].shape
448458
449459 # listify zoom range
450460 if np .isscalar (zoom_range ):
@@ -464,13 +474,13 @@ def random_transform(x, y=None,
464474
465475 # Channel shift
466476 if channel_shift_range != 0 :
467- x = random_channel_shift (x , channel_shift_range , rows_idx , cols_idx ,
468- chan_idx )
477+ seq [ 'data' ] = random_channel_shift (seq [ 'data' ] , channel_shift_range ,
478+ rows_idx , cols_idx , chan_idx )
469479
470480 # Gamma correction
471481 if gamma > 0 :
472482 scale = float (1 )
473- x = ((x / scale ) ** gamma ) * scale * gain
483+ seq [ 'data' ] = ((seq [ 'data' ] / scale ) ** gamma ) * scale * gain
474484
475485 # Affine transformations (zoom, rotation, shift, ..)
476486 if (rotation_range or height_shift_range or width_shift_range or
@@ -488,12 +498,12 @@ def random_transform(x, y=None,
488498 # --> Shift/Translation
489499 if height_shift_range :
490500 tx = (np .random .uniform (- height_shift_range , height_shift_range ) *
491- x . shape [rows_idx ])
501+ sh [rows_idx ])
492502 else :
493503 tx = 0
494504 if width_shift_range :
495505 ty = (np .random .uniform (- width_shift_range , width_shift_range ) *
496- x . shape [cols_idx ])
506+ sh [cols_idx ])
497507 else :
498508 ty = 0
499509 translation_matrix = np .array ([[1 , 0 , tx ],
@@ -520,62 +530,64 @@ def random_transform(x, y=None,
520530 transform_matrix = np .dot (np .dot (np .dot (rotation_matrix ,
521531 translation_matrix ),
522532 shear_matrix ), zoom_matrix )
523- h , w = x . shape [rows_idx ], x . shape [cols_idx ]
533+ h , w = sh [rows_idx ], sh [cols_idx ]
524534 transform_matrix = transform_matrix_offset_center (transform_matrix ,
525535 h , w )
526536 # Apply all the transformations together
527- x = apply_transform (x , transform_matrix , fill_mode = fill_mode ,
528- cval = cval , order = 1 , rows_idx = rows_idx ,
529- cols_idx = cols_idx )
530- if y is not None and len (y ) > 0 :
531- y = apply_transform (y , transform_matrix , fill_mode = fill_mode ,
532- cval = cvalMask , order = 0 , rows_idx = rows_idx ,
533- cols_idx = cols_idx )
537+ seq ['data' ] = apply_transform (seq ['data' ], transform_matrix ,
538+ fill_mode = fill_mode , cval = cval , order = 1 ,
539+ rows_idx = rows_idx , cols_idx = cols_idx )
540+ if seq ['labels' ] is not None and len (seq ['labels' ]) > 0 :
541+ seq ['labels' ] = apply_transform (seq ['labels' ],
542+ transform_matrix ,
543+ fill_mode = fill_mode , cval = cvalMask ,
544+ order = 0 , rows_idx = rows_idx ,
545+ cols_idx = cols_idx )
534546
535547 # Horizontal flip
536548 if np .random .random () < horizontal_flip : # 0 = disabled
537- x = flip_axis (x , cols_idx )
538- if y is not None and len (y ) > 0 :
539- y = flip_axis (y , cols_idx )
549+ seq [ 'data' ] = flip_axis (seq [ 'data' ] , cols_idx )
550+ if seq [ 'labels' ] is not None and len (seq [ 'labels' ] ) > 0 :
551+ seq [ 'labels' ] = flip_axis (seq [ 'labels' ] , cols_idx )
540552
541553 # Vertical flip
542554 if np .random .random () < vertical_flip : # 0 = disabled
543- x = flip_axis (x , rows_idx )
544- if y is not None and len (y ) > 0 :
545- y = flip_axis (y , rows_idx )
555+ seq [ 'data' ] = flip_axis (seq [ 'data' ] , rows_idx )
556+ if seq [ 'labels' ] is not None and len (seq [ 'labels' ] ) > 0 :
557+ seq [ 'labels' ] = flip_axis (seq [ 'labels' ] , rows_idx )
546558
547559 # Spline warp
548560 if spline_warp :
549561 import SimpleITK as sitk
550- warp_field = gen_warp_field (shape = (x . shape [rows_idx ],
551- x . shape [cols_idx ]),
562+ warp_field = gen_warp_field (shape = (sh [rows_idx ],
563+ sh [cols_idx ]),
552564 sigma = warp_sigma ,
553565 grid_size = warp_grid_size )
554- x = apply_warp (x , warp_field ,
555- interpolator = sitk .sitkLinear ,
556- fill_mode = fill_mode ,
557- fill_constant = cval ,
558- rows_idx = rows_idx , cols_idx = cols_idx )
559- if y is not None and len ( y ) > 0 :
560- y = np .round (apply_warp ( y , warp_field ,
561- interpolator = sitk . sitkNearestNeighbor ,
562- fill_mode = fill_mode ,
563- fill_constant = cvalMask ,
564- rows_idx = rows_idx , cols_idx = cols_idx ))
566+ seq [ 'data' ] = apply_warp (seq [ 'data' ] , warp_field ,
567+ interpolator = sitk .sitkLinear ,
568+ fill_mode = fill_mode , fill_constant = cval ,
569+ rows_idx = rows_idx , cols_idx = cols_idx )
570+ if seq [ 'labels' ] is not None and len ( seq [ 'labels' ]) > 0 :
571+ # TODO is this round right??
572+ seq [ 'labels' ] = np .round (
573+ apply_warp ( seq [ 'labels' ], warp_field ,
574+ interpolator = sitk . sitkNearestNeighbor ,
575+ fill_mode = fill_mode , fill_constant = cvalMask ,
576+ rows_idx = rows_idx , cols_idx = cols_idx ))
565577
566578 # Crop
567579 # Expects axes with shape (..., 0, 1)
568580 # TODO: Add center crop
569581 if crop_size :
570582 # Reshape to (..., 0, 1)
571- pattern = [el for el in range (x .ndim ) if el != rows_idx and
583+ pattern = [el for el in range (seq [ 'data' ] .ndim ) if el != rows_idx and
572584 el != cols_idx ] + [rows_idx , cols_idx ]
573- inv_pattern = [pattern .index (el ) for el in range (x .ndim )]
574- x = x .transpose (pattern )
585+ inv_pattern = [pattern .index (el ) for el in range (seq [ 'data' ] .ndim )]
586+ seq [ 'data' ] = seq [ 'data' ] .transpose (pattern )
575587
576588 crop = list (crop_size )
577589 pad = [0 , 0 ]
578- h , w = x .shape [- 2 :]
590+ h , w = seq [ 'data' ] .shape [- 2 :]
579591
580592 # Compute amounts
581593 if crop [0 ] < h :
@@ -594,38 +606,40 @@ def random_transform(x, y=None,
594606 left , crop [1 ] = 0 , w
595607
596608 # Cropping
597- x = x [..., top :top + crop [0 ], left :left + crop [1 ]]
598- if y is not None and len (y ) > 0 :
599- y = y .transpose (pattern )
600- y = y [..., top :top + crop [0 ], left :left + crop [1 ]]
609+ seq ['data' ] = seq ['data' ][..., top :top + crop [0 ], left :left + crop [1 ]]
610+ if seq ['labels' ] is not None and len (seq ['labels' ]) > 0 :
611+ seq ['labels' ] = seq ['labels' ].transpose (pattern )
612+ seq ['labels' ] = seq ['labels' ][..., top :top + crop [0 ],
613+ left :left + crop [1 ]]
601614 # Padding
602615 if pad != [0 , 0 ]:
603- pad_pattern = ((0 , 0 ),) * (x .ndim - 2 ) + (
616+ pad_pattern = ((0 , 0 ),) * (seq [ 'data' ] .ndim - 2 ) + (
604617 (pad [0 ]// 2 , pad [0 ] - pad [0 ]// 2 ),
605618 (pad [1 ]// 2 , pad [1 ] - pad [1 ]// 2 ))
606- x = np .pad (x , pad_pattern , 'constant' )
607- y = np .pad (y , pad_pattern , 'constant' , constant_values = void_label )
619+ seq ['data' ] = np .pad (seq ['data' ], pad_pattern , 'constant' )
620+ seq ['labels' ] = np .pad (seq ['labels' ], pad_pattern , 'constant' ,
621+ constant_values = void_label )
608622
609- x = x .transpose (inv_pattern )
610- if y is not None and len (y ) > 0 :
611- y = y .transpose (inv_pattern )
623+ # Reshape to original shape
624+ seq ['data' ] = seq ['data' ].transpose (inv_pattern )
625+ if seq ['labels' ] is not None and len (seq ['labels' ]) > 0 :
626+ seq ['labels' ] = seq ['labels' ].transpose (inv_pattern )
612627
613628 if return_optical_flow :
614- flow = optical_flow (x , rows_idx , cols_idx , chan_idx ,
629+ flow = optical_flow (seq [ 'data' ] , rows_idx , cols_idx , chan_idx ,
615630 return_rgb = return_optical_flow == 'rgb' )
616- x = np .concatenate ((x , flow ), axis = chan_idx )
631+ seq [ 'data' ] = np .concatenate ((seq [ 'data' ] , flow ), axis = chan_idx )
617632
618633 # Save augmented images
619634 if save_to_dir :
620635 import seaborn as sns
621636 fname = 'data_augm_{}.png' .format (np .random .randint (1e4 ))
622637 print ('Save to dir' .format (fname ))
623638 cmap = sns .hls_palette (nclasses )
624- save_img2 (x , y , os .path .join (save_to_dir , fname ),
639+ save_img2 (seq [ 'data' ], seq [ 'labels' ] , os .path .join (save_to_dir , fname ),
625640 cmap , void_label , rows_idx , cols_idx , chan_idx )
626641
627642 # Undo extra dim
628- if y is not None and len (y ) > 0 :
629- y = y [..., 0 ]
643+ if seq [ 'labels' ] is not None and len (seq [ 'labels' ] ) > 0 :
644+ seq [ 'labels' ] = seq [ 'labels' ] [..., 0 ]
630645
631- return x , y
0 commit comments