1+ import os
2+ import uuid
13import numpy as np
24import torch as ch
35from torch .utils .data import Dataset
6+ from torchvision import transforms as tvt
47from assertpy import assert_that
58from tempfile import NamedTemporaryFile
69from torchvision .datasets import CIFAR10
10+ from torchvision .utils import save_image , make_grid
711from torch .utils .data import Subset
812from ffcv .fields .basics import IntDecoder
913from ffcv .fields .rgb_image import SimpleRGBImageDecoder
10- from ffcv .transforms .cutout import Cutout
1114
1215from ffcv .writer import DatasetWriter
1316from ffcv .fields import IntField , RGBImageField
1417from ffcv .loader import Loader
1518from ffcv .pipeline .compiler import Compiler
16- from ffcv .transforms import Squeeze , Cutout , ToTensor , Poison , RandomHorizontalFlip
19+ from ffcv .transforms import *
1720
18- def run_test (length , pipeline , compile ):
21+ SAVE_IMAGES = True
22+ IMAGES_TMP_PATH = '/tmp/ffcv_augtest_output'
23+ if SAVE_IMAGES :
24+ os .makedirs (IMAGES_TMP_PATH , exist_ok = True )
25+
26+ UNAUGMENTED_PIPELINE = [
27+ SimpleRGBImageDecoder (),
28+ ToTensor (),
29+ ToTorchImage ()
30+ ]
31+
32+ def run_test (length , pipeline , compile = False ):
1933 my_dataset = Subset (CIFAR10 (root = '/tmp' , train = True , download = True ), range (length ))
2034
2135 with NamedTemporaryFile () as handle :
@@ -28,52 +42,170 @@ def run_test(length, pipeline, compile):
2842
2943 writer .from_indexed_dataset (my_dataset , chunksize = 10 )
3044
31- Compiler .set_enabled (True )
45+ Compiler .set_enabled (compile )
3246
3347 loader = Loader (name , batch_size = 7 , num_workers = 2 , pipelines = {
3448 'image' : pipeline ,
3549 'label' : [IntDecoder (), ToTensor (), Squeeze ()]
3650 },
3751 drop_last = False )
52+
53+ unaugmented_loader = Loader (name , batch_size = 7 , num_workers = 2 , pipelines = {
54+ 'image' : UNAUGMENTED_PIPELINE ,
55+ 'label' : [IntDecoder (), ToTensor (), Squeeze ()]
56+ }, drop_last = False )
57+
3858 tot_indices = 0
3959 tot_images = 0
40- for images , label in loader :
41- tot_indices += label .shape [0 ]
60+ for (images , labels ), (original_images , original_labels ) in zip (loader , unaugmented_loader ):
61+ print (images .shape , original_images .shape )
62+ tot_indices += labels .shape [0 ]
4263 tot_images += images .shape [0 ]
64+
65+ for label , original_label in zip (labels , original_labels ):
66+ assert_that (label ).is_equal_to (original_label )
67+
68+ if SAVE_IMAGES :
69+ save_image (make_grid (ch .concat ([images , original_images ])/ 255. , images .shape [0 ]),
70+ os .path .join (IMAGES_TMP_PATH , str (uuid .uuid4 ()) + '.jpeg' )
71+ )
72+
4373 assert_that (tot_indices ).is_equal_to (len (my_dataset ))
4474 assert_that (tot_images ).is_equal_to (len (my_dataset ))
4575
76+ def test_cutout ():
77+ for comp in [True , False ]:
78+ run_test (100 , [
79+ SimpleRGBImageDecoder (),
80+ Cutout (8 ),
81+ ToTensor (),
82+ ToTorchImage ()
83+ ], comp )
84+
85+
4686def test_flip ():
87+ for comp in [True , False ]:
88+ run_test (100 , [
89+ SimpleRGBImageDecoder (),
90+ RandomHorizontalFlip (1.0 ),
91+ ToTensor (),
92+ ToTorchImage ()
93+ ], comp )
94+
95+
96+ def test_module_wrapper ():
97+ for comp in [True , False ]:
98+ run_test (100 , [
99+ SimpleRGBImageDecoder (),
100+ ToTensor (),
101+ ToTorchImage (),
102+ ModuleWrapper (tvt .Grayscale (3 )),
103+ ], comp )
104+
105+
106+ def test_mixup ():
107+ for comp in [True , False ]:
108+ run_test (100 , [
109+ SimpleRGBImageDecoder (),
110+ ImageMixup (.5 , False ),
111+ ToTensor (),
112+ ToTorchImage ()
113+ ], comp )
114+
115+
116+ def test_poison ():
117+ mask = np .zeros ((32 , 32 , 3 ))
118+ # Red sqaure
119+ mask [:5 , :5 , 0 ] = 1
120+ alpha = np .ones ((32 , 32 ))
121+
122+ for comp in [True , False ]:
123+ run_test (100 , [
124+ SimpleRGBImageDecoder (),
125+ Poison (mask , alpha , list (range (100 ))),
126+ ToTensor (),
127+ ToTorchImage ()
128+ ], comp )
129+
130+
131+ def test_random_resized_crop ():
132+ for comp in [True , False ]:
133+ run_test (100 , [
134+ SimpleRGBImageDecoder (),
135+ RandomResizedCrop (scale = (0.08 , 1.0 ),
136+ ratio = (0.75 , 4 / 3 ),
137+ size = 32 ),
138+ ToTensor (),
139+ ToTorchImage ()
140+ ], comp )
141+
142+
143+ def test_translate ():
144+ for comp in [True , False ]:
145+ run_test (100 , [
146+ SimpleRGBImageDecoder (),
147+ RandomTranslate (padding = 10 ),
148+ ToTensor (),
149+ ToTorchImage ()
150+ ], comp )
151+
152+
153+ ## Torchvision Transforms
154+ def test_torchvision_greyscale ():
47155 run_test (100 , [
48156 SimpleRGBImageDecoder (),
49- RandomHorizontalFlip (1.0 ),
50- ToTensor ()
51- ], True )
157+ ToTensor (),
158+ ToTorchImage (),
159+ tvt .Grayscale (3 ),
160+ ])
52161
53- def test_cutout ():
162+ def test_torchvision_centercrop_pad ():
54163 run_test (100 , [
55164 SimpleRGBImageDecoder (),
56- Cutout (8 ),
57- ToTensor ()
58- ], True )
165+ ToTensor (),
166+ ToTorchImage (),
167+ tvt .CenterCrop (10 ),
168+ tvt .Pad (11 )
169+ ])
59170
171+ def test_torchvision_random_affine ():
60172 run_test (100 , [
61173 SimpleRGBImageDecoder (),
62- Cutout (8 ),
63- ToTensor ()
64- ], False )
174+ ToTensor (),
175+ ToTorchImage (),
176+ tvt .RandomAffine (25 ),
177+ ])
65178
179+ def test_torchvision_random_crop ():
180+ run_test (100 , [
181+ SimpleRGBImageDecoder (),
182+ ToTensor (),
183+ ToTorchImage (),
184+ tvt .Pad (10 ),
185+ tvt .RandomCrop (size = 32 ),
186+ ])
66187
67- def test_poison ():
68- mask = np .zeros ((32 , 32 , 3 ))
69- # Red sqaure
70- mask [:5 , :5 , 0 ] = 1
71- alpha = np .ones ((32 , 32 ))
188+ def test_torchvision_color_jitter ():
72189 run_test (100 , [
73190 SimpleRGBImageDecoder (),
74- Poison (mask , alpha , [0 , 1 , 2 ]),
75- ToTensor ()
76- ], False )
191+ ToTensor (),
192+ ToTorchImage (),
193+ tvt .ColorJitter (.5 , .5 , .5 , .5 ),
194+ ])
195+
77196
78197if __name__ == '__main__' :
198+ # test_cutout()
79199 test_flip ()
200+ # test_module_wrapper()
201+ # test_mixup()
202+ # test_poison()
203+ # test_random_resized_crop()
204+ # test_translate()
205+
206+ ## Torchvision Transforms
207+ # test_torchvision_greyscale()
208+ # test_torchvision_centercrop_pad()
209+ # test_torchvision_random_affine()
210+ # test_torchvision_random_crop()
211+ # test_torchvision_color_jitter()
0 commit comments