-
-
Notifications
You must be signed in to change notification settings - Fork 8
Allow virtual lazy tensors as targets in classification and regression #386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@tdhock the PR is WIP, but I was wondering whether you can give some feedback on whether this seems to be useful to you and whether the API is intuitive? |
Hi thanks for the invite to review. |
@@ -11,6 +11,8 @@ | |||
This means that for binary classification tasks, `t_loss("cross_entropy")` now generates | |||
`nn_bce_with_logits_loss` instead of `nn_cross_entropy_loss`. | |||
This also came with a reparametrization of the `t_loss("cross_entropy")` loss (thanks to @tdhock, #374). | |||
* fix: `NA` is now a valid shape for lazy tensors. | |||
* feat: `lazy_tensor`s of length 0 can now be materialized. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since I have not used lazy tensors very much, it would help to see a more detailed description of changes here in NEWS. I also would expect some changes to https://mlr3torch.mlr-org.com/articles/lazy_tensor.html but I do not see any in this PR yet.
What is the typical use case which motivates this PR?
This PR adds an experimental feature that allows to convert a
torch::dataset
to anmlr3::Task
.Essentially, the
torch::dataset
is converted to adata.table
consisting only oflazy_tensor
columns (including the target column).In order to make this compatible with the
mlr3
API (measures etc.), it is necessary to provide a converter for the target column that converts from thetorch_tensor
to the associated R type:When accessing the data from the task, the
lazy_tensor
columns for those columns for which a converter exists arematerialize()
d and the converter is applied, making it seem like this is just a standardnumeric()
.However,
LearnerTorch
avoids the conversion and can directly load the target tensors (as defined by thetensor_dataset
above) during training.Because the individual batches can only be loaded as a whole, this means that some data-access is more expensive.
E.g.,
task$truth(1:10)
needs to load all10
batches even though we are only interested in the target.For this reason, some operations are disallowed, such as target transformations or adding new rows to the task:
Furthermore, converted columns are cached, which is demonstrated below.
On the second access to head, the counter of the
dataset
is not incremented and hence$.getbatch()
was not called, but instead loaded from the cache.Created on 2025-04-17 with reprex v2.1.1
Internally, this works via the
DataBackendLazyTensors
(TODO: describe this)