Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/vivarium_testing_utils/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,89 @@
via the pytest11 entry point defined in setup.py.
"""

import shutil
from datetime import datetime

import pytest
from _pytest.config import Config, argparsing
from _pytest.python import Function
from layered_config_tree import LayeredConfigTree
from pytest_mock import MockerFixture

SLOW_TEST_DAY = "Sunday"


def pytest_addoption(parser: argparsing.Parser) -> None:
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
parser.addoption(
"--slurm-project",
type=str,
default="proj_simscience",
help="SLURM project for cluster tests (default: proj_simscience)",
)


def pytest_configure(config: Config) -> None:
config.addinivalue_line("markers", "slow: mark test as slow to run")
config.addinivalue_line(
"markers", "cluster: mark test as requiring a SLURM cluster environment"
)


def pytest_collection_modifyitems(config: Config, items: list[Function]) -> None:
if not config.getoption("--runslow"):
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)

if not is_on_slurm():
skip_cluster = pytest.mark.skip(reason="not running on SLURM cluster")
for item in items:
if "cluster" in item.keywords:
item.add_marker(skip_cluster)

# Weekly tests also require it to be the slow test day
if not is_slow_test_day():
skip_weekly = pytest.mark.skip(
reason="not the designated slow test day for weekly tests"
)
for item in items:
if "weekly" in item.keywords:
item.add_marker(skip_weekly)


def is_on_slurm() -> bool:
"""Returns True if the current environment is a SLURM cluster."""
return shutil.which("sbatch") is not None


def is_slow_test_day(slow_test_day: str = SLOW_TEST_DAY) -> bool:
"""Determine if today is the day to run slow/weekly tests.

Parameters
----------
slow_test_day
The day to run the weekly tests on. Acceptable values are "Monday",
"Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", or "Sunday".
Default is "Sunday".

Notes
-----
There is some risk that a test will be inadvertently skipped if there is a
significant delay between when a pipeline is kicked off and when the test
itself is run.
"""
return [
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
][datetime.today().weekday()] == slow_test_day


@pytest.fixture
def no_gbd_cache(mocker: MockerFixture) -> None:
Expand Down
8 changes: 0 additions & 8 deletions tests/automated_validation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,6 @@ def reference_weights() -> pd.DataFrame:
)


def is_on_slurm() -> bool:
"""Returns True if the current environment is a SLURM cluster."""
return shutil.which("sbatch") is not None


IS_ON_SLURM = is_on_slurm()


@pytest.fixture
def gbd_pop() -> pd.DataFrame:
"""Sample GBD population structure data."""
Expand Down
5 changes: 1 addition & 4 deletions tests/automated_validation/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pytest_mock import MockFixture
from vivarium_inputs import interface

from tests.automated_validation.conftest import IS_ON_SLURM
from vivarium_testing_utils.automated_validation.bundle import RatioMeasureDataBundle
from vivarium_testing_utils.automated_validation.comparison import FuzzyComparison
from vivarium_testing_utils.automated_validation.constants import (
Expand Down Expand Up @@ -315,10 +314,8 @@ def test_fuzzy_comparison_align_datasets_calculation(


@pytest.mark.slow
@pytest.mark.cluster
def test_comparison_with_gbd_init(sim_result_dir: Path) -> None:
if not IS_ON_SLURM:
pytest.skip("No cluster access to use GBD data.")

age_bins = interface.get_age_bins()
age_bins.index.rename({"age_group_name": INPUT_DATA_INDEX_NAMES.AGE_GROUP}, inplace=True)

Expand Down
4 changes: 1 addition & 3 deletions tests/automated_validation/test_data_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pytest_mock import MockFixture
from vivarium_inputs import interface

from tests.automated_validation.conftest import IS_ON_SLURM
from vivarium_testing_utils.automated_validation.bundle import RatioMeasureDataBundle
from vivarium_testing_utils.automated_validation.constants import (
DRAW_INDEX,
Expand Down Expand Up @@ -222,10 +221,9 @@ def test_aggregate_reference_stratifications(


@pytest.mark.slow
@pytest.mark.cluster
def test_data_bundle_gbd_source(sim_result_dir: Path) -> None:
"""Test that GBD data source is handled correctly in RatioMeasureDataBundle."""
if not IS_ON_SLURM:
pytest.skip("GBD access not available for this test.")

age_bins = interface.get_age_bins()
age_bins.index.rename({"age_group_name": INPUT_DATA_INDEX_NAMES.AGE_GROUP}, inplace=True)
Expand Down
4 changes: 1 addition & 3 deletions tests/automated_validation/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas as pd
import pytest

from tests.automated_validation.conftest import IS_ON_SLURM
from vivarium_testing_utils.automated_validation.constants import (
DRAW_INDEX,
POPULATION_STRUCTURE_ARTIFACT_KEY,
Expand Down Expand Up @@ -222,6 +221,7 @@ def test___get_raw_data_from_source(


@pytest.mark.slow
@pytest.mark.cluster
@pytest.mark.parametrize(
"key",
[
Expand All @@ -232,8 +232,6 @@ def test___get_raw_data_from_source(
)
def test__load_gbd_data(key: str, sim_result_dir: Path) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need the cluster mark here?

"""Ensure that we can load standard GBD data"""
if not IS_ON_SLURM:
pytest.skip("No access to IHME cluster to extract GBD data.")

data_loader = DataLoader(sim_result_dir)
gbd_data = data_loader._load_from_gbd(key)
Expand Down
13 changes: 3 additions & 10 deletions tests/automated_validation/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from vivarium.framework.artifact.artifact import ArtifactException
from vivarium_inputs import interface

from tests.automated_validation.conftest import (
IS_ON_SLURM,
get_model_spec,
load_exposure_categories,
)
from tests.automated_validation.conftest import get_model_spec, load_exposure_categories
from vivarium_testing_utils.automated_validation.constants import (
DRAW_INDEX,
INPUT_DATA_INDEX_NAMES,
Expand Down Expand Up @@ -328,6 +324,7 @@ def test_get_frame_different_test_source(test_source: str, sim_result_dir: Path)
assert set(data.columns) == {"test_rate", "reference_rate", "percent_error"}


@pytest.mark.cluster
@pytest.mark.parametrize(
"data_key",
[
Expand All @@ -344,8 +341,6 @@ def test_get_frame_different_test_source(test_source: str, sim_result_dir: Path)
def test_cache_gbd_data(sim_result_dir: Path, data_key: str) -> None:
"""Tests that we can cache custom GBD and retreive it. More importantly, tests that
GBD data is properly mapped from id columns to value columns upon caching."""
if not IS_ON_SLURM:
pytest.skip("No access to slurm shared filesystem available for testing.")

context = ValidationContext(sim_result_dir)
# NOTE: Some of these CSVs are reused but have the same schema. Users will be expected to
Expand Down Expand Up @@ -557,14 +552,12 @@ def test_get_frame_filters(mocker: MockFixture, sim_result_dir: Path) -> None:
],
)
@pytest.mark.slow
@pytest.mark.cluster
def test_compare_artifact_and_gbd(
integration_artifact_data_mapper: dict[str, pd.DataFrame | str],
tmp_path_factory: TempPathFactory,
data_key: str,
) -> None:
if not IS_ON_SLURM:
pytest.skip("No cluster access to use GBD data.")

# Create sim output directory
tmp_path = tmp_path_factory.mktemp("model_run_output")
# Create the directory structure
Expand Down
20 changes: 0 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
from typing import Generator

import pytest
from _pytest.config import Config, argparsing
from _pytest.logging import LogCaptureFixture
from _pytest.python import Function
from loguru import logger


def pytest_addoption(parser: argparsing.Parser) -> None:
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")


def pytest_configure(config: Config) -> None:
config.addinivalue_line("markers", "slow: mark test as slow to run")


def pytest_collection_modifyitems(config: Config, items: list[Function]) -> None:
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)


@pytest.fixture
def caplog(caplog: LogCaptureFixture) -> Generator[LogCaptureFixture, None, None]:
handler_id = logger.add(
Expand Down