-
Notifications
You must be signed in to change notification settings - Fork 0
Load data from simulation #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
4eeeb1f
203b4cd
ef6362d
8555a9c
a933823
cfbd1b3
4cd7819
29cb21b
2c2eb14
1d80041
f903ac8
dec92b8
c28eb73
eb8f284
d16a678
72af127
4b7c960
c22cb26
269941c
d66aa04
ead9479
ea4f762
d255c32
11a2b14
187b4a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from vivarium_testing_utils.automated_validation.data_loader import DataLoader |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,82 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from enum import Enum | ||
| from pathlib import Path | ||
|
|
||
| import pandas as pd | ||
| from layered_config_tree import LayeredConfigTree | ||
| from layered_config_tree import ConfigurationKeyError, LayeredConfigTree | ||
|
|
||
|
|
||
| class DataSource(Enum): | ||
| SIM = "sim" | ||
| GBD = "gbd" | ||
| ARTIFACT = "artifact" | ||
| CUSTOM = "custom" | ||
|
|
||
| @classmethod | ||
| def from_str(cls, source: str) -> DataSource: | ||
| try: | ||
| return cls(source) | ||
| except ValueError: | ||
| raise ValueError(f"Source {source} not recognized. Must be one of {DataSource}") | ||
|
|
||
|
|
||
| class DataLoader: | ||
| def __init__(self, results_dir: str, cache_size_mb: int = 1000): | ||
| self.results_dir = results_dir | ||
| self.results_dir = Path(results_dir) | ||
| self.sim_output_dir = self.results_dir / "results" | ||
| self.cache_size_mb = cache_size_mb | ||
| self.raw_datasets = LayeredConfigTree() | ||
| self.raw_datasets = LayeredConfigTree({data_source: {} for data_source in DataSource}) | ||
| self.loader_mapping = { | ||
| DataSource.SIM: self._load_from_sim, | ||
| DataSource.GBD: self._load_from_gbd, | ||
| DataSource.ARTIFACT: self._load_from_artifact, | ||
| DataSource.CUSTOM: self._load_custom, | ||
| } | ||
| self.metadata = LayeredConfigTree() | ||
| self.artifact = None # Just stubbing this out for now | ||
|
|
||
| def load_data(self, dataset_key: str, data_type: str) -> None: | ||
| raise NotImplementedError | ||
| def get_sim_outputs(self) -> list[str]: | ||
| """Get a list of the datasets in the given simulation output directory. | ||
| Only return the filename, not the extension.""" | ||
| return [str(f.stem) for f in self.sim_output_dir.glob("*.parquet")] | ||
|
|
||
| def get_dataset(self, dataset_key: str, data_type: str) -> pd.DataFrame: | ||
| def get_artifact_keys(self) -> list[str]: | ||
| raise NotImplementedError | ||
|
|
||
| def sim_outputs(self) -> list[str]: | ||
| raise NotImplementedError | ||
| def get_dataset(self, dataset_key: str, source: str) -> pd.DataFrame: | ||
| """Return the dataset from the cache if it exists, otherwise load it from the source.""" | ||
| source_enum = DataSource.from_str(source) | ||
| try: | ||
| return self.raw_datasets[source_enum][dataset_key] | ||
| except ConfigurationKeyError: | ||
| dataset = self._load_from_source(dataset_key, source_enum) | ||
| self._add_to_datasets(dataset_key, source_enum, dataset) | ||
| return dataset | ||
|
|
||
| def artifact_keys(self) -> list[str]: | ||
| raise NotImplementedError | ||
| def _load_from_source(self, dataset_key: str, source: DataSource) -> None: | ||
| """Load the data from the given source via the loader mapping.""" | ||
| return self.loader_mapping[source](dataset_key) | ||
|
|
||
| def _add_to_datasets( | ||
| self, dataset_key: str, source: DataSource, data: pd.DataFrame | ||
| ) -> None: | ||
| """Update the raw_datasets cache with the given data.""" | ||
| self.raw_datasets.update({source: {dataset_key: data}}) | ||
|
|
||
| def _load_from_sim(self, dataset_key: str) -> pd.DataFrame: | ||
| """Load the data from the simulation output directory and set the non-value columns as indices.""" | ||
| sim_data = pd.read_parquet(self.sim_output_dir / f"{dataset_key}.parquet") | ||
| if "value" not in sim_data.columns: | ||
| raise ValueError(f"{dataset_key}.parquet requires a column labeled 'value'.") | ||
| sim_data = sim_data.set_index(sim_data.columns.drop("value").tolist()) | ||
| return sim_data | ||
|
|
||
| def load_from_sim(self, dataset_key: str) -> pd.DataFrame: | ||
| def _load_from_artifact(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
|
|
||
| def load_from_artifact(self, dataset_key: str) -> pd.DataFrame: | ||
| def _load_from_gbd(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
|
|
||
| def load_from_gbd(self, dataset_key: str) -> pd.DataFrame: | ||
| def _load_custom(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sim_result_dir(): | ||
| return Path(__file__).parent / "data/sim_outputs" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| from pathlib import Path | ||
| from unittest.mock import MagicMock | ||
|
|
||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| from vivarium_testing_utils.automated_validation.data_loader import DataLoader, DataSource | ||
|
|
||
|
|
||
| def test_get_sim_outputs(sim_result_dir: Path) -> None: | ||
| """Test we have the correctly truncated sim data keys""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| assert set(data_loader.get_sim_outputs()) == { | ||
| "deaths", | ||
| "person_time_cause", | ||
| "transition_count_cause", | ||
| } | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we stick with |
||
| def test_get_dataset_bad_source(sim_result_dir: Path) -> None: | ||
|
||
| """Ensure that we raise an error if the source is not recognized""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| with pytest.raises(ValueError): | ||
| data_loader.get_dataset("deaths", "foo") | ||
|
|
||
|
|
||
| def test_get_dataset(sim_result_dir: Path) -> None: | ||
| """Ensure that we load data from disk if needed, and don't if not.""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| # check that we call load_from_source the first time we call get_dataset | ||
| data_loader._load_from_source = MagicMock() | ||
| data_loader.get_dataset("deaths", "sim"), pd.DataFrame | ||
| data_loader._load_from_source.assert_called_once_with("deaths", DataSource.SIM) | ||
| # check that we don't call load_from_source the second time we call get_dataset | ||
| data_loader._load_from_source = MagicMock() | ||
| data_loader.get_dataset("deaths", "sim"), pd.DataFrame | ||
| data_loader._load_from_source.assert_not_called() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "dataset_key, source", | ||
| [ | ||
| ("deaths", DataSource.SIM), | ||
| ], | ||
| ) | ||
| def load_from_source(dataset_key: str, source: DataSource, sim_result_dir: Path) -> None: | ||
| """Ensure we can sensibly load using key / source combinations""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| assert not data_loader.raw_datasets.get(source).get(dataset_key) | ||
| data_loader._load_from_source(dataset_key, source) | ||
| assert data_loader.raw_datasets.get(source).get(dataset_key) | ||
|
|
||
|
|
||
| def test_add_to_datasets(sim_result_dir: Path) -> None: | ||
| """Ensure that we can add data to the cache""" | ||
| df = pd.DataFrame({"baz": [1, 2, 3]}) | ||
| data_loader = DataLoader(sim_result_dir) | ||
| data_loader._add_to_datasets("foo", "bar", df) | ||
| assert data_loader.raw_datasets.get("bar").get("foo").equals(df) | ||
|
|
||
|
|
||
| def test_load_from_sim(sim_result_dir: Path) -> None: | ||
| """Ensure that we can load data from the simulation output directory""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| person_time_cause = data_loader._load_from_sim("deaths") | ||
| assert person_time_cause.shape == (8, 1) | ||
| # check that value is column and rest are indices | ||
| assert person_time_cause.index.names == [ | ||
| "measure", | ||
| "entity_type", | ||
| "entity", | ||
| "sub_entity", | ||
| "age_group", | ||
| "sex", | ||
| "input_draw", | ||
| "random_seed", | ||
| ] | ||
| assert person_time_cause.columns == ["value"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should source be a string or an enum-like object? We only accept a few precise values, and I don't think the plan is for the
DataLoaderto be exposed to the end-user, so we don't have to worry about it being annoying for them to have to import and use these objects.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The complication here is that the user does actually specify the data sources via the interface (when they decide what test data and reference data to compare against), and they should be able to pass a string there.
The strategy I took is to still define the data sources through ENUMs and just check against the string in
get_datasets(). I'm not sure if there's an easier way to do the string check, I'm not super familiar with using enums.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the user will only call
context.add_comparison()which can have a string argument to define the types of data being compared. This method will be hidden from them, so it seems it should be able to take the enum as input?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but then I think you have to import
DataSourceintoContextand pass the string throughDataSourcethere before passing it into this method. It seemed to me like encapsulating all that into DataLoader / dataloader.py was preferable to keeping the signature consistent, but maybe we'll just need to use this enum in multiple places anyway and it won't matter so much.