@@ -19,59 +19,21 @@ test_that("CIFAR-10 works", {
1919 expect_equal(task $ backend $ ncol , 6 )
2020})
2121
22- # TODO: delete
23- test_that(" CIFAR-10 data matches the torchvision implementation" , {
22+ test_that(" CIFAR-100 works" , {
2423 withr :: local_options(mlr3torch.cache = TRUE )
25- task = tsk(" cifar10" )
26- task $ data()
27-
28- # check the responses
29- train_idx = 1 : 50000
30- int_mlr3torch_responses = as.integer(task $ data()$ class [train_idx ])
31-
32- cifar10_ds_train = cifar10_dataset(root = file.path(get_cache_dir(), " datasets" , " cifar10" , " raw" ), train = TRUE ,
33- download = FALSE )
34-
35- # TODO: determine whether a separate function is truly necessary
36- # proably not, if getitem() allows a vector
37- get_response = function (idx , ds ) {
38- ds $ .getitem(idx )$ y
39- }
40- cifar10_ds_responses = map_int(train_idx , get_response , ds = cifar10_ds_train )
41-
42- expect_true(all.equal(int_mlr3torch_responses , cifar10_ds_responses ))
43-
44- # check a subset of train images
45- small_train_idx = c(1 , 2 , 27 , 9999 ,
46- 10000 , 10001 , 10901 , 19999 ,
47- 20000 , 20001 , 29999 ,
48- 30000 , 30001 , 39999 ,
49- 40000 , 40001 , 49999 ,
50- 50000
51- )
52- task_small = task $ clone()
53- task_small $ filter(small_train_idx )
54-
55- test_same_at_idx = function (idx , lt_list , imgs_arr ) {
56- all.equal(as.array(lt_list [[idx ]]), imgs_arr [idx , , , ])
57- }
58-
59- lt_list = materialize(task_small $ data()$ image )
60- imgs_arr = cifar10_ds_train $ .getitem(small_train_idx )$ x
61-
62- expect_true(all(map_lgl(1 : length(small_train_idx ), test_same_at_idx , lt_list = lt_list , imgs_arr = imgs_arr )))
63-
64- # check a subset of test images
65- test_idx = c(1 , 2 , 27 , 8484 , 9999 , 10000 )
66-
67- test_dt_from_task = task $ backend $ data(rows = 50001 : 60000 , cols = task $ backend $ colnames )
68- expect_true(all(test_dt_from_task $ split == " test" ))
24+ task = tsk(" cifar100" )
6925
70- lt_list_test = materialize(test_dt_from_task [test_idx , ]$ image )
71-
72- cifar10_ds_test = cifar10_dataset(root = file.path(get_cache_dir(), " datasets" , " cifar10" , " raw" ), train = FALSE ,
73- download = FALSE )
74- imgs_arr_test = cifar10_ds_test $ .getitem(test_idx )$ x
26+ expect_equal(task $ nrow , 50000 )
7527
76- expect_true(all(map_lgl(1 : length(test_idx ), test_same_at_idx , lt_list = lt_list , imgs_arr = imgs_arr )))
28+ task $ filter(1 : 10 )
29+ expect_equal(task $ id , " cifar100" )
30+ expect_equal(task $ label , " CIFAR-100 Classification" )
31+ expect_equal(task $ feature_names , " image" )
32+ expect_equal(task $ target_names , " class" )
33+ expect_equal(task $ man , " mlr3torch::mlr_tasks_cifar10" )
34+ task $ data()
35+ expect_true(" cifar-100-binary" %in% list.files(file.path(get_cache_dir(), " datasets" , " cifar100" , " raw" )))
36+ expect_true(" data.rds" %in% list.files(file.path(get_cache_dir(), " datasets" , " cifar100" )))
37+ expect_equal(task $ backend $ nrow , 60000 )
38+ expect_equal(task $ backend $ ncol , 6 )
7739})
0 commit comments