Skip to content

Add feature that allows to train a LearnerTorch on task where the target is also a lazy tensor #357

Open
@sebffischer

Description

@sebffischer

The general approach would be to:

  • Define a TaskClassifTorch where the target feature type is a lazy_tensor (a long)
  • Define a TaskRegrTorch where the target feature type is a lazy_tensor (float)

Because we don't want to reimplement everything (measures, learners), we need a way to convert this to a TaskClassif and TaskRegr respectively.

For this we need a custom DataBackend that returns the target as a numeric (regr) or factor (classif).
However, this DataBackendTorch should also have a method to directly retrieve the underlying tensor data which will be used by the LearnerTorch when iterating the batches.

We then need converters

as_task_classif.TaskClassifTorch
as_task_classif.TaskRegrTorch

The whole code for e.g. using torchvision::dataset_mnist() would then be:

ds = dataset_mnist()
task = as_task_classif(ds, target = "y", input_shapes = list(...))
learner = lrn("classif.alexnet")
learner$train(task)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions