Skip to content

spaVAE_Batch

Tian Tian edited this page Dec 21, 2023 · 11 revisions

src/spaVAE_Batch

spaVAE model for batch integration.

Wrap script to run spaVAE_Batch model.

Parameters:

--data_file: data file name.
--select_genes: number of selected genes for embedding analysis, default = 0 means no filtering. It will use the mean-variance relationship to select informative genes.
--batch_size: mini-batch size, default = "auto", which means if sample size <= 1024 then batch size = 128, if 1024 < sample size <= 2048 then batch size = 256, if sample size > 2048 then batch size = 512.
--maxiter: number of max training iterations, default = 5000.
--train_size: proportion of training set, others will be validating set, default = 0.95.
--patience: patience of early stopping when using validating set, default = 200.
--lr: learning rate, default = 1e-3.
--weight_decay: weight decay coefficient, default = 1e-6.
--noise: coefficient of random Gaussian noise for the encoder, default = 0.
--dropoutE: dropout probability for encoder, default = 0.
--dropoutD: dropout probability for decoder, default = 0.
--encoder_layers: hidden layer sizes of encoder, default = [128, 64].
--GP_dim: dimension of the latent Gaussian process embedding, default = 2.
--Normal_dim: dimension of the latent standard Gaussian embedding, default = 8.
--decoder_layers: hidden layer sizes of decoder, default = [128].
--dynamicVAE: whether to use dynamicVAE to tune the value of beta, if setting to false, then beta is fixed to initial value.
--init_beta: initial coefficient of the KL loss, default = 10.
--min_beta: minimal coefficient of the KL loss, default = 4.
--max_beta: maximal coefficient of the KL loss, default = 25. min_beta, max_beta, and KL_loss are used for dynamic VAE algorithm. --KL_loss: desired KL_divergence value (GP and standard normal combined), default = 0.025.
--num_samples: number of samplings of the posterior distribution of latent embedding during training, default = 1.
--fix_inducing_points: fixed or trainable inducing points, default = True, which means inducing points are fixed.
--grid_inducing_points: whether to use 2D grid inducing points or k-means centroids of positions as inducing points, default = True. "True" for 2D grid, "False" for k-means centroids.
--inducing_point_steps: if using 2D grid inducing points, set the number of 2D grid steps, default = None. Needed when grid_inducing_points = True.
--inducing_point_nums: if using k-means centroids on positions, set the number of inducing points, default = None. Needed when grid_inducing_points = False.
--fixed_gp_params: kernel scale is fixed or not, default = False, which means kernel scale is trainable.
--loc_range: positional locations will be scaled to the specified range. For example, loc_range = 20 means x and y locations will be scaled to the range 0 to 20, default = 20.
--kernel_scale: initial kernel scale, default = 20.
--model_file: file name to save weights of the model, default = model.pt
--final_latent_file: file name to output final latent representations, default = final_latent.txt.
--denoised_counts_file: file name to output denoised counts, default = denoised_mean.txt.
--device: pytorch device, default = cuda.

The most critical parameter is inducing_point_steps or inducing_point_nums, which controls the number of inducing points in the Gaussian process prior. Less number of inducing points would have higher computational efficiency, but more number could capture more complex spatial patterns. If using inducing_point_steps, then n_inducing_points = $(\text{inducing\_point\_steps}+1)^2$.

Main model functions of spaVAE_Batch.

forward:

Forward pass.

PARAMETERS:

  • x: tensor, mini-batch of spatial locations.
  • y: tensor, mini-batch of preprocessed counts.
  • batch: mini-batch of one-hot encoded batch IDs.
  • raw_y: tensor, mini-batch of raw counts.
  • size_factor: tensor, mini-batch of size factors.
  • num_samples: tensor, number of samplings of the posterior distribution of latent embedding.
  • raw_y and size_factor are used for NB likelihood.

RETURNS:

  • Tuple of tensors need for model training.

batching_latent_samples:

Return latent representation of each spot.

PARAMETERS:

  • X: numpy array, shape (n_spots, 2), location information.
  • Y: numpy array, shape (n_spots, n_genes), preprocessed count matrix.
  • B: numpy array, shape (n_spots, n_batches), one-hot encoded batch IDs.
  • batch_size: default = 512, mini-batch size for data loading into mode.

RETURNS:

  • Numpy array, low-dimensional representation for each spot.

batching_denoise_counts:

Return denoised counts (decoded) for each spot.

PARAMETERS:

  • X: numpy array, shape (n_spots, 2), location information.
  • Y: numpy array, shape (n_spots, n_genes), preprocessed count matrix.
  • B: numpy array, shape (n_spots, n_batches), one-hot encoded batch IDs.
  • n_samples: number of samplings of the posterior distribution of latent embedding. The denoised counts are average of the samplings.
  • batch_size: default = 512, mini-batch size for data loading into mode.

RETURNS:

  • Numpy array, denoised counts (decoded) for each spot.

differential_expression:

Differential expression (DE) analysis method. DE compares group1 vs group2 across n batches, denoted as batch0 through batch(n-1).

PARAMETERS:

  • group1_idx: spot index of group 1.
  • group2_idx: spot index of group 2.
  • num_denoise_samples: default = 10,000, number of samplings in each group.
  • batch_size: default = 512, mini-batch size for data loading into mode.
  • pos: numpy array, shape (n_spots, 2), location information.
  • ncounts: numpy array, shape (n_spots, n_genes), preprocessed count matrix.
  • batch: numpy array, shape (n_spots, n_batches), one-hot encoded batch IDs.
  • gene_name: numpy array, shape (n_genes), string array of gene names.
  • raw_counts: numpy array, shape (n_spots, n_genes), raw count matrix.
  • n_samples: number of samplings of the posterior distribution of latent embedding. The denoised counts are average of the samplings.
  • estimate_pseudocount: default = True, whether to estimate pseudocount from data. If not, pseudocount = 0.01

RETURNS:

  • Pandas dataframe. Columns are "LFC": estimated group-wise log fold change; "mean_LFC": mean of pairwise log fold change; "median_LFC": median of pairwise log fold change; "sd_LFC": standard deviation of pairwise log fold change; "prob_DE": DE probability; "prob_not_DE": EE probability; "bayes_factor": Bayes factor = log(DE probability/EE probability); "denoised_mean1_batch{0 - (n-1)}": average of denoised (decoded) counts of group1 in batch0 - batch(n-1); "denoised_mean2_batch{0 - (n-1)}": average of denoised (decoded) counts of group2 in batch0 - batch(n-1); "raw_mean1_batch{0 - (n-1)}": average of raw counts of group1 in batch0 - batch(n-1); "raw_mean2_batch{0 - (n-1)}": average of raw counts of group2 in batch0 - batch(n-1). Bayes factor and LFC can be used to prioritize DE genes.

train_model:

Model training function.

PARAMETERS:

  • pos: numpy array, shape (n_spots, 2), location information.
  • ncounts: numpy array, shape (n_spots, n_genes), preprocessed count matrix.
  • raw_counts: numpy array, shape (n_spots, n_genes), raw count matrix.
  • size_factor: numpy array, shape (n_spots), the size factor of each spot, which is need for the NB loss.
  • batch: numpy array, shape (n_spots, n_batches), one-hot encoded batch IDs.
  • lr: default = 0.001, learning rate for AdamW optimizer.
  • weight_decay: default = 1e-6, weight decay for AdamW optimizer.
  • batch_size: default = 512, mini-batch size.
  • maxiter: default = 5000, maximum number of iterations.
  • patience: default = 200, Patience for early stopping.
  • save_model: default = True, whether to save the model weights.
  • model_weights: default = "model.pt", file name to save the model weights.
  • print_kernel_scale: default = True, whether to print current kernel scale during training steps.

Sparse variational Gaussian process.

kernel_matrix:

Computes GP kernel matrix $K(x, y)$ for multi-batched data. In this case, first two columns in x and y are spatial locations, and the other columns are one-hot encoded batch IDs.

PARAMETERS:

  • x: tensor, position vector x.
  • y: tensor, position vector y.
  • diag_only: whether or not to only compute diagonal terms of the kernel matrix.

RETURN:

  • kernel matrix

variational_loss:

Compute variational loss of Gaussian process ($L_H$ in the equation (4)) for the current mini-batch data.

PARAMETERS:

  • x: tensor, shape (batch, 2), auxiliary (location) information for current batch.
  • y: tensor, shape (batch, 1), latent mean vector for current dimension, output by the encoder network.
  • noise: tensor, shape (batch, 1), latent variance vector for current dimension, output by the encoder network.
  • mu_hat: tensor, posterior mean for current dimension (equation (5)).
  • A_hat: tensor, (diagonal of) posterior covariance matrix for current dimension (equation (5)).

RETURN:

  • sum_term, KL_term (variational loss = sum term + KL term)

approximate_posterior_params:

Compute posterior parameters for the current mini-batch data ($\boldsymbol{\mu}_b^l$ and $\boldsymbol{A}_b^l$ in equation (5))

PARAMETERS:

  • index_points_test: tensor, testing set of auxiliary (location) information.
  • index_points_train: tensor, training set of auxiliary (location) information.
  • y: tensor, shape (batch, 1), latent mean vector for current dimension, output by the encoder network.
  • noise: tensor, shape (batch, 1), latent variance vector for current dimension, output by the encoder network.

RETURN:

  • mean_vector, B: $\boldsymbol{m}_b^l$ and $\boldsymbol{B}_b^l$ in equation (7).
  • mu_hat, A_hat: $\boldsymbol{\mu}_b^l$ and $\boldsymbol{A}_b^l$ in equation (5).

Kernel functions.

CauchyKernel:

Cauchy kernel.

  • forward: calculate $K(x,y)$. x, y are auxiliary (location) information.
  • forward_diag: calculate diagonal elements of $K(x,y)$. x, y are auxiliary (location) information.

BatchedCauchyKernel:

Batched version of Cauchy kernel (integrating different batches).

  • forward: calculate $x_{batch}y_{batch}K(x,y)$. x, y are auxiliary (location) information, and sample_x, sample_y are one-hot encoded batch index.
  • forward_diag: calculate diagonal elements of x, y are auxiliary (location) information, and sample_x, sample_y are one-hot encoded batch index.

EQKernel:

Exponentiated quadratic kernel.

  • forward: calculate $K(x,y)$. x, y are auxiliary (location) information.
  • forward_diag: calculate diagonal elements of $K(x,y)$. x, y are auxiliary (location) information.

BatchedEQKernel:

Batched version of exponentiated quadratic kernel (integrating different batches).

  • forward: calculate $x_{batch}y_{batch}K(x,y)$. x, y are auxiliary (location) information, and sample_x, sample_y are one-hot encoded batch index.
  • forward_diag: calculate diagonal elements of $x_{batch}y_{batch}K(x,y)$. x, y are auxiliary (location) information, and sample_x, sample_y are one-hot encoded batch index.