Skip to content

feat: Allow for more flexible shapes #396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ export(auto_device)
export(batchgetter_categ)
export(batchgetter_num)
export(callback_set)
export(infer_shapes)
export(ingress_categ)
export(ingress_ltnsr)
export(ingress_num)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* feat: TabResNet learner now supports lazy tensors.
* feat: The `LearnerTorch` base class now supports the private method `$.ingress_tokens(task, param_vals)`
for generating the `torch::dataset`.
* feat: Shapes can now have multiple `NA`s and not only the batch dimension can be missing. However, most `nn()` operators still expect only one missing values and will throw an error if multiple dimensions are unknown.

# mlr3torch 0.2.1

Expand Down
26 changes: 20 additions & 6 deletions R/DataDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,26 @@ assert_compatible_shapes = function(shapes, dataset) {
}

iwalk(shapes, function(dataset_shape, name) {
if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) {
expected_shape = example[[name]]$shape
expected_shape[1] = NA
stopf(paste0("First batch from dataset is incompatible with the provided shape of %s:\n",
"* Provided shape: %s.\n* Expected shape: %s."), name,
shape_to_str(unname(shapes[name])), shape_to_str(list(expected_shape)))
if (is.null(dataset_shape)) {
return(NULL)
}
shape_specified = shapes[[name]]
shape_example = example[[name]]$shape
if (length(shape_specified) != length(shape_example)) {
stopf("The specified number of dimensions for element '%s' is %s, but the dataset returned %s",
name, length(shape_specified), length(shape_example))
}

if (all(is.na(shape_specified))) {
# compatible with any shape
return(NULL)
}

shape_example[is.na(shape_specified)] = NA
if (!test_equal(shape_specified, shape_example)) {
stopf(paste0("First example batch from dataset is incompatible with the provided shape of %s:\n",
"* Observed shape: %s.\n* Specified shape: %s."), name,
shape_to_str(example[[name]]$shape), shape_to_str(shape_specified))
}
})
}
9 changes: 3 additions & 6 deletions R/PipeOpTaskPreprocTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,9 @@ create_ps = function(fn) {
#' @param shapes_out (`function` or `NULL` or `"infer"`)\cr
#' The private `.shapes_out(shapes_in, param_vals, task)` method of [`PipeOpTaskPreprocTorch`]
#' (see section Inheriting).
#' Special values are `NULL` and `infer`:
#' Special values are `NULL` and `"infer"`:
#' If `NULL`, the output shapes are unknown.
#' If "infer", the output shape function is inferred and calculates the output shapes as follows:
#' For an input shape of (NA, ...) a meta-tensor of shape (1, ...) is created and the preprocessing function is
#' applied. Afterwards the batch dimension (1) is replaced with NA and the shape is returned.
#' If the first dimension is not `NA`, the output shape of applying the preprocessing function is returned.
#' Option `"infer"` uses [`infer_shapes`].
#' Method `"infer"` should be correct in most cases, but might fail in some edge cases.
#' @param param_set ([`ParamSet`][paradox::ParamSet] or `NULL`)\cr
#' The parameter set.
Expand Down Expand Up @@ -452,7 +449,7 @@ pipeop_preproc_torch = function(id, fn, shapes_out = NULL, param_set = NULL, pac
# we e.g. want torchvision in suggests, so we cannot already access the function.
if (identical(shapes_out, "infer")) {
shapes_out = crate(function(shapes_in, param_vals, task) {
getFromNamespace("infer_shapes", "mlr3torch")(shapes_in = shapes_in, param_vals = param_vals, output_names = self$output$name, fn = self$fn, rowwise = self$rowwise, id = self$id)
getFromNamespace("infer_shapes", "mlr3torch")(shapes_in = shapes_in, param_vals = param_vals, output_names = self$output$name, fn = self$fn, rowwise = self$rowwise, id = self$id) # nolint
})
} else if (is.function(shapes_out) || is.null(shapes_out)) {
# nothing to do
Expand Down
17 changes: 11 additions & 6 deletions R/PipeOpTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#' the private method `.shape_dependent_params()`.
#' * `.shapes_out(shapes_in, param_vals, task)`\cr
#' (`list()`, `list()`, [`Task`][mlr3::Task] or `NULL`) -> named `list()`\cr
#' This private method gets a list of `numeric` vectors (`shapes_in`), the parameter values (`param_vals`),
#' This private method gets a list of `integer` vectors (`shapes_in`), the parameter values (`param_vals`),
#' as well as an (optional) [`Task`][mlr3::Task].
# The `shapes_in` list indicates the shape of input tensors that will be fed to the module's `$forward()` function.
# The list has one item per input tensor, typically only one.
Expand All @@ -49,6 +49,7 @@
#' The output shapes must be in the same order as the output names of the `PipeOp`.
#' In case the output shapes depends on the task (as is the case for [`PipeOpTorchHead`]), the function should return
#' valid output shapes (possibly containing `NA`s) if the `task` argument is provided or not.
#' It is important to properly handle the presence of `NA`s in the input shapes.
#' * `.shape_dependent_params(shapes_in, param_vals, task)`\cr
#' (`list()`, `list()`) -> named `list()`\cr
#' This private method has the same inputs as `.shapes_out`.
Expand Down Expand Up @@ -252,14 +253,19 @@ PipeOpTorch = R6Class("PipeOpTorch",
#' In case there is more than one output channel, the `nn_module` that is constructed by this
#' [`PipeOp`][mlr3pipelines::PipeOp] during training must return a named `list()`, where the names of the list are the
#' names out the output channels. The default is `"output"`.
#' @param only_batch_unknown (`logical(1)`)\cr
#' Whether only the batch dimension can be missing in the input shapes or whether other
#' dimensions can also be unknown.
#' Default is `TRUE`.
initialize = function(id, module_generator, param_set = ps(), param_vals = list(),
inname = "input", outname = "output", packages = "torch", tags = NULL) {
inname = "input", outname = "output", packages = "torch", tags = NULL, only_batch_unknown = TRUE) {
self$module_generator = assert_class(module_generator, "nn_module_generator", null.ok = TRUE)
assert_character(inname, .var.name = "input channel names")
assert_character(inname, .var.name = "input channel names")
assert_character(outname, .var.name = "output channel names", min.len = 1L)
assert_character(tags, null.ok = TRUE)
assert_character(packages, any.missing = FALSE)

private$.only_batch_unknown = assert_flag(only_batch_unknown)
packages = union(packages, c("mlr3torch", "torch"))
input = data.table(name = inname, train = "ModelDescriptor", predict = "Task")
output = data.table(name = outname, train = "ModelDescriptor", predict = "Task")
Expand Down Expand Up @@ -288,17 +294,16 @@ PipeOpTorch = R6Class("PipeOpTorch",
assert_r6(task, "Task", null.ok = TRUE)
if (is.numeric(shapes_in)) shapes_in = list(shapes_in)
# batch dimension can be known or unknown
assert_shapes(shapes_in, unknown_batch = NULL)
assert_shapes(shapes_in, unknown_batch = NULL, only_batch_unknown = private$.only_batch_unknown)
if ("..." %nin% self$input$name) {
assert_true(length(shapes_in) == nrow(self$input),
.var.name = "number of input shapes equal to number of input channels")
}
set_names(private$.shapes_out(shapes_in, self$param_set$get_values(), task = task), self$output$name)
}

# TODO: printer that calls the nn_module's printer
),
private = list(
.only_batch_unknown = TRUE,
.shapes_out = function(shapes_in, param_vals, task) shapes_in,
.shape_dependent_params = function(shapes_in, param_vals, task) param_vals,
.make_module = function(shapes_in, param_vals, task) {
Expand Down
53 changes: 34 additions & 19 deletions R/PipeOpTorchActivation.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ PipeOpTorchHardTanh = R6Class("PipeOpTorchHardTanh",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_hardtanh,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -179,7 +180,8 @@ PipeOpTorchLeakyReLU = R6Class("PipeOpTorchLeakyReLU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_leaky_relu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand All @@ -199,7 +201,6 @@ register_po("nn_leaky_relu", PipeOpTorchLeakyReLU)
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#'
#' @export
PipeOpTorchLogSigmoid = R6Class("PipeOpTorchLogSigmoid",
inherit = PipeOpTorch,
Expand All @@ -214,7 +215,8 @@ PipeOpTorchLogSigmoid = R6Class("PipeOpTorchLogSigmoid",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_log_sigmoid,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -254,7 +256,8 @@ PipeOpTorchPReLU = R6Class("PipeOpTorchPReLU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_prelu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand All @@ -274,7 +277,6 @@ register_po("nn_prelu", PipeOpTorchPReLU)
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#'
#' @export
PipeOpTorchReLU = R6Class("PipeOpTorchReLU",
inherit = PipeOpTorch,
Expand All @@ -291,7 +293,8 @@ PipeOpTorchReLU = R6Class("PipeOpTorchReLU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_relu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -327,7 +330,8 @@ PipeOpTorchReLU6 = R6Class("PipeOpTorchReLU6",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_relu6,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -370,7 +374,8 @@ PipeOpTorchRReLU = R6Class("PipeOpTorchRReLU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_rrelu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -407,7 +412,8 @@ PipeOpTorchSELU = R6Class("PipeOpTorchSELU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_selu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -447,7 +453,8 @@ PipeOpTorchCELU = R6Class("PipeOpTorchCELU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_celu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -484,7 +491,8 @@ PipeOpTorchGELU = R6Class("PipeOpTorchGELU",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_gelu,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -518,7 +526,8 @@ PipeOpTorchSigmoid = R6Class("PipeOpTorchSigmoid",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_sigmoid,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -557,7 +566,8 @@ PipeOpTorchSoftPlus = R6Class("PipeOpTorchSoftPlus",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_softplus,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -594,7 +604,8 @@ PipeOpTorchSoftShrink = R6Class("PipeOpTorchSoftShrink",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_softshrink,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -627,7 +638,8 @@ PipeOpTorchSoftSign = R6Class("PipeOpTorchSoftSign",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_softsign,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -661,7 +673,8 @@ PipeOpTorchTanh = R6Class("PipeOpTorchTanh",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_tanh,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -695,7 +708,8 @@ PipeOpTorchTanhShrink = R6Class("PipeOpTorchTanhShrink",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_tanhshrink,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down Expand Up @@ -739,7 +753,8 @@ PipeOpTorchThreshold = R6Class("PipeOpTorchThreshold",
param_set = param_set,
param_vals = param_vals,
module_generator = nn_threshold,
tags = "activation"
tags = "activation",
only_batch_unknown = FALSE
)
}
)
Expand Down
1 change: 0 additions & 1 deletion R/PipeOpTorchConv.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,3 @@ conv_output_shape = function(shape_in, conv_dim, padding, dilation, stride, kern
(if (ceil_mode) base::ceiling else base::floor)((shape_tail + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)
)
}

2 changes: 1 addition & 1 deletion R/PipeOpTorchIngress.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ PipeOpTorchIngress = R6Class("PipeOpTorchIngress",
#' the output of `Task$data(rows = batch_indices, cols = features)`
#' and it should produce a tensor of shape `shape_out`.
#' @param shape (`integer`)\cr
#' Shape that `batchgetter` will produce. Batch-dimension should be included as `NA`.
#' Shape that `batchgetter` will produce. Batch dimension must be included as `NA` (but other dimensions can also be `NA`, i.e., unknown).
#' @return `TorchIngressToken` object.
#' @family Graph Network
#' @export
Expand Down
9 changes: 7 additions & 2 deletions R/PipeOpTorchLinear.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@ PipeOpTorchLinear = R6Class("PipeOpTorchLinear",
id = id,
param_set = param_set,
param_vals = param_vals,
module_generator = nn_linear
module_generator = nn_linear,
only_batch_unknown = FALSE
)
}
),
private = list(
.shape_dependent_params = function(shapes_in, param_vals, task) {
c(param_vals, list(in_features = tail(shapes_in[[1]], 1)))
d_in = tail(shapes_in[[1]], 1)
if (is.na(d_in)) {
stopf("PipeOpLinear received an input shape where the last dimension is unknown. Please provide a known shape.")
}
c(param_vals, list(in_features = d_in))
},
.shapes_out = function(shapes_in, param_vals, task) list(c(head(shapes_in[[1]], -1), param_vals$out_features))
)
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTorchReshape.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' This internally calls [`torch::torch_reshape()`] with the given `shape`.
#' @section Parameters:
#' * `shape` :: `integer(1)`\cr
#' The desired output shape. Unknown dimension (one at most) can either be specified as `-1` or `NA`.
#' The desired output shape. Unknown dimension (one at most) can either be specified as `-1`.
#' @templateVar id nn_reshape
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
Expand Down
3 changes: 2 additions & 1 deletion R/lazy_tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ is_lazy_tensor = function(x) {
#' Is not cloned, so should be cloned beforehand.
#' @param shape (`integer()` or `NULL`)\cr
#' The shape of the lazy tensor.
#' `NA`s indicate dimensions where the shape is not known.
#' @param shape_predict (`integer()` or `NULL`)\cr
#' The shape of the lazy tensor if it was applied during `$predict()`.
#'
Expand Down Expand Up @@ -356,4 +357,4 @@ rep_len.lazy_tensor = function(x, ...) {
#' lazy_shape(lt)
lazy_shape = function(x) {
dd(x)$pointer_shape
}
}
Loading