Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"tables",
"networkx",
"loguru",
"layered_config_tree",
"pyarrow",
# Type stubs
"pandas-stubs",
"networkx-stubs",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from vivarium_testing_utils.automated_validation.data_loader import DataLoader
76 changes: 63 additions & 13 deletions src/vivarium_testing_utils/automated_validation/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,82 @@
from __future__ import annotations

from enum import Enum
from pathlib import Path

import pandas as pd
from layered_config_tree import LayeredConfigTree
from layered_config_tree import ConfigurationKeyError, LayeredConfigTree


class DataSource(Enum):
SIM = "sim"
GBD = "gbd"
ARTIFACT = "artifact"
CUSTOM = "custom"

@classmethod
def from_str(cls, source: str) -> DataSource:
try:
return cls(source)
except ValueError:
raise ValueError(f"Source {source} not recognized. Must be one of {DataSource}")


class DataLoader:
def __init__(self, results_dir: str, cache_size_mb: int = 1000):
self.results_dir = results_dir
self.results_dir = Path(results_dir)
self.sim_output_dir = self.results_dir / "results"
self.cache_size_mb = cache_size_mb
self.raw_datasets = LayeredConfigTree()
self.raw_datasets = LayeredConfigTree({data_source: {} for data_source in DataSource})
self.loader_mapping = {
DataSource.SIM: self._load_from_sim,
DataSource.GBD: self._load_from_gbd,
DataSource.ARTIFACT: self._load_from_artifact,
DataSource.CUSTOM: self._load_custom,
}
self.metadata = LayeredConfigTree()
self.artifact = None # Just stubbing this out for now

def load_data(self, dataset_key: str, data_type: str) -> None:
raise NotImplementedError
def get_sim_outputs(self) -> list[str]:
"""Get a list of the datasets in the given simulation output directory.
Only return the filename, not the extension."""
return [str(f.stem) for f in self.sim_output_dir.glob("*.parquet")]

def get_dataset(self, dataset_key: str, data_type: str) -> pd.DataFrame:
def get_artifact_keys(self) -> list[str]:
raise NotImplementedError

def sim_outputs(self) -> list[str]:
raise NotImplementedError
def get_dataset(self, dataset_key: str, source: str) -> pd.DataFrame:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should source be a string or an enum-like object? We only accept a few precise values, and I don't think the plan is for the DataLoader to be exposed to the end-user, so we don't have to worry about it being annoying for them to have to import and use these objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complication here is that the user does actually specify the data sources via the interface (when they decide what test data and reference data to compare against), and they should be able to pass a string there.

The strategy I took is to still define the data sources through ENUMs and just check against the string in get_datasets(). I'm not sure if there's an easier way to do the string check, I'm not super familiar with using enums.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the user will only call context.add_comparison() which can have a string argument to define the types of data being compared. This method will be hidden from them, so it seems it should be able to take the enum as input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but then I think you have to import DataSource into Context and pass the string through DataSource there before passing it into this method. It seemed to me like encapsulating all that into DataLoader / dataloader.py was preferable to keeping the signature consistent, but maybe we'll just need to use this enum in multiple places anyway and it won't matter so much.

"""Return the dataset from the cache if it exists, otherwise load it from the source."""
source_enum = DataSource.from_str(source)
try:
return self.raw_datasets[source_enum][dataset_key]
except ConfigurationKeyError:
dataset = self._load_from_source(dataset_key, source_enum)
self._add_to_datasets(dataset_key, source_enum, dataset)
return dataset

def artifact_keys(self) -> list[str]:
raise NotImplementedError
def _load_from_source(self, dataset_key: str, source: DataSource) -> None:
"""Load the data from the given source via the loader mapping."""
return self.loader_mapping[source](dataset_key)

def _add_to_datasets(
self, dataset_key: str, source: DataSource, data: pd.DataFrame
) -> None:
"""Update the raw_datasets cache with the given data."""
self.raw_datasets.update({source: {dataset_key: data}})

def _load_from_sim(self, dataset_key: str) -> pd.DataFrame:
"""Load the data from the simulation output directory and set the non-value columns as indices."""
sim_data = pd.read_parquet(self.sim_output_dir / f"{dataset_key}.parquet")
if "value" not in sim_data.columns:
raise ValueError(f"{dataset_key}.parquet requires a column labeled 'value'.")
sim_data = sim_data.set_index(sim_data.columns.drop("value").tolist())
return sim_data

def load_from_sim(self, dataset_key: str) -> pd.DataFrame:
def _load_from_artifact(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError

def load_from_artifact(self, dataset_key: str) -> pd.DataFrame:
def _load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError

def load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
def _load_custom(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError
14 changes: 5 additions & 9 deletions src/vivarium_testing_utils/automated_validation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,24 @@

from vivarium_testing_utils.automated_validation import plot_utils
from vivarium_testing_utils.automated_validation.comparison import Comparison
from vivarium_testing_utils.automated_validation.data_loader import DataLoader
from vivarium_testing_utils.automated_validation.data_loader import DataLoader, DataSource


class ValidationContext:
def __init__(self, results_dir: str | Path, age_groups: pd.DataFrame | None):
self.data_loader = DataLoader(results_dir)
self._data_loader = DataLoader(results_dir)
self.comparisons = LayeredConfigTree()

def get_sim_outputs(self):
return self.data_loader.sim_outputs()
return self._data_loader.sim_outputs()

def get_artifact_keys(self):
return self.data_loader.artifact_keys()
return self._data_loader.artifact_keys()

def add_comparison(
self, measure_key: str, test_source: str, ref_source: str, stratifications: list[str]
) -> None:
test_data = self.data_loader.get_dataset(measure_key, test_source)
ref_data = self.data_loader.get_dataset(measure_key, ref_source)
self.comparisons.update(
[measure_key], Comparison(measure_key, test_data, ref_data, stratifications)
)
raise NotImplementedError

def verify(self, comparison_key: str, stratifications: list[str] = []):
self.comparisons[comparison_key].verify(stratifications)
Expand Down
8 changes: 8 additions & 0 deletions tests/automated_validation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pathlib import Path

import pytest


@pytest.fixture
def sim_result_dir():
return Path(__file__).parent / "data/sim_outputs"
Binary file not shown.
Binary file not shown.
Binary file not shown.
78 changes: 78 additions & 0 deletions tests/automated_validation/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from pathlib import Path
from unittest.mock import MagicMock

import pandas as pd
import pytest

from vivarium_testing_utils.automated_validation.data_loader import DataLoader, DataSource


def test_get_sim_outputs(sim_result_dir: Path) -> None:
"""Test we have the correctly truncated sim data keys"""
data_loader = DataLoader(sim_result_dir)
assert set(data_loader.get_sim_outputs()) == {
"deaths",
"person_time_cause",
"transition_count_cause",
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we stick with source being defined as a string, we need a test where a bad source string is provided.

def test_get_dataset_bad_source(sim_result_dir: Path) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you take my suggestion above, this error would be thrown by context.add_comparison()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this and added a stub test for later, since add_comparison() isn't implemented yet.

"""Ensure that we raise an error if the source is not recognized"""
data_loader = DataLoader(sim_result_dir)
with pytest.raises(ValueError):
data_loader.get_dataset("deaths", "foo")


def test_get_dataset(sim_result_dir: Path) -> None:
"""Ensure that we load data from disk if needed, and don't if not."""
data_loader = DataLoader(sim_result_dir)
# check that we call load_from_source the first time we call get_dataset
data_loader._load_from_source = MagicMock()
data_loader.get_dataset("deaths", "sim"), pd.DataFrame
data_loader._load_from_source.assert_called_once_with("deaths", DataSource.SIM)
# check that we don't call load_from_source the second time we call get_dataset
data_loader._load_from_source = MagicMock()
data_loader.get_dataset("deaths", "sim"), pd.DataFrame
data_loader._load_from_source.assert_not_called()


@pytest.mark.parametrize(
"dataset_key, source",
[
("deaths", DataSource.SIM),
],
)
def load_from_source(dataset_key: str, source: DataSource, sim_result_dir: Path) -> None:
"""Ensure we can sensibly load using key / source combinations"""
data_loader = DataLoader(sim_result_dir)
assert not data_loader.raw_datasets.get(source).get(dataset_key)
data_loader._load_from_source(dataset_key, source)
assert data_loader.raw_datasets.get(source).get(dataset_key)


def test_add_to_datasets(sim_result_dir: Path) -> None:
"""Ensure that we can add data to the cache"""
df = pd.DataFrame({"baz": [1, 2, 3]})
data_loader = DataLoader(sim_result_dir)
data_loader._add_to_datasets("foo", "bar", df)
assert data_loader.raw_datasets.get("bar").get("foo").equals(df)


def test_load_from_sim(sim_result_dir: Path) -> None:
"""Ensure that we can load data from the simulation output directory"""
data_loader = DataLoader(sim_result_dir)
person_time_cause = data_loader._load_from_sim("deaths")
assert person_time_cause.shape == (8, 1)
# check that value is column and rest are indices
assert person_time_cause.index.names == [
"measure",
"entity_type",
"entity",
"sub_entity",
"age_group",
"sex",
"input_draw",
"random_seed",
]
assert person_time_cause.columns == ["value"]