-
Notifications
You must be signed in to change notification settings - Fork 0
Load data from simulation #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
4eeeb1f
203b4cd
ef6362d
8555a9c
a933823
cfbd1b3
4cd7819
29cb21b
2c2eb14
1d80041
f903ac8
dec92b8
c28eb73
eb8f284
d16a678
72af127
4b7c960
c22cb26
269941c
d66aa04
ead9479
ea4f762
d255c32
11a2b14
187b4a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from vivarium_testing_utils.automated_validation.data_loader import DataLoader |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,64 @@ | ||
| from pathlib import Path | ||
|
|
||
| import pandas as pd | ||
| from layered_config_tree import LayeredConfigTree | ||
| from layered_config_tree import ConfigurationKeyError, LayeredConfigTree | ||
|
|
||
|
|
||
| class DataLoader: | ||
| def __init__(self, results_dir: str, cache_size_mb: int = 1000): | ||
| self.results_dir = results_dir | ||
| self.results_dir = Path(results_dir) | ||
| self.sim_output_dir = self.results_dir / "results" | ||
| self.cache_size_mb = cache_size_mb | ||
| self.raw_datasets = LayeredConfigTree() | ||
| self.raw_datasets = LayeredConfigTree( | ||
| {"sim": {}, "gbd": {}, "artifact": {}, "custom": {}} | ||
|
||
| ) | ||
| self.loader_mapping = { | ||
| "sim": self.load_from_sim, | ||
| "gbd": self.load_from_gbd, | ||
| "artifact": self.load_from_artifact, | ||
| "custom": self.load_custom, | ||
| } | ||
| self.metadata = LayeredConfigTree() | ||
| self.artifact = None # Just stubbing this out for now | ||
|
|
||
| def load_data(self, dataset_key: str, data_type: str) -> None: | ||
| raise NotImplementedError | ||
| def get_sim_outputs(self) -> list[str]: | ||
| """Get a list of the datasets in the given simulation output directory. | ||
| Only return the filename, not the extension.""" | ||
| return [str(f.stem) for f in self.sim_output_dir.glob("*.parquet")] | ||
|
|
||
| def get_dataset(self, dataset_key: str, data_type: str) -> pd.DataFrame: | ||
| def get_artifact_keys(self) -> list[str]: | ||
| raise NotImplementedError | ||
|
|
||
| def sim_outputs(self) -> list[str]: | ||
| raise NotImplementedError | ||
| def get_dataset(self, dataset_key: str, source: str) -> pd.DataFrame: | ||
|
||
| """Return the dataset from the cache if it exists, otherwise load it from the source.""" | ||
| try: | ||
| return self.raw_datasets[source][dataset_key] | ||
| except ConfigurationKeyError: | ||
| dataset = self.load_from_source(dataset_key, source) | ||
| self.add_to_datasets(dataset_key, source, dataset) | ||
| return dataset | ||
|
|
||
| def artifact_keys(self) -> list[str]: | ||
| raise NotImplementedError | ||
| def load_from_source(self, dataset_key: str, source: str) -> None: | ||
|
||
| """Load the data from the given source via the loader mapping.""" | ||
| return self.loader_mapping[source](dataset_key) | ||
|
|
||
| def add_to_datasets(self, dataset_key: str, source: str, data: pd.DataFrame) -> None: | ||
|
||
| """Update the raw_datasets cache with the given data.""" | ||
| self.raw_datasets.update({source: {dataset_key: data}}) | ||
|
|
||
| def load_from_sim(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
| """Load the data from the simulation output directory and set the non-value columns as indices.""" | ||
| sim_data = pd.read_parquet(self.sim_output_dir / f"{dataset_key}.parquet") | ||
| if "value" not in sim_data.columns: | ||
| raise ValueError(f"Value column not found in {dataset_key}.parquet") | ||
|
||
| sim_data = sim_data.set_index(sim_data.columns.drop("value").tolist()) | ||
| return sim_data | ||
|
|
||
| def load_from_artifact(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
|
|
||
| def load_from_gbd(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
|
|
||
| def load_custom(self, dataset_key: str) -> pd.DataFrame: | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sim_result_dir(): | ||
| return "tests/automated_validation/data/sim_outputs" | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from unittest.mock import MagicMock | ||
|
|
||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| from vivarium_testing_utils.automated_validation.data_loader import DataLoader | ||
|
|
||
|
|
||
| def test_get_sim_outputs(sim_result_dir): | ||
|
||
| """Test we have the correctly truncated sim data keys""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| assert set(data_loader.get_sim_outputs()) == { | ||
| "deaths", | ||
| "person_time_cause", | ||
| "transition_count_cause", | ||
| } | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we stick with |
||
| def test_get_dataset(sim_result_dir): | ||
| """Ensure that we load data from disk if needed, and don't if not.""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| # check that we call load_from_source the first time we call get_dataset | ||
| data_loader.load_from_source = MagicMock() | ||
| data_loader.get_dataset("deaths", "sim"), pd.DataFrame | ||
| data_loader.load_from_source.assert_called_once_with("deaths", "sim") | ||
| # check that we don't call load_from_source the second time we call get_dataset | ||
| data_loader.load_from_source = MagicMock() | ||
| data_loader.get_dataset("deaths", "sim"), pd.DataFrame | ||
| data_loader.load_from_source.assert_not_called() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "dataset_key, source", | ||
| [ | ||
| ("deaths", "sim"), | ||
| ], | ||
| ) | ||
| def load_from_source(dataset_key, source, sim_result_dir): | ||
| """Ensure we can sensibly load using key / source combinations""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| assert not data_loader.raw_datasets.get(source).get(dataset_key) | ||
| data_loader.load_from_source(dataset_key, source) | ||
| assert data_loader.raw_datasets.get(source).get(dataset_key) | ||
|
|
||
|
|
||
| def test_add_to_datasets(sim_result_dir): | ||
| """Ensure that we can add data to the cache""" | ||
| df = pd.DataFrame({"baz": [1, 2, 3]}) | ||
| data_loader = DataLoader(sim_result_dir) | ||
| data_loader.add_to_datasets("foo", "bar", df) | ||
| assert data_loader.raw_datasets.get("bar").get("foo").equals(df) | ||
|
|
||
|
|
||
| def test_load_from_sim(sim_result_dir): | ||
| """Ensure that we can load data from the simulation output directory""" | ||
| data_loader = DataLoader(sim_result_dir) | ||
| person_time_cause = data_loader.load_from_sim("deaths") | ||
| assert person_time_cause.shape == (8, 1) | ||
| # check that value is column and rest are indices | ||
| assert person_time_cause.index.names == [ | ||
| "measure", | ||
| "entity_type", | ||
| "entity", | ||
| "sub_entity", | ||
| "age_group", | ||
| "sex", | ||
| "input_draw", | ||
| "random_seed", | ||
| ] | ||
| assert person_time_cause.columns == ["value"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want this to be a
LayeredConfigTreerather than just a dict? I see we're chaining.get()calls in a test. Was that the reason?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i got a suggestion to use LCT for the 'dot' key access. There's no fundamental reason not to use a dict