Skip to content

Jax version of VAEMixin #2779

@justjhong

Description

@justjhong

Should be easy to reimplement VAEMixin for JAX models. Will require a whole new class since the forward pass call is completely different.

Example implementation of get_reconstruction_error:

 def get_reconstruction_error(
        self,
        adata: AnnData | None = None,
        indices: list[int] | None = None,
        batch_size: int | None = None,
        **kwargs,
    ) -> dict[str, float]:
        adata = self._validate_anndata(adata)
        dataloader = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True
        )

        reconstruction_loss_sum = 0.0
        for batch in dataloader:
            vars_in = {"params": self.module.params, **self.module.state}
            outputs = self.module.apply(vars_in, batch, rngs=self.module.rngs, **kwargs)
            rec_loss_output = outputs[2].reconstruction_loss_sum.item()
            reconstruction_loss_sum += rec_loss_output

        return -(reconstruction_loss_sum / len(dataloader.dataset))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions