Skip to content

Commit 61dc13b

Browse files
committed
...
1 parent 06595c7 commit 61dc13b

File tree

5 files changed

+181
-26
lines changed

5 files changed

+181
-26
lines changed

R/DataBackendLazyTensors.R

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
3636
cloneable = FALSE,
3737
inherit = DataBackendDataTable,
3838
public = list(
39+
chunk_size = NULL,
3940
#' @description
4041
#' Create a new instance of this [R6][R6::R6Class] class.
4142
#' @param data (`data.table`)\cr
@@ -48,10 +49,12 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
4849
#' @param cache (`character()`)\cr
4950
#' Names of the columns that should be cached.
5051
#' Per default, all columns that are converted are cached.
51-
initialize = function(data, primary_key, converter, cache = names(converter)) {
52+
initialize = function(data, primary_key, converter, cache = names(converter), chunk_size = 100) {
5253
private$.converter = assert_list(converter, types = "function", any.missing = FALSE)
5354
assert_subset(names(converter), colnames(data))
55+
assert_subset(cache, names(converter), empty.ok = TRUE)
5456
private$.cached_cols = assert_subset(cache, names(converter))
57+
self$chunk_size = assert_int(chunk_size, lower = 1L)
5558
walk(names(private$.converter), function(nm) {
5659
if (!inherits(data[[nm]], "lazy_tensor")) {
5760
stopf("Column '%s' is not a lazy tensor.", nm)
@@ -69,18 +72,25 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
6972
# no caching, no materialization as this is called in the training loop
7073
return(super$data(rows, cols))
7174
}
72-
if (all(cols %in% names(private$.data_cache))) {
73-
cache_hit = private$.data_cache[list(rows), cols, on = self$primary_key, with = FALSE]
75+
if (all(intersect(cols, private$.cached_cols) %in% names(private$.data_cache))) {
76+
expensive_cols = intersect(cols, private$.cached_cols)
77+
other_cols = setdiff(cols, expensive_cols)
78+
cache_hit = private$.data_cache[list(rows), expensive_cols, on = self$primary_key, with = FALSE]
7479
complete = complete.cases(cache_hit)
7580
cache_hit = cache_hit[complete]
7681
if (nrow(cache_hit) == length(rows)) {
77-
return(cache_hit)
82+
tbl = cbind(cache_hit, super$data(rows, other_cols))
83+
setcolorder(tbl, cols)
84+
return(tbl)
7885
}
79-
combined = rbindlist(list(cache_hit, private$.load_and_cache(rows[!complete], cols)))
86+
combined = rbindlist(list(cache_hit, private$.load_and_cache(rows[!complete], expensive_cols)))
8087
reorder = vector("integer", nrow(combined))
8188
reorder[complete] = seq_len(nrow(cache_hit))
8289
reorder[!complete] = nrow(cache_hit) + seq_len(nrow(combined) - nrow(cache_hit))
83-
return(combined[reorder])
90+
91+
tbl = cbind(combined[reorder], super$data(rows, other_cols))
92+
setcolorder(tbl, cols)
93+
return(tbl)
8494
}
8595

8696
private$.load_and_cache(rows, cols)
@@ -109,7 +119,17 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
109119
tbl = super$data(rows, cols)
110120
cols_to_convert = intersect(names(private$.converter), names(tbl))
111121
tbl_to_mat = tbl[, cols_to_convert, with = FALSE]
112-
tbl_mat = materialize(tbl_to_mat, rbind = TRUE)
122+
# chunk the rows of tbl_to_mat into chunks of size self$chunk_size, apply materialize
123+
n = nrow(tbl_to_mat)
124+
chunks = split(seq_len(n), rep(seq_len(ceiling(n / self$chunk_size)), each = self$chunk_size, length.out = n))
125+
126+
tbl_mat = if (n == 0) {
127+
set_names(list(torch_empty(0)), names(tbl_to_mat))
128+
} else {
129+
set_names(lapply(transpose_list(lapply(chunks, function(chunk) {
130+
materialize(tbl_to_mat[chunk, ], rbind = TRUE)
131+
})), torch_cat, dim = 1L), names(tbl_to_mat))
132+
}
113133

114134
for (nm in cols_to_convert) {
115135
converted = private$.converter[[nm]](tbl_mat[[nm]])
@@ -135,13 +155,62 @@ as_data_backend.dataset = function(x, dataset_shapes, ...) {
135155
}
136156

137157
#' @export
138-
as_task_classif.dataset = function(x, dataset_shapes, target, ...) {
139-
# TODO
158+
as_task_classif.dataset = function(x, target, levels, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
159+
if (length(x) < 2) {
160+
stopf("Dataset must have at least 2 rows.")
161+
}
162+
batch = dataloader(x, batch_size = 2)$.iter()$.next()
163+
if (is.null(converter)) {
164+
if (length(levels) == 2) {
165+
if (batch[[target]]$dtype != torch_float()) {
166+
stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
167+
}
168+
if (test_equal(batch[[target]]$shape, c(2L, 1L))) {
169+
converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
170+
} else {
171+
stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
172+
paste(batch[[target]]$shape[-1L], collapse = ", "))
173+
}
174+
converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
175+
} else {
176+
if (batch[[target]]$dtype != torch_int()) {
177+
stopf("Target must be an integer tensor, but has dtype %s", batch[[target]]$dtype)
178+
}
179+
if (test_equal(batch[[target]]$shape, 2L)) {
180+
converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
181+
} else {
182+
stopf("Target must be an integer tensor of shape (batch_size), but has shape (batch_size, %s)",
183+
paste(batch[[target]]$shape[-1L], collapse = ", "))
184+
}
185+
converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
186+
}
187+
}
188+
be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
189+
as_task_classif(be, target = target, ...)
140190
}
141191

142192
#' @export
143-
as_task_regr.dataset = function(x, dataset_shapes, target, converter, ...) {
144-
# TODO
193+
as_task_regr.dataset = function(x, target, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
194+
if (length(x) < 2) {
195+
stopf("Dataset must have at least 2 rows.")
196+
}
197+
if (is.null(converter)) {
198+
converter = set_names(list(as.numeric), target)
199+
}
200+
batch = dataloader(x, batch_size = 2)$.iter()$.next()
201+
202+
if (batch[[target]]$dtype != torch_float()) {
203+
stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
204+
}
205+
206+
if (!test_equal(batch[[target]]$shape, c(2L, 1L))) {
207+
stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
208+
paste(batch[[target]]$shape[-1L], collapse = ", "))
209+
}
210+
211+
dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes)
212+
be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
213+
as_task_regr(be, target = target, ...)
145214
}
146215

147216
#' @export
@@ -177,4 +246,4 @@ check_lazy_tensors_backend = function(be, candidates, visited = character()) {
177246
}
178247
union(visited, intersect(candidates, be$colnames))
179248
}
180-
}
249+
}

R/materialize.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ materialize.lazy_tensor = function(x, device = "cpu", rbind = FALSE, ...) { # no
106106
materialize_internal(x = x, device = device, cache = NULL, rbind = rbind)
107107
}
108108

109-
get_input = function(ds, ids, varying_shapes, rbind) {
109+
get_input = function(ds, ids, varying_shapes) {
110110
if (is.null(ds$.getbatch)) { # .getindex is never NULL but a function that errs if it was not defined
111111
x = map(ids, function(id) map(ds$.getitem(id), function(x) x$unsqueeze(1)))
112112
if (varying_shapes) {
@@ -201,7 +201,7 @@ materialize_internal = function(x, device = "cpu", cache = NULL, rbind) {
201201
}
202202

203203
if (!do_caching || !input_hit) {
204-
input = get_input(ds, ids, varying_shapes, rbind)
204+
input = get_input(ds, ids, varying_shapes)
205205
}
206206

207207
if (do_caching && !input_hit) {

TODO.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
```
2222
* Add checks on usage of `DataBackendLazyTensors` in `task_dataset`
2323
* Add optimization that truths values don't have to be loaded twice during resampling, i.e.
24-
once for making the predictions and once for retrieving the truth column.
24+
once for making the predictions and once for retrieving the truth column.
25+
* only allow caching converter columns in `DataBackendLazyTensors` (probably just remove the `cache` parameter)

man/DataBackendLazyTensors.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_DataBackendLazyTensors.R

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
test_that("correct input checks", {
2-
3-
})
4-
51
test_that("main API works", {
62
# regression target
73
ds = tensor_dataset(
@@ -102,11 +98,71 @@ test_that("classif target works", {
10298
})
10399

104100
test_that("errors when weird preprocessing", {
105-
# test following example pipeops:
106-
# - target trafo
107-
# - fix factors
108-
# - smote
101+
})
102+
103+
test_that("chunking works ", {
104+
ds = dataset(
105+
initialize = function() {
106+
self$x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))
107+
self$y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1))
108+
self$counter = 0
109+
},
110+
.getbatch = function(i) {
111+
self$counter = self$counter + 1
112+
list(x = self$x[i, drop = FALSE], y = self$y[i, drop = FALSE])
113+
},
114+
.length = function() {
115+
nrow(self$x)
116+
}
117+
)()
109118

119+
be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), chunk_size = 3,
120+
converter = list(y = as.numeric))
121+
122+
counter_prev = ds$counter
123+
be$data(1:3, c("x", "y"))
124+
expect_equal(ds$counter, counter_prev + 1)
125+
counter_prev = ds$counter
126+
be$data(4:10, c("x", "y"))
127+
expect_equal(ds$counter, counter_prev + 3)
128+
})
129+
130+
test_that("can retrieve 0 rows", {
131+
ds = tensor_dataset(
132+
x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1)),
133+
y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1))
134+
)
135+
be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)),
136+
converter = list(y = as.numeric))
137+
res = be$data(integer(0), c("x", "y", "row_id"))
138+
expect_data_table(res, nrows = 0, ncols = 3)
139+
expect_class(res$x, "lazy_tensor")
140+
expect_class(res$y, "numeric")
141+
expect_equal(res$row_id, integer(0))
142+
})
143+
144+
test_that("task converters work", {
145+
# regression target
146+
ds = tensor_dataset(
147+
x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(),
148+
y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1))$float()
149+
)
150+
task = as_task_regr(ds, target = "y", converter = list(y = as.numeric))
151+
task$data(integer(0))
152+
expect_equal(task$head(2)$y, 1:2)
153+
expect_equal(task$feature_names, "x")
154+
expect_equal(task$target_names, "y")
155+
expect_task(task)
156+
157+
158+
# binary classification
159+
ds = tensor_dataset(
160+
x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(),
161+
y = torch_tensor(rep(0:1, times = 50))$float()$unsqueeze(2L)
162+
)
163+
task = as_task_classif(ds, target = "y", levels = c("yes", "no"))
164+
expect_task(task)
165+
expect_equal(task$head()$y, factor(rep(c("yes", "no"), times = 3), levels = c("yes", "no")))
110166
})
111167

112168
test_that("caching works", {
@@ -147,8 +203,8 @@ test_that("caching works", {
147203
# y is no in the cache, so .getitem() is not called on $data()
148204
check(be, ds, 1, "y", 0)
149205

150-
# but x is not cached, so we still need to call .getitem below
151-
check(be, ds, 1, c("x", "y"), 1)
206+
# everything is in the cache
207+
check(be, ds, 1, c("x", "y"), 0)
152208
# lazy tensor causes no materialization
153209
check(be, ds, 1, "x", 0)
154210

@@ -247,3 +303,31 @@ test_that("check_lazy_tensors_backend works", {
247303
expect_error(check_lazy_tensors_backend(task2$backend, c("x", "y")),
248304
regexp = "A converter column ('y')", fixed = TRUE)
249305
})
306+
307+
308+
test_that("...", {
309+
ds = dataset(
310+
initialize = function(x, y) {
311+
self$x = torch_randn(100, 3)
312+
self$y = torch_randn(100, 1)
313+
self$counter = 0
314+
},
315+
.getbatch = function(i) {
316+
print("hallo")
317+
self$counter = self$counter + 1L
318+
list(x = self$x[i, drop = FALSE], y = self$y[i, drop = FALSE])
319+
},
320+
.length = function() 100
321+
)()
322+
323+
task = as_task_regr(ds, target = "y")
324+
325+
counter = ds$counter
326+
task$head()
327+
print(ds$counter - counter)
328+
counter = ds$counter
329+
task$head()
330+
expec
331+
print(ds$counter - counter)
332+
333+
})

0 commit comments

Comments
 (0)