Skip to content

Commit 3ea3bac

Browse files
committed
different hash and man
1 parent 1ca1e12 commit 3ea3bac

File tree

2 files changed

+18
-55
lines changed

2 files changed

+18
-55
lines changed

R/TaskClassif_cifar.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ load_task_cifar10 = function(id = "cifar10") {
149149

150150
task$col_roles$feature = "image"
151151

152-
# TODO: different hash, same manual
153-
backend$hash = task$man = "mlr3torch::mlr_tasks_cifar"
152+
backend$hash = "mlr3torch::mlr_tasks_cifar10"
153+
task$man = "mlr3torch::mlr_tasks_cifar"
154154

155155
task$filter(1:50000)
156156

@@ -253,7 +253,8 @@ load_task_cifar100 = function(id = "cifar100") {
253253

254254
task$col_roles$feature = "image"
255255

256-
backend$hash = task$man = "mlr3torch::mlr_tasks_cifar"
256+
backend$hash = "mlr3torch::mlr_tasks_cifar100"
257+
task$man = "mlr3torch::mlr_tasks_cifar"
257258

258259
task$filter(1:50000)
259260

tests/testthat/test_TaskClassif_cifar.R

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,59 +19,21 @@ test_that("CIFAR-10 works", {
1919
expect_equal(task$backend$ncol, 6)
2020
})
2121

22-
# TODO: delete
23-
test_that("CIFAR-10 data matches the torchvision implementation", {
22+
test_that("CIFAR-100 works", {
2423
withr::local_options(mlr3torch.cache = TRUE)
25-
task = tsk("cifar10")
26-
task$data()
27-
28-
# check the responses
29-
train_idx = 1:50000
30-
int_mlr3torch_responses = as.integer(task$data()$class[train_idx])
31-
32-
cifar10_ds_train = cifar10_dataset(root = file.path(get_cache_dir(), "datasets", "cifar10", "raw"), train = TRUE,
33-
download = FALSE)
34-
35-
# TODO: determine whether a separate function is truly necessary
36-
# proably not, if getitem() allows a vector
37-
get_response = function(idx, ds) {
38-
ds$.getitem(idx)$y
39-
}
40-
cifar10_ds_responses = map_int(train_idx, get_response, ds = cifar10_ds_train)
41-
42-
expect_true(all.equal(int_mlr3torch_responses, cifar10_ds_responses))
43-
44-
# check a subset of train images
45-
small_train_idx = c(1, 2, 27, 9999,
46-
10000, 10001, 10901, 19999,
47-
20000, 20001, 29999,
48-
30000, 30001, 39999,
49-
40000, 40001, 49999,
50-
50000
51-
)
52-
task_small = task$clone()
53-
task_small$filter(small_train_idx)
54-
55-
test_same_at_idx = function(idx, lt_list, imgs_arr) {
56-
all.equal(as.array(lt_list[[idx]]), imgs_arr[idx, , , ])
57-
}
58-
59-
lt_list = materialize(task_small$data()$image)
60-
imgs_arr = cifar10_ds_train$.getitem(small_train_idx)$x
61-
62-
expect_true(all(map_lgl(1:length(small_train_idx), test_same_at_idx, lt_list = lt_list, imgs_arr = imgs_arr)))
63-
64-
# check a subset of test images
65-
test_idx = c(1, 2, 27, 8484, 9999, 10000)
66-
67-
test_dt_from_task = task$backend$data(rows = 50001:60000, cols = task$backend$colnames)
68-
expect_true(all(test_dt_from_task$split == "test"))
24+
task = tsk("cifar100")
6925

70-
lt_list_test = materialize(test_dt_from_task[test_idx, ]$image)
71-
72-
cifar10_ds_test = cifar10_dataset(root = file.path(get_cache_dir(), "datasets", "cifar10", "raw"), train = FALSE,
73-
download = FALSE)
74-
imgs_arr_test = cifar10_ds_test$.getitem(test_idx)$x
26+
expect_equal(task$nrow, 50000)
7527

76-
expect_true(all(map_lgl(1:length(test_idx), test_same_at_idx, lt_list = lt_list, imgs_arr = imgs_arr)))
28+
task$filter(1:10)
29+
expect_equal(task$id, "cifar100")
30+
expect_equal(task$label, "CIFAR-100 Classification")
31+
expect_equal(task$feature_names, "image")
32+
expect_equal(task$target_names, "class")
33+
expect_equal(task$man, "mlr3torch::mlr_tasks_cifar10")
34+
task$data()
35+
expect_true("cifar-100-binary" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar100", "raw")))
36+
expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar100")))
37+
expect_equal(task$backend$nrow, 60000)
38+
expect_equal(task$backend$ncol, 6)
7739
})

0 commit comments

Comments
 (0)