Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 42 additions & 29 deletions src/scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,52 +236,65 @@ def __init__(

encoder_cat_list = cat_list if encode_covariates else None
_extra_encoder_kwargs = extra_encoder_kwargs or {}
z_encoder_kwargs = {
"n_cat_list": encoder_cat_list,
"n_layers": n_layers,
"n_hidden": n_hidden,
"dropout_rate": dropout_rate,
"distribution": latent_distribution,
"inject_covariates": deeply_inject_covariates,
"use_batch_norm": use_batch_norm_encoder,
"use_layer_norm": use_layer_norm_encoder,
"var_activation": var_activation,
"return_dist": True,
}
z_encoder_kwargs.update(_extra_encoder_kwargs)
self.z_encoder = Encoder(
n_input_encoder,
n_latent,
n_cat_list=encoder_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
inject_covariates=deeply_inject_covariates,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
var_activation=var_activation,
return_dist=True,
**_extra_encoder_kwargs,
**z_encoder_kwargs,
)
# l encoder goes from n_input-dimensional data to 1-d library size
# n_layers is fixed to 1 for the library encoder
l_encoder_extra_kwargs = {
k: v for k, v in _extra_encoder_kwargs.items() if k != "n_layers"
}
l_encoder_kwargs = {
"n_layers": 1,
"n_cat_list": encoder_cat_list,
"n_hidden": n_hidden,
"dropout_rate": dropout_rate,
"inject_covariates": deeply_inject_covariates,
"use_batch_norm": use_batch_norm_encoder,
"use_layer_norm": use_layer_norm_encoder,
"var_activation": var_activation,
"return_dist": True,
}
l_encoder_kwargs.update(l_encoder_extra_kwargs)
self.l_encoder = Encoder(
n_input_encoder,
1,
n_layers=1,
n_cat_list=encoder_cat_list,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
inject_covariates=deeply_inject_covariates,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
var_activation=var_activation,
return_dist=True,
**_extra_encoder_kwargs,
**l_encoder_kwargs,
)
n_input_decoder = n_latent + n_continuous_cov
if self.batch_representation == "embedding":
n_input_decoder += batch_dim

_extra_decoder_kwargs = extra_decoder_kwargs or {}
decoder_kwargs = {
"n_cat_list": cat_list,
"n_layers": n_layers,
"n_hidden": n_hidden,
"inject_covariates": deeply_inject_covariates,
"use_batch_norm": use_batch_norm_decoder,
"use_layer_norm": use_layer_norm_decoder,
"scale_activation": "softplus" if use_size_factor_key else "softmax",
}
decoder_kwargs.update(_extra_decoder_kwargs)
self.decoder = DecoderSCVI(
n_input_decoder,
n_input,
n_cat_list=cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
inject_covariates=deeply_inject_covariates,
use_batch_norm=use_batch_norm_decoder,
use_layer_norm=use_layer_norm_decoder,
scale_activation="softplus" if use_size_factor_key else "softmax",
**_extra_decoder_kwargs,
**decoder_kwargs,
)

def _get_inference_input(
Expand Down
11 changes: 11 additions & 0 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,14 @@ def test_scvi_mlflow(
mlflow_log_artifact(
model_path + "/" + run_name + "_" + SAVE_KEYS.MODEL_FNAME, run_id=model.run_id
)


def test_scvi_asymmetric():
adata = synthetic_iid()
SCVI.setup_anndata(
adata,
batch_key="batch",
labels_key="labels",
)
model = SCVI(adata, extra_encoder_kwargs={"n_layers": 4}, extra_decoder_kwargs={"n_layers": 1})
model.train(1)
Loading