Skip to content

Commit 7f94ca0

Browse files
authored
Merge branch 'main' into Ori-Dropping-Jax
2 parents c2b320f + c543aa7 commit 7f94ca0

4 files changed

Lines changed: 42 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo
1313

1414
#### Fixed
1515

16+
- Fix list of metrics to be recorded in {class}`scvi.autotune.AutotuneExperiment`, {pr}`3816`.
17+
1618
#### Changed
1719

1820
#### Removed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"]
5959
editing = ["jupyter", "pre-commit"]
6060
dev = ["scvi-tools[editing,tests]"]
6161
test = ["scvi-tools[tests]"]
62-
cuda = ["torchvision", "torchaudio", "mlx[cuda]"]
63-
cuda13 = ["torchvision", "torchaudio", "mlx[cuda13]"]
62+
cuda = ["torchvision", "torchaudio"]
63+
cuda13 = ["torchvision", "torchaudio"]
6464
tpu = ["torch_xla[tpu]"]
6565
metal = ["torchvision", "torchaudio", "mlx-metal"]
6666

src/scvi/autotune/_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def metrics_callback(self) -> Callback:
721721
)
722722
on = "validation_end" if "validation" in self.metrics[0] else "train_end"
723723

724-
return callback_cls(metrics=self.metrics[0], on=on, save_checkpoints=self.save_checkpoints)
724+
return callback_cls(metrics=self.metrics, on=on, save_checkpoints=self.save_checkpoints)
725725

726726
@property
727727
def scib_metrics_callback(self) -> Callback:

tests/autotune/test_tune.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,43 @@ def test_run_autotune_scvi_basic_adata(save_checkpoints: bool, metric: str, save
4545
assert isinstance(experiment.result_grid, ResultGrid)
4646

4747

48+
@pytest.mark.autotune
49+
def test_run_autotune_scvi_multiple_metrics(save_path: str):
50+
from ray import tune
51+
from ray.tune import ResultGrid
52+
53+
from scvi.autotune import AutotuneExperiment, run_autotune
54+
55+
settings.logging_dir = save_path
56+
adata = synthetic_iid()
57+
SCVI.setup_anndata(adata)
58+
59+
experiment = run_autotune(
60+
SCVI,
61+
adata,
62+
metrics=["elbo_validation", "validation_loss"],
63+
mode="min",
64+
search_space={
65+
"model_params": {
66+
"n_hidden": tune.choice([1, 2]),
67+
},
68+
"train_params": {
69+
"max_epochs": 1,
70+
},
71+
},
72+
num_samples=2,
73+
seed=0,
74+
scheduler="asha",
75+
searcher="hyperopt",
76+
ignore_reinit_error=True,
77+
)
78+
assert isinstance(experiment, AutotuneExperiment)
79+
assert isinstance(experiment.result_grid, ResultGrid)
80+
result_df = experiment.result_grid.get_dataframe()
81+
assert "elbo_validation" in result_df.columns
82+
assert "validation_loss" in result_df.columns
83+
84+
4885
@pytest.mark.autotune
4986
@pytest.mark.parametrize("save_checkpoints", [True, False])
5087
def test_run_autotune_scvi_basic_mdata(save_checkpoints: bool, save_path: str):

0 commit comments

Comments
 (0)