You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Previously, only the batch dimension was allowed to be `NA`.
However, this assumption is too restrictive, as it does not allow for transformer-based architectures where the sequence dimension is unknown.
This PR changes this and `NA`s can now be at every position of the `shape`.
However, many `nn()` operators expect only the batch dimension to be unknown.
Therefore, the `only_batch_unknown` argument was added to `PipeOpTorch` that is `TRUE` by default and needs to be overwritten.
E.g. `PipeOpTorchLinear` can handle `NA`s, as long as they are now in the last dimension.
Other operators that can handle it are, e.g., activation functions or in the future the multihead-attention module.
This PR also improved the shape inference and fixed some other small bugs.
TODOs:
* [x] the `infer_shapes()` method must be adjusted. Now, we replace all `NA`s dimensions with an arbitrary dimension.
we should also do this twice to check whether the results are compatible
* [x] check all occurences of `assert_shape()` and verify that we no longer make the assumption that only the batch dimension is `NA`
* [x] check the `PipeOpTorch` object implementations that they err gracefully --> `PipeOpTorch` now need to indicate whether they can handle `NA`s that are not in the batch dimension.
* [x] Update the documentation on the shape
Copy file name to clipboardExpand all lines: NEWS.md
+1
Original file line number
Diff line number
Diff line change
@@ -16,6 +16,7 @@
16
16
* feat: TabResNet learner now supports lazy tensors.
17
17
* feat: The `LearnerTorch` base class now supports the private method `$.ingress_tokens(task, param_vals)`
18
18
for generating the `torch::dataset`.
19
+
* feat: Shapes can now have multiple `NA`s and not only the batch dimension can be missing. However, most `nn()` operators still expect only one missing values and will throw an error if multiple dimensions are unknown.
0 commit comments