27
27
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
28
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
30
- use crate :: common:: { Tensor } ;
30
+ use crate :: common:: Tensor ;
31
+ use crate :: dataset:: base:: { Dataset , ImageDatasetOps } ;
32
+ use crate :: get_workspace_dir;
31
33
use flate2:: read:: GzDecoder ;
32
34
use log:: debug;
35
+ use ndarray:: { IxDyn , Shape } ;
33
36
use std:: collections:: HashSet ;
34
37
use std:: fs;
35
38
use std:: fs:: File ;
36
39
use std:: future:: Future ;
37
40
use std:: io:: Read ;
38
41
use std:: path:: Path ;
39
42
use std:: pin:: Pin ;
40
- use ndarray:: { IxDyn , Shape } ;
41
43
use tar:: Archive ;
42
- use crate :: dataset:: base:: { Dataset , ImageDatasetOps } ;
43
- use crate :: get_workspace_dir;
44
44
45
45
/// A struct representing the CIFAR10 dataset.
46
46
pub struct Cifar10Dataset {
@@ -155,7 +155,11 @@ impl Cifar10Dataset {
155
155
156
156
for & file in files {
157
157
let ( img, lbl) = Self :: parse_file (
158
- & format ! ( "{}/.cache/dataset/cifar10/cifar-10-batches-bin/{}" , env!( "CARGO_MANIFEST_DIR" ) , file) ,
158
+ & format ! (
159
+ "{}/.cache/dataset/cifar10/{}" ,
160
+ get_workspace_dir( ) . display( ) ,
161
+ file
162
+ ) ,
159
163
total_examples / files. len ( ) ,
160
164
) ;
161
165
images. extend ( img) ;
@@ -172,7 +176,10 @@ impl Cifar10Dataset {
172
176
3 ,
173
177
] ) ) ,
174
178
) ,
175
- Tensor :: new ( labels, Shape :: from ( IxDyn ( & [ total_examples, Self :: CIFAR10_NUM_CLASSES ] ) ) ) ,
179
+ Tensor :: new (
180
+ labels,
181
+ Shape :: from ( IxDyn ( & [ total_examples, Self :: CIFAR10_NUM_CLASSES ] ) ) ,
182
+ ) ,
176
183
)
177
184
}
178
185
@@ -397,77 +404,99 @@ impl ImageDatasetOps for Cifar10Dataset {
397
404
Self {
398
405
train : self . train . clone ( ) ,
399
406
test : self . test . clone ( ) ,
400
- val : self . val . clone ( )
407
+ val : self . val . clone ( ) ,
401
408
}
402
409
}
403
410
}
404
411
405
- // #[cfg(test)]
406
- // mod tests {
407
- // use super::*;
408
- // use serial_test::serial;
409
- // use tokio::runtime::Runtime;
410
- //
411
- // fn setup() {
412
- // let workspace_dir = get_workspace_dir();
413
- // let cache_path = format!("{}/.cache/dataset/cifar10", workspace_dir.display());
414
- // if Path::new(&cache_path).exists() {
415
- // fs::remove_dir_all(&cache_path).expect("Failed to delete cache directory");
416
- // }
417
- // }
418
- //
419
- // #[test]
420
- // #[serial]
421
- // fn test_download_and_extract() {
422
- // setup();
423
- // let rt = Runtime::new().unwrap();
424
- // rt.block_on(async {
425
- // Cifar10Dataset::download_and_extract().await;
426
- // let workspace_dir = get_workspace_dir();
427
- // let cache_path = format!("{}/.cache/dataset/cifar10/cifar-10-binary", workspace_dir.display());
428
- // assert!(Path::new(&cache_path).exists(), "CIFAR-10 dataset should be downloaded and extracted");
429
- // });
430
- // }
431
- //
432
- // #[test]
433
- // #[serial]
434
- // fn test_parse_file() {
435
- // // Ensure the dataset is downloaded before parsing
436
- // test_download_and_extract();
437
- //
438
- // let (images, labels) = Cifar10Dataset::parse_file("path/to/data_batch_1.bin", 10000);
439
- // assert_eq!(images.len(), 10000 * 32 * 32 * 3, "Images should have the correct length");
440
- // assert_eq!(labels.len(), 10000 * 10, "Labels should have the correct length");
441
- // }
442
- //
443
- // #[test]
444
- // #[serial]
445
- // fn test_load_data() {
446
- // // Ensure the dataset is downloaded before loading data
447
- // test_download_and_extract();
448
- //
449
- // let dataset = Cifar10Dataset::load_data(&["data_batch_1.bin"], 10000);
450
- // assert_eq!(dataset.inputs.shape(), &[10000, 32, 32, 3], "Dataset inputs should have the correct shape");
451
- // assert_eq!(dataset.labels.shape(), &[10000, 10], "Dataset labels should have the correct shape");
452
- // }
453
- //
454
- // #[test]
455
- // #[serial]
456
- // fn test_load_train() {
457
- // let rt = Runtime::new().unwrap();
458
- // rt.block_on(async {
459
- // let dataset = Cifar10Dataset::load_train().await;
460
- // assert!(dataset.train.is_some(), "Training dataset should be loaded");
461
- // });
462
- // }
463
- //
464
- // #[test]
465
- // #[serial]
466
- // fn test_load_test() {
467
- // let rt = Runtime::new().unwrap();
468
- // rt.block_on(async {
469
- // let dataset = Cifar10Dataset::load_test().await;
470
- // assert!(dataset.test.is_some(), "Test dataset should be loaded");
471
- // });
472
- // }
473
- // }
412
+ #[ cfg( test) ]
413
+ mod tests {
414
+ use super :: * ;
415
+ use ndarray:: Dimension ;
416
+ use serial_test:: serial;
417
+
418
+ fn setup ( ) {
419
+ let workspace_dir = get_workspace_dir ( ) ;
420
+ let cache_path = format ! ( "{}/.cache/dataset/cifar10" , workspace_dir. display( ) ) ;
421
+ if Path :: new ( & cache_path) . exists ( ) {
422
+ fs:: remove_dir_all ( & cache_path) . expect ( "Failed to delete cache directory" ) ;
423
+ }
424
+ }
425
+
426
+ #[ tokio:: test]
427
+ #[ serial]
428
+ async fn test_download_and_extract ( ) {
429
+ setup ( ) ;
430
+ Cifar10Dataset :: download_and_extract ( ) . await ;
431
+ let workspace_dir = get_workspace_dir ( ) ;
432
+ let cache_path = format ! (
433
+ "{}/.cache/dataset/cifar10/data_batch_1.bin" ,
434
+ workspace_dir. display( )
435
+ ) ;
436
+ assert ! (
437
+ Path :: new( & cache_path) . exists( ) ,
438
+ "CIFAR-10 dataset should be downloaded and extracted"
439
+ ) ;
440
+ }
441
+
442
+ #[ test]
443
+ #[ serial]
444
+ fn test_parse_file ( ) {
445
+ // Ensure the dataset is downloaded before parsing
446
+ test_download_and_extract ( ) ;
447
+ let workspace_dir = get_workspace_dir ( ) ;
448
+ let cache_path = format ! (
449
+ "{}/.cache/dataset/cifar10/data_batch_1.bin" ,
450
+ workspace_dir. display( )
451
+ ) ;
452
+
453
+ let ( images, labels) = Cifar10Dataset :: parse_file ( & cache_path, 10000 ) ;
454
+ assert_eq ! (
455
+ images. len( ) ,
456
+ 10000 * 32 * 32 * 3 ,
457
+ "Images should have the correct length"
458
+ ) ;
459
+ assert_eq ! (
460
+ labels. len( ) ,
461
+ 10000 * 10 ,
462
+ "Labels should have the correct length"
463
+ ) ;
464
+ }
465
+
466
+ #[ test]
467
+ #[ serial]
468
+ fn test_load_data ( ) {
469
+ // Ensure the dataset is downloaded before loading data
470
+ test_download_and_extract ( ) ;
471
+
472
+ let dataset = Cifar10Dataset :: load_data ( & [ "data_batch_1.bin" ] , 10000 ) ;
473
+
474
+ // Compare the shape of inputs
475
+ assert_eq ! (
476
+ dataset. inputs. shape( ) . raw_dim( ) . as_array_view( ) . to_vec( ) ,
477
+ & [ 10000 , 32 , 32 , 3 ] ,
478
+ "Dataset inputs should have the correct shape"
479
+ ) ;
480
+
481
+ // Compare the shape of labels
482
+ assert_eq ! (
483
+ dataset. labels. shape( ) . raw_dim( ) . as_array_view( ) . to_vec( ) ,
484
+ & [ 10000 , 10 ] ,
485
+ "Dataset labels should have the correct shape"
486
+ ) ;
487
+ }
488
+
489
+ #[ tokio:: test]
490
+ #[ serial]
491
+ async fn test_load_train ( ) {
492
+ let dataset = Cifar10Dataset :: load_train ( ) . await ;
493
+ assert ! ( dataset. train. is_some( ) , "Training dataset should be loaded" ) ;
494
+ }
495
+
496
+ #[ tokio:: test]
497
+ #[ serial]
498
+ async fn test_load_test ( ) {
499
+ let dataset = Cifar10Dataset :: load_test ( ) . await ;
500
+ assert ! ( dataset. test. is_some( ) , "Test dataset should be loaded" ) ;
501
+ }
502
+ }
0 commit comments