Skip to content

Commit 4393104

Browse files
authored
improve lazy-tensor printer and converter (#452)
1 parent 9de0863 commit 4393104

25 files changed

+92
-60
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ NeedsCompilation: no
7979
ByteCompile: yes
8080
Encoding: UTF-8
8181
Roxygen: list(markdown = TRUE, r6 = TRUE)
82-
RoxygenNote: 7.3.2.9000
82+
RoxygenNote: 7.3.3
8383
Collate:
8484
'CallbackSet.R'
8585
'aaa.R'

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# mlr3torch (development version)
22

3+
* Feat: Improve `lazy_tensor` printing.
4+
* Fix: Improve consistency in `as_lazy_tensor()` when converting 1D tensors to lazy tensors.
5+
36
# mlr3torch 0.3.2
47

58
## Bug Fixes

R/PipeOpTaskPreprocTorch.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@
9393
#' @examplesIf torch::torch_is_installed()
9494
#' # Creating a simple task
9595
#' d = data.table(
96-
#' x1 = as_lazy_tensor(rnorm(10)),
97-
#' x2 = as_lazy_tensor(rnorm(10)),
98-
#' x3 = as_lazy_tensor(as.double(1:10)),
96+
#' x1 = as_lazy_tensor(matrix(rnorm(10), ncol = 1)),
97+
#' x2 = as_lazy_tensor(matrix(rnorm(10), ncol = 1)),
98+
#' x3 = as_lazy_tensor(matrix(as.double(1:10), ncol = 1)),
9999
#' y = rnorm(10)
100100
#' )
101101
#'

R/lazy_tensor.R

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,27 @@ format.lazy_tensor = function(x, ...) { # nolint
9191
if (!length(x)) return(character(0))
9292
shape = dd(x)$pointer_shape
9393
shape = if (is.null(shape)) {
94-
return(rep("<tnsr[]>", length(x)))
94+
return(rep("<tnsr[?]>", length(x)))
9595
}
9696
shape = paste0(dd(x)$pointer_shape[-1L], collapse = "x")
9797

9898
map_chr(x, function(elt) {
9999
sprintf("<tnsr[%s]>", shape)
100100
})
101101
}
102-
103102
#' @export
104-
print.lazy_tensor = function(x, ...) {
105-
cat(paste0("<ltnsr[", length(x), "]>", "\n", collapse = ""))
106-
if (length(x) == 0) return(invisible(x))
107-
108-
out <- stats::setNames(format(x), names(x))
109-
print(out, quote = FALSE)
103+
print.lazy_tensor = function(x, ...) { # nolint
104+
if (length(x) == 0) {
105+
cat("<ltnsr[len=0]>\n")
106+
return(invisible(x))
107+
}
108+
shape = dd(x)$pointer_shape
109+
if (is.null(shape)) {
110+
cat(sprintf("<ltnsr[len=%d, shapes=unknown]>\n", length(x)))
111+
} else {
112+
shape_str = paste0(shape[-1L], collapse = ",")
113+
cat(sprintf("<ltnsr[len=%d, shapes=(%s)]>\n", length(x), shape_str))
114+
}
110115
invisible(x)
111116
}
112117

@@ -181,9 +186,6 @@ as_lazy_tensor.numeric = function(x, ...) { # nolint
181186

182187
#' @export
183188
as_lazy_tensor.torch_tensor = function(x, ...) { # nolint
184-
if (length(dim(x)) == 1L) {
185-
x = x$unsqueeze(2)
186-
}
187189
ds = dataset(
188190
initialize = function(x) {
189191
self$x = x

man/mlr_pipeops_preproc_torch.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# print/format
2+
3+
Code
4+
lazy_tensor()
5+
Output
6+
<ltnsr[len=0]>
7+
8+
---
9+
10+
Code
11+
as_lazy_tensor(1)
12+
Output
13+
<ltnsr[len=1, shapes=()]>
14+
15+
---
16+
17+
Code
18+
as_lazy_tensor(matrix(1:10, ncol = 1))
19+
Output
20+
<ltnsr[len=10, shapes=(1)]>
21+
22+
---
23+
24+
Code
25+
as_lazy_tensor(ds, dataset_shapes = list(x = NULL))
26+
Output
27+
<ltnsr[len=10, shapes=unknown]>
28+

tests/testthat/helper_autotest.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ expect_torch_callback = function(torch_callback, check_man = TRUE, check_paramse
345345
})
346346

347347
cb = torch_callback$generate()
348-
expect_deep_clone(cb, cb$clone(deep = TRUE))
348+
expect_deep_clone_mlr3torch(cb, cb$clone(deep = TRUE))
349349
}
350350

351351
#' @title Autotest for PipeOpTaskPreprocTorch
@@ -470,7 +470,7 @@ expect_learner_torch = function(learner, task, check_man = TRUE, check_id = TRUE
470470
# state cloning is tested separately
471471
learner1 = learner
472472
learner1$state = NULL
473-
expect_deep_clone(learner1, learner1$clone(deep = TRUE))
473+
expect_deep_clone_mlr3torch(learner1, learner1$clone(deep = TRUE))
474474
rr = resample(task, learner, rsmp("holdout"), store_models = TRUE)
475475
expect_double(rr$aggregate())
476476
checkmate::expect_class(rr, "ResampleResult")

tests/testthat/helper_functions.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ expect_pipeop_class = function(poclass, constargs = list(), check_ps_default_val
108108
expect_pipeop(po, check_ps_default_values = check_ps_default_values)
109109

110110
poclone = po$clone(deep = TRUE)
111-
expect_deep_clone(po, poclone)
111+
expect_deep_clone_mlr3torch(po, poclone)
112112

113113
in_nop = rep(list(NO_OP), po$innum)
114114
in_nonnop = rep(list(NULL), po$innum)
@@ -126,7 +126,7 @@ expect_pipeop_class = function(poclass, constargs = list(), check_ps_default_val
126126
# check again with no_op-trained PO
127127
expect_pipeop(po, check_ps_default_values = check_ps_default_values)
128128
poclone = po$clone(deep = TRUE)
129-
expect_deep_clone(po, poclone)
129+
expect_deep_clone_mlr3torch(po, poclone)
130130

131131
}
132132

tests/testthat/helper_mlr3pipelines.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mlr_helpers = list.files(system.file("testthat", package = "mlr3pipelines"), pat
1111
lapply(mlr_helpers, FUN = source)
1212

1313
# expect that 'one' is a deep clone of 'two'
14-
expect_deep_clone = function(one, two) {
14+
expect_deep_clone_mlr3torch = function(one, two) {
1515
# is equal
1616
expect_equal(one, two)
1717
visited = new.env()

tests/testthat/test_LearnerTorch.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ test_that("deep cloning", {
1111
learner$state$train_task = NULL
1212

1313
learner_cloned = learner$clone(deep = TRUE)
14-
expect_deep_clone(learner, learner_cloned)
14+
expect_deep_clone_mlr3torch(learner, learner_cloned)
1515

1616
network = learner$network
1717
network_cloned = learner_cloned$network
@@ -498,7 +498,7 @@ test_that("col_info is propertly subset when comparing task validity during pred
498498
test_that("deep clone works", {
499499
l1 = lrn("classif.mlp")
500500
l2 = l1$clone(deep = TRUE)
501-
expect_deep_clone(l1, l2)
501+
expect_deep_clone_mlr3torch(l1, l2)
502502
})
503503

504504
test_that("param set is read-only", {
@@ -630,7 +630,7 @@ test_that("param_set source works", {
630630

631631
l1 = l$clone(deep = TRUE)
632632

633-
expect_deep_clone(l, l1)
633+
expect_deep_clone_mlr3torch(l, l1)
634634
l1$param_set$set_values(
635635
a = 17,
636636
epochs = 18,
@@ -744,7 +744,7 @@ test_that("configure loss, optimizer and callbacks after construction", {
744744
cb.checkpoint.freq = 456
745745
)
746746
learner1 = learner$clone(deep = TRUE)
747-
expect_deep_clone(learner, learner1)
747+
expect_deep_clone_mlr3torch(learner, learner1)
748748
expect_equal(learner1$param_set$values$loss.reduction, "mean")
749749
expect_equal(learner1$param_set$values$opt.lr, 123)
750750
expect_equal(learner1$param_set$values$cb.checkpoint.freq, 456)

0 commit comments

Comments
 (0)