diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index affe7cab..80dea6aa 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -16,3 +16,4 @@ py:exc DataSourceError # misc # TODO: remove when dropping support for Python 3.9 py:class Path +py:class dd.DataFrame diff --git a/docs/source/conf.py b/docs/source/conf.py index a7dd93a6..0e900d71 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,13 +17,6 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. from pathlib import Path -from docutils import nodes -from docutils.nodes import Text -from sphinx.addnodes import literal_emphasis, pending_xref -from sphinx.application import Sphinx -from sphinx.environment import BuildEnvironment -from sphinx.ext.intersphinx import missing_reference - import pseudopeople base_dir = Path(pseudopeople.__file__).parent diff --git a/setup.py b/setup.py index 64cabf46..2845e92b 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ "types-tqdm", "types-setuptools", "pyarrow-stubs", + "types-psutil", ] setup_requires = ["setuptools_scm"] diff --git a/src/pseudopeople/dataset.py b/src/pseudopeople/dataset.py index 46c311a6..e608fa77 100644 --- a/src/pseudopeople/dataset.py +++ b/src/pseudopeople/dataset.py @@ -7,7 +7,6 @@ import pandas as pd from tqdm import tqdm -from pseudopeople.configuration import Keys from pseudopeople.configuration.noise_configuration import NoiseConfiguration from pseudopeople.constants.metadata import DATEFORMATS from pseudopeople.constants.noise_type_metadata import COPY_HOUSEHOLD_MEMBER_COLS diff --git a/src/pseudopeople/interface.py b/src/pseudopeople/interface.py index 7643b879..639038fb 100644 --- a/src/pseudopeople/interface.py +++ b/src/pseudopeople/interface.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pandas as pd from loguru import logger @@ -25,6 +25,35 @@ get_state_abbreviation, ) +if TYPE_CHECKING: + import dask.dataframe as dd + + +@overload +def _generate_dataset( + dataset_schema: DatasetSchema, + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + filters: Sequence[DataFilter], + verbose: bool, + engine_name: Literal["pandas"], +) -> pd.DataFrame: + ... + + +@overload +def _generate_dataset( + dataset_schema: DatasetSchema, + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + filters: Sequence[DataFilter], + verbose: bool, + engine_name: Literal["dask"], +) -> dd.DataFrame: + ... + def _generate_dataset( dataset_schema: DatasetSchema, @@ -34,7 +63,7 @@ def _generate_dataset( filters: Sequence[DataFilter], verbose: bool = False, engine_name: Literal["pandas", "dask"] = "pandas", -) -> pd.DataFrame: +) -> pd.DataFrame | dd.DataFrame: """ Helper for generating noised datasets. @@ -67,7 +96,7 @@ def _generate_dataset( engine = get_engine_from_string(engine_name) - noised_dataset: pd.DataFrame + noised_dataset: pd.DataFrame | dd.DataFrame if engine == PANDAS_ENGINE: # We process shards serially data_file_paths = get_dataset_filepaths(source, dataset_schema.name) @@ -111,7 +140,7 @@ def _generate_dataset( noised_dataset = pd.concat(noised_datasets_list, ignore_index=True) noised_dataset = coerce_dtypes(noised_dataset, dataset_schema) - else: + else: # dask try: from distributed.client import default_client @@ -126,6 +155,8 @@ def _generate_dataset( import dask import dask.dataframe as dd + set_up_dask_client() + # Our work depends on the particulars of how dtypes work, and is only # built to work with NumPy dtypes, so we turn off the Dask default behavior # of using PyArrow dtypes. @@ -203,6 +234,7 @@ def _get_data_changelog_version(changelog: Path) -> Version: return version +@overload def generate_decennial_census( source: Path | str | None = None, seed: int = 0, @@ -210,8 +242,33 @@ def generate_decennial_census( year: int | None = 2020, state: str | None = None, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_decennial_census( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + state: str | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_decennial_census( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + state: str | None = None, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople decennial census dataset which represents simulated responses to the US Census Bureau's Census of Population @@ -301,6 +358,7 @@ def generate_decennial_census( ) +@overload def generate_american_community_survey( source: Path | str | None = None, seed: int = 0, @@ -308,8 +366,33 @@ def generate_american_community_survey( year: int | None = 2020, state: str | None = None, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_american_community_survey( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + state: str | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_american_community_survey( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + state: str | None = None, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople ACS dataset which represents simulated responses to the ACS survey. @@ -414,6 +497,7 @@ def generate_american_community_survey( ) +@overload def generate_current_population_survey( source: Path | str | None = None, seed: int = 0, @@ -421,8 +505,33 @@ def generate_current_population_survey( year: int | None = 2020, state: str | None = None, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_current_population_survey( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + state: str | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_current_population_survey( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + state: str | None = None, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople CPS dataset which represents simulated responses to the CPS survey. @@ -528,6 +637,7 @@ def generate_current_population_survey( ) +@overload def generate_taxes_w2_and_1099( source: Path | str | None = None, seed: int = 0, @@ -535,8 +645,33 @@ def generate_taxes_w2_and_1099( year: int | None = 2020, state: str | None = None, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_taxes_w2_and_1099( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + state: str | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_taxes_w2_and_1099( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + state: str | None = None, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople W2 and 1099 tax dataset which represents simulated tax form data. @@ -626,6 +761,7 @@ def generate_taxes_w2_and_1099( ) +@overload def generate_women_infants_and_children( source: Path | str | None = None, seed: int = 0, @@ -633,8 +769,33 @@ def generate_women_infants_and_children( year: int | None = 2020, state: str | None = None, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_women_infants_and_children( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + state: str | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_women_infants_and_children( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + state: str | None = None, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople WIC dataset which represents a simulated version of the administrative data that would be recorded by WIC. @@ -729,14 +890,38 @@ def generate_women_infants_and_children( ) +@overload def generate_social_security( source: Path | str | None = None, seed: int = 0, config: Path | str | dict[str, Any] | None = None, year: int | None = 2020, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_social_security( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_social_security( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople SSA dataset which represents simulated Social Security Administration (SSA) data. @@ -817,6 +1002,7 @@ def generate_social_security( ) +@overload def generate_taxes_1040( source: Path | str | None = None, seed: int = 0, @@ -824,8 +1010,33 @@ def generate_taxes_1040( year: int | None = 2020, state: str | None = None, verbose: bool = False, - engine: Literal["pandas", "dask"] = "pandas", + engine: Literal["pandas"] = "pandas", ) -> pd.DataFrame: + ... + + +@overload +def generate_taxes_1040( + source: Path | str | None, + seed: int, + config: Path | str | dict[str, Any] | None, + year: int | None, + state: str | None, + verbose: bool, + engine: Literal["dask"], +) -> dd.DataFrame: + ... + + +def generate_taxes_1040( + source: Path | str | None = None, + seed: int = 0, + config: Path | str | dict[str, Any] | None = None, + year: int | None = 2020, + state: str | None = None, + verbose: bool = False, + engine: Literal["pandas", "dask"] = "pandas", +) -> pd.DataFrame | dd.DataFrame: """ Generates a pseudopeople 1040 tax dataset which represents simulated tax form data. @@ -931,3 +1142,25 @@ def get_dataset_filepaths(source: Path, dataset_name: str) -> list[Path]: dataset_paths = [x for x in directory.glob(f"{dataset_name}*")] sorted_dataset_paths = sorted(dataset_paths) return sorted_dataset_paths + + +def set_up_dask_client() -> None: + """Sets up a Dask client if one is not already running.""" + from dask.distributed import get_client + + # Determine whether or not a Dask client is already running. If not, + # create a new one. + try: + get_client() + except ValueError: + # No Dask client is running so we create one. + from dask.distributed import LocalCluster + from dask.system import CPU_COUNT + + # extract the memory limit from the environment variable + cluster = LocalCluster( # type: ignore [no-untyped-call] + name="pseudopeople_dask_cluster", + n_workers=CPU_COUNT, + threads_per_worker=1, + ) + cluster.get_client() # type: ignore [no-untyped-call] diff --git a/tests/integration/test_interface.py b/tests/integration/test_interface.py index ef34bfd3..465dd6c3 100644 --- a/tests/integration/test_interface.py +++ b/tests/integration/test_interface.py @@ -3,8 +3,6 @@ from pathlib import Path from typing import Any -import numpy as np -import numpy.typing as npt import pandas as pd import pytest from _pytest.fixtures import FixtureRequest @@ -12,17 +10,13 @@ from pytest_mock import MockerFixture from vivarium_testing_utils import FuzzyChecker -from pseudopeople.schema_entities import COLUMNS, DATASET_SCHEMAS, Column +from pseudopeople import NO_NOISE +from pseudopeople.configuration.noise_configuration import NoiseConfiguration +from pseudopeople.schema_entities import DATASET_SCHEMAS, DatasetSchema from pseudopeople.utilities import coerce_dtypes from tests.constants import DATASET_GENERATION_FUNCS -from tests.integration.conftest import ( - IDX_COLS, - SEED, - STATE, - _get_common_datasets, - get_unnoised_data, -) -from tests.utilities import initialize_dataset_with_sample, run_column_noising_tests +from tests.integration.conftest import IDX_COLS, SEED, STATE, get_unnoised_data +from tests.utilities import initialize_dataset_with_sample @pytest.mark.parametrize( @@ -56,8 +50,6 @@ def test_noising_sharded_vs_unsharded_data( """Tests that the amount of noising is approximately the same whether we noise a single sample dataset or we concatenate and noise multiple datasets """ - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) mocker.patch("pseudopeople.interface.validate_source_compatibility") generation_function = DATASET_GENERATION_FUNCS[dataset_name] @@ -133,8 +125,6 @@ def test_seed_behavior( dataset_name: str, engine: str, config: dict[str, Any], request: FixtureRequest ) -> None: """Tests seed behavior""" - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) generation_function = DATASET_GENERATION_FUNCS[dataset_name] original = get_unnoised_data(dataset_name) if engine == "dask": @@ -211,16 +201,11 @@ def test_dataset_filter_by_year( """Mock the noising function so that it returns the date column of interest with the original (unnoised) values to ensure filtering is happening """ - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) year = 2030 # not default 2020 - # Generate a new (non-fixture) dataset for a single year but mocked such - # that no noise actually happens (otherwise the years would get noised and - # we couldn't tell if the filter was working properly) - mocker.patch("pseudopeople.dataset.Dataset._noise_dataset") + # Do not noise (noising changes values post-filtering) generation_function = DATASET_GENERATION_FUNCS[dataset_name] - data = generation_function(year=year, engine=engine) + data = generation_function(year=year, engine=engine, config=NO_NOISE) if engine == "dask": data = data.compute() dataset = DATASET_SCHEMAS.get_dataset_schema(dataset_name) @@ -242,19 +227,14 @@ def test_dataset_filter_by_year( "dask", ], ) -def test_dataset_filter_by_year_with_full_dates( - mocker: MockerFixture, dataset_name: str, engine: str -) -> None: +def test_dataset_filter_by_year_with_full_dates(dataset_name: str, engine: str) -> None: """Mock the noising function so that it returns the date column of interest with the original (unnoised) values to ensure filtering is happening """ year = 2030 # not default 2020 - # Generate a new (non-fixture) noised dataset for a single year but mocked such - # that no noise actually happens (otherwise the years would get noised and - # we couldn't tell if the filter was working properly) - mocker.patch("pseudopeople.dataset.Dataset._noise_dataset") + # Do not noise (noising changes values post-filtering) generation_function = DATASET_GENERATION_FUNCS[dataset_name] - noised_data = generation_function(year=year, engine=engine) + noised_data = generation_function(year=year, engine=engine, config=NO_NOISE) if engine == "dask": noised_data = noised_data.compute() dataset_schema = DATASET_SCHEMAS.get_dataset_schema(dataset_name) @@ -296,16 +276,13 @@ def test_generate_dataset_with_state_filtered( mocker: MockerFixture, ) -> None: """Test that values returned by dataset generators are only for the specified state""" - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) mocker.patch("pseudopeople.interface.validate_source_compatibility") dataset_schema = DATASET_SCHEMAS.get_dataset_schema(dataset_name) generation_function = DATASET_GENERATION_FUNCS[dataset_name] - # Skip noising (noising can incorrect select another state) - mocker.patch("pseudopeople.dataset.Dataset._noise_dataset") + # Do not noise (noising changes values post-filtering) noised_data = generation_function( - source=split_sample_data_dir_state_edit, state=STATE, engine=engine + source=split_sample_data_dir_state_edit, state=STATE, engine=engine, config=NO_NOISE ) if engine == "dask": noised_data = noised_data.compute() @@ -344,15 +321,14 @@ def test_generate_dataset_with_state_unfiltered( # the functionality of these functions to work but we should consider updating fixtures/tests # in the future. - albrja """Test that values returned by dataset generators are for all locations if state unspecified""" - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) mocker.patch("pseudopeople.interface.validate_source_compatibility") dataset_schema = DATASET_SCHEMAS.get_dataset_schema(dataset_name) - # Skip noising (noising can incorrect select another state) - mocker.patch("pseudopeople.dataset.Dataset._noise_dataset") generation_function = DATASET_GENERATION_FUNCS[dataset_name] - noised_data = generation_function(source=split_sample_data_dir_state_edit, engine=engine) + # Do not noise (noising changes values post-filtering) + noised_data = generation_function( + source=split_sample_data_dir_state_edit, engine=engine, config=NO_NOISE + ) assert len(noised_data[dataset_schema.state_column_name].unique()) > 1 @@ -373,24 +349,23 @@ def test_generate_dataset_with_state_unfiltered( "dask", ], ) -def test_dataset_filter_by_state_and_year( +def test_dataset_filter_by_state_and_year_foo( mocker: MockerFixture, split_sample_data_dir_state_edit: Path, dataset_name: str, engine: str, ) -> None: """Test that dataset generation works with state and year filters in conjunction""" - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) year = 2030 # not default 2020 mocker.patch("pseudopeople.interface.validate_source_compatibility") - mocker.patch("pseudopeople.dataset.Dataset._noise_dataset") generation_function = DATASET_GENERATION_FUNCS[dataset_name] + # Do not noise (noising changes values post-filtering) noised_data = generation_function( source=split_sample_data_dir_state_edit, year=year, state=STATE, engine=engine, + config=NO_NOISE, ) if engine == "dask": noised_data = noised_data.compute() @@ -419,13 +394,14 @@ def test_dataset_filter_by_state_and_year_with_full_dates( """Test that dataset generation works with state and year filters in conjunction""" year = 2030 # not default 2020 mocker.patch("pseudopeople.interface.validate_source_compatibility") - mocker.patch("pseudopeople.dataset.Dataset._noise_dataset") generation_function = DATASET_GENERATION_FUNCS[dataset_name] + # Do not noise (noising changes values post-filtering) noised_data = generation_function( source=split_sample_data_dir_state_edit, year=year, state=STATE, engine=engine, + config=NO_NOISE, ) if engine == "dask": noised_data = noised_data.compute() @@ -466,8 +442,6 @@ def test_generate_dataset_with_bad_state( mocker: MockerFixture, ) -> None: """Test that bad state values result in informative ValueErrors""" - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) bad_state = "Silly State That Doesn't Exist" mocker.patch("pseudopeople.interface.validate_source_compatibility") generation_function = DATASET_GENERATION_FUNCS[dataset_name] @@ -503,8 +477,6 @@ def test_generate_dataset_with_bad_year( dataset_name: str, engine: str, split_sample_data_dir: Path, mocker: MockerFixture ) -> None: """Test that a ValueError is raised both for a bad year and a year that has no data""" - if "TODO" in dataset_name: - pytest.skip(reason=dataset_name) bad_year = 0 no_data_year = 2000 mocker.patch("pseudopeople.interface.validate_source_compatibility") @@ -525,28 +497,3 @@ def test_generate_dataset_with_bad_year( ) if engine == "dask": df.compute() - - -#################### -# HELPER FUNCTIONS # -#################### -def _get_column_noise_level( - column: Column, - noised_data: pd.DataFrame, - unnoised_data: pd.DataFrame, - common_idx: pd.Index[int], -) -> tuple[int, pd.Index[int]]: - - # Check that originally missing data remained missing - originally_missing_sample_idx = unnoised_data.index[unnoised_data[column.name].isna()] - - assert noised_data.loc[originally_missing_sample_idx, column.name].isna().all() - - # Check for noising where applicable - to_compare_sample_idx = common_idx.difference(originally_missing_sample_idx) - different_check: npt.NDArray[np.bool_] = np.array( - unnoised_data.loc[to_compare_sample_idx, column.name].values - != noised_data.loc[to_compare_sample_idx, column.name].values - ) - - return different_check.sum(), to_compare_sample_idx diff --git a/tests/unit/test_interface.py b/tests/unit/test_interface.py index 0687aaaf..8b630b8d 100644 --- a/tests/unit/test_interface.py +++ b/tests/unit/test_interface.py @@ -1,7 +1,12 @@ +import os from pathlib import Path +import numpy as np +import psutil import pytest from _pytest.tmpdir import TempPathFactory +from dask.distributed import LocalCluster, get_client +from dask.system import CPU_COUNT from packaging.version import parse from pytest_mock import MockerFixture @@ -9,9 +14,11 @@ from pseudopeople.exceptions import DataSourceError from pseudopeople.interface import ( _get_data_changelog_version, + set_up_dask_client, validate_source_compatibility, ) from pseudopeople.schema_entities import DATASET_SCHEMAS +from tests.utilities import is_on_slurm CENSUS = DATASET_SCHEMAS.get_dataset_schema(DatasetNames.CENSUS) @@ -94,3 +101,52 @@ def test_validate_source_compatibility_wrong_directory(tmp_path: Path) -> None: bad_path.mkdir() with pytest.raises(FileNotFoundError, match="Could not find 'decennial_census' in"): validate_source_compatibility(bad_path, CENSUS) + + +def test_set_up_dask_client_default() -> None: + + # There should be no dask client yet + with pytest.raises(ValueError): + client = get_client() + + set_up_dask_client() + client = get_client() + assert isinstance(client.cluster, LocalCluster) + assert client.cluster.name == "pseudopeople_dask_cluster" + workers = client.scheduler_info()["workers"] # type: ignore[no-untyped-call] + assert len(workers) == CPU_COUNT + assert all(worker["nthreads"] == 1 for worker in workers.values()) + if is_on_slurm(): + try: + available_memory = float(os.environ["SLURM_MEM_PER_NODE"]) / 1024 + except KeyError: + raise RuntimeError( + "You are on Slurm but SLURM_MEM_PER_NODE is not set. " + "It is likely that you are SSHed onto a node (perhaps using VSCode). " + "In this case, dask will assign the total memory of the node to the " + "cluster instead of the allocated memory from the srun call. " + "Pseudopeople should only be used on Slurm directly on the node " + "assigned via an srun (both for pytests as well as actual work)." + ) + else: + available_memory = psutil.virtual_memory().total / (1024 ** 3) + assert np.isclose(sum(worker["memory_limit"] / 1024**3 for worker in workers.values()), available_memory, rtol=0.01) + + +def test_set_up_dask_client_custom() -> None: + memory_limit = 1 # gb + n_workers = 3 + cluster = LocalCluster( # type: ignore[no-untyped-call] + name="custom", + n_workers=n_workers, + threads_per_worker=2, + memory_limit=memory_limit * 1024**3, + ) + client = cluster.get_client() # type: ignore[no-untyped-call] + set_up_dask_client() + client = get_client() + assert client.cluster.name == "custom" + workers = client.scheduler_info()["workers"] + assert len(workers) == 3 + assert all(worker["nthreads"] == 2 for worker in workers.values()) + assert sum(worker["memory_limit"] / 1024**3 for worker in workers.values()) == memory_limit * n_workers diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 118a7a84..13cd5930 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,3 +1,4 @@ + import numpy as np import pandas as pd import pytest diff --git a/tests/utilities.py b/tests/utilities.py index 34e46790..97918768 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -1,8 +1,8 @@ from __future__ import annotations import math +import shutil from collections.abc import Callable -from functools import partial from typing import Any import numpy as np @@ -190,3 +190,15 @@ def get_single_noise_type_config( ] = new_probability return config_dict + + +def is_on_slurm() -> bool: + """Returns True if the current environment is a SLURM cluster. + + Notes + ----- + This function simply checks for the presence of the `sbatch` command to _infer_ + if SLURM is installed. It does _not_ check if SLURM is currently active or + managing jobs. + """ + return shutil.which("sbatch") is not None