Skip to content

Commit 131d562

Browse files
authored
Merge branch 'no_jit_assert' into v0.0.3
2 parents 405c096 + d817cc2 commit 131d562

File tree

8 files changed

+174
-50
lines changed

8 files changed

+174
-50
lines changed

ffcv/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .replace_label import ReplaceLabel
88
from .normalize import NormalizeImage
99
from .translate import RandomTranslate
10+
from .mixup import ImageMixup, LabelMixup, MixupToOneHot
11+
from .module import ModuleWrapper
1012

1113
__all__ = ['ToTensor', 'ToDevice',
1214
'ToTorchImage', 'NormalizeImage',

ffcv/transforms/cutout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import numpy as np
55
from typing import Callable, Optional, Tuple
6+
from dataclasses import replace
67

78
from ffcv.pipeline.compiler import Compiler
89
from ..pipeline.allocation_query import AllocationQuery
@@ -48,5 +49,4 @@ def cutout_square(images, *_):
4849
return cutout_square
4950

5051
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
51-
assert previous_state.jit_mode
52-
return previous_state, None
52+
return replace(previous_state, jit_mode=True), None

ffcv/transforms/flip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Random horizontal flip
33
"""
4-
from numpy import dtype
4+
from dataclasses import replace
55
from numpy.random import rand
66
from typing import Callable, Optional, Tuple
77
from ..pipeline.allocation_query import AllocationQuery
@@ -42,5 +42,5 @@ def flip(images, dst):
4242
return flip
4343

4444
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
45-
assert previous_state.jit_mode
46-
return (previous_state, AllocationQuery(previous_state.shape, previous_state.dtype))
45+
return (replace(previous_state, jit_mode=True),
46+
AllocationQuery(previous_state.shape, previous_state.dtype))

ffcv/transforms/mixup.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ def mixer(images, dst, indices):
5353
return mixer
5454

5555
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
56-
# assert previous_state.jit_mode
57-
# We do everything in place
5856
return (previous_state, AllocationQuery(shape=previous_state.shape,
5957
dtype=previous_state.dtype))
6058

@@ -92,8 +90,6 @@ def mixer(labels, temp_array, indices):
9290
return mixer
9391

9492
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
95-
# assert previous_state.jit_mode
96-
# We do everything in place
9793
return (replace(previous_state, shape=(3,), dtype=np.float32),
9894
AllocationQuery((3,), dtype=np.float32))
9995

@@ -115,6 +111,7 @@ def one_hotter(mixedup_labels, dst):
115111
return one_hotter
116112

117113
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
114+
# Should already be converted to tensor
118115
assert not previous_state.jit_mode
119116
return (replace(previous_state, shape=(self.num_classes,)), \
120117
AllocationQuery((self.num_classes,), dtype=previous_state.dtype, device=previous_state.device))

ffcv/transforms/poisoning.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
"""
22
Poison images by adding a mask
33
"""
4-
from collections.abc import Sequence
54
from typing import Tuple
5+
from dataclasses import replace
66

77
import numpy as np
8-
from numpy import dtype
9-
from numpy.core.numeric import indices
10-
from numpy.random import rand
118
from typing import Callable, Optional, Tuple
129
from ..pipeline.allocation_query import AllocationQuery
1310
from ..pipeline.operation import Operation
@@ -67,6 +64,6 @@ def poison(images, temp_array, indices):
6764
return poison
6865

6966
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
70-
assert previous_state.jit_mode
7167
# We do everything in place
72-
return (previous_state, AllocationQuery(shape=previous_state.shape, dtype=np.float32))
68+
return (replace(previous_state, jit_mode=True), \
69+
AllocationQuery(shape=previous_state.shape, dtype=np.dtype('float32')))

ffcv/transforms/replace_label.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
"""
22
Replace label
33
"""
4-
from collections.abc import Sequence
54
from typing import Tuple
65

76
import numpy as np
8-
from numpy import dtype
9-
from numpy.core.numeric import indices
10-
from numpy.random import rand
7+
from dataclasses import replace
118
from typing import Callable, Optional, Tuple
129
from ..pipeline.allocation_query import AllocationQuery
1310
from ..pipeline.operation import Operation
@@ -50,6 +47,4 @@ def replace_label(labels, temp_array, indices):
5047
return replace_label
5148

5249
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
53-
assert previous_state.jit_mode
54-
# We do everything in place
55-
return (previous_state, None)
50+
return (replace(previous_state, jit_mode=True), None)

ffcv/transforms/translate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
Random translate
33
"""
44
import numpy as np
5-
from numpy import dtype
65
from numpy.random import randint
7-
from typing import Any, Callable, Optional, Tuple, Union
6+
from typing import Callable, Optional, Tuple
7+
from dataclasses import replace
88
from ..pipeline.allocation_query import AllocationQuery
99
from ..pipeline.operation import Operation
1010
from ..pipeline.state import State
@@ -51,5 +51,6 @@ def translate(images, dst):
5151

5252
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
5353
h, w, c = previous_state.shape
54-
assert previous_state.jit_mode
55-
return (previous_state, AllocationQuery((h + 2 * self.padding, w + 2 * self.padding, c), previous_state.dtype))
54+
return (replace(previous_state, jit_mode=True), \
55+
AllocationQuery((h + 2 * self.padding, w + 2 * self.padding, c), previous_state.dtype))
56+

tests/test_augmentations.py

Lines changed: 156 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
1+
import os
2+
import uuid
13
import numpy as np
24
import torch as ch
35
from torch.utils.data import Dataset
6+
from torchvision import transforms as tvt
47
from assertpy import assert_that
58
from tempfile import NamedTemporaryFile
69
from torchvision.datasets import CIFAR10
10+
from torchvision.utils import save_image, make_grid
711
from torch.utils.data import Subset
812
from ffcv.fields.basics import IntDecoder
913
from ffcv.fields.rgb_image import SimpleRGBImageDecoder
10-
from ffcv.transforms.cutout import Cutout
1114

1215
from ffcv.writer import DatasetWriter
1316
from ffcv.fields import IntField, RGBImageField
1417
from ffcv.loader import Loader
1518
from ffcv.pipeline.compiler import Compiler
16-
from ffcv.transforms import Squeeze, Cutout, ToTensor, Poison, RandomHorizontalFlip
19+
from ffcv.transforms import *
1720

18-
def run_test(length, pipeline, compile):
21+
SAVE_IMAGES = True
22+
IMAGES_TMP_PATH = '/tmp/ffcv_augtest_output'
23+
if SAVE_IMAGES:
24+
os.makedirs(IMAGES_TMP_PATH, exist_ok=True)
25+
26+
UNAUGMENTED_PIPELINE=[
27+
SimpleRGBImageDecoder(),
28+
ToTensor(),
29+
ToTorchImage()
30+
]
31+
32+
def run_test(length, pipeline, compile=False):
1933
my_dataset = Subset(CIFAR10(root='/tmp', train=True, download=True), range(length))
2034

2135
with NamedTemporaryFile() as handle:
@@ -28,52 +42,170 @@ def run_test(length, pipeline, compile):
2842

2943
writer.from_indexed_dataset(my_dataset, chunksize=10)
3044

31-
Compiler.set_enabled(True)
45+
Compiler.set_enabled(compile)
3246

3347
loader = Loader(name, batch_size=7, num_workers=2, pipelines={
3448
'image': pipeline,
3549
'label': [IntDecoder(), ToTensor(), Squeeze()]
3650
},
3751
drop_last=False)
52+
53+
unaugmented_loader = Loader(name, batch_size=7, num_workers=2, pipelines={
54+
'image': UNAUGMENTED_PIPELINE,
55+
'label': [IntDecoder(), ToTensor(), Squeeze()]
56+
}, drop_last=False)
57+
3858
tot_indices = 0
3959
tot_images = 0
40-
for images, label in loader:
41-
tot_indices += label.shape[0]
60+
for (images, labels), (original_images, original_labels) in zip(loader, unaugmented_loader):
61+
print(images.shape, original_images.shape)
62+
tot_indices += labels.shape[0]
4263
tot_images += images.shape[0]
64+
65+
for label, original_label in zip(labels, original_labels):
66+
assert_that(label).is_equal_to(original_label)
67+
68+
if SAVE_IMAGES:
69+
save_image(make_grid(ch.concat([images, original_images])/255., images.shape[0]),
70+
os.path.join(IMAGES_TMP_PATH, str(uuid.uuid4()) + '.jpeg')
71+
)
72+
4373
assert_that(tot_indices).is_equal_to(len(my_dataset))
4474
assert_that(tot_images).is_equal_to(len(my_dataset))
4575

76+
def test_cutout():
77+
for comp in [True, False]:
78+
run_test(100, [
79+
SimpleRGBImageDecoder(),
80+
Cutout(8),
81+
ToTensor(),
82+
ToTorchImage()
83+
], comp)
84+
85+
4686
def test_flip():
87+
for comp in [True, False]:
88+
run_test(100, [
89+
SimpleRGBImageDecoder(),
90+
RandomHorizontalFlip(1.0),
91+
ToTensor(),
92+
ToTorchImage()
93+
], comp)
94+
95+
96+
def test_module_wrapper():
97+
for comp in [True, False]:
98+
run_test(100, [
99+
SimpleRGBImageDecoder(),
100+
ToTensor(),
101+
ToTorchImage(),
102+
ModuleWrapper(tvt.Grayscale(3)),
103+
], comp)
104+
105+
106+
def test_mixup():
107+
for comp in [True, False]:
108+
run_test(100, [
109+
SimpleRGBImageDecoder(),
110+
ImageMixup(.5, False),
111+
ToTensor(),
112+
ToTorchImage()
113+
], comp)
114+
115+
116+
def test_poison():
117+
mask = np.zeros((32, 32, 3))
118+
# Red sqaure
119+
mask[:5, :5, 0] = 1
120+
alpha = np.ones((32, 32))
121+
122+
for comp in [True, False]:
123+
run_test(100, [
124+
SimpleRGBImageDecoder(),
125+
Poison(mask, alpha, list(range(100))),
126+
ToTensor(),
127+
ToTorchImage()
128+
], comp)
129+
130+
131+
def test_random_resized_crop():
132+
for comp in [True, False]:
133+
run_test(100, [
134+
SimpleRGBImageDecoder(),
135+
RandomResizedCrop(scale=(0.08, 1.0),
136+
ratio=(0.75, 4/3),
137+
size=32),
138+
ToTensor(),
139+
ToTorchImage()
140+
], comp)
141+
142+
143+
def test_translate():
144+
for comp in [True, False]:
145+
run_test(100, [
146+
SimpleRGBImageDecoder(),
147+
RandomTranslate(padding=10),
148+
ToTensor(),
149+
ToTorchImage()
150+
], comp)
151+
152+
153+
## Torchvision Transforms
154+
def test_torchvision_greyscale():
47155
run_test(100, [
48156
SimpleRGBImageDecoder(),
49-
RandomHorizontalFlip(1.0),
50-
ToTensor()
51-
], True)
157+
ToTensor(),
158+
ToTorchImage(),
159+
tvt.Grayscale(3),
160+
])
52161

53-
def test_cutout():
162+
def test_torchvision_centercrop_pad():
54163
run_test(100, [
55164
SimpleRGBImageDecoder(),
56-
Cutout(8),
57-
ToTensor()
58-
], True)
165+
ToTensor(),
166+
ToTorchImage(),
167+
tvt.CenterCrop(10),
168+
tvt.Pad(11)
169+
])
59170

171+
def test_torchvision_random_affine():
60172
run_test(100, [
61173
SimpleRGBImageDecoder(),
62-
Cutout(8),
63-
ToTensor()
64-
], False)
174+
ToTensor(),
175+
ToTorchImage(),
176+
tvt.RandomAffine(25),
177+
])
65178

179+
def test_torchvision_random_crop():
180+
run_test(100, [
181+
SimpleRGBImageDecoder(),
182+
ToTensor(),
183+
ToTorchImage(),
184+
tvt.Pad(10),
185+
tvt.RandomCrop(size=32),
186+
])
66187

67-
def test_poison():
68-
mask = np.zeros((32, 32, 3))
69-
# Red sqaure
70-
mask[:5, :5, 0] = 1
71-
alpha = np.ones((32, 32))
188+
def test_torchvision_color_jitter():
72189
run_test(100, [
73190
SimpleRGBImageDecoder(),
74-
Poison(mask, alpha, [0, 1, 2]),
75-
ToTensor()
76-
], False)
191+
ToTensor(),
192+
ToTorchImage(),
193+
tvt.ColorJitter(.5, .5, .5, .5),
194+
])
195+
77196

78197
if __name__ == '__main__':
198+
# test_cutout()
79199
test_flip()
200+
# test_module_wrapper()
201+
# test_mixup()
202+
# test_poison()
203+
# test_random_resized_crop()
204+
# test_translate()
205+
206+
## Torchvision Transforms
207+
# test_torchvision_greyscale()
208+
# test_torchvision_centercrop_pad()
209+
# test_torchvision_random_affine()
210+
# test_torchvision_random_crop()
211+
# test_torchvision_color_jitter()

0 commit comments

Comments
 (0)