Skip to content

Commit 813a6a5

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2fa70b3 commit 813a6a5

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

velovi/_model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class VELOVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
4949
Use a linear decoder from latent space to time.
5050
**model_kwargs
5151
Keyword args for :class:`~velovi.VELOVAE`
52+
5253
"""
5354

5455
def __init__(
@@ -108,13 +109,8 @@ def __init__(
108109
**model_kwargs,
109110
)
110111
self._model_summary_string = (
111-
"VELOVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
112-
"{}"
113-
).format(
114-
n_hidden,
115-
n_latent,
116-
n_layers,
117-
dropout_rate,
112+
f"VELOVI Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
113+
f"{dropout_rate}"
118114
)
119115
self.init_params_ = self._get_init_params(locals())
120116

@@ -164,6 +160,7 @@ def train(
164160
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
165161
**trainer_kwargs
166162
Other keyword args for :class:`~scvi.train.Trainer`.
163+
167164
"""
168165
user_plan_kwargs = plan_kwargs.copy() if isinstance(plan_kwargs, dict) else {}
169166
plan_kwargs = {"lr": lr, "weight_decay": weight_decay, "optimizer": "AdamW"}
@@ -238,6 +235,7 @@ def get_state_assignment(
238235
-------
239236
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
240237
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
238+
241239
"""
242240
adata = self._validate_anndata(adata)
243241
scdl = self._make_data_loader(
@@ -342,6 +340,7 @@ def get_latent_time(
342340
-------
343341
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
344342
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
343+
345344
"""
346345
adata = self._validate_anndata(adata)
347346
if indices is None:
@@ -484,6 +483,7 @@ def get_velocity(
484483
-------
485484
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
486485
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
486+
487487
"""
488488
adata = self._validate_anndata(adata)
489489
if indices is None:
@@ -658,6 +658,7 @@ def get_expression_fit(
658658
-------
659659
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
660660
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
661+
661662
"""
662663
adata = self._validate_anndata(adata)
663664

@@ -813,6 +814,7 @@ def get_gene_likelihood(
813814
-------
814815
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
815816
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
817+
816818
"""
817819
adata = self._validate_anndata(adata)
818820
scdl = self._make_data_loader(
@@ -919,6 +921,7 @@ def setup_anndata(
919921
Returns
920922
-------
921923
%(returns)s
924+
922925
"""
923926
setup_method_args = cls._get_setup_method_args(**locals())
924927
anndata_fields = [
@@ -969,6 +972,7 @@ def get_permutation_scores(
969972
-------
970973
Tuple of DataFrame and AnnData. DataFrame is genes by cell types with score per cell type.
971974
AnnData is the permutated version of the original AnnData.
975+
972976
"""
973977
adata = self._validate_anndata(adata)
974978
adata_manager = self.get_anndata_manager(adata)
@@ -1092,6 +1096,7 @@ def _directional_statistics_per_cell(
10921096
----------
10931097
tensor
10941098
Shape of samples by genes for a given cell.
1099+
10951100
"""
10961101
n_samples = tensor.shape[0]
10971102
# over samples axis

velovi/_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Main module."""
2+
23
from typing import Callable, Iterable, Literal, Optional
34

45
import numpy as np
@@ -44,6 +45,7 @@ class DecoderVELOVI(nn.Module):
4445
Whether to use layer norm in layers
4546
linear_decoder
4647
Whether to use linear decoder for time
48+
4749
"""
4850

4951
def __init__(
@@ -183,6 +185,7 @@ class VELOVAE(BaseModuleClass):
183185
var_activation
184186
Callable used to ensure positivity of the variational distributions' variance.
185187
When `None`, defaults to `torch.exp`.
188+
186189
"""
187190

188191
def __init__(

velovi/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def preprocess_data(
5656
Returns
5757
-------
5858
Preprocessed adata.
59+
5960
"""
6061
if min_max_scale:
6162
scaler = MinMaxScaler()

0 commit comments

Comments
 (0)