Skip to content

Commit f6feddd

Browse files
authored
feat(nn_block): parameter trafo (#399)
1 parent f70fc60 commit f6feddd

File tree

6 files changed

+118
-37
lines changed

6 files changed

+118
-37
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
* feat: TabResNet learner now supports lazy tensors.
1717
* feat: The `LearnerTorch` base class now supports the private method `$.ingress_tokens(task, param_vals)`
1818
for generating the `torch::dataset`.
19+
* feat: `nn("block")` (which allows to repeat the same network segment multiple
20+
times) now has an extra argument `trafo`, which allows to modify the
21+
parameter values per layer.
1922
* feat: Shapes can now have multiple `NA`s and not only the batch dimension can be missing. However, most `nn()` operators still expect only one missing values and will throw an error if multiple dimensions are unknown.
2023

2124
# mlr3torch 0.2.1

R/PipeOpTorchBlock.R

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,46 @@
77
#' `__<layer>`.
88
#'
99
#' @section Parameters:
10-
#' The parameters available for the block itself, as well as
10+
#' The parameters available for the provided `block`, as well as
1111
#' * `n_blocks` :: `integer(1)`\cr
1212
#' How often to repeat the block.
13+
#' * `trafo` :: `function(i, param_vals, param_set) -> list()`\cr
14+
#' A function that allows to transform the parameters vaues of each layer (`block`).
15+
#' Here,
16+
#' * `i` :: `integer(1)`\cr
17+
#' is the index of the layer, ranging from `1` to `n_blocks`.
18+
#' * `param_vals` :: named `list()`\cr
19+
#' are the parameter values of the layer `i`.
20+
#' * `param_set` :: [`ParamSet`][paradox::ParamSet]\cr
21+
#' is the parameter set of the whole `PipeOpTorchBlock`.
22+
#'
23+
#' The function must return the modified parameter values for the given layer.
24+
#' This, e.g., allows for special behavior of the first or last layer.
1325
#' @section Input and Output Channels:
1426
#' The `PipeOp` sets its input and output channels to those from the `block` (Graph)
1527
#' it received during construction.
1628
#' @templateVar id nn_block
1729
#' @template pipeop_torch
1830
#' @export
1931
#' @examplesIf torch::torch_is_installed()
20-
#' block = po("nn_linear") %>>% po("nn_relu")
21-
#' po_block = po("nn_block", block,
22-
#' nn_linear.out_features = 10L, n_blocks = 3)
23-
#' network = po("torch_ingress_num") %>>%
24-
#' po_block %>>%
25-
#' po("nn_head") %>>%
26-
#' po("torch_loss", t_loss("cross_entropy")) %>>%
27-
#' po("torch_optimizer", t_opt("adam")) %>>%
28-
#' po("torch_model_classif",
29-
#' batch_size = 50,
30-
#' epochs = 3)
32+
#' # repeat a simple linear layer with ReLU activation 3 times, but set the bias for the last
33+
#' # layer to `FALSE`
34+
#' block = nn("linear") %>>% nn("relu")
3135
#'
32-
#' task = tsk("iris")
33-
#' network$train(task)
36+
#' blocks = nn("block", block,
37+
#' linear.out_features = 10L, linear.bias = TRUE, n_blocks = 3,
38+
#' trafo = function(i, param_vals, param_set) {
39+
#' if (i == param_set$get_values()$n_blocks) {
40+
#' param_vals$linear.bias = FALSE
41+
#' }
42+
#' param_vals
43+
#' })
44+
#' graph = po("torch_ingress_num") %>>%
45+
#' blocks %>>%
46+
#' nn("head")
47+
#' md = graph$train(tsk("iris"))[[1L]]
48+
#' network = model_descriptor_to_module(md)
49+
#' network
3450
PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
3551
inherit = PipeOpTorch,
3652
public = list(
@@ -44,8 +60,12 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
4460
initialize = function(block, id = "nn_block", param_vals = list()) {
4561
private$.block = as_graph(block)
4662
private$.param_set_base = ps(
47-
n_blocks = p_int(lower = 0L, tags = c("train", "required"))
63+
n_blocks = p_int(lower = 0L, tags = c("train", "required")),
64+
trafo = p_uty(tags = "train", custom_check = crate(function(x) {
65+
check_function(x, args = c("i", "param_vals", "param_set"))
66+
}))
4867
)
68+
4969
super$initialize(
5070
id = id,
5171
param_vals = param_vals,
@@ -68,11 +88,18 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
6888
private = list(
6989
.block = NULL,
7090
.make_graph = function(block, n_blocks) {
91+
trafo = self$param_set$get_values()$trafo
7192
graph = block
72-
graph$update_ids(prefix = paste0(self$id, "."))
73-
graphs = c(list(graph), replicate(n_blocks - 1L, graph$clone(deep = TRUE)))
93+
graphs = c(replicate(n_blocks, graph$clone(deep = TRUE)))
94+
if (!is.null(trafo)) {
95+
param_vals = map(graphs, function(graph) graph$param_set$get_values())
96+
walk(seq_along(param_vals), function(i) {
97+
vals = trafo(i = i, param_vals = param_vals[[i]], param_set = self$param_set)
98+
graphs[[i]]$param_set$values = vals
99+
})
100+
}
74101
lapply(seq_len(n_blocks), function(i) {
75-
graphs[[i]]$update_ids(postfix = paste0("__", i))
102+
graphs[[i]]$update_ids(prefix = paste0(self$id, "."), postfix = paste0("__", i))
76103
})
77104
Reduce(`%>>%`, graphs)
78105
},
@@ -112,10 +139,10 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
112139
map(mdouts, "pointer_shape")
113140
},
114141
.train = function(inputs) {
115-
if (self$param_set$values$n_blocks == 0L) {
142+
param_vals = self$param_set$get_values()
143+
if (param_vals$n_blocks == 0L) {
116144
return(inputs)
117145
}
118-
param_vals = self$param_set$get_values(tags = "train")
119146
block = private$.block$clone(deep = TRUE)
120147
graph = private$.make_graph(block, param_vals$n_blocks)
121148
inputs = set_names(inputs, graph$input$name)
@@ -130,4 +157,4 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
130157

131158

132159
#' @include aaa.R
133-
register_po("nn_block", PipeOpTorchBlock, metainf = list(block = as_graph(po("nop"))))
160+
register_po("nn_block", PipeOpTorchBlock, metainf = list(block = as_graph(po("nop"))))

R/nn.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@
1111
#' # is the same as:
1212
#' po2 = nn("linear")
1313
nn = function(.key, ...) {
14-
invoke(po, .obj = paste0("nn_", .key), .args = insert_named(list(id = .key), list(...)))
14+
args = list(...)
15+
if (is.null(args$id)) {
16+
args$id = .key
17+
}
18+
invoke(po, .obj = paste0("nn_", .key), .args = args)
1519
}

man/mlr_pipeops_nn_block.Rd

Lines changed: 32 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_PipeOpTorchBlock.R

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,20 @@ test_that("0 blocks are possible", {
9494
md = po("torch_ingress_num")$train(list(tsk("iris")))[[1L]]
9595
mdout = nn("block", block = nn("linear", out_features = 10), n_blocks = 0)$train(list(md))[[1L]]
9696
expect_equal(mdout$pointer_shape, c(NA, 4))
97-
})
97+
})
98+
99+
test_that("trafo works", {
100+
graph = po("torch_ingress_num") %>>%
101+
po("nn_block", nn("linear", out_features = 10),
102+
n_blocks = 2L,
103+
trafo = function(i, param_vals, param_set) {
104+
if (i == 2) {
105+
param_vals$linear.bias = FALSE
106+
}
107+
param_vals
108+
})
109+
110+
network = model_descriptor_to_module(graph$train(tsk("iris"))[[1L]])
111+
expect_class(network$module_list$`0`$bias, "torch_tensor")
112+
expect_true(is.null(network$module_list$`1`$bias))
113+
})

tests/testthat/test_nn.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,16 @@ test_that("nn works", {
44
expect_class(x, "PipeOpTorchLinear")
55
expect_equal(x$param_set$values$out_features, 3)
66
})
7+
8+
test_that("overwrite id", {
9+
obj = nn("linear", id = "abc")
10+
expect_equal(obj$id, "abc")
11+
})
12+
13+
test_that("unnamed arg", {
14+
graph = po("torch_ingress_num") %>>% nn("block", nn("linear", out_features = 3), n_blocks = 2)
15+
md = graph$train(tsk("iris"))[[1L]]
16+
network = model_descriptor_to_module(md)
17+
expect_equal(network$module_list[[1]]$out_features, 3)
18+
expect_equal(network$module_list[[2]]$out_features, 3)
19+
})

0 commit comments

Comments
 (0)