@@ -61,13 +61,13 @@ def pairwise_align(sliceA, sliceB, alpha = 0.1, G_init = None, a_distribution =
6161 return pi , logw ['fgw_dist' ]
6262 return pi
6363
64- def center_align (A , slices , lmbda , alpha = 0.1 , n_components = 15 , threshold = 0.001 , max_iter = 10 , norm = False , random_seed = None , pis_init = None , verbose = False ):
64+ def center_align (A , slices , lmbda = None , alpha = 0.1 , n_components = 15 , threshold = 0.001 , max_iter = 10 , norm = False , random_seed = None , pis_init = None , verbose = False ):
6565 """
6666 Computes center alignment of slices.
6767
6868 param: A - Initialization of starting AnnData Spatial Object; Make sure to include gene expression AND spatial info
6969 param: slices - List of slices (AnnData objects) used to calculate center alignment
70- param: lmbda - List of probability weights assigned to each slice
70+ param: lmbda - List of probability weights assigned to each slice; default is uniform weights
7171 param: n_components - Number of components in NMF decomposition
7272 param: threshold - Threshold for convergence of W and H
7373 param: max_iter - maximum number of iterations for solving for center slice
@@ -81,6 +81,9 @@ def center_align(A, slices, lmbda, alpha = 0.1, n_components = 15, threshold = 0
8181 return: pi - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns)
8282 """
8383
84+ if lmbda is None :
85+ lmbda = len (slices )* [1 / len (slices )]
86+
8487 # get common genes
8588 common_genes = A .var .index
8689 for s in slices :
@@ -104,9 +107,13 @@ def center_align(A, slices, lmbda, alpha = 0.1, n_components = 15, threshold = 0
104107 H = model .components_
105108 center_coordinates = A .obsm ['spatial' ]
106109
110+ if not isinstance (center_coordinates , np .ndarray ):
111+ print ("Warning: A.obsm['spatial'] is not of type numpy array ." )
112+
107113 # Initialize center_slice
108114 center_slice = anndata .AnnData (np .dot (W ,H ))
109115 center_slice .var .index = common_genes
116+ center_slice .obs .index = A .obs .index
110117 center_slice .obsm ['spatial' ] = center_coordinates
111118
112119 # Minimize R
0 commit comments