From 2984010eaabf91d7f68bcbf4323b875558712c67 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 16:31:20 -0800 Subject: [PATCH 1/5] Initial mrvi script --- tools/models/scvi/mrvi-config.yaml | 37 +++++++++++++++ tools/models/scvi/mrvi.py | 75 ++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 tools/models/scvi/mrvi-config.yaml create mode 100644 tools/models/scvi/mrvi.py diff --git a/tools/models/scvi/mrvi-config.yaml b/tools/models/scvi/mrvi-config.yaml new file mode 100644 index 000000000..374c5507a --- /dev/null +++ b/tools/models/scvi/mrvi-config.yaml @@ -0,0 +1,37 @@ +census: + organism: + "homo_sapiens" + obs_query: # Use if you want to train on a subset of the model + null + obs_query_model: # Required when loading data for model training. Do not change. + 'is_primary_data == True and nnz >= 300' +hvg: + top_n_hvg: + 500 + hvg_batch: + [suspension_type, assay] +anndata: + batch_key: + [dataset_id, assay, suspension_type, donor_id] + model_filename: + anndata_model.h5ad +model: + filename: "scvi.model" + n_hidden: 512 + n_latent: 200 + n_layers: 1 + dropout_rate: 0.1 +train: + max_epochs: 20 + batch_size: 1048 + train_size: 0.95 + early_stopping: True + trainer: + early_stopping_patience: 2 + early_stopping_monitor: validation_loss # should be validation_loss - see https://github.com/chanzuckerberg/cellxgene-census/issues/777#issuecomment-1743196837 + check_val_every_n_epoch: 1 + multi_gpu: False + num_workers: 4 + devices: [0, 1, 2, 3] +training_plan: + lr: 1.0e-4 \ No newline at end of file diff --git a/tools/models/scvi/mrvi.py b/tools/models/scvi/mrvi.py new file mode 100644 index 000000000..d03ca8ed0 --- /dev/null +++ b/tools/models/scvi/mrvi.py @@ -0,0 +1,75 @@ +import scvi_v2 + +print(scvi_v2.__file__) + + +import torch + +torch.manual_seed(0) + +import anndata as ad +import flax.linen as nn +import yaml +from lightning.pytorch.loggers import TensorBoardLogger + +file = "mrvi-config.yaml" + +if __name__ == "__main__": + with open(file) as f: + config = yaml.safe_load(f) + + adata_config = config["anndata"] + filename = adata_config.get("model_filename") + + scvi_dataset = ad.read_h5ad(filename) + + train_kwargs = { + "early_stopping": True, + "plan_kwargs": {"lr": 1e-3, "n_epochs_kl_warmup": 20}, + } + + model_kwargs = { + "n_latent": 100, + "n_latent_u": 20, + "qz_nn_flavor": "attention", + "px_nn_flavor": "attention", + "qz_kwargs": {"use_map": False, "stop_gradients": False, "stop_gradients_mlp": True, "dropout_rate": 0.03}, + "px_kwargs": { + "stop_gradients": False, + "stop_gradients_mlp": True, + "h_activation": nn.softmax, + "dropout_rate": 0.03, + "low_dim_batch": True, + }, + "learn_z_u_prior_scale": False, + "z_u_prior": False, + "u_prior_mixture": True, + "u_prior_mixture_k": 100, + } + + scvi_dataset.obs["nuisance"] = ( + # scvi_dataset.obs['dataset_id'].astype(str) + '_' + + scvi_dataset.obs["assay"].astype(str) + + "_" + + scvi_dataset.obs["suspension_type"].astype(str) + ) + scvi_dataset.obs["sample"] = ( + scvi_dataset.obs["dataset_id"].astype(str) + "_" + scvi_dataset.obs["donor_id"].astype(str) + ) + + scvi_v2.MrVI.setup_anndata(scvi_dataset, sample_key="sample", batch_key="nuisance", labels_key="cell_type") + mrvi_model = scvi_v2.MrVI(scvi_dataset, **model_kwargs) + + logger = TensorBoardLogger("tb_logs", name="my_model") + + mrvi_model.train(max_epochs=50, batch_size=4096, use_gpu=True, accelerator="gpu", devices=1, **train_kwargs) + + mrvi_model.save(filename) + + # # Get z representation + # adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True) + # # Get u representation + # adata.obsm["X_mrvi_u"] = mrvi_model.get_latent_representation(give_z=False) + # sc.pp.neighbors(adata, use_rep="X_mrvi_u", key_added="neighbors_mrvi", method='rapids', n_neighbors=30) + # sc.tl.umap(adata, neighbors_key="neighbors_mrvi", method='rapids') + # sc.pl.umap(adata, color=['dataset_id', 'cell_subclass', 'suspension_type', 'sex'], ncols=1, frameon=False, wspace=0.4, title='mrVI (Census)') From e70fce8ff2664755cc45833059bbb7d7beb92e23 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 16:34:38 -0800 Subject: [PATCH 2/5] adjustments --- tools/models/scvi/mrvi-config.yaml | 22 ++-------------------- tools/models/scvi/mrvi.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/tools/models/scvi/mrvi-config.yaml b/tools/models/scvi/mrvi-config.yaml index 374c5507a..5d071557a 100644 --- a/tools/models/scvi/mrvi-config.yaml +++ b/tools/models/scvi/mrvi-config.yaml @@ -7,7 +7,7 @@ census: 'is_primary_data == True and nnz >= 300' hvg: top_n_hvg: - 500 + 5000 hvg_batch: [suspension_type, assay] anndata: @@ -16,22 +16,4 @@ anndata: model_filename: anndata_model.h5ad model: - filename: "scvi.model" - n_hidden: 512 - n_latent: 200 - n_layers: 1 - dropout_rate: 0.1 -train: - max_epochs: 20 - batch_size: 1048 - train_size: 0.95 - early_stopping: True - trainer: - early_stopping_patience: 2 - early_stopping_monitor: validation_loss # should be validation_loss - see https://github.com/chanzuckerberg/cellxgene-census/issues/777#issuecomment-1743196837 - check_val_every_n_epoch: 1 - multi_gpu: False - num_workers: 4 - devices: [0, 1, 2, 3] -training_plan: - lr: 1.0e-4 \ No newline at end of file + filename: "mrvi.model" \ No newline at end of file diff --git a/tools/models/scvi/mrvi.py b/tools/models/scvi/mrvi.py index d03ca8ed0..bf2572061 100644 --- a/tools/models/scvi/mrvi.py +++ b/tools/models/scvi/mrvi.py @@ -57,14 +57,21 @@ scvi_dataset.obs["dataset_id"].astype(str) + "_" + scvi_dataset.obs["donor_id"].astype(str) ) + model_config = config.get("model") + n_hidden = model_config.get("n_hidden") + n_latent = model_config.get("n_latent") + n_layers = model_config.get("n_layers") + dropout_rate = model_config.get("dropout_rate") + output_filename = model_config.get("filename") + scvi_v2.MrVI.setup_anndata(scvi_dataset, sample_key="sample", batch_key="nuisance", labels_key="cell_type") mrvi_model = scvi_v2.MrVI(scvi_dataset, **model_kwargs) - logger = TensorBoardLogger("tb_logs", name="my_model") + logger = TensorBoardLogger("mrvi_tb_logs", name="mrvi_50_epochs") mrvi_model.train(max_epochs=50, batch_size=4096, use_gpu=True, accelerator="gpu", devices=1, **train_kwargs) - mrvi_model.save(filename) + mrvi_model.save(output_filename) # # Get z representation # adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True) From 723a150507d4770b379b7d768033a4db337abd19 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Thu, 25 Jan 2024 12:01:23 -0800 Subject: [PATCH 3/5] quick test --- tools/models/scvi/mrvi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/models/scvi/mrvi.py b/tools/models/scvi/mrvi.py index bf2572061..67f48535e 100644 --- a/tools/models/scvi/mrvi.py +++ b/tools/models/scvi/mrvi.py @@ -25,9 +25,10 @@ train_kwargs = { "early_stopping": True, - "plan_kwargs": {"lr": 1e-3, "n_epochs_kl_warmup": 20}, } + plan_kwargs = {"lr": 1e-3, "n_epochs_kl_warmup": 20} + model_kwargs = { "n_latent": 100, "n_latent_u": 20, @@ -69,7 +70,7 @@ logger = TensorBoardLogger("mrvi_tb_logs", name="mrvi_50_epochs") - mrvi_model.train(max_epochs=50, batch_size=4096, use_gpu=True, accelerator="gpu", devices=1, **train_kwargs) + mrvi_model.train(max_epochs=50, batch_size=4096, use_gpu=True, plan_kwargs=plan_kwargs, **train_kwargs) mrvi_model.save(output_filename) From d6c8f171ed4c76bcab36f14c1486614797c3a86d Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 5 Feb 2024 11:59:15 -0800 Subject: [PATCH 4/5] Purge bad arg --- tools/models/scvi/mrvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/models/scvi/mrvi.py b/tools/models/scvi/mrvi.py index 67f48535e..23390a34d 100644 --- a/tools/models/scvi/mrvi.py +++ b/tools/models/scvi/mrvi.py @@ -70,7 +70,7 @@ logger = TensorBoardLogger("mrvi_tb_logs", name="mrvi_50_epochs") - mrvi_model.train(max_epochs=50, batch_size=4096, use_gpu=True, plan_kwargs=plan_kwargs, **train_kwargs) + mrvi_model.train(max_epochs=50, batch_size=4096, plan_kwargs=plan_kwargs, **train_kwargs) mrvi_model.save(output_filename) From efc86abf5e298564e275c0de60ebdac26a6436e3 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 6 Feb 2024 09:41:53 -0800 Subject: [PATCH 5/5] Simplify --- tools/models/scvi/mrvi-generate-embedding.py | 69 ++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tools/models/scvi/mrvi-generate-embedding.py diff --git a/tools/models/scvi/mrvi-generate-embedding.py b/tools/models/scvi/mrvi-generate-embedding.py new file mode 100644 index 000000000..ffc5afd18 --- /dev/null +++ b/tools/models/scvi/mrvi-generate-embedding.py @@ -0,0 +1,69 @@ +import scvi_v2 + +print(scvi_v2.__file__) + + +import torch + +torch.manual_seed(0) + +import anndata as ad +import numpy as np +import yaml + +file = "mrvi-config.yaml" + +if __name__ == "__main__": + with open(file) as f: + config = yaml.safe_load(f) + + adata_config = config["anndata"] + filename = adata_config.get("model_filename") + + census_config = config["census"] + experiment_name = census_config.get("organism") + + scvi_dataset = ad.read_h5ad(filename) + + scvi_dataset.obs["nuisance"] = ( + # scvi_dataset.obs['dataset_id'].astype(str) + '_' + + scvi_dataset.obs["assay"].astype(str) + + "_" + + scvi_dataset.obs["suspension_type"].astype(str) + ) + scvi_dataset.obs["sample"] = ( + scvi_dataset.obs["dataset_id"].astype(str) + "_" + scvi_dataset.obs["donor_id"].astype(str) + ) + + # hv = pd.read_pickle("hv_genes.pkl") + # hv_idx = hv[hv].index + + # census = cellxgene_census.open_soma(census_version="2023-12-15") + + # obs_query = None # not for now + + # query = census["census_data"][experiment_name].axis_query( + # measurement_name="RNA", + # obs_query=obs_query, + # var_query=soma.AxisQuery(coords=(list(hv_idx),)), + # ) + + # idx = query.obs(column_names=["soma_joinid"]).concat().to_pandas().index.to_numpy() + + # adata = query.to_anndata(X_name="raw") + + model_config = config.get("model") + model_filename = model_config.get("filename") + + # May or may not be necessary + # scvi_v2.MrVI.setup_anndata(scvi_dataset, sample_key="sample", batch_key="nuisance", labels_key="cell_type") + + mrvi_model = scvi_v2.MrVI.load("mrvi.model", adata=scvi_dataset) + + latent = mrvi_model.get_latent_representation(give_z=False) + + # with open("mrvi-latent-idx.npy", "wb") as f: + # np.save(f, idx) + + with open("mrvi-latent.npy", "wb") as f: + np.save(f, latent)