Skip to content
This repository was archived by the owner on May 7, 2020. It is now read-only.

Commit 7a591e9

Browse files
committed
Refactor data_augm to modify the dict in place
* Do not create pointer seq_x, seq_y. It is easy to introduce bugs when operations on them are not reflected in the original dictioary. * Pass the dataset object rather than all its parameteres. * NOTE: This commit breaks the optical flow. Will be fixed in the next commit.
1 parent 6cb5509 commit 7a591e9

File tree

2 files changed

+126
-110
lines changed

2 files changed

+126
-110
lines changed

dataset_loaders/data_augmentation.py

Lines changed: 77 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,9 @@ def apply_warp(x, warp_field, fill_mode='reflect',
329329
return x
330330

331331

332-
def random_transform(x, y=None,
332+
def random_transform(dataset,
333+
seq,
334+
prefix_and_fnames=None,
333335
rotation_range=0.,
334336
width_shift_range=0.,
335337
height_shift_range=0.,
@@ -361,10 +363,14 @@ def random_transform(x, y=None,
361363
362364
Parameters
363365
----------
364-
x: array of floats
365-
An image.
366-
y: array of int
367-
An array with labels.
366+
dataset: a :class:`Dataset` instance
367+
The instance of the current dataset. First step towards making
368+
this a class method.
369+
seq: a dictionary of numpy array
370+
A dictionary with at least these keys: 'data', 'labels', 'filenames',
371+
'subset'.
372+
prefix_and_fnames: list
373+
A list of prefix and names for the current sequence
368374
rotation_range: int
369375
Degrees of rotation (0 to 180).
370376
width_shift_range: float
@@ -432,19 +438,23 @@ def random_transform(x, y=None,
432438
433439
Reference
434440
---------
435-
[1] https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
441+
[1] https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py # noqa
436442
'''
437443
# Set this to a dir, if you want to save augmented images samples
438444
save_to_dir = None
445+
nclasses = dataset.nclasses
446+
void_label = dataset.void_labels
439447

440448
if rescale:
441449
raise NotImplementedError()
442450

443-
# Do not modify the original images
444-
x = x.copy()
445-
if y is not None and len(y) > 0:
446-
y = y[..., None] # Add extra dim to y to simplify computation
447-
y = y.copy()
451+
# Make sure we do not modify the original images
452+
seq['data'] = seq['data'].copy()
453+
if seq['labels'] is not None and len(seq['labels']) > 0:
454+
seq['labels'] = seq['labels'].copy()
455+
# Add extra dim to y to simplify computation
456+
seq['labels'] = seq['labels'][..., None]
457+
sh = seq['data'].shape
448458

449459
# listify zoom range
450460
if np.isscalar(zoom_range):
@@ -464,13 +474,13 @@ def random_transform(x, y=None,
464474

465475
# Channel shift
466476
if channel_shift_range != 0:
467-
x = random_channel_shift(x, channel_shift_range, rows_idx, cols_idx,
468-
chan_idx)
477+
seq['data'] = random_channel_shift(seq['data'], channel_shift_range,
478+
rows_idx, cols_idx, chan_idx)
469479

470480
# Gamma correction
471481
if gamma > 0:
472482
scale = float(1)
473-
x = ((x / scale) ** gamma) * scale * gain
483+
seq['data'] = ((seq['data'] / scale) ** gamma) * scale * gain
474484

475485
# Affine transformations (zoom, rotation, shift, ..)
476486
if (rotation_range or height_shift_range or width_shift_range or
@@ -488,12 +498,12 @@ def random_transform(x, y=None,
488498
# --> Shift/Translation
489499
if height_shift_range:
490500
tx = (np.random.uniform(-height_shift_range, height_shift_range) *
491-
x.shape[rows_idx])
501+
sh[rows_idx])
492502
else:
493503
tx = 0
494504
if width_shift_range:
495505
ty = (np.random.uniform(-width_shift_range, width_shift_range) *
496-
x.shape[cols_idx])
506+
sh[cols_idx])
497507
else:
498508
ty = 0
499509
translation_matrix = np.array([[1, 0, tx],
@@ -520,62 +530,64 @@ def random_transform(x, y=None,
520530
transform_matrix = np.dot(np.dot(np.dot(rotation_matrix,
521531
translation_matrix),
522532
shear_matrix), zoom_matrix)
523-
h, w = x.shape[rows_idx], x.shape[cols_idx]
533+
h, w = sh[rows_idx], sh[cols_idx]
524534
transform_matrix = transform_matrix_offset_center(transform_matrix,
525535
h, w)
526536
# Apply all the transformations together
527-
x = apply_transform(x, transform_matrix, fill_mode=fill_mode,
528-
cval=cval, order=1, rows_idx=rows_idx,
529-
cols_idx=cols_idx)
530-
if y is not None and len(y) > 0:
531-
y = apply_transform(y, transform_matrix, fill_mode=fill_mode,
532-
cval=cvalMask, order=0, rows_idx=rows_idx,
533-
cols_idx=cols_idx)
537+
seq['data'] = apply_transform(seq['data'], transform_matrix,
538+
fill_mode=fill_mode, cval=cval, order=1,
539+
rows_idx=rows_idx, cols_idx=cols_idx)
540+
if seq['labels'] is not None and len(seq['labels']) > 0:
541+
seq['labels'] = apply_transform(seq['labels'],
542+
transform_matrix,
543+
fill_mode=fill_mode, cval=cvalMask,
544+
order=0, rows_idx=rows_idx,
545+
cols_idx=cols_idx)
534546

535547
# Horizontal flip
536548
if np.random.random() < horizontal_flip: # 0 = disabled
537-
x = flip_axis(x, cols_idx)
538-
if y is not None and len(y) > 0:
539-
y = flip_axis(y, cols_idx)
549+
seq['data'] = flip_axis(seq['data'], cols_idx)
550+
if seq['labels'] is not None and len(seq['labels']) > 0:
551+
seq['labels'] = flip_axis(seq['labels'], cols_idx)
540552

541553
# Vertical flip
542554
if np.random.random() < vertical_flip: # 0 = disabled
543-
x = flip_axis(x, rows_idx)
544-
if y is not None and len(y) > 0:
545-
y = flip_axis(y, rows_idx)
555+
seq['data'] = flip_axis(seq['data'], rows_idx)
556+
if seq['labels'] is not None and len(seq['labels']) > 0:
557+
seq['labels'] = flip_axis(seq['labels'], rows_idx)
546558

547559
# Spline warp
548560
if spline_warp:
549561
import SimpleITK as sitk
550-
warp_field = gen_warp_field(shape=(x.shape[rows_idx],
551-
x.shape[cols_idx]),
562+
warp_field = gen_warp_field(shape=(sh[rows_idx],
563+
sh[cols_idx]),
552564
sigma=warp_sigma,
553565
grid_size=warp_grid_size)
554-
x = apply_warp(x, warp_field,
555-
interpolator=sitk.sitkLinear,
556-
fill_mode=fill_mode,
557-
fill_constant=cval,
558-
rows_idx=rows_idx, cols_idx=cols_idx)
559-
if y is not None and len(y) > 0:
560-
y = np.round(apply_warp(y, warp_field,
561-
interpolator=sitk.sitkNearestNeighbor,
562-
fill_mode=fill_mode,
563-
fill_constant=cvalMask,
564-
rows_idx=rows_idx, cols_idx=cols_idx))
566+
seq['data'] = apply_warp(seq['data'], warp_field,
567+
interpolator=sitk.sitkLinear,
568+
fill_mode=fill_mode, fill_constant=cval,
569+
rows_idx=rows_idx, cols_idx=cols_idx)
570+
if seq['labels'] is not None and len(seq['labels']) > 0:
571+
# TODO is this round right??
572+
seq['labels'] = np.round(
573+
apply_warp(seq['labels'], warp_field,
574+
interpolator=sitk.sitkNearestNeighbor,
575+
fill_mode=fill_mode, fill_constant=cvalMask,
576+
rows_idx=rows_idx, cols_idx=cols_idx))
565577

566578
# Crop
567579
# Expects axes with shape (..., 0, 1)
568580
# TODO: Add center crop
569581
if crop_size:
570582
# Reshape to (..., 0, 1)
571-
pattern = [el for el in range(x.ndim) if el != rows_idx and
583+
pattern = [el for el in range(seq['data'].ndim) if el != rows_idx and
572584
el != cols_idx] + [rows_idx, cols_idx]
573-
inv_pattern = [pattern.index(el) for el in range(x.ndim)]
574-
x = x.transpose(pattern)
585+
inv_pattern = [pattern.index(el) for el in range(seq['data'].ndim)]
586+
seq['data'] = seq['data'].transpose(pattern)
575587

576588
crop = list(crop_size)
577589
pad = [0, 0]
578-
h, w = x.shape[-2:]
590+
h, w = seq['data'].shape[-2:]
579591

580592
# Compute amounts
581593
if crop[0] < h:
@@ -594,38 +606,40 @@ def random_transform(x, y=None,
594606
left, crop[1] = 0, w
595607

596608
# Cropping
597-
x = x[..., top:top+crop[0], left:left+crop[1]]
598-
if y is not None and len(y) > 0:
599-
y = y.transpose(pattern)
600-
y = y[..., top:top+crop[0], left:left+crop[1]]
609+
seq['data'] = seq['data'][..., top:top+crop[0], left:left+crop[1]]
610+
if seq['labels'] is not None and len(seq['labels']) > 0:
611+
seq['labels'] = seq['labels'].transpose(pattern)
612+
seq['labels'] = seq['labels'][..., top:top+crop[0],
613+
left:left+crop[1]]
601614
# Padding
602615
if pad != [0, 0]:
603-
pad_pattern = ((0, 0),) * (x.ndim - 2) + (
616+
pad_pattern = ((0, 0),) * (seq['data'].ndim - 2) + (
604617
(pad[0]//2, pad[0] - pad[0]//2),
605618
(pad[1]//2, pad[1] - pad[1]//2))
606-
x = np.pad(x, pad_pattern, 'constant')
607-
y = np.pad(y, pad_pattern, 'constant', constant_values=void_label)
619+
seq['data'] = np.pad(seq['data'], pad_pattern, 'constant')
620+
seq['labels'] = np.pad(seq['labels'], pad_pattern, 'constant',
621+
constant_values=void_label)
608622

609-
x = x.transpose(inv_pattern)
610-
if y is not None and len(y) > 0:
611-
y = y.transpose(inv_pattern)
623+
# Reshape to original shape
624+
seq['data'] = seq['data'].transpose(inv_pattern)
625+
if seq['labels'] is not None and len(seq['labels']) > 0:
626+
seq['labels'] = seq['labels'].transpose(inv_pattern)
612627

613628
if return_optical_flow:
614-
flow = optical_flow(x, rows_idx, cols_idx, chan_idx,
629+
flow = optical_flow(seq['data'], rows_idx, cols_idx, chan_idx,
615630
return_rgb=return_optical_flow=='rgb')
616-
x = np.concatenate((x, flow), axis=chan_idx)
631+
seq['data'] = np.concatenate((seq['data'], flow), axis=chan_idx)
617632

618633
# Save augmented images
619634
if save_to_dir:
620635
import seaborn as sns
621636
fname = 'data_augm_{}.png'.format(np.random.randint(1e4))
622637
print ('Save to dir'.format(fname))
623638
cmap = sns.hls_palette(nclasses)
624-
save_img2(x, y, os.path.join(save_to_dir, fname),
639+
save_img2(seq['data'], seq['labels'], os.path.join(save_to_dir, fname),
625640
cmap, void_label, rows_idx, cols_idx, chan_idx)
626641

627642
# Undo extra dim
628-
if y is not None and len(y) > 0:
629-
y = y[..., 0]
643+
if seq['labels'] is not None and len(seq['labels']) > 0:
644+
seq['labels'] = seq['labels'][..., 0]
630645

631-
return x, y

0 commit comments

Comments
 (0)