Skip to content

Training models with torch.Tensor input #2736

@j-bac

Description

@j-bac

Is your feature request related to a problem? Please describe.
It is not currently straightforward to pass external dataloaders to train a model. In particular, loading torch.Tensor data and directly feeding it to a model as input doesn't seem possible because scvi.data._utils._check_nonnegative_integers does not handle torch.Tensor.

It would be very useful to be able to feed a custom dataloader, dictionary or AnnData as direct input to model.train() without having to copy torch.Tensor back to numpy or pandas. Maybe this can be implemented using model.train(data_module=data_module) ?

Describe the solution you'd like

import torch
import scanpy as sc
import scvi

counts = torch.randint(0,10,(500, 10))

adata = sc.AnnData(scipy.sparse.csr_matrix(counts.shape), #AnnData does not allow torch.Tensor in .X field
                                 layers={'counts':counts})

scvi.model.SCVI.setup_anndata(adata,layer="counts")
model = scvi.model.SCVI(adata)
model.train()

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions