Skip to content

Commit 2a0337b

Browse files
committed
...
1 parent c59135f commit 2a0337b

11 files changed

+99
-26
lines changed

R/DataDescriptor.R

+5-1
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,14 @@ assert_compatible_shapes = function(shapes, dataset) {
255255
if (length(shape_specified) != length(shape_example)) {
256256
stopf("The specified number of dimensions for element '%s' is %s, but the dataset returned %s",
257257
name, length(shape_specified), length(shape_example))
258+
}
258259

260+
if (all(is.na(shape_specified))) {
261+
# compatible with any shape
262+
return(NULL)
259263
}
260-
shape_example[is.na(shape_specified)] = NA
261264

265+
shape_example[is.na(shape_specified)] = NA
262266
if (!test_equal(shape_specified, shape_example)) {
263267
stopf(paste0("First example batch from dataset is incompatible with the provided shape of %s:\n",
264268
"* Observed shape: %s.\n* Specified shape: %s."), name,

R/PipeOpTorch.R

-2
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,6 @@ PipeOpTorch = R6Class("PipeOpTorch",
301301
}
302302
set_names(private$.shapes_out(shapes_in, self$param_set$get_values(), task = task), self$output$name)
303303
}
304-
305-
# TODO: printer that calls the nn_module's printer
306304
),
307305
private = list(
308306
.only_batch_unknown = TRUE,

R/PipeOpTorchActivation.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ PipeOpTorchReLU = R6Class("PipeOpTorchReLU",
293293
param_set = param_set,
294294
param_vals = param_vals,
295295
module_generator = nn_relu,
296-
tags = "activation"
296+
tags = "activation",
297+
only_batch_unknown = FALSE
297298
)
298299
}
299300
)

R/PipeOpTorchIngress.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ PipeOpTorchIngress = R6Class("PipeOpTorchIngress",
119119
#' the output of `Task$data(rows = batch_indices, cols = features)`
120120
#' and it should produce a tensor of shape `shape_out`.
121121
#' @param shape (`integer`)\cr
122-
#' Shape that `batchgetter` will produce. Batch-dimension should be included as `NA`.
122+
#' Shape that `batchgetter` will produce. At least the batch dimension should be included as `NA`.
123123
#' @return `TorchIngressToken` object.
124124
#' @family Graph Network
125125
#' @export

R/PipeOpTorchLinear.R

+5-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ PipeOpTorchLinear = R6Class("PipeOpTorchLinear",
3939
),
4040
private = list(
4141
.shape_dependent_params = function(shapes_in, param_vals, task) {
42-
c(param_vals, list(in_features = tail(shapes_in[[1]], 1)))
42+
d_in = tail(shapes_in[[1]], 1)
43+
if (is.na(d_in)) {
44+
stopf("PipeOpLinear received an input shape where the last dimension is unknown. Please provide a known shape.")
45+
}
46+
c(param_vals, list(in_features = d_in))
4347
},
4448
.shapes_out = function(shapes_in, param_vals, task) list(c(head(shapes_in[[1]], -1), param_vals$out_features))
4549
)

R/utils.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ uniqueify = function(new, existing) {
148148

149149
shape_to_str = function(x) {
150150
assert(test_list(x) || test_integerish(x) || is.null(x))
151-
if (is.numeric(x)) { # single shape
151+
if (test_integerish(x)) { # single shape
152152
return(sprintf("(%s)", paste0(x, collapse = ",")))
153153
}
154154
if (is.null(x)) {

tests/testthat/test_DataDescriptor.R

+16-14
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,22 @@ test_that("infer shapes from dataset", {
6868
expect_error(DataDescriptor$new(ds), "must return a named list of tensors")
6969
})
7070

71-
#test_that("assert_compatible_shapes", {
72-
# ds = make_dataset(list(x = c(2, 3)), getbatch = TRUE)
73-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3)), ds), regexp = NA)
74-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 1)), ds), regexp = "(NA,2,1)")
75-
# ds = make_dataset(list(x = c(2, 3)), getbatch = FALSE)
76-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3)), ds), regexp = NA)
77-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 1)), ds), regexp = "(NA,2,1)")
78-
# ds = make_dataset(list(x = c(2, 3), y = 1), getbatch = TRUE)
79-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 1)), ds), regexp = NA)
80-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 2)), ds), regexp = "shape of y")
81-
# ds = make_dataset(list(x = c(2, 3), y = 1), getbatch = FALSE)
82-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 1)), ds), regexp = NA)
83-
# expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 2)), ds), regexp = "shape of y")
84-
#})
71+
test_that("assert_compatible_shapes", {
72+
ds = make_dataset(list(x = c(2, 3)), getbatch = TRUE)
73+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3)), ds), regexp = NA)
74+
expect_error(assert_compatible_shapes(list(x = c(NA, NA, NA)), ds), regexp = NA)
75+
expect_error(assert_compatible_shapes(list(x = c(NA, NA, NA, NA)), ds), "returned 3")
76+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 1)), ds), regexp = "(NA,2,1)")
77+
ds = make_dataset(list(x = c(2, 3)), getbatch = FALSE)
78+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3)), ds), regexp = NA)
79+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 1)), ds), regexp = "(NA,2,1)")
80+
ds = make_dataset(list(x = c(2, 3), y = 1), getbatch = TRUE)
81+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 1)), ds), regexp = NA)
82+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 2)), ds), regexp = "shape of y")
83+
ds = make_dataset(list(x = c(2, 3), y = 1), getbatch = FALSE)
84+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 1)), ds), regexp = NA)
85+
expect_error(assert_compatible_shapes(list(x = c(NA, 2, 3), y = c(NA, 2)), ds), regexp = "shape of y")
86+
})
8587

8688
test_that("as_data_descriptor", {
8789
ds = make_dataset(list(x = 1))

tests/testthat/test_PipeOpTorch.R

+30-1
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,33 @@ test_that("only_batch_unknown", {
146146
expect_equal(obj$shapes_out(list(c(NA, NA, 1))), list(output = c(NA, NA, 10)))
147147
obj$.__enclos_env__$private$.only_batch_unknown = TRUE
148148
expect_error(obj$shapes_out(list(c(NA, NA, 1))), regexp = "Invalid shape: (NA,NA,1)", fixed = TRUE)
149-
})
149+
})
150+
151+
test_that("NA in second dimension", {
152+
ds = dataset(
153+
initialize = function() {
154+
self$xs = lapply(1:10, function(i) torch_randn(sample(1:10, 1), 3))
155+
},
156+
.getitem = function(i) {
157+
list(x = self$xs[[i]])
158+
},
159+
.length = function() {
160+
length(self$xs)
161+
}
162+
)()
163+
164+
task = as_task_regr(data.table(
165+
x = as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, NA, 3))),
166+
y = rnorm(10)
167+
), target = "y", id = "test")
168+
169+
graph = po("torch_ingress_ltnsr") %>>% po("nn_linear", out_features = 10)
170+
171+
md = graph$train(task)[[1L]]
172+
173+
expect_equal(md$pointer_shape, c(NA, NA, 10))
174+
175+
net = model_descriptor_to_module(md)
176+
expect_equal(net(torch_randn(1, 2, 3))$shape, c(1, 2, 10))
177+
expect_equal(net(torch_randn(2, 1, 3))$shape, c(2, 1, 10))
178+
})

tests/testthat/test_PipeOpTorchLinear.R

+33
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,36 @@ test_that("PipeOpTorchLinear paramtest", {
1111
res = expect_paramset(po_linear, nn_linear, exclude = "in_features")
1212
expect_paramtest(res)
1313
})
14+
15+
test_that("NA in second dimension", {
16+
ds = dataset(
17+
initialize = function() {
18+
self$xs = lapply(1:10, function(i) torch_randn(sample(1:10, 1), 10))
19+
},
20+
.getitem = function(i) {
21+
list(x = self$xs[[i]])
22+
},
23+
.length = function() {
24+
length(self$xs)
25+
}
26+
)()
27+
28+
task = as_task_regr(data.table(
29+
x = as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, NA, NA))),
30+
y = rnorm(10)
31+
), target = "y", id = "test")
32+
33+
graph = po("torch_ingress_ltnsr") %>>% po("nn_linear", out_features = 10)
34+
35+
expect_error(graph$train(task), "Please provide a known shape")
36+
37+
task = as_task_regr(data.table(
38+
x = as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, NA, 10))),
39+
y = rnorm(10)
40+
), target = "y", id = "test")
41+
42+
md = graph$train(task)[[1L]]
43+
expect_equal(md$pointer_shape, c(NA, NA, 10))
44+
net = model_descriptor_to_module(md)
45+
expect_equal(net(torch_randn(1, 2, 10))$shape, c(1, 2, 10))
46+
})

tests/testthat/test_shape.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,27 @@ test_that("infer_shapes works", {
5151

5252
# names
5353
expect_equal(
54-
names(infer_shapes(list(x = c(NA, 4)), list(), output_names = "out", identity, TRUE)),
54+
names(infer_shapes(list(x = c(NA, 4)), list(), output_names = "out", identity, TRUE, "a")),
5555
"out"
5656
)
5757

5858
# multiple inputs
5959
expect_equal(
60-
infer_shapes(list(x = c(NA, 3, 4), y = c(NA, 3)), list(), output_names = c("out1", "out2"), function(x) x[.., 1:2], TRUE), # nolint
60+
infer_shapes(list(x = c(NA, 3, 4), y = c(NA, 3)), list(), output_names = c("out1", "out2"), function(x) x[.., 1:2], TRUE, "a"), # nolint
6161
list(
6262
out1 = c(NA, 3, 2),
6363
out2 = c(NA, 2)
6464
)
6565
)
6666
# param_vals
6767
expect_equal(
68-
infer_shapes(list(x = c(NA, 4)), fn = function(x, d) x[, d], param_vals = list(d = 1:2), output_names = "out", rowwise = FALSE), # nolint
68+
infer_shapes(list(x = c(NA, 4)), fn = function(x, d) x[, d], param_vals = list(d = 1:2), output_names = "out", rowwise = FALSE, "a"), # nolint
6969
list(
7070
out = c(NA, 2)
7171
)
7272
)
7373
expect_equal(
74-
infer_shapes(list(x = c(NA, 4)), fn = function(x, d) x[, d], param_vals = list(d = 1:3), output_names = "out", rowwise = FALSE), # nolint
74+
infer_shapes(list(x = c(NA, 4)), fn = function(x, d) x[, d], param_vals = list(d = 1:3), output_names = "out", rowwise = FALSE, "a"), # nolint
7575
list(
7676
out = c(NA, 3)
7777
)

tests/testthat/test_utils.R

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ test_that("order_named_args works", {
5757
expect_error(order_named_args(function(y, ..., x) NULL, list(y = 4, 2, 3, x = 1)), regexp = "`...` must")
5858
})
5959
test_that("shape_to_str works", {
60+
expect_equal(shape_to_str(c(NA, NA)), "(NA,NA)")
6061
expect_equal(shape_to_str(1), "(1)")
6162
expect_equal(shape_to_str(c(1, 2)), "(1,2)")
6263
expect_equal(shape_to_str(NULL), "(<unknown>)")
@@ -67,3 +68,4 @@ test_that("shape_to_str works", {
6768

6869
md = po("torch_ingress_ltnsr")$train(list(nano_imagenet()))[[1L]]
6970
})
71+

0 commit comments

Comments
 (0)