11# Based on
22# https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
33import os
4+ import shutil
5+ import warnings
46
57import numpy as np
6- from scipy import interpolate
78import scipy .misc
89import scipy .ndimage as ndi
910from skimage .color import rgb2gray , gray2rgb
1011from skimage import img_as_float
1112
1213
13- def optical_flow ( seq , rows_idx , cols_idx , chan_idx , return_rgb = False ):
14- '''Optical flow
14+ def farn_optical_flow ( dataset ):
15+ '''Farneback optical flow
1516
1617 Takes a 4D array of sequences and returns a 4D array with
1718 an RGB optical flow image for each frame in the input'''
1819 import cv2
19- if seq .ndim != 4 :
20+ warnings .warn ('Farneback optical flow not stored on disk. It will now be '
21+ 'computed on the whole dataset and stored on disk.'
22+ 'Time to sit back and get a coffee!' )
23+
24+ # Create a copy of the dataset to iterate on
25+ dataset = dataset .__class__ (batch_size = 1 ,
26+ return_01c = True ,
27+ return_0_255 = True ,
28+ shuffle_at_each_epoch = False ,
29+ infinite_iterator = False )
30+
31+ ret = dataset .next ()
32+ frame0 = ret ['data' ]
33+ prefix0 = ret ['subset' ][0 ]
34+ if frame0 .ndim != 4 :
2035 raise RuntimeError ('Optical flow expected 4 dimensions, got %d' %
21- seq .ndim )
22- seq = seq .copy ()
23- seq = (seq * 255 ).astype ('uint8' )
24- # Reshape to channel last: (b*seq, 0, 1, ch) if seq
25- pattern = [el for el in range (seq .ndim )
26- if el not in (rows_idx , cols_idx , chan_idx )]
27- pattern += [rows_idx , cols_idx , chan_idx ]
28- inv_pattern = [pattern .index (el ) for el in range (seq .ndim )]
29- seq = seq .transpose (pattern )
30- if seq .shape [0 ] == 1 :
31- raise RuntimeError ('Optical flow needs a sequence longer than 1 '
32- 'to work' )
33- seq = seq [..., ::- 1 ] # Go BGR for OpenCV
34-
35- frame1 = seq [0 ]
36- if return_rgb :
37- flow_seq = np .zeros_like (seq )
38- hsv = np .zeros_like (frame1 )
39- else :
40- sh = list (seq .shape )
41- sh [- 1 ] = 2
42- flow_seq = np .zeros (sh )
43-
44- frame1 = cv2 .cvtColor (frame1 , cv2 .COLOR_BGR2GRAY ) # Go to gray
36+ frame0 .ndim )
37+ frame0 = frame0 [0 , ..., ::- 1 ] # go BGR for OpenCV + remove batch dim
38+ frame0 = cv2 .cvtColor (frame0 , cv2 .COLOR_BGR2GRAY ) # Go gray
4539
4640 flow = None
47- for i , frame2 in enumerate (seq [1 :]):
48- frame2 = cv2 .cvtColor (frame2 , cv2 .COLOR_BGR2GRAY ) # Go to gray
49- flow = cv2 .calcOpticalFlowFarneback (prev = frame1 ,
50- next = frame2 ,
41+ of_path = os .path .join (dataset .path , 'OF' , 'Farn' )
42+ of_shared_path = os .path .join (dataset .shared_path , 'OF' , 'Farn' )
43+
44+ for ret in dataset :
45+ frame1 = ret ['data' ]
46+ filename1 = ret ['filenames' ][0 , 0 ]
47+ # Strip extension, if any
48+ filename1 = filename1 [:- 4 ] + '.' .join (filename1 [- 4 :].split ('.' )[:- 1 ])
49+ prefix1 = ret ['subset' ][0 ]
50+
51+ if frame1 .ndim != 4 :
52+ raise RuntimeError ('Optical flow expected 4 dimensions, got %d' %
53+ frame1 .ndim )
54+
55+ frame1 = frame1 [0 , ..., ::- 1 ] # go BGR for OpenCV + remove batch dim
56+ frame1 = cv2 .cvtColor (frame1 , cv2 .COLOR_BGR2GRAY ) # Go gray
57+
58+ if prefix1 != prefix0 :
59+ # First frame of a new subset
60+ frame0 = frame1
61+ prefix0 = prefix1
62+ continue
63+
64+ # Compute displacement
65+ flow = cv2 .calcOpticalFlowFarneback (prev = frame0 ,
66+ next = frame1 ,
5167 pyr_scale = 0.5 ,
5268 levels = 3 ,
5369 winsize = 10 ,
@@ -56,24 +72,22 @@ def optical_flow(seq, rows_idx, cols_idx, chan_idx, return_rgb=False):
5672 poly_sigma = 1.1 ,
5773 flags = 0 ,
5874 flow = flow )
59- mag , ang = cv2 .cartToPolar (flow [..., 0 ], flow [..., 1 ],
60- angleInDegrees = True )
61- # normalize between 0 and 255
62- ang = ang / 360 * 255
63- if return_rgb :
64- hsv [..., 0 ] = ang
65- hsv [..., 1 ] = 255
66- hsv [..., 2 ] = cv2 .normalize (mag , None , 0 , 255 , cv2 .NORM_MINMAX )
67- rgb = cv2 .cvtColor (hsv , cv2 .COLOR_HSV2RGB )
68- flow_seq [i + 1 ] = rgb
69- # Image.fromarray(rgb).show()
70- # cv2.imwrite('opticalfb.png', frame2)
71- # cv2.imwrite('opticalhsv.png', bgr)
72- else :
73- flow_seq [i + 1 ] = np .stack ((ang , mag ), 2 )
74- frame1 = frame2
75- flow_seq = flow_seq .transpose (inv_pattern )
76- return flow_seq / 255. # return in [0, 1]
75+
76+ # Save in the local path
77+ if not os .path .exists (os .path .join (of_path , prefix1 )):
78+ os .makedirs (os .path .join (of_path , prefix1 ))
79+ # Save the flow as dy, dx
80+ np .save (os .path .join (of_path , prefix1 , filename1 ), flow [..., ::- 1 ])
81+ # cv2.imwrite(os.path.join(of_path, prefix1, filename1 + '.png'), flow)
82+ frame0 = frame1
83+ prefix0 = prefix1
84+
85+ # Store a copy in shared_path
86+ # TODO there might be a race condition when multiple experiments are
87+ # run and one checks for the existence of the shared path OF dir
88+ # while this copy is happening.
89+ if of_path != of_shared_path :
90+ shutil .copytree (of_path , of_shared_path )
7791
7892
7993def my_label2rgb (labels , cmap , bglabel = None , bg_color = (0. , 0. , 0. )):
@@ -348,10 +362,11 @@ def random_transform(dataset,
348362 warp_sigma = 0.1 ,
349363 warp_grid_size = 3 ,
350364 crop_size = None ,
351- return_optical_flow = False ,
352365 nclasses = None ,
353366 gamma = 0. ,
354367 gain = 1. ,
368+ return_optical_flow = False ,
369+ optical_flow_type = 'Farn' ,
355370 chan_idx = 3 , # No batch yet: (s, 0, 1, c)
356371 rows_idx = 1 , # No batch yet: (s, 0, 1, c)
357372 cols_idx = 2 , # No batch yet: (s, 0, 1, c)
@@ -416,17 +431,24 @@ def random_transform(dataset,
416431 crop_size: tuple
417432 The size of crop to be applied to images and masks (after any
418433 other transformation).
419- return_optical_flow: bool
420- If not False a dense optical flow will be concatenated to the
421- end of the channel axis of the image. If True, angle and
422- magnitude will be returned, if set to 'rbg' an RGB representation
423- will be returned instead. Default: False.
424434 nclasses: int
425435 The number of classes of the dataset.
426436 gamma: float
427437 Controls gamma in Gamma correction.
428438 gain: float
429439 Controls gain in Gamma correction.
440+ return_optical_flow: string
441+ Either 'displacement' or 'rbg'.
442+ If set, a dense optical flow will be retrieved from disk (or
443+ computed when missing) and returned as a 'flow' key.
444+ If 'displacement', the optical flow will be returned as a
445+ two-dimensional array of (dx, dy) displacement. If 'rgb', a
446+ three dimensional RGB array with values in [0, 255] will be
447+ returned. Default: None.
448+ optical_flow_type: string
449+ Indicates the method used to generate the optical flow. The
450+ optical flow is loaded from a specific directory based on this
451+ type.
430452 chan_idx: int
431453 The index of the channel axis.
432454 rows_idx: int
@@ -575,6 +597,72 @@ def random_transform(dataset,
575597 fill_mode = fill_mode , fill_constant = cvalMask ,
576598 rows_idx = rows_idx , cols_idx = cols_idx ))
577599
600+ # Optical flow
601+ if return_optical_flow :
602+ return_optical_flow = return_optical_flow .lower ()
603+ if return_optical_flow not in ['rgb' , 'displacement' ]:
604+ raise RuntimeError ('Unknown return_optical_flow value: %s' %
605+ return_optical_flow )
606+ if optical_flow_type not in ['Brox' , 'Farn' , 'LK' , 'TVL1' ]:
607+ raise RuntimeError ('Unknown optical flow type: %s' %
608+ optical_flow_type )
609+ if prefix_and_fnames is None :
610+ raise RuntimeError ('You should specify a list of prefixes '
611+ 'and filenames' )
612+ # Find the filename of the first frame of this prefix
613+ first_frame_of_prefix = sorted (dataset .get_names ()[seq ['subset' ]])[0 ]
614+
615+ of_base_path = os .path .join (dataset .path , 'OF' , optical_flow_type )
616+ if not os .path .isdir (of_base_path ):
617+ # The OF is not on disk: compute it and store it
618+ if optical_flow_type != 'Farn' :
619+ raise RuntimeError ('For optical_flow_type other than Farn '
620+ 'please run your own implementation '
621+ 'manually and save it in %s' % of_base_path )
622+ farn_optical_flow (dataset ) # Compute and store on disk
623+
624+ # Load the OF from disk
625+ import skimage
626+ flow = []
627+ for frame in prefix_and_fnames :
628+ if frame [1 ] == first_frame_of_prefix :
629+ # It's the first frame of the prefix, there is no
630+ # previous frame to compute the OF with, return a blank one
631+ of = np .zeros (sh [1 :], seq ['data' ].dtype )
632+ flow .append (of )
633+ continue
634+
635+ # Read from disk
636+ of_path = os .path .join (of_base_path , frame [0 ],
637+ frame [1 ].rstrip ('.' ) + '.npy' )
638+ if os .path .exists (of_path ):
639+ of = np .load (of_path )
640+ else :
641+ raise RuntimeError ('Optical flow not found for this '
642+ 'file: %s' % of_path )
643+
644+ if return_optical_flow == 'rgb' :
645+ def cart2pol (x , y ):
646+ mag = np .sqrt (x ** 2 + y ** 2 )
647+ ang = np .arctan2 (y , x ) # note, in [-pi, pi]
648+ return mag , ang
649+ mag , ang = cart2pol (of [..., 0 ], of [..., 1 ])
650+
651+ # Normalize to [0, 1]
652+ sh = of .shape [:2 ]
653+ diag_len = np .sqrt (sh [0 ]** 2 + sh [1 ]** 2 , dtype = 'float32' )
654+ ang = ((ang / np .pi ) + 1 ) / 2
655+ mag = mag / diag_len
656+
657+ # Convert to RGB
658+ hsv = np .ones ((sh [0 ], sh [1 ], 3 )) * 255
659+ hsv [..., 0 ] = ang
660+ hsv [..., 2 ] = mag
661+ of = skimage .color .hsv2rgb (hsv ) # HSV --> RGB
662+
663+ flow .append (np .array (of ))
664+ flow = np .array (flow )
665+
578666 # Crop
579667 # Expects axes with shape (..., 0, 1)
580668 # TODO: Add center crop
@@ -611,6 +699,9 @@ def random_transform(dataset,
611699 seq ['labels' ] = seq ['labels' ].transpose (pattern )
612700 seq ['labels' ] = seq ['labels' ][..., top :top + crop [0 ],
613701 left :left + crop [1 ]]
702+ if return_optical_flow :
703+ flow = flow .transpose (pattern )
704+ flow = flow [..., top :top + crop [0 ], left :left + crop [1 ]]
614705 # Padding
615706 if pad != [0 , 0 ]:
616707 pad_pattern = ((0 , 0 ),) * (seq ['data' ].ndim - 2 ) + (
@@ -619,16 +710,15 @@ def random_transform(dataset,
619710 seq ['data' ] = np .pad (seq ['data' ], pad_pattern , 'constant' )
620711 seq ['labels' ] = np .pad (seq ['labels' ], pad_pattern , 'constant' ,
621712 constant_values = void_label )
713+ if return_optical_flow :
714+ flow = np .pad (flow , pad_pattern , 'constant' ) # pad with zeros
622715
623716 # Reshape to original shape
624717 seq ['data' ] = seq ['data' ].transpose (inv_pattern )
625718 if seq ['labels' ] is not None and len (seq ['labels' ]) > 0 :
626719 seq ['labels' ] = seq ['labels' ].transpose (inv_pattern )
627-
628- if return_optical_flow :
629- flow = optical_flow (seq ['data' ], rows_idx , cols_idx , chan_idx ,
630- return_rgb = return_optical_flow == 'rgb' )
631- seq ['data' ] = np .concatenate ((seq ['data' ], flow ), axis = chan_idx )
720+ if return_optical_flow :
721+ flow = flow .transpose (inv_pattern )
632722
633723 # Save augmented images
634724 if save_to_dir :
@@ -643,3 +733,5 @@ def random_transform(dataset,
643733 if seq ['labels' ] is not None and len (seq ['labels' ]) > 0 :
644734 seq ['labels' ] = seq ['labels' ][..., 0 ]
645735
736+ if return_optical_flow :
737+ seq ['flow' ] = np .array (flow )
0 commit comments