Skip to content
This repository was archived by the owner on May 7, 2020. It is now read-only.

Commit ed8f880

Browse files
committed
Improve optical flow to load/store from disk
* Allow to load OF from disk from .npy files * Compute the OF at run time if missing (only Farneback available ATM) * Add parameter to select the OF type * Add parameter to select whether to return OF as RGB or displacement
1 parent 7a591e9 commit ed8f880

File tree

3 files changed

+170
-61
lines changed

3 files changed

+170
-61
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ If you use this code, please cite:
6060
the *optical flow* data augmentations respectively.
6161
</br>
6262

63+
### Optical flow
64+
The dataset loaders can optionally load from disk, or in some cases compute,
65+
the optical flow associated to the video sequences. To do so it looks for a
66+
file in `<dataset_path>/OF/<OF_type>/prefix/filename.npy>` where prefix is the
67+
name of the subset (or video) as returned by get_names(). If the file is
68+
missing it will try to compute the optical flow for the entire dataset once and
69+
store it on disk.
70+
71+
At the moment the only optical flow algorithm supported to this end is the
72+
Farneback (requires openCV installed, choose 'Farn' as type), but you can
73+
easily pre-compute the optical flow with your preferred algorithm and then load
74+
it via the dataset loaders. An example code for a few algorithms is provided
75+
[here](https://gist.github.com/marcociccone/593638e932a48df7cfd0afe71052ef1d).
76+
NO SUPPORT WILL BE PROVIDED FOR THIS CODE OR ANY OTHER OPTICAL FLOW CODE NOT
77+
DIRECTLY INTEGRATED IN THIS FRAMEWORK.
78+
6379
### Notes
6480
* **The code is provided as is, please expect minimal-to-none support on it.**
6581
* This framework is provided for research purposes only. Although we tried our

dataset_loaders/data_augmentation.py

Lines changed: 153 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,69 @@
11
# Based on
22
# https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
33
import os
4+
import shutil
5+
import warnings
46

57
import numpy as np
6-
from scipy import interpolate
78
import scipy.misc
89
import scipy.ndimage as ndi
910
from skimage.color import rgb2gray, gray2rgb
1011
from skimage import img_as_float
1112

1213

13-
def optical_flow(seq, rows_idx, cols_idx, chan_idx, return_rgb=False):
14-
'''Optical flow
14+
def farn_optical_flow(dataset):
15+
'''Farneback optical flow
1516
1617
Takes a 4D array of sequences and returns a 4D array with
1718
an RGB optical flow image for each frame in the input'''
1819
import cv2
19-
if seq.ndim != 4:
20+
warnings.warn('Farneback optical flow not stored on disk. It will now be '
21+
'computed on the whole dataset and stored on disk.'
22+
'Time to sit back and get a coffee!')
23+
24+
# Create a copy of the dataset to iterate on
25+
dataset = dataset.__class__(batch_size=1,
26+
return_01c=True,
27+
return_0_255=True,
28+
shuffle_at_each_epoch=False,
29+
infinite_iterator=False)
30+
31+
ret = dataset.next()
32+
frame0 = ret['data']
33+
prefix0 = ret['subset'][0]
34+
if frame0.ndim != 4:
2035
raise RuntimeError('Optical flow expected 4 dimensions, got %d' %
21-
seq.ndim)
22-
seq = seq.copy()
23-
seq = (seq * 255).astype('uint8')
24-
# Reshape to channel last: (b*seq, 0, 1, ch) if seq
25-
pattern = [el for el in range(seq.ndim)
26-
if el not in (rows_idx, cols_idx, chan_idx)]
27-
pattern += [rows_idx, cols_idx, chan_idx]
28-
inv_pattern = [pattern.index(el) for el in range(seq.ndim)]
29-
seq = seq.transpose(pattern)
30-
if seq.shape[0] == 1:
31-
raise RuntimeError('Optical flow needs a sequence longer than 1 '
32-
'to work')
33-
seq = seq[..., ::-1] # Go BGR for OpenCV
34-
35-
frame1 = seq[0]
36-
if return_rgb:
37-
flow_seq = np.zeros_like(seq)
38-
hsv = np.zeros_like(frame1)
39-
else:
40-
sh = list(seq.shape)
41-
sh[-1] = 2
42-
flow_seq = np.zeros(sh)
43-
44-
frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) # Go to gray
36+
frame0.ndim)
37+
frame0 = frame0[0, ..., ::-1] # go BGR for OpenCV + remove batch dim
38+
frame0 = cv2.cvtColor(frame0, cv2.COLOR_BGR2GRAY) # Go gray
4539

4640
flow = None
47-
for i, frame2 in enumerate(seq[1:]):
48-
frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) # Go to gray
49-
flow = cv2.calcOpticalFlowFarneback(prev=frame1,
50-
next=frame2,
41+
of_path = os.path.join(dataset.path, 'OF', 'Farn')
42+
of_shared_path = os.path.join(dataset.shared_path, 'OF', 'Farn')
43+
44+
for ret in dataset:
45+
frame1 = ret['data']
46+
filename1 = ret['filenames'][0, 0]
47+
# Strip extension, if any
48+
filename1 = filename1[:-4] + '.'.join(filename1[-4:].split('.')[:-1])
49+
prefix1 = ret['subset'][0]
50+
51+
if frame1.ndim != 4:
52+
raise RuntimeError('Optical flow expected 4 dimensions, got %d' %
53+
frame1.ndim)
54+
55+
frame1 = frame1[0, ..., ::-1] # go BGR for OpenCV + remove batch dim
56+
frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) # Go gray
57+
58+
if prefix1 != prefix0:
59+
# First frame of a new subset
60+
frame0 = frame1
61+
prefix0 = prefix1
62+
continue
63+
64+
# Compute displacement
65+
flow = cv2.calcOpticalFlowFarneback(prev=frame0,
66+
next=frame1,
5167
pyr_scale=0.5,
5268
levels=3,
5369
winsize=10,
@@ -56,24 +72,22 @@ def optical_flow(seq, rows_idx, cols_idx, chan_idx, return_rgb=False):
5672
poly_sigma=1.1,
5773
flags=0,
5874
flow=flow)
59-
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1],
60-
angleInDegrees=True)
61-
# normalize between 0 and 255
62-
ang = ang / 360 * 255
63-
if return_rgb:
64-
hsv[..., 0] = ang
65-
hsv[..., 1] = 255
66-
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
67-
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
68-
flow_seq[i+1] = rgb
69-
# Image.fromarray(rgb).show()
70-
# cv2.imwrite('opticalfb.png', frame2)
71-
# cv2.imwrite('opticalhsv.png', bgr)
72-
else:
73-
flow_seq[i+1] = np.stack((ang, mag), 2)
74-
frame1 = frame2
75-
flow_seq = flow_seq.transpose(inv_pattern)
76-
return flow_seq / 255. # return in [0, 1]
75+
76+
# Save in the local path
77+
if not os.path.exists(os.path.join(of_path, prefix1)):
78+
os.makedirs(os.path.join(of_path, prefix1))
79+
# Save the flow as dy, dx
80+
np.save(os.path.join(of_path, prefix1, filename1), flow[..., ::-1])
81+
# cv2.imwrite(os.path.join(of_path, prefix1, filename1 + '.png'), flow)
82+
frame0 = frame1
83+
prefix0 = prefix1
84+
85+
# Store a copy in shared_path
86+
# TODO there might be a race condition when multiple experiments are
87+
# run and one checks for the existence of the shared path OF dir
88+
# while this copy is happening.
89+
if of_path != of_shared_path:
90+
shutil.copytree(of_path, of_shared_path)
7791

7892

7993
def my_label2rgb(labels, cmap, bglabel=None, bg_color=(0., 0., 0.)):
@@ -348,10 +362,11 @@ def random_transform(dataset,
348362
warp_sigma=0.1,
349363
warp_grid_size=3,
350364
crop_size=None,
351-
return_optical_flow=False,
352365
nclasses=None,
353366
gamma=0.,
354367
gain=1.,
368+
return_optical_flow=False,
369+
optical_flow_type='Farn',
355370
chan_idx=3, # No batch yet: (s, 0, 1, c)
356371
rows_idx=1, # No batch yet: (s, 0, 1, c)
357372
cols_idx=2, # No batch yet: (s, 0, 1, c)
@@ -416,17 +431,24 @@ def random_transform(dataset,
416431
crop_size: tuple
417432
The size of crop to be applied to images and masks (after any
418433
other transformation).
419-
return_optical_flow: bool
420-
If not False a dense optical flow will be concatenated to the
421-
end of the channel axis of the image. If True, angle and
422-
magnitude will be returned, if set to 'rbg' an RGB representation
423-
will be returned instead. Default: False.
424434
nclasses: int
425435
The number of classes of the dataset.
426436
gamma: float
427437
Controls gamma in Gamma correction.
428438
gain: float
429439
Controls gain in Gamma correction.
440+
return_optical_flow: string
441+
Either 'displacement' or 'rbg'.
442+
If set, a dense optical flow will be retrieved from disk (or
443+
computed when missing) and returned as a 'flow' key.
444+
If 'displacement', the optical flow will be returned as a
445+
two-dimensional array of (dx, dy) displacement. If 'rgb', a
446+
three dimensional RGB array with values in [0, 255] will be
447+
returned. Default: None.
448+
optical_flow_type: string
449+
Indicates the method used to generate the optical flow. The
450+
optical flow is loaded from a specific directory based on this
451+
type.
430452
chan_idx: int
431453
The index of the channel axis.
432454
rows_idx: int
@@ -575,6 +597,72 @@ def random_transform(dataset,
575597
fill_mode=fill_mode, fill_constant=cvalMask,
576598
rows_idx=rows_idx, cols_idx=cols_idx))
577599

600+
# Optical flow
601+
if return_optical_flow:
602+
return_optical_flow = return_optical_flow.lower()
603+
if return_optical_flow not in ['rgb', 'displacement']:
604+
raise RuntimeError('Unknown return_optical_flow value: %s' %
605+
return_optical_flow)
606+
if optical_flow_type not in ['Brox', 'Farn', 'LK', 'TVL1']:
607+
raise RuntimeError('Unknown optical flow type: %s' %
608+
optical_flow_type)
609+
if prefix_and_fnames is None:
610+
raise RuntimeError('You should specify a list of prefixes '
611+
'and filenames')
612+
# Find the filename of the first frame of this prefix
613+
first_frame_of_prefix = sorted(dataset.get_names()[seq['subset']])[0]
614+
615+
of_base_path = os.path.join(dataset.path, 'OF', optical_flow_type)
616+
if not os.path.isdir(of_base_path):
617+
# The OF is not on disk: compute it and store it
618+
if optical_flow_type != 'Farn':
619+
raise RuntimeError('For optical_flow_type other than Farn '
620+
'please run your own implementation '
621+
'manually and save it in %s' % of_base_path)
622+
farn_optical_flow(dataset) # Compute and store on disk
623+
624+
# Load the OF from disk
625+
import skimage
626+
flow = []
627+
for frame in prefix_and_fnames:
628+
if frame[1] == first_frame_of_prefix:
629+
# It's the first frame of the prefix, there is no
630+
# previous frame to compute the OF with, return a blank one
631+
of = np.zeros(sh[1:], seq['data'].dtype)
632+
flow.append(of)
633+
continue
634+
635+
# Read from disk
636+
of_path = os.path.join(of_base_path, frame[0],
637+
frame[1].rstrip('.') + '.npy')
638+
if os.path.exists(of_path):
639+
of = np.load(of_path)
640+
else:
641+
raise RuntimeError('Optical flow not found for this '
642+
'file: %s' % of_path)
643+
644+
if return_optical_flow == 'rgb':
645+
def cart2pol(x, y):
646+
mag = np.sqrt(x**2 + y**2)
647+
ang = np.arctan2(y, x) # note, in [-pi, pi]
648+
return mag, ang
649+
mag, ang = cart2pol(of[..., 0], of[..., 1])
650+
651+
# Normalize to [0, 1]
652+
sh = of.shape[:2]
653+
diag_len = np.sqrt(sh[0]**2 + sh[1]**2, dtype='float32')
654+
ang = ((ang / np.pi) + 1) / 2
655+
mag = mag / diag_len
656+
657+
# Convert to RGB
658+
hsv = np.ones((sh[0], sh[1], 3)) * 255
659+
hsv[..., 0] = ang
660+
hsv[..., 2] = mag
661+
of = skimage.color.hsv2rgb(hsv) # HSV --> RGB
662+
663+
flow.append(np.array(of))
664+
flow = np.array(flow)
665+
578666
# Crop
579667
# Expects axes with shape (..., 0, 1)
580668
# TODO: Add center crop
@@ -611,6 +699,9 @@ def random_transform(dataset,
611699
seq['labels'] = seq['labels'].transpose(pattern)
612700
seq['labels'] = seq['labels'][..., top:top+crop[0],
613701
left:left+crop[1]]
702+
if return_optical_flow:
703+
flow = flow.transpose(pattern)
704+
flow = flow[..., top:top+crop[0], left:left+crop[1]]
614705
# Padding
615706
if pad != [0, 0]:
616707
pad_pattern = ((0, 0),) * (seq['data'].ndim - 2) + (
@@ -619,16 +710,15 @@ def random_transform(dataset,
619710
seq['data'] = np.pad(seq['data'], pad_pattern, 'constant')
620711
seq['labels'] = np.pad(seq['labels'], pad_pattern, 'constant',
621712
constant_values=void_label)
713+
if return_optical_flow:
714+
flow = np.pad(flow, pad_pattern, 'constant') # pad with zeros
622715

623716
# Reshape to original shape
624717
seq['data'] = seq['data'].transpose(inv_pattern)
625718
if seq['labels'] is not None and len(seq['labels']) > 0:
626719
seq['labels'] = seq['labels'].transpose(inv_pattern)
627-
628-
if return_optical_flow:
629-
flow = optical_flow(seq['data'], rows_idx, cols_idx, chan_idx,
630-
return_rgb=return_optical_flow=='rgb')
631-
seq['data'] = np.concatenate((seq['data'], flow), axis=chan_idx)
720+
if return_optical_flow:
721+
flow = flow.transpose(inv_pattern)
632722

633723
# Save augmented images
634724
if save_to_dir:
@@ -643,3 +733,5 @@ def random_transform(dataset,
643733
if seq['labels'] is not None and len(seq['labels']) > 0:
644734
seq['labels'] = seq['labels'][..., 0]
645735

736+
if return_optical_flow:
737+
seq['flow'] = np.array(flow)

dataset_loaders/parallel_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __init__(self,
200200
# Set default values for the data augmentation params if not specified
201201
default_data_augm_kwargs = {
202202
'crop_size': None,
203+
'return_optical_flow': None,
203204
'rotation_range': 0,
204205
'width_shift_range': 0,
205206
'height_shift_range': 0,

0 commit comments

Comments
 (0)