Skip to content

Commit e3e5d82

Browse files
committed
add option to seed numba rng
1 parent 1d94d23 commit e3e5d82

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

ffcv/transforms/cutout.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,26 @@ class Cutout(Operation):
2121
Useful for when a normalization layer follows cutout, in which case
2222
you can set the fill such that the square is zero
2323
post-normalization.
24+
seed : int, optional
25+
Random seed set inside code passed to numba (for reproducibility).
26+
Nonnegative values are reduced modulo 2**32 to be in valid range.
27+
Negative values are not used. Defaults to -1.
2428
"""
25-
def __init__(self, crop_size: int, fill: Tuple[int, int, int] = (0, 0, 0)):
29+
def __init__(self, crop_size: int, fill: Tuple[int, int, int] = (0, 0, 0), seed: int = -1):
2630
super().__init__()
2731
self.crop_size = crop_size
2832
self.fill = np.array(fill)
33+
self.seed = seed % 2 ** 32 if seed >= 0 else seed
2934

3035
def generate_code(self) -> Callable:
3136
my_range = Compiler.get_iterator()
3237
crop_size = self.crop_size
3338
fill = self.fill
39+
seed = self.seed
3440

3541
def cutout_square(images, *_):
42+
if seed >= 0:
43+
np.random.seed(seed)
3644
for i in my_range(images.shape[0]):
3745
# Generate random origin
3846
coord = (

ffcv/transforms/flip.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Random horizontal flip
33
"""
4+
import numpy as np
45
from numpy import dtype
56
from numpy.random import rand
67
from typing import Callable, Optional, Tuple
@@ -18,17 +19,25 @@ class RandomHorizontalFlip(Operation):
1819
flip_prob : float
1920
The probability with which to flip each image in the batch
2021
horizontally.
22+
seed : int, optional
23+
Random seed set inside code passed to numba (for reproducibility).
24+
Nonnegative values are reduced modulo 2**32 to be in valid range.
25+
Negative values are not used. Defaults to -1.
2126
"""
2227

23-
def __init__(self, flip_prob: float = 0.5):
28+
def __init__(self, flip_prob: float = 0.5, seed: int = -1):
2429
super().__init__()
2530
self.flip_prob = flip_prob
31+
self.seed = seed % 2 ** 32 if seed >= 0 else seed
2632

2733
def generate_code(self) -> Callable:
2834
my_range = Compiler.get_iterator()
2935
flip_prob = self.flip_prob
36+
seed = self.seed
3037

3138
def flip(images, dst):
39+
if seed >= 0:
40+
np.random.seed(seed)
3241
should_flip = rand(images.shape[0]) < flip_prob
3342
for i in my_range(images.shape[0]):
3443
if should_flip[i]:

ffcv/transforms/random_resized_crop.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,24 @@ class RandomResizedCrop(Operation):
2020
Lower and upper bounds for random aspect ratio of the crop.
2121
size : int
2222
Side length of the output.
23+
seed : int, optional
24+
Random seed set inside code passed to numba (for reproducibility).
25+
Nonnegative values are reduced modulo 2**32 to be in valid range.
26+
Negative values are not used. Defaults to -1.
2327
"""
24-
def __init__(self, scale: Tuple[float, float], ratio: Tuple[float, float], size: int):
28+
def __init__(self, scale: Tuple[float, float], ratio: Tuple[float, float], size: int, seed: int = -1):
2529
super().__init__()
2630
self.scale = scale
2731
self.ratio = ratio
2832
self.size = size
33+
self.seed = seed % 2 ** 32 if seed >= 0 else seed
2934

3035
def generate_code(self) -> Callable:
3136
scale, ratio = self.scale, self.ratio
37+
seed = self.seed
3238
def random_resized_crop(im, dst):
39+
if seed >= 0:
40+
np.random.seed(seed)
3341
i, j, h, w = fast_crop.get_random_crop(im.shape[0],
3442
im.shape[1],
3543
scale,

ffcv/transforms/translate.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,26 @@ class RandomTranslate(Operation):
2020
Max number of pixels to translate in any direction.
2121
fill : tuple
2222
An RGB color ((0, 0, 0) by default) to fill the area outside the shifted image.
23+
seed : int, optional
24+
Random seed set inside code passed to numba (for reproducibility).
25+
Nonnegative values are reduced modulo 2**32 to be in valid range.
26+
Negative values are not used. Defaults to -1.
2327
"""
2428

25-
def __init__(self, padding: int, fill: Tuple[int, int, int] = (0, 0, 0)):
29+
def __init__(self, padding: int, fill: Tuple[int, int, int] = (0, 0, 0), seed: int = -1):
2630
super().__init__()
2731
self.padding = padding
2832
self.fill = np.array(fill)
33+
self.seed = seed % 2 ** 32 if seed >= 0 else seed
2934

3035
def generate_code(self) -> Callable:
3136
my_range = Compiler.get_iterator()
3237
pad = self.padding
38+
seed = self.seed
3339

3440
def translate(images, dst):
41+
if seed >= 0:
42+
np.random.seed(seed)
3543
n, h, w, _ = images.shape
3644
# y_coords = randint(low=0, high=2 * pad + 1, size=(n,))
3745
# x_coords = randint(low=0, high=2 * pad + 1, size=(n,))

0 commit comments

Comments
 (0)