@@ -70,8 +70,6 @@ class SCANVAE(SupervisedModuleClass, VAE):
7070 If None, initialized to uniform probability over cell types
7171 labels_groups
7272 Label group designations
73- use_labels_groups
74- Whether to use the label groups
7573 linear_classifier
7674 If `True`, uses a single linear layer for classification instead of a
7775 multi-layer perceptron.
@@ -102,7 +100,6 @@ def __init__(
102100 use_observed_lib_size : bool = True ,
103101 y_prior : torch .Tensor | None = None ,
104102 labels_groups : Sequence [int ] = None ,
105- use_labels_groups : bool = False ,
106103 linear_classifier : bool = False ,
107104 classifier_parameters : dict | None = None ,
108105 use_batch_norm : Literal ["encoder" , "decoder" , "none" , "both" ] = "both" ,
@@ -176,30 +173,7 @@ def __init__(
176173 y_prior if y_prior is not None else (1 / n_labels ) * torch .ones (1 , n_labels ),
177174 requires_grad = False ,
178175 )
179- self .use_labels_groups = use_labels_groups
180176 self .labels_groups = np .array (labels_groups ) if labels_groups is not None else None
181- if self .use_labels_groups :
182- if labels_groups is None :
183- raise ValueError ("Specify label groups" )
184- unique_groups = np .unique (self .labels_groups )
185- self .n_groups = len (unique_groups )
186- if not (unique_groups == np .arange (self .n_groups )).all ():
187- raise ValueError ()
188- self .classifier_groups = Classifier (
189- n_latent , n_hidden , self .n_groups , n_layers , dropout_rate
190- )
191- self .groups_index = torch .nn .ParameterList (
192- [
193- torch .nn .Parameter (
194- torch .tensor (
195- (self .labels_groups == i ).astype (np .uint8 ),
196- dtype = torch .uint8 ,
197- ),
198- requires_grad = False ,
199- )
200- for i in range (self .n_groups )
201- ]
202- )
203177
204178 def loss (
205179 self ,
0 commit comments