-
Notifications
You must be signed in to change notification settings - Fork 450
Expand file tree
/
Copy path_vae.py
More file actions
906 lines (806 loc) · 34.9 KB
/
_vae.py
File metadata and controls
906 lines (806 loc) · 34.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
from __future__ import annotations
import logging
import warnings
from typing import TYPE_CHECKING
import numpy as np
import torch
from torch.nn.functional import one_hot
from scvi import REGISTRY_KEYS, settings
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.module._constants import MODULE_KEYS
from scvi.module.base import (
BaseMinifiedModeModuleClass,
EmbeddingModuleMixin,
LossOutput,
auto_move_data,
)
from scvi.utils import unsupported_if_adata_minified
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal
from torch.distributions import Distribution
logger = logging.getLogger(__name__)
class VAE(EmbeddingModuleMixin, BaseMinifiedModeModuleClass):
"""Variational auto-encoder :cite:p:`Lopez18`.
Parameters
----------
n_input
Number of input features.
n_batch
Number of batches. If ``0``, no batch correction is performed.
n_labels
Number of labels.
n_hidden
Number of nodes per hidden layer. Passed into :class:`~scvi.nn.Encoder` and
:class:`~scvi.nn.DecoderSCVI`.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers. Passed into :class:`~scvi.nn.Encoder` and
:class:`~scvi.nn.DecoderSCVI`.
n_continuous_cov
Number of continuous covariates.
n_cats_per_cov
A list of integers containing the number of categories for each categorical covariate.
dropout_rate
Dropout rate. Passed into :class:`~scvi.nn.Encoder` but not :class:`~scvi.nn.DecoderSCVI`.
dispersion
Flexibility of the dispersion parameter when ``gene_likelihood`` is either ``"nb"`` or
``"zinb"``. One of the following:
* ``"gene"``: parameter is constant per gene across cells.
* ``"gene-batch"``: parameter is constant per gene per batch.
* ``"gene-label"``: parameter is constant per gene per label.
* ``"gene-cell"``: parameter is constant per gene per cell.
log_variational
If ``True``, use :func:`~torch.log1p` on input data before encoding for numerical stability
(not normalization).
gene_likelihood
Distribution to use for reconstruction in the generative process. One of the following:
* ``"nb"``: :class:`~scvi.distributions.NegativeBinomial`.
* ``"zinb"``: :class:`~scvi.distributions.ZeroInflatedNegativeBinomial`.
* ``"poisson"``: :class:`~scvi.distributions.Poisson`.
* ``"normal"``: :class:`~torch.distributions.Normal`.
latent_distribution
Distribution to use for the latent space. One of the following:
* ``"normal"``: isotropic normal.
* ``"ln"``: logistic normal with normal params N(0, 1).
encode_covariates
If ``True``, covariates are concatenated to gene expression prior to passing through
the encoder(s). Else, only the gene expression is used.
deeply_inject_covariates
If ``True`` and ``n_layers > 1``, covariates are concatenated to the outputs of hidden
layers in the encoder(s) (if ``encoder_covariates`` is ``True``) and the decoder prior to
passing through the next layer.
batch_representation
``EXPERIMENTAL`` Method for encoding batch information. One of the following:
* ``"one-hot"``: represent batches with one-hot encodings.
* ``"embedding"``: represent batches with continuously-valued embeddings using
:class:`~scvi.nn.Embedding`.
Note that batch representations are only passed into the encoder(s) if
``encode_covariates`` is ``True``.
use_batch_norm
Specifies where to use :class:`~torch.nn.BatchNorm1d` in the model. One of the following:
* ``"none"``: don't use batch norm in either encoder(s) or decoder.
* ``"encoder"``: use batch norm only in the encoder(s).
* ``"decoder"``: use batch norm only in the decoder.
* ``"both"``: use batch norm in both encoder(s) and decoder.
Note: if ``use_layer_norm`` is also specified, both will be applied (first
:class:`~torch.nn.BatchNorm1d`, then :class:`~torch.nn.LayerNorm`).
use_layer_norm
Specifies where to use :class:`~torch.nn.LayerNorm` in the model. One of the following:
* ``"none"``: don't use layer norm in either encoder(s) or decoder.
* ``"encoder"``: use layer norm only in the encoder(s).
* ``"decoder"``: use layer norm only in the decoder.
* ``"both"``: use layer norm in both encoder(s) and decoder.
Note: if ``use_batch_norm`` is also specified, both will be applied (first
:class:`~torch.nn.BatchNorm1d`, then :class:`~torch.nn.LayerNorm`).
use_size_factor_key
If ``True``, use the :attr:`~anndata.AnnData.obs` column as defined by the
``size_factor_key`` parameter in the model's ``setup_anndata`` method as the scaling
factor in the mean of the conditional distribution. Takes priority over
``use_observed_lib_size``.
use_observed_lib_size
If ``True``, use the observed library size for RNA as the scaling factor in the mean of the
conditional distribution.
extra_payload_autotune
If ``True``, will return extra matrices in the loss output to be used during autotune
library_log_means
:class:`~numpy.ndarray` of shape ``(1, n_batch)`` of means of the log library sizes that
parameterize the prior on library size if ``use_size_factor_key`` is ``False`` and
``use_observed_lib_size`` is ``False``.
library_log_vars
:class:`~numpy.ndarray` of shape ``(1, n_batch)`` of variances of the log library sizes
that parameterize the prior on library size if ``use_size_factor_key`` is ``False`` and
``use_observed_lib_size`` is ``False``.
var_activation
Callable used to ensure positivity of the variance of the variational distribution. Passed
into :class:`~scvi.nn.Encoder`. Defaults to :func:`~torch.exp`.
extra_encoder_kwargs
Additional keyword arguments passed into :class:`~scvi.nn.Encoder`.
extra_decoder_kwargs
Additional keyword arguments passed into :class:`~scvi.nn.DecoderSCVI`.
batch_embedding_kwargs
Keyword arguments passed into :class:`~scvi.nn.Embedding` if ``batch_representation`` is
set to ``"embedding"``.
"""
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_labels: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
n_continuous_cov: int = 0,
n_cats_per_cov: list[int] | None = None,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
log_variational: bool = True,
gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb",
latent_distribution: Literal["normal", "ln"] = "normal",
encode_covariates: bool = False,
deeply_inject_covariates: bool = True,
batch_representation: Literal["one-hot", "embedding"] = "one-hot",
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
use_size_factor_key: bool = False,
use_observed_lib_size: bool = True,
extra_payload_autotune: bool = False,
library_log_means: np.ndarray | None = None,
library_log_vars: np.ndarray | None = None,
var_activation: Callable[[torch.Tensor], torch.Tensor] = None,
extra_encoder_kwargs: dict | None = None,
extra_decoder_kwargs: dict | None = None,
batch_embedding_kwargs: dict | None = None,
):
from scvi.nn import DecoderSCVI, Encoder
super().__init__()
self.dispersion = dispersion
self.n_latent = n_latent
self.log_variational = log_variational
self.gene_likelihood = gene_likelihood
self.n_batch = n_batch
self.n_input = n_input
self.n_labels = n_labels
self.n_hidden = n_hidden
self.n_layers = n_layers
self.latent_distribution = latent_distribution
self.encode_covariates = encode_covariates
self.use_size_factor_key = use_size_factor_key
self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size
self.extra_payload_autotune = extra_payload_autotune
if not self.use_observed_lib_size:
if library_log_means is None or library_log_vars is None:
raise ValueError(
"If not using observed_lib_size, "
"must provide library_log_means and library_log_vars."
)
self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float())
self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float())
if self.dispersion == "gene":
self.px_r = torch.nn.Parameter(torch.randn(n_input))
elif self.dispersion == "gene-batch":
self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
elif self.dispersion == "gene-label":
self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
elif self.dispersion == "gene-cell":
pass
else:
raise ValueError(
"`dispersion` must be one of 'gene', 'gene-batch', 'gene-label', 'gene-cell'."
)
self.batch_representation = batch_representation
if self.batch_representation == "embedding":
self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **(batch_embedding_kwargs or {}))
batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim
elif self.batch_representation != "one-hot":
raise ValueError("`batch_representation` must be one of 'one-hot', 'embedding'.")
use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both"
use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both"
n_input_encoder = n_input + n_continuous_cov * encode_covariates
if self.batch_representation == "embedding":
n_input_encoder += batch_dim * encode_covariates
cat_list = list([] if n_cats_per_cov is None else n_cats_per_cov)
else:
cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)
encoder_cat_list = cat_list if encode_covariates else None
_extra_encoder_kwargs = extra_encoder_kwargs or {}
z_encoder_kwargs = {
"n_cat_list": encoder_cat_list,
"n_layers": n_layers,
"n_hidden": n_hidden,
"dropout_rate": dropout_rate,
"distribution": latent_distribution,
"inject_covariates": deeply_inject_covariates,
"use_batch_norm": use_batch_norm_encoder,
"use_layer_norm": use_layer_norm_encoder,
"var_activation": var_activation,
"return_dist": True,
}
z_encoder_kwargs.update(_extra_encoder_kwargs)
self.z_encoder = Encoder(
n_input_encoder,
n_latent,
**z_encoder_kwargs,
)
# l encoder goes from n_input-dimensional data to 1-d library size
# n_layers is fixed to 1 for the library encoder
l_encoder_extra_kwargs = {
k: v for k, v in _extra_encoder_kwargs.items() if k != "n_layers"
}
l_encoder_kwargs = {
"n_layers": 1,
"n_cat_list": encoder_cat_list,
"n_hidden": n_hidden,
"dropout_rate": dropout_rate,
"inject_covariates": deeply_inject_covariates,
"use_batch_norm": use_batch_norm_encoder,
"use_layer_norm": use_layer_norm_encoder,
"var_activation": var_activation,
"return_dist": True,
}
l_encoder_kwargs.update(l_encoder_extra_kwargs)
self.l_encoder = Encoder(
n_input_encoder,
1,
**l_encoder_kwargs,
)
n_input_decoder = n_latent + n_continuous_cov
if self.batch_representation == "embedding":
n_input_decoder += batch_dim
_extra_decoder_kwargs = extra_decoder_kwargs or {}
decoder_kwargs = {
"n_cat_list": cat_list,
"n_layers": n_layers,
"n_hidden": n_hidden,
"inject_covariates": deeply_inject_covariates,
"use_batch_norm": use_batch_norm_decoder,
"use_layer_norm": use_layer_norm_decoder,
"scale_activation": "softplus" if use_size_factor_key else "softmax",
}
decoder_kwargs.update(_extra_decoder_kwargs)
self.decoder = DecoderSCVI(
n_input_decoder,
n_input,
**decoder_kwargs,
)
def _get_inference_input(
self,
tensors: dict[str, torch.Tensor | None],
full_forward_pass: bool = False,
) -> dict[str, torch.Tensor | None]:
"""Get input tensors for the inference process."""
if full_forward_pass or self.minified_data_type is None:
loader = "full_data"
elif self.minified_data_type in [
ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS,
]:
loader = "minified_data"
else:
raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}")
if loader == "full_data":
return {
MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY],
MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY],
MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None),
MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None),
}
else:
return {
MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY],
MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY],
REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE],
}
def _get_generative_input(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[str, torch.Tensor | Distribution | None],
) -> dict[str, torch.Tensor | None]:
"""Get input tensors for the generative process."""
size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None)
if size_factor is not None:
size_factor = torch.log(size_factor)
return {
MODULE_KEYS.Z_KEY: inference_outputs[MODULE_KEYS.Z_KEY],
MODULE_KEYS.LIBRARY_KEY: inference_outputs[MODULE_KEYS.LIBRARY_KEY],
MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY],
MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY],
MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None),
MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None),
MODULE_KEYS.SIZE_FACTOR_KEY: size_factor,
}
def _compute_local_library_params(
self,
batch_index: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Computes local library parameters.
Compute two tensors of shape (batch_index.shape[0], 1) where each
element corresponds to the mean and variances, respectively, of the
log library sizes in the batch the cell corresponds to.
"""
from torch.nn.functional import linear
n_batch = self.library_log_means.shape[1]
local_library_log_means = linear(
one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_means
)
local_library_log_vars = linear(
one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_vars
)
return local_library_log_means, local_library_log_vars
@auto_move_data
def _regular_inference(
self,
x: torch.Tensor,
batch_index: torch.Tensor,
cont_covs: torch.Tensor | None = None,
cat_covs: torch.Tensor | None = None,
n_samples: int = 1,
) -> dict[str, torch.Tensor | Distribution | None]:
"""Run the regular inference process."""
x_ = x
if self.use_observed_lib_size:
library = torch.log(x.sum(1)).unsqueeze(1)
if self.log_variational:
x_ = torch.log1p(x_)
if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x_, cont_covs), dim=-1)
else:
encoder_input = x_
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = ()
if self.batch_representation == "embedding" and self.encode_covariates:
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
encoder_input = torch.cat([encoder_input, batch_rep], dim=-1)
qz, z = self.z_encoder(encoder_input, *categorical_input)
else:
qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
ql = None
if not self.use_observed_lib_size:
if self.batch_representation == "embedding":
ql, library_encoded = self.l_encoder(encoder_input, *categorical_input)
else:
ql, library_encoded = self.l_encoder(
encoder_input, batch_index, *categorical_input
)
library = library_encoded
if n_samples > 1:
untran_z = qz.sample((n_samples,))
z = self.z_encoder.z_transformation(untran_z)
if self.use_observed_lib_size:
library = library.unsqueeze(0).expand(
(n_samples, library.size(0), library.size(1))
)
else:
library = ql.sample((n_samples,))
return {
MODULE_KEYS.Z_KEY: z,
MODULE_KEYS.QZ_KEY: qz,
MODULE_KEYS.QL_KEY: ql,
MODULE_KEYS.LIBRARY_KEY: library,
}
@auto_move_data
def _cached_inference(
self,
qzm: torch.Tensor,
qzv: torch.Tensor,
observed_lib_size: torch.Tensor,
n_samples: int = 1,
) -> dict[str, torch.Tensor | None]:
"""Run the cached inference process."""
from torch.distributions import Normal
qz = Normal(qzm, qzv.sqrt())
# use dist.sample() rather than rsample because we aren't optimizing the z here
untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,))
z = self.z_encoder.z_transformation(untran_z)
library = torch.log(observed_lib_size)
if n_samples > 1:
library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1)))
return {
MODULE_KEYS.Z_KEY: z,
MODULE_KEYS.QZ_KEY: qz,
MODULE_KEYS.QL_KEY: None,
MODULE_KEYS.LIBRARY_KEY: library,
}
@auto_move_data
def generative(
self,
z: torch.Tensor,
library: torch.Tensor,
batch_index: torch.Tensor,
cont_covs: torch.Tensor | None = None,
cat_covs: torch.Tensor | None = None,
size_factor: torch.Tensor | None = None,
y: torch.Tensor | None = None,
transform_batch: torch.Tensor | None = None,
) -> dict[str, Distribution | None]:
"""Run the generative process."""
from torch.nn.functional import linear
from scvi.distributions import (
NegativeBinomial,
Normal,
Poisson,
ZeroInflatedNegativeBinomial,
)
# TODO: refactor forward function to not rely on y
# Likelihood distribution
if cont_covs is None:
decoder_input = z
elif z.dim() != cont_covs.dim():
decoder_input = torch.cat(
[z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1
)
else:
decoder_input = torch.cat([z, cont_covs], dim=-1)
if cat_covs is not None:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = ()
if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch
if not self.use_size_factor_key:
size_factor = library
if self.batch_representation == "embedding":
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
decoder_input = torch.cat([decoder_input, batch_rep], dim=-1)
px_scale, px_r, px_rate, px_dropout = self.decoder(
self.dispersion,
decoder_input,
size_factor,
*categorical_input,
y,
)
else:
px_scale, px_r, px_rate, px_dropout = self.decoder(
self.dispersion,
decoder_input,
size_factor,
batch_index,
*categorical_input,
y,
)
if self.dispersion == "gene-label":
px_r = linear(
one_hot(y.squeeze(-1), self.n_labels).float(), self.px_r
) # px_r gets transposed - last dimension is nb genes
elif self.dispersion == "gene-batch":
px_r = linear(one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r)
elif self.dispersion == "gene":
px_r = self.px_r
px_r = torch.exp(px_r)
if self.gene_likelihood == "zinb":
px = ZeroInflatedNegativeBinomial(
mu=px_rate,
theta=px_r,
zi_logits=px_dropout,
scale=px_scale,
)
elif self.gene_likelihood == "nb":
px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale)
elif self.gene_likelihood == "poisson":
px = Poisson(rate=px_rate, scale=px_scale)
elif self.gene_likelihood == "normal":
px = Normal(px_rate, px_r, normal_mu=px_scale)
# Priors
if self.use_observed_lib_size:
pl = None
else:
(
local_library_log_means,
local_library_log_vars,
) = self._compute_local_library_params(batch_index)
pl = Normal(local_library_log_means, local_library_log_vars.sqrt())
pz = Normal(torch.zeros_like(z), torch.ones_like(z))
return {
MODULE_KEYS.PX_KEY: px,
MODULE_KEYS.PL_KEY: pl,
MODULE_KEYS.PZ_KEY: pz,
}
@unsupported_if_adata_minified
def loss(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[str, torch.Tensor | Distribution | None],
generative_outputs: dict[str, Distribution | None],
kl_weight: torch.tensor | float = 1.0,
) -> LossOutput:
"""Compute the loss."""
from torch.distributions import kl_divergence
x = tensors[REGISTRY_KEYS.X_KEY]
kl_divergence_z = kl_divergence(
inference_outputs[MODULE_KEYS.QZ_KEY], generative_outputs[MODULE_KEYS.PZ_KEY]
).sum(dim=-1)
if not self.use_observed_lib_size:
kl_divergence_l = kl_divergence(
inference_outputs[MODULE_KEYS.QL_KEY], generative_outputs[MODULE_KEYS.PL_KEY]
).sum(dim=1)
else:
kl_divergence_l = torch.zeros_like(kl_divergence_z)
reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
kl_local_for_warmup = kl_divergence_z
kl_local_no_warmup = kl_divergence_l
weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup
loss = torch.mean(reconst_loss + weighted_kl_local)
# a payload to be used during autotune
if self.extra_payload_autotune:
extra_metrics_payload = {
"z": inference_outputs["z"],
"batch": tensors[REGISTRY_KEYS.BATCH_KEY],
"labels": tensors[REGISTRY_KEYS.LABELS_KEY],
}
else:
extra_metrics_payload = {}
return LossOutput(
loss=loss,
reconstruction_loss=reconst_loss,
kl_local={
MODULE_KEYS.KL_L_KEY: kl_divergence_l,
MODULE_KEYS.KL_Z_KEY: kl_divergence_z,
},
extra_metrics=extra_metrics_payload,
)
@torch.inference_mode()
def sample(
self,
tensors: dict[str, torch.Tensor],
n_samples: int = 1,
max_poisson_rate: float = 1e8,
generative_kwargs: dict | None = None,
) -> torch.Tensor:
r"""Generate predictive samples from the posterior predictive distribution.
The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where
:math:`x` is the input data and :math:`\hat{x}` is the sampled data.
We sample from this distribution by first sampling ``n_samples`` times from the posterior
distribution :math:`q(z \mid x)` for a given observation, and then sampling from the
likelihood :math:`p(\hat{x} \mid z)` for each of these.
Parameters
----------
tensors
Dictionary of tensors passed into ``VAE.forward``.
n_samples
Number of Monte Carlo samples to draw from the distribution for each observation.
max_poisson_rate
The maximum value to which to clip the ``rate`` parameter of
:class:`~scvi.distributions.Poisson`. Avoids numerical sampling issues when the
parameter is very large due to the variance of the distribution.
generative_kwargs
Keyword args for ``generative()`` in fwd pass
Returns
-------
Tensor on CPU with shape ``(n_obs, n_vars)`` if ``n_samples == 1``, else
``(n_obs, n_vars,)``.
"""
from scvi.distributions import Poisson
inference_kwargs = {"n_samples": n_samples}
_, generative_outputs = self.forward(
tensors,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
compute_loss=False,
)
dist = generative_outputs[MODULE_KEYS.PX_KEY]
if self.gene_likelihood == "poisson":
# TODO: NEED TORCH MPS FIX for 'aten::poisson'
dist = (
Poisson(torch.clamp(dist.rate.to("cpu"), max=max_poisson_rate))
if self.device.type == "mps"
else Poisson(torch.clamp(dist.rate, max=max_poisson_rate))
)
# (n_obs, n_vars) if n_samples == 1, else (n_samples, n_obs, n_vars)
samples = dist.sample()
# (n_samples, n_obs, n_vars) -> (n_obs, n_vars, n_samples)
samples = torch.permute(samples, (1, 2, 0)) if n_samples > 1 else samples
return samples.cpu()
@torch.inference_mode()
@auto_move_data
def marginal_ll(
self,
tensors: dict[str, torch.Tensor],
n_mc_samples: int,
return_mean: bool = False,
n_mc_samples_per_pass: int = 1,
):
"""Compute the marginal log-likelihood of the data under the model.
Parameters
----------
tensors
Dictionary of tensors passed into ``VAE.forward``.
n_mc_samples
Number of Monte Carlo samples to use for the estimation of the marginal log-likelihood.
return_mean
Whether to return the mean of marginal likelihoods over cells.
n_mc_samples_per_pass
Number of Monte Carlo samples to use per pass. This is useful to avoid memory issues.
"""
from torch import logsumexp
from torch.distributions import Normal
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
to_sum = []
if n_mc_samples_per_pass > n_mc_samples:
warnings.warn(
"Number of chunks is larger than the total number of samples, setting it to the "
"number of samples",
RuntimeWarning,
stacklevel=settings.warnings_stacklevel,
)
n_mc_samples_per_pass = n_mc_samples
n_passes = int(np.ceil(n_mc_samples / n_mc_samples_per_pass))
for _ in range(n_passes):
# Distribution parameters and sampled variables
inference_outputs, _, losses = self.forward(
tensors,
inference_kwargs={"n_samples": n_mc_samples_per_pass},
get_inference_input_kwargs={"full_forward_pass": True},
)
qz = inference_outputs[MODULE_KEYS.QZ_KEY]
ql = inference_outputs[MODULE_KEYS.QL_KEY]
z = inference_outputs[MODULE_KEYS.Z_KEY]
library = inference_outputs[MODULE_KEYS.LIBRARY_KEY]
# Reconstruction Loss
reconst_loss = losses.dict_sum(losses.reconstruction_loss)
# Log-probabilities
p_z = (
Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)).log_prob(z).sum(dim=-1)
)
p_x_zl = -reconst_loss
q_z_x = qz.log_prob(z).sum(dim=-1)
log_prob_sum = p_z + p_x_zl - q_z_x
if not self.use_observed_lib_size:
(
local_library_log_means,
local_library_log_vars,
) = self._compute_local_library_params(batch_index)
p_l = (
Normal(local_library_log_means, local_library_log_vars.sqrt())
.log_prob(library)
.sum(dim=-1)
)
q_l_x = ql.log_prob(library).sum(dim=-1)
log_prob_sum += p_l - q_l_x
if n_mc_samples_per_pass == 1:
log_prob_sum = log_prob_sum.unsqueeze(0)
to_sum.append(log_prob_sum)
to_sum = torch.cat(to_sum, dim=0)
batch_log_lkl = logsumexp(to_sum, dim=0) - np.log(n_mc_samples)
if return_mean:
batch_log_lkl = torch.mean(batch_log_lkl).item()
else:
batch_log_lkl = batch_log_lkl.cpu()
return batch_log_lkl
class LDVAE(VAE):
"""Linear-decoded Variational auto-encoder model.
Implementation of :cite:p:`Svensson20`.
This model uses a linear decoder, directly mapping the latent representation
to gene expression levels. It still uses a deep neural network to encode
the latent representation.
Compared to standard VAE, this model is less powerful but can be used to
inspect which genes contribute to variation in the dataset. It may also be used
for all scVI tasks, like differential expression, batch correction, imputation, etc.
However, batch correction may be less powerful as it assumes a linear model.
Parameters
----------
n_input
Number of input genes
n_batch
Number of batches
n_labels
Number of labels
n_hidden
Number of nodes per hidden layer (for encoder)
n_latent
Dimensionality of the latent space
n_layers_encoder
Number of hidden layers used for encoder NNs
dropout_rate
Dropout rate for neural networks
dispersion
One of the following
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
log_variational
Log(data+1) prior to encoding for numerical stability. Not normalization.
gene_likelihood
One of
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
* ``'poisson'`` - Poisson distribution
use_batch_norm
Bool whether to use batch norm in decoder
bias
Bool whether to have bias term in linear decoder
latent_distribution
One of
* ``'normal'`` - Isotropic normal
* ``'ln'`` - Logistic normal with normal params N(0, 1)
use_observed_lib_size
Use observed library size for RNA as scaling factor in mean of conditional distribution.
**kwargs
"""
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_labels: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers_encoder: int = 1,
dropout_rate: float = 0.1,
dispersion: str = "gene",
log_variational: bool = True,
gene_likelihood: str = "nb",
use_batch_norm: bool = True,
bias: bool = False,
latent_distribution: str = "normal",
use_observed_lib_size: bool = False,
**kwargs,
):
from scvi.nn import Encoder, LinearDecoderSCVI
super().__init__(
n_input=n_input,
n_batch=n_batch,
n_labels=n_labels,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers_encoder,
dropout_rate=dropout_rate,
dispersion=dispersion,
log_variational=log_variational,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
use_observed_lib_size=use_observed_lib_size,
**kwargs,
)
self.use_batch_norm = use_batch_norm
self.z_encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers_encoder,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
use_batch_norm=True,
use_layer_norm=False,
return_dist=True,
)
self.l_encoder = Encoder(
n_input,
1,
n_layers=1,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=True,
use_layer_norm=False,
return_dist=True,
)
self.decoder = LinearDecoderSCVI(
n_latent,
n_input,
n_cat_list=[n_batch],
use_batch_norm=use_batch_norm,
use_layer_norm=False,
bias=bias,
)
@torch.inference_mode()
def get_loadings(self) -> np.ndarray:
"""Extract per-gene weights in the linear decoder."""
# This is BW, where B is diag(b) batch norm, W is the weight matrix
if self.use_batch_norm is True:
w = self.decoder.factor_regressor.fc_layers[0][0].weight
bn = self.decoder.factor_regressor.fc_layers[0][1]
sigma = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight
b = gamma / sigma
b_identity = torch.diag(b)
loadings = torch.matmul(b_identity, w)
else:
loadings = self.decoder.factor_regressor.fc_layers[0][0].weight
loadings = loadings.detach().cpu().numpy()
if self.n_batch > 1:
loadings = loadings[:, : -self.n_batch]
return loadings