77# ' `__<layer>`.
88# '
99# ' @section Parameters:
10- # ' The parameters available for the block itself , as well as
10+ # ' The parameters available for the provided `block` , as well as
1111# ' * `n_blocks` :: `integer(1)`\cr
1212# ' How often to repeat the block.
13+ # ' * `trafo` :: `function(i, param_vals, param_set) -> list()`\cr
14+ # ' A function that allows to transform the parameters vaues of each layer (`block`).
15+ # ' Here,
16+ # ' * `i` :: `integer(1)`\cr
17+ # ' is the index of the layer, ranging from `1` to `n_blocks`.
18+ # ' * `param_vals` :: named `list()`\cr
19+ # ' are the parameter values of the layer `i`.
20+ # ' * `param_set` :: [`ParamSet`][paradox::ParamSet]\cr
21+ # ' is the parameter set of the whole `PipeOpTorchBlock`.
22+ # '
23+ # ' The function must return the modified parameter values for the given layer.
24+ # ' This, e.g., allows for special behavior of the first or last layer.
1325# ' @section Input and Output Channels:
1426# ' The `PipeOp` sets its input and output channels to those from the `block` (Graph)
1527# ' it received during construction.
1628# ' @templateVar id nn_block
1729# ' @template pipeop_torch
1830# ' @export
1931# ' @examplesIf torch::torch_is_installed()
20- # ' block = po("nn_linear") %>>% po("nn_relu")
21- # ' po_block = po("nn_block", block,
22- # ' nn_linear.out_features = 10L, n_blocks = 3)
23- # ' network = po("torch_ingress_num") %>>%
24- # ' po_block %>>%
25- # ' po("nn_head") %>>%
26- # ' po("torch_loss", t_loss("cross_entropy")) %>>%
27- # ' po("torch_optimizer", t_opt("adam")) %>>%
28- # ' po("torch_model_classif",
29- # ' batch_size = 50,
30- # ' epochs = 3)
32+ # ' # repeat a simple linear layer with ReLU activation 3 times, but set the bias for the last
33+ # ' # layer to `FALSE`
34+ # ' block = nn("linear") %>>% nn("relu")
3135# '
32- # ' task = tsk("iris")
33- # ' network$train(task)
36+ # ' blocks = nn("block", block,
37+ # ' linear.out_features = 10L, linear.bias = TRUE, n_blocks = 3,
38+ # ' trafo = function(i, param_vals, param_set) {
39+ # ' if (i == param_set$get_values()$n_blocks) {
40+ # ' param_vals$linear.bias = FALSE
41+ # ' }
42+ # ' param_vals
43+ # ' })
44+ # ' graph = po("torch_ingress_num") %>>%
45+ # ' blocks %>>%
46+ # ' nn("head")
47+ # ' md = graph$train(tsk("iris"))[[1L]]
48+ # ' network = model_descriptor_to_module(md)
49+ # ' network
3450PipeOpTorchBlock = R6Class(" PipeOpTorchBlock" ,
3551 inherit = PipeOpTorch ,
3652 public = list (
@@ -44,8 +60,12 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
4460 initialize = function (block , id = " nn_block" , param_vals = list ()) {
4561 private $ .block = as_graph(block )
4662 private $ .param_set_base = ps(
47- n_blocks = p_int(lower = 0L , tags = c(" train" , " required" ))
63+ n_blocks = p_int(lower = 0L , tags = c(" train" , " required" )),
64+ trafo = p_uty(tags = " train" , custom_check = crate(function (x ) {
65+ check_function(x , args = c(" i" , " param_vals" , " param_set" ))
66+ }))
4867 )
68+
4969 super $ initialize(
5070 id = id ,
5171 param_vals = param_vals ,
@@ -68,11 +88,18 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
6888 private = list (
6989 .block = NULL ,
7090 .make_graph = function (block , n_blocks ) {
91+ trafo = self $ param_set $ get_values()$ trafo
7192 graph = block
72- graph $ update_ids(prefix = paste0(self $ id , " ." ))
73- graphs = c(list (graph ), replicate(n_blocks - 1L , graph $ clone(deep = TRUE )))
93+ graphs = c(replicate(n_blocks , graph $ clone(deep = TRUE )))
94+ if (! is.null(trafo )) {
95+ param_vals = map(graphs , function (graph ) graph $ param_set $ get_values())
96+ walk(seq_along(param_vals ), function (i ) {
97+ vals = trafo(i = i , param_vals = param_vals [[i ]], param_set = self $ param_set )
98+ graphs [[i ]]$ param_set $ values = vals
99+ })
100+ }
74101 lapply(seq_len(n_blocks ), function (i ) {
75- graphs [[i ]]$ update_ids(postfix = paste0(" __" , i ))
102+ graphs [[i ]]$ update_ids(prefix = paste0( self $ id , " . " ), postfix = paste0(" __" , i ))
76103 })
77104 Reduce(`%>>%` , graphs )
78105 },
@@ -112,10 +139,10 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
112139 map(mdouts , " pointer_shape" )
113140 },
114141 .train = function (inputs ) {
115- if (self $ param_set $ values $ n_blocks == 0L ) {
142+ param_vals = self $ param_set $ get_values()
143+ if (param_vals $ n_blocks == 0L ) {
116144 return (inputs )
117145 }
118- param_vals = self $ param_set $ get_values(tags = " train" )
119146 block = private $ .block $ clone(deep = TRUE )
120147 graph = private $ .make_graph(block , param_vals $ n_blocks )
121148 inputs = set_names(inputs , graph $ input $ name )
@@ -130,4 +157,4 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
130157
131158
132159# ' @include aaa.R
133- register_po(" nn_block" , PipeOpTorchBlock , metainf = list (block = as_graph(po(" nop" ))))
160+ register_po(" nn_block" , PipeOpTorchBlock , metainf = list (block = as_graph(po(" nop" ))))
0 commit comments