@@ -36,6 +36,7 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
3636 cloneable = FALSE ,
3737 inherit = DataBackendDataTable ,
3838 public = list (
39+ chunk_size = NULL ,
3940 # ' @description
4041 # ' Create a new instance of this [R6][R6::R6Class] class.
4142 # ' @param data (`data.table`)\cr
@@ -48,10 +49,12 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
4849 # ' @param cache (`character()`)\cr
4950 # ' Names of the columns that should be cached.
5051 # ' Per default, all columns that are converted are cached.
51- initialize = function (data , primary_key , converter , cache = names(converter )) {
52+ initialize = function (data , primary_key , converter , cache = names(converter ), chunk_size = 100 ) {
5253 private $ .converter = assert_list(converter , types = " function" , any.missing = FALSE )
5354 assert_subset(names(converter ), colnames(data ))
55+ assert_subset(cache , names(converter ), empty.ok = TRUE )
5456 private $ .cached_cols = assert_subset(cache , names(converter ))
57+ self $ chunk_size = assert_int(chunk_size , lower = 1L )
5558 walk(names(private $ .converter ), function (nm ) {
5659 if (! inherits(data [[nm ]], " lazy_tensor" )) {
5760 stopf(" Column '%s' is not a lazy tensor." , nm )
@@ -69,18 +72,25 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
6972 # no caching, no materialization as this is called in the training loop
7073 return (super $ data(rows , cols ))
7174 }
72- if (all(cols %in% names(private $ .data_cache ))) {
73- cache_hit = private $ .data_cache [list (rows ), cols , on = self $ primary_key , with = FALSE ]
75+ if (all(intersect(cols , private $ .cached_cols ) %in% names(private $ .data_cache ))) {
76+ expensive_cols = intersect(cols , private $ .cached_cols )
77+ other_cols = setdiff(cols , expensive_cols )
78+ cache_hit = private $ .data_cache [list (rows ), expensive_cols , on = self $ primary_key , with = FALSE ]
7479 complete = complete.cases(cache_hit )
7580 cache_hit = cache_hit [complete ]
7681 if (nrow(cache_hit ) == length(rows )) {
77- return (cache_hit )
82+ tbl = cbind(cache_hit , super $ data(rows , other_cols ))
83+ setcolorder(tbl , cols )
84+ return (tbl )
7885 }
79- combined = rbindlist(list (cache_hit , private $ .load_and_cache(rows [! complete ], cols )))
86+ combined = rbindlist(list (cache_hit , private $ .load_and_cache(rows [! complete ], expensive_cols )))
8087 reorder = vector(" integer" , nrow(combined ))
8188 reorder [complete ] = seq_len(nrow(cache_hit ))
8289 reorder [! complete ] = nrow(cache_hit ) + seq_len(nrow(combined ) - nrow(cache_hit ))
83- return (combined [reorder ])
90+
91+ tbl = cbind(combined [reorder ], super $ data(rows , other_cols ))
92+ setcolorder(tbl , cols )
93+ return (tbl )
8494 }
8595
8696 private $ .load_and_cache(rows , cols )
@@ -109,7 +119,17 @@ DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
109119 tbl = super $ data(rows , cols )
110120 cols_to_convert = intersect(names(private $ .converter ), names(tbl ))
111121 tbl_to_mat = tbl [, cols_to_convert , with = FALSE ]
112- tbl_mat = materialize(tbl_to_mat , rbind = TRUE )
122+ # chunk the rows of tbl_to_mat into chunks of size self$chunk_size, apply materialize
123+ n = nrow(tbl_to_mat )
124+ chunks = split(seq_len(n ), rep(seq_len(ceiling(n / self $ chunk_size )), each = self $ chunk_size , length.out = n ))
125+
126+ tbl_mat = if (n == 0 ) {
127+ set_names(list (torch_empty(0 )), names(tbl_to_mat ))
128+ } else {
129+ set_names(lapply(transpose_list(lapply(chunks , function (chunk ) {
130+ materialize(tbl_to_mat [chunk , ], rbind = TRUE )
131+ })), torch_cat , dim = 1L ), names(tbl_to_mat ))
132+ }
113133
114134 for (nm in cols_to_convert ) {
115135 converted = private $ .converter [[nm ]](tbl_mat [[nm ]])
@@ -135,13 +155,62 @@ as_data_backend.dataset = function(x, dataset_shapes, ...) {
135155}
136156
137157# ' @export
138- as_task_classif.dataset = function (x , dataset_shapes , target , ... ) {
139- # TODO
158+ as_task_classif.dataset = function (x , target , levels , converter = NULL , dataset_shapes = NULL , chunk_size = 100 , cache = names(converter ), ... ) {
159+ if (length(x ) < 2 ) {
160+ stopf(" Dataset must have at least 2 rows." )
161+ }
162+ batch = dataloader(x , batch_size = 2 )$ .iter()$ .next()
163+ if (is.null(converter )) {
164+ if (length(levels ) == 2 ) {
165+ if (batch [[target ]]$ dtype != torch_float()) {
166+ stopf(" Target must be a float tensor, but has dtype %s" , batch [[target ]]$ dtype )
167+ }
168+ if (test_equal(batch [[target ]]$ shape , c(2L , 1L ))) {
169+ converter = set_names(list (crate(function (x ) factor (as.integer(x ), levels = 0 : 1 , labels = levels ), levels )), target )
170+ } else {
171+ stopf(" Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)" ,
172+ paste(batch [[target ]]$ shape [- 1L ], collapse = " , " ))
173+ }
174+ converter = set_names(list (crate(function (x ) factor (as.integer(x ), levels = 0 : 1 , labels = levels ), levels )), target )
175+ } else {
176+ if (batch [[target ]]$ dtype != torch_int()) {
177+ stopf(" Target must be an integer tensor, but has dtype %s" , batch [[target ]]$ dtype )
178+ }
179+ if (test_equal(batch [[target ]]$ shape , 2L )) {
180+ converter = set_names(list (crate(function (x ) factor (as.integer(x ), labels = levels ), levels )), target )
181+ } else {
182+ stopf(" Target must be an integer tensor of shape (batch_size), but has shape (batch_size, %s)" ,
183+ paste(batch [[target ]]$ shape [- 1L ], collapse = " , " ))
184+ }
185+ converter = set_names(list (crate(function (x ) factor (as.integer(x ), labels = levels ), levels )), target )
186+ }
187+ }
188+ be = as_data_backend(x , dataset_shapes , converter = converter , cache = cache , chunk_size = chunk_size )
189+ as_task_classif(be , target = target , ... )
140190}
141191
142192# ' @export
143- as_task_regr.dataset = function (x , dataset_shapes , target , converter , ... ) {
144- # TODO
193+ as_task_regr.dataset = function (x , target , converter = NULL , dataset_shapes = NULL , chunk_size = 100 , cache = names(converter ), ... ) {
194+ if (length(x ) < 2 ) {
195+ stopf(" Dataset must have at least 2 rows." )
196+ }
197+ if (is.null(converter )) {
198+ converter = set_names(list (as.numeric ), target )
199+ }
200+ batch = dataloader(x , batch_size = 2 )$ .iter()$ .next()
201+
202+ if (batch [[target ]]$ dtype != torch_float()) {
203+ stopf(" Target must be a float tensor, but has dtype %s" , batch [[target ]]$ dtype )
204+ }
205+
206+ if (! test_equal(batch [[target ]]$ shape , c(2L , 1L ))) {
207+ stopf(" Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)" ,
208+ paste(batch [[target ]]$ shape [- 1L ], collapse = " , " ))
209+ }
210+
211+ dataset_shapes = get_or_check_dataset_shapes(x , dataset_shapes )
212+ be = as_data_backend(x , dataset_shapes , converter = converter , cache = cache , chunk_size = chunk_size )
213+ as_task_regr(be , target = target , ... )
145214}
146215
147216# ' @export
@@ -177,4 +246,4 @@ check_lazy_tensors_backend = function(be, candidates, visited = character()) {
177246 }
178247 union(visited , intersect(candidates , be $ colnames ))
179248 }
180- }
249+ }
0 commit comments