Skip to content

Commit 71646b0

Browse files
authored
Fix: grouped-label classification when use_labels_groups=True need to use logits =false. (#3805)
close #3804
1 parent b133c19 commit 71646b0

6 files changed

Lines changed: 8 additions & 94 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo
2727

2828
#### Removed
2929

30+
- Removed grouped-label classification legacy code in {class}`scvi.model.SCANVI`,
31+
{class}`scvi.external.TOTALANVI`, and {class}`scvi.external.METHYLANVI`, {pr}`3805`.
32+
3033
### 1.4.2 (2026-02-26)
3134

3235
#### Added

src/scvi/external/methylvi/_methylanvi_module.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ class METHYLANVAE(SupervisedModuleClass, METHYLVAE, BSSeqModuleMixin):
5555
If None, initialized to uniform probability over cell types
5656
labels_groups
5757
Label group designations
58-
use_labels_groups
59-
Whether to use the label groups
6058
linear_classifier
6159
If `True`, uses a single linear layer for classification instead of a
6260
multi-layer perceptron.
@@ -86,7 +84,6 @@ def __init__(
8684
dispersion: Literal["region", "region-cell"] = "region",
8785
y_prior=None,
8886
labels_groups: Sequence[int] = None,
89-
use_labels_groups: bool = False,
9087
linear_classifier: bool = False,
9188
classifier_parameters: dict | None = None,
9289
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
@@ -156,30 +153,7 @@ def __init__(
156153
y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels),
157154
requires_grad=False,
158155
)
159-
self.use_labels_groups = use_labels_groups
160156
self.labels_groups = np.array(labels_groups) if labels_groups is not None else None
161-
if self.use_labels_groups:
162-
if labels_groups is None:
163-
raise ValueError("Specify label groups")
164-
unique_groups = np.unique(self.labels_groups)
165-
self.n_groups = len(unique_groups)
166-
if not (unique_groups == np.arange(self.n_groups)).all():
167-
raise ValueError()
168-
self.classifier_groups = Classifier(
169-
n_latent, n_hidden, self.n_groups, n_layers, dropout_rate
170-
)
171-
self.groups_index = torch.nn.ParameterList(
172-
[
173-
torch.nn.Parameter(
174-
torch.tensor(
175-
(self.labels_groups == i).astype(np.uint8),
176-
dtype=torch.uint8,
177-
),
178-
requires_grad=False,
179-
)
180-
for i in range(self.n_groups)
181-
]
182-
)
183157

184158
@auto_move_data
185159
def classify(

src/scvi/external/totalanvi/_module.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def __init__(
143143
extra_decoder_kwargs: dict | None = None,
144144
y_prior=None,
145145
labels_groups: Sequence[int] = None,
146-
use_labels_groups: bool = False,
147146
linear_classifier: bool = False,
148147
classifier_parameters: dict | None = None,
149148
):
@@ -234,30 +233,7 @@ def __init__(
234233
y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels),
235234
requires_grad=False,
236235
)
237-
self.use_labels_groups = use_labels_groups
238236
self.labels_groups = np.array(labels_groups) if labels_groups is not None else None
239-
if self.use_labels_groups:
240-
if labels_groups is None:
241-
raise ValueError("Specify label groups")
242-
unique_groups = np.unique(self.labels_groups)
243-
self.n_groups = len(unique_groups)
244-
if not (unique_groups == np.arange(self.n_groups)).all():
245-
raise ValueError()
246-
self.classifier_groups = Classifier(
247-
n_latent, n_hidden, self.n_groups, n_layers_encoder, dropout_rate_encoder
248-
)
249-
self.groups_index = torch.nn.ParameterList(
250-
[
251-
torch.nn.Parameter(
252-
torch.tensor(
253-
(self.labels_groups == i).astype(np.uint8),
254-
dtype=torch.uint8,
255-
),
256-
requires_grad=False,
257-
)
258-
for i in range(self.n_groups)
259-
]
260-
)
261237

262238
@auto_move_data
263239
def classify(

src/scvi/module/_scanvae.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ class SCANVAE(SupervisedModuleClass, VAE):
7070
If None, initialized to uniform probability over cell types
7171
labels_groups
7272
Label group designations
73-
use_labels_groups
74-
Whether to use the label groups
7573
linear_classifier
7674
If `True`, uses a single linear layer for classification instead of a
7775
multi-layer perceptron.
@@ -102,7 +100,6 @@ def __init__(
102100
use_observed_lib_size: bool = True,
103101
y_prior: torch.Tensor | None = None,
104102
labels_groups: Sequence[int] = None,
105-
use_labels_groups: bool = False,
106103
linear_classifier: bool = False,
107104
classifier_parameters: dict | None = None,
108105
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
@@ -176,30 +173,7 @@ def __init__(
176173
y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels),
177174
requires_grad=False,
178175
)
179-
self.use_labels_groups = use_labels_groups
180176
self.labels_groups = np.array(labels_groups) if labels_groups is not None else None
181-
if self.use_labels_groups:
182-
if labels_groups is None:
183-
raise ValueError("Specify label groups")
184-
unique_groups = np.unique(self.labels_groups)
185-
self.n_groups = len(unique_groups)
186-
if not (unique_groups == np.arange(self.n_groups)).all():
187-
raise ValueError()
188-
self.classifier_groups = Classifier(
189-
n_latent, n_hidden, self.n_groups, n_layers, dropout_rate
190-
)
191-
self.groups_index = torch.nn.ParameterList(
192-
[
193-
torch.nn.Parameter(
194-
torch.tensor(
195-
(self.labels_groups == i).astype(np.uint8),
196-
dtype=torch.uint8,
197-
),
198-
requires_grad=False,
199-
)
200-
for i in range(self.n_groups)
201-
]
202-
)
203177

204178
def loss(
205179
self,

src/scvi/module/base/_base_module.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -778,16 +778,7 @@ class SupervisedModuleClass:
778778

779779
@auto_move_data
780780
def classify_helper(self, z):
781-
if self.use_labels_groups:
782-
w_g = self.classifier_groups(z)
783-
unw_y = self.classifier(z)
784-
w_y = torch.zeros_like(unw_y)
785-
for i, group_index in enumerate(self.groups_index):
786-
unw_y_g = unw_y[:, group_index]
787-
w_y[:, group_index] = unw_y_g / (unw_y_g.sum(dim=-1, keepdim=True) + 1e-8)
788-
w_y[:, group_index] *= w_g[:, [i]]
789-
else:
790-
w_y = self.classifier(z)
781+
w_y = self.classifier(z)
791782
return w_y
792783

793784
@auto_move_data

tests/model/test_scanvi.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,8 @@ def test_scanvi():
167167
assert scanvi_pxr is not scvi_pxr
168168
scanvi_model.train(1)
169169

170-
# Test without label groups
171-
scanvi_model = SCANVI.from_scvi_model(
172-
m, "label_0", labels_key="labels", use_labels_groups=False
173-
)
170+
# Test without label groups (default)
171+
scanvi_model = SCANVI.from_scvi_model(m, "label_0", labels_key="labels")
174172
scanvi_model.train(1)
175173

176174
# test from_scvi_model with size_factor
@@ -283,10 +281,8 @@ def test_scanvi_with_external_indices():
283281
assert scanvi_pxr is not scvi_pxr
284282
scanvi_model.train(1)
285283

286-
# Test without label groups
287-
scanvi_model = SCANVI.from_scvi_model(
288-
m, "label_0", labels_key="labels", use_labels_groups=False
289-
)
284+
# Test without label groups (default)
285+
scanvi_model = SCANVI.from_scvi_model(m, "label_0", labels_key="labels")
290286
scanvi_model.train(1)
291287

292288
# test from_scvi_model with size_factor

0 commit comments

Comments
 (0)