Skip to content

Commit f70fc60

Browse files
authored
feat: Allow for more flexible shapes (#396)
Previously, only the batch dimension was allowed to be `NA`. However, this assumption is too restrictive, as it does not allow for transformer-based architectures where the sequence dimension is unknown. This PR changes this and `NA`s can now be at every position of the `shape`. However, many `nn()` operators expect only the batch dimension to be unknown. Therefore, the `only_batch_unknown` argument was added to `PipeOpTorch` that is `TRUE` by default and needs to be overwritten. E.g. `PipeOpTorchLinear` can handle `NA`s, as long as they are now in the last dimension. Other operators that can handle it are, e.g., activation functions or in the future the multihead-attention module. This PR also improved the shape inference and fixed some other small bugs. TODOs: * [x] the `infer_shapes()` method must be adjusted. Now, we replace all `NA`s dimensions with an arbitrary dimension. we should also do this twice to check whether the results are compatible * [x] check all occurences of `assert_shape()` and verify that we no longer make the assumption that only the batch dimension is `NA` * [x] check the `PipeOpTorch` object implementations that they err gracefully --> `PipeOpTorch` now need to indicate whether they can handle `NA`s that are not in the batch dimension. * [x] Update the documentation on the shape
1 parent 1392aba commit f70fc60

28 files changed

+404
-109
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ export(auto_device)
171171
export(batchgetter_categ)
172172
export(batchgetter_num)
173173
export(callback_set)
174+
export(infer_shapes)
174175
export(ingress_categ)
175176
export(ingress_ltnsr)
176177
export(ingress_num)

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* feat: TabResNet learner now supports lazy tensors.
1717
* feat: The `LearnerTorch` base class now supports the private method `$.ingress_tokens(task, param_vals)`
1818
for generating the `torch::dataset`.
19+
* feat: Shapes can now have multiple `NA`s and not only the batch dimension can be missing. However, most `nn()` operators still expect only one missing values and will throw an error if multiple dimensions are unknown.
1920

2021
# mlr3torch 0.2.1
2122

R/DataDescriptor.R

+20-6
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,26 @@ assert_compatible_shapes = function(shapes, dataset) {
247247
}
248248

249249
iwalk(shapes, function(dataset_shape, name) {
250-
if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) {
251-
expected_shape = example[[name]]$shape
252-
expected_shape[1] = NA
253-
stopf(paste0("First batch from dataset is incompatible with the provided shape of %s:\n",
254-
"* Provided shape: %s.\n* Expected shape: %s."), name,
255-
shape_to_str(unname(shapes[name])), shape_to_str(list(expected_shape)))
250+
if (is.null(dataset_shape)) {
251+
return(NULL)
252+
}
253+
shape_specified = shapes[[name]]
254+
shape_example = example[[name]]$shape
255+
if (length(shape_specified) != length(shape_example)) {
256+
stopf("The specified number of dimensions for element '%s' is %s, but the dataset returned %s",
257+
name, length(shape_specified), length(shape_example))
258+
}
259+
260+
if (all(is.na(shape_specified))) {
261+
# compatible with any shape
262+
return(NULL)
263+
}
264+
265+
shape_example[is.na(shape_specified)] = NA
266+
if (!test_equal(shape_specified, shape_example)) {
267+
stopf(paste0("First example batch from dataset is incompatible with the provided shape of %s:\n",
268+
"* Observed shape: %s.\n* Specified shape: %s."), name,
269+
shape_to_str(example[[name]]$shape), shape_to_str(shape_specified))
256270
}
257271
})
258272
}

R/PipeOpTaskPreprocTorch.R

+3-6
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,9 @@ create_ps = function(fn) {
404404
#' @param shapes_out (`function` or `NULL` or `"infer"`)\cr
405405
#' The private `.shapes_out(shapes_in, param_vals, task)` method of [`PipeOpTaskPreprocTorch`]
406406
#' (see section Inheriting).
407-
#' Special values are `NULL` and `infer`:
407+
#' Special values are `NULL` and `"infer"`:
408408
#' If `NULL`, the output shapes are unknown.
409-
#' If "infer", the output shape function is inferred and calculates the output shapes as follows:
410-
#' For an input shape of (NA, ...) a meta-tensor of shape (1, ...) is created and the preprocessing function is
411-
#' applied. Afterwards the batch dimension (1) is replaced with NA and the shape is returned.
412-
#' If the first dimension is not `NA`, the output shape of applying the preprocessing function is returned.
409+
#' Option `"infer"` uses [`infer_shapes`].
413410
#' Method `"infer"` should be correct in most cases, but might fail in some edge cases.
414411
#' @param param_set ([`ParamSet`][paradox::ParamSet] or `NULL`)\cr
415412
#' The parameter set.
@@ -452,7 +449,7 @@ pipeop_preproc_torch = function(id, fn, shapes_out = NULL, param_set = NULL, pac
452449
# we e.g. want torchvision in suggests, so we cannot already access the function.
453450
if (identical(shapes_out, "infer")) {
454451
shapes_out = crate(function(shapes_in, param_vals, task) {
455-
getFromNamespace("infer_shapes", "mlr3torch")(shapes_in = shapes_in, param_vals = param_vals, output_names = self$output$name, fn = self$fn, rowwise = self$rowwise, id = self$id)
452+
getFromNamespace("infer_shapes", "mlr3torch")(shapes_in = shapes_in, param_vals = param_vals, output_names = self$output$name, fn = self$fn, rowwise = self$rowwise, id = self$id) # nolint
456453
})
457454
} else if (is.function(shapes_out) || is.null(shapes_out)) {
458455
# nothing to do

R/PipeOpTorch.R

+13-5
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#' the private method `.shape_dependent_params()`.
4141
#' * `.shapes_out(shapes_in, param_vals, task)`\cr
4242
#' (`list()`, `list()`, [`Task`][mlr3::Task] or `NULL`) -> named `list()`\cr
43-
#' This private method gets a list of `numeric` vectors (`shapes_in`), the parameter values (`param_vals`),
43+
#' This private method gets a list of `integer` vectors (`shapes_in`), the parameter values (`param_vals`),
4444
#' as well as an (optional) [`Task`][mlr3::Task].
4545
# The `shapes_in` list indicates the shape of input tensors that will be fed to the module's `$forward()` function.
4646
# The list has one item per input tensor, typically only one.
@@ -49,6 +49,10 @@
4949
#' The output shapes must be in the same order as the output names of the `PipeOp`.
5050
#' In case the output shapes depends on the task (as is the case for [`PipeOpTorchHead`]), the function should return
5151
#' valid output shapes (possibly containing `NA`s) if the `task` argument is provided or not.
52+
#' It is important to properly handle the presence of `NA`s in the input shapes.
53+
#' By default (if construction argument `only_batch_unknown` is `TRUE`), only the batch dimension can be `NA`.
54+
#' If you set this to `FALSE`, you need to take other unknown dimensions into account.
55+
#' The method can also throw an error if the input shapes violate some assumptions.
5256
#' * `.shape_dependent_params(shapes_in, param_vals, task)`\cr
5357
#' (`list()`, `list()`) -> named `list()`\cr
5458
#' This private method has the same inputs as `.shapes_out`.
@@ -252,14 +256,19 @@ PipeOpTorch = R6Class("PipeOpTorch",
252256
#' In case there is more than one output channel, the `nn_module` that is constructed by this
253257
#' [`PipeOp`][mlr3pipelines::PipeOp] during training must return a named `list()`, where the names of the list are the
254258
#' names out the output channels. The default is `"output"`.
259+
#' @param only_batch_unknown (`logical(1)`)\cr
260+
#' Whether only the batch dimension can be missing in the input shapes or whether other
261+
#' dimensions can also be unknown.
262+
#' Default is `TRUE`.
255263
initialize = function(id, module_generator, param_set = ps(), param_vals = list(),
256-
inname = "input", outname = "output", packages = "torch", tags = NULL) {
264+
inname = "input", outname = "output", packages = "torch", tags = NULL, only_batch_unknown = TRUE) {
257265
self$module_generator = assert_class(module_generator, "nn_module_generator", null.ok = TRUE)
258266
assert_character(inname, .var.name = "input channel names")
259267
assert_character(outname, .var.name = "output channel names", min.len = 1L)
260268
assert_character(tags, null.ok = TRUE)
261269
assert_character(packages, any.missing = FALSE)
262270

271+
private$.only_batch_unknown = assert_flag(only_batch_unknown)
263272
packages = union(packages, c("mlr3torch", "torch"))
264273
input = data.table(name = inname, train = "ModelDescriptor", predict = "Task")
265274
output = data.table(name = outname, train = "ModelDescriptor", predict = "Task")
@@ -288,17 +297,16 @@ PipeOpTorch = R6Class("PipeOpTorch",
288297
assert_r6(task, "Task", null.ok = TRUE)
289298
if (is.numeric(shapes_in)) shapes_in = list(shapes_in)
290299
# batch dimension can be known or unknown
291-
assert_shapes(shapes_in, unknown_batch = NULL)
300+
assert_shapes(shapes_in, unknown_batch = NULL, only_batch_unknown = private$.only_batch_unknown)
292301
if ("..." %nin% self$input$name) {
293302
assert_true(length(shapes_in) == nrow(self$input),
294303
.var.name = "number of input shapes equal to number of input channels")
295304
}
296305
set_names(private$.shapes_out(shapes_in, self$param_set$get_values(), task = task), self$output$name)
297306
}
298-
299-
# TODO: printer that calls the nn_module's printer
300307
),
301308
private = list(
309+
.only_batch_unknown = TRUE,
302310
.shapes_out = function(shapes_in, param_vals, task) shapes_in,
303311
.shape_dependent_params = function(shapes_in, param_vals, task) param_vals,
304312
.make_module = function(shapes_in, param_vals, task) {

R/PipeOpTorchActivation.R

+34-19
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ PipeOpTorchHardTanh = R6Class("PipeOpTorchHardTanh",
139139
param_set = param_set,
140140
param_vals = param_vals,
141141
module_generator = nn_hardtanh,
142-
tags = "activation"
142+
tags = "activation",
143+
only_batch_unknown = FALSE
143144
)
144145
}
145146
)
@@ -179,7 +180,8 @@ PipeOpTorchLeakyReLU = R6Class("PipeOpTorchLeakyReLU",
179180
param_set = param_set,
180181
param_vals = param_vals,
181182
module_generator = nn_leaky_relu,
182-
tags = "activation"
183+
tags = "activation",
184+
only_batch_unknown = FALSE
183185
)
184186
}
185187
)
@@ -199,7 +201,6 @@ register_po("nn_leaky_relu", PipeOpTorchLeakyReLU)
199201
#' @template pipeop_torch
200202
#' @template pipeop_torch_example
201203
#'
202-
#'
203204
#' @export
204205
PipeOpTorchLogSigmoid = R6Class("PipeOpTorchLogSigmoid",
205206
inherit = PipeOpTorch,
@@ -214,7 +215,8 @@ PipeOpTorchLogSigmoid = R6Class("PipeOpTorchLogSigmoid",
214215
param_set = param_set,
215216
param_vals = param_vals,
216217
module_generator = nn_log_sigmoid,
217-
tags = "activation"
218+
tags = "activation",
219+
only_batch_unknown = FALSE
218220
)
219221
}
220222
)
@@ -254,7 +256,8 @@ PipeOpTorchPReLU = R6Class("PipeOpTorchPReLU",
254256
param_set = param_set,
255257
param_vals = param_vals,
256258
module_generator = nn_prelu,
257-
tags = "activation"
259+
tags = "activation",
260+
only_batch_unknown = FALSE
258261
)
259262
}
260263
)
@@ -274,7 +277,6 @@ register_po("nn_prelu", PipeOpTorchPReLU)
274277
#' @template pipeop_torch
275278
#' @template pipeop_torch_example
276279
#'
277-
#'
278280
#' @export
279281
PipeOpTorchReLU = R6Class("PipeOpTorchReLU",
280282
inherit = PipeOpTorch,
@@ -291,7 +293,8 @@ PipeOpTorchReLU = R6Class("PipeOpTorchReLU",
291293
param_set = param_set,
292294
param_vals = param_vals,
293295
module_generator = nn_relu,
294-
tags = "activation"
296+
tags = "activation",
297+
only_batch_unknown = FALSE
295298
)
296299
}
297300
)
@@ -327,7 +330,8 @@ PipeOpTorchReLU6 = R6Class("PipeOpTorchReLU6",
327330
param_set = param_set,
328331
param_vals = param_vals,
329332
module_generator = nn_relu6,
330-
tags = "activation"
333+
tags = "activation",
334+
only_batch_unknown = FALSE
331335
)
332336
}
333337
)
@@ -370,7 +374,8 @@ PipeOpTorchRReLU = R6Class("PipeOpTorchRReLU",
370374
param_set = param_set,
371375
param_vals = param_vals,
372376
module_generator = nn_rrelu,
373-
tags = "activation"
377+
tags = "activation",
378+
only_batch_unknown = FALSE
374379
)
375380
}
376381
)
@@ -407,7 +412,8 @@ PipeOpTorchSELU = R6Class("PipeOpTorchSELU",
407412
param_set = param_set,
408413
param_vals = param_vals,
409414
module_generator = nn_selu,
410-
tags = "activation"
415+
tags = "activation",
416+
only_batch_unknown = FALSE
411417
)
412418
}
413419
)
@@ -447,7 +453,8 @@ PipeOpTorchCELU = R6Class("PipeOpTorchCELU",
447453
param_set = param_set,
448454
param_vals = param_vals,
449455
module_generator = nn_celu,
450-
tags = "activation"
456+
tags = "activation",
457+
only_batch_unknown = FALSE
451458
)
452459
}
453460
)
@@ -484,7 +491,8 @@ PipeOpTorchGELU = R6Class("PipeOpTorchGELU",
484491
param_set = param_set,
485492
param_vals = param_vals,
486493
module_generator = nn_gelu,
487-
tags = "activation"
494+
tags = "activation",
495+
only_batch_unknown = FALSE
488496
)
489497
}
490498
)
@@ -518,7 +526,8 @@ PipeOpTorchSigmoid = R6Class("PipeOpTorchSigmoid",
518526
param_set = param_set,
519527
param_vals = param_vals,
520528
module_generator = nn_sigmoid,
521-
tags = "activation"
529+
tags = "activation",
530+
only_batch_unknown = FALSE
522531
)
523532
}
524533
)
@@ -557,7 +566,8 @@ PipeOpTorchSoftPlus = R6Class("PipeOpTorchSoftPlus",
557566
param_set = param_set,
558567
param_vals = param_vals,
559568
module_generator = nn_softplus,
560-
tags = "activation"
569+
tags = "activation",
570+
only_batch_unknown = FALSE
561571
)
562572
}
563573
)
@@ -594,7 +604,8 @@ PipeOpTorchSoftShrink = R6Class("PipeOpTorchSoftShrink",
594604
param_set = param_set,
595605
param_vals = param_vals,
596606
module_generator = nn_softshrink,
597-
tags = "activation"
607+
tags = "activation",
608+
only_batch_unknown = FALSE
598609
)
599610
}
600611
)
@@ -627,7 +638,8 @@ PipeOpTorchSoftSign = R6Class("PipeOpTorchSoftSign",
627638
param_set = param_set,
628639
param_vals = param_vals,
629640
module_generator = nn_softsign,
630-
tags = "activation"
641+
tags = "activation",
642+
only_batch_unknown = FALSE
631643
)
632644
}
633645
)
@@ -661,7 +673,8 @@ PipeOpTorchTanh = R6Class("PipeOpTorchTanh",
661673
param_set = param_set,
662674
param_vals = param_vals,
663675
module_generator = nn_tanh,
664-
tags = "activation"
676+
tags = "activation",
677+
only_batch_unknown = FALSE
665678
)
666679
}
667680
)
@@ -695,7 +708,8 @@ PipeOpTorchTanhShrink = R6Class("PipeOpTorchTanhShrink",
695708
param_set = param_set,
696709
param_vals = param_vals,
697710
module_generator = nn_tanhshrink,
698-
tags = "activation"
711+
tags = "activation",
712+
only_batch_unknown = FALSE
699713
)
700714
}
701715
)
@@ -739,7 +753,8 @@ PipeOpTorchThreshold = R6Class("PipeOpTorchThreshold",
739753
param_set = param_set,
740754
param_vals = param_vals,
741755
module_generator = nn_threshold,
742-
tags = "activation"
756+
tags = "activation",
757+
only_batch_unknown = FALSE
743758
)
744759
}
745760
)

R/PipeOpTorchConv.R

-1
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,3 @@ conv_output_shape = function(shape_in, conv_dim, padding, dilation, stride, kern
163163
(if (ceil_mode) base::ceiling else base::floor)((shape_tail + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)
164164
)
165165
}
166-

R/PipeOpTorchIngress.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ PipeOpTorchIngress = R6Class("PipeOpTorchIngress",
119119
#' the output of `Task$data(rows = batch_indices, cols = features)`
120120
#' and it should produce a tensor of shape `shape_out`.
121121
#' @param shape (`integer`)\cr
122-
#' Shape that `batchgetter` will produce. Batch-dimension should be included as `NA`.
122+
#' Shape that `batchgetter` will produce. Batch dimension must be included as `NA` (but other dimensions can also be `NA`, i.e., unknown).
123123
#' @return `TorchIngressToken` object.
124124
#' @family Graph Network
125125
#' @export

R/PipeOpTorchLinear.R

+7-2
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,18 @@ PipeOpTorchLinear = R6Class("PipeOpTorchLinear",
3232
id = id,
3333
param_set = param_set,
3434
param_vals = param_vals,
35-
module_generator = nn_linear
35+
module_generator = nn_linear,
36+
only_batch_unknown = FALSE
3637
)
3738
}
3839
),
3940
private = list(
4041
.shape_dependent_params = function(shapes_in, param_vals, task) {
41-
c(param_vals, list(in_features = tail(shapes_in[[1]], 1)))
42+
d_in = tail(shapes_in[[1]], 1)
43+
if (is.na(d_in)) {
44+
stopf("PipeOpLinear received an input shape where the last dimension is unknown. Please provide an input with a known last dimension.")
45+
}
46+
c(param_vals, list(in_features = d_in))
4247
},
4348
.shapes_out = function(shapes_in, param_vals, task) list(c(head(shapes_in[[1]], -1), param_vals$out_features))
4449
)

R/PipeOpTorchReshape.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#' This internally calls [`torch::torch_reshape()`] with the given `shape`.
66
#' @section Parameters:
77
#' * `shape` :: `integer(1)`\cr
8-
#' The desired output shape. Unknown dimension (one at most) can either be specified as `-1` or `NA`.
8+
#' The desired output shape. Unknown dimension (one at most) can either be specified as `-1`.
99
#' @templateVar id nn_reshape
1010
#' @template pipeop_torch_channels_default
1111
#' @template pipeop_torch

R/lazy_tensor.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ is_lazy_tensor = function(x) {
231231
#' Is not cloned, so should be cloned beforehand.
232232
#' @param shape (`integer()` or `NULL`)\cr
233233
#' The shape of the lazy tensor.
234+
#' `NA`s indicate dimensions where the shape is not known.
234235
#' @param shape_predict (`integer()` or `NULL`)\cr
235236
#' The shape of the lazy tensor if it was applied during `$predict()`.
236237
#'
@@ -356,4 +357,4 @@ rep_len.lazy_tensor = function(x, ...) {
356357
#' lazy_shape(lt)
357358
lazy_shape = function(x) {
358359
dd(x)$pointer_shape
359-
}
360+
}

0 commit comments

Comments
 (0)