Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"tables",
"networkx",
"loguru",
"layered_config_tree",
"pyarrow",
# Type stubs
"pandas-stubs",
"networkx-stubs",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from vivarium_testing_utils.automated_validation.data_loader import DataLoader
54 changes: 43 additions & 11 deletions src/vivarium_testing_utils/automated_validation/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,64 @@
from pathlib import Path

import pandas as pd
from layered_config_tree import LayeredConfigTree
from layered_config_tree import ConfigurationKeyError, LayeredConfigTree


class DataLoader:
def __init__(self, results_dir: str, cache_size_mb: int = 1000):
self.results_dir = results_dir
self.results_dir = Path(results_dir)
self.sim_output_dir = self.results_dir / "results"
self.cache_size_mb = cache_size_mb
self.raw_datasets = LayeredConfigTree()
self.raw_datasets = LayeredConfigTree(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want this to be a LayeredConfigTree rather than just a dict? I see we're chaining .get() calls in a test. Was that the reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i got a suggestion to use LCT for the 'dot' key access. There's no fundamental reason not to use a dict

{"sim": {}, "gbd": {}, "artifact": {}, "custom": {}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you imagine custom representing? I'm not sure how we could handle custom data unless we had bespoke transformation functions for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking users could just essentially upload a dataframe satisfying some minimal formatting requirements (i.e., has a 'value' column).
Yes, probably they would then have to create their own Metric Calculation function or something to tell the VC what to do with it. I don't think I have tickets for that. maybe the best thing to do is ticket those separately and call them lower priority?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good way to handle it

)
self.loader_mapping = {
"sim": self.load_from_sim,
"gbd": self.load_from_gbd,
"artifact": self.load_from_artifact,
"custom": self.load_custom,
}
self.metadata = LayeredConfigTree()
self.artifact = None # Just stubbing this out for now

def load_data(self, dataset_key: str, data_type: str) -> None:
raise NotImplementedError
def get_sim_outputs(self) -> list[str]:
"""Get a list of the datasets in the given simulation output directory.
Only return the filename, not the extension."""
return [str(f.stem) for f in self.sim_output_dir.glob("*.parquet")]

def get_dataset(self, dataset_key: str, data_type: str) -> pd.DataFrame:
def get_artifact_keys(self) -> list[str]:
raise NotImplementedError

def sim_outputs(self) -> list[str]:
raise NotImplementedError
def get_dataset(self, dataset_key: str, source: str) -> pd.DataFrame:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should source be a string or an enum-like object? We only accept a few precise values, and I don't think the plan is for the DataLoader to be exposed to the end-user, so we don't have to worry about it being annoying for them to have to import and use these objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complication here is that the user does actually specify the data sources via the interface (when they decide what test data and reference data to compare against), and they should be able to pass a string there.

The strategy I took is to still define the data sources through ENUMs and just check against the string in get_datasets(). I'm not sure if there's an easier way to do the string check, I'm not super familiar with using enums.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the user will only call context.add_comparison() which can have a string argument to define the types of data being compared. This method will be hidden from them, so it seems it should be able to take the enum as input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but then I think you have to import DataSource into Context and pass the string through DataSource there before passing it into this method. It seemed to me like encapsulating all that into DataLoader / dataloader.py was preferable to keeping the signature consistent, but maybe we'll just need to use this enum in multiple places anyway and it won't matter so much.

"""Return the dataset from the cache if it exists, otherwise load it from the source."""
try:
return self.raw_datasets[source][dataset_key]
except ConfigurationKeyError:
dataset = self.load_from_source(dataset_key, source)
self.add_to_datasets(dataset_key, source, dataset)
return dataset

def artifact_keys(self) -> list[str]:
raise NotImplementedError
def load_from_source(self, dataset_key: str, source: str) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is, this will throw a KeyError if an unexpected source is provided.

"""Load the data from the given source via the loader mapping."""
return self.loader_mapping[source](dataset_key)

def add_to_datasets(self, dataset_key: str, source: str, data: pd.DataFrame) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a private method?

"""Update the raw_datasets cache with the given data."""
self.raw_datasets.update({source: {dataset_key: data}})

def load_from_sim(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError
"""Load the data from the simulation output directory and set the non-value columns as indices."""
sim_data = pd.read_parquet(self.sim_output_dir / f"{dataset_key}.parquet")
if "value" not in sim_data.columns:
raise ValueError(f"Value column not found in {dataset_key}.parquet")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this error message makes it look like a generic value column is missing, when in fact it is the column "value".

sim_data = sim_data.set_index(sim_data.columns.drop("value").tolist())
return sim_data

def load_from_artifact(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError

def load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError

def load_custom(self, dataset_key: str) -> pd.DataFrame:
raise NotImplementedError
6 changes: 6 additions & 0 deletions tests/automated_validation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytest


@pytest.fixture
def sim_result_dir():
return "tests/automated_validation/data/sim_outputs"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would using something like Path(file).resolve().parent be safer here?

Binary file not shown.
Binary file not shown.
Binary file not shown.
70 changes: 70 additions & 0 deletions tests/automated_validation/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest.mock import MagicMock

import pandas as pd
import pytest

from vivarium_testing_utils.automated_validation.data_loader import DataLoader


def test_get_sim_outputs(sim_result_dir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're writing new code, can you include type hints?

"""Test we have the correctly truncated sim data keys"""
data_loader = DataLoader(sim_result_dir)
assert set(data_loader.get_sim_outputs()) == {
"deaths",
"person_time_cause",
"transition_count_cause",
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we stick with source being defined as a string, we need a test where a bad source string is provided.

def test_get_dataset(sim_result_dir):
"""Ensure that we load data from disk if needed, and don't if not."""
data_loader = DataLoader(sim_result_dir)
# check that we call load_from_source the first time we call get_dataset
data_loader.load_from_source = MagicMock()
data_loader.get_dataset("deaths", "sim"), pd.DataFrame
data_loader.load_from_source.assert_called_once_with("deaths", "sim")
# check that we don't call load_from_source the second time we call get_dataset
data_loader.load_from_source = MagicMock()
data_loader.get_dataset("deaths", "sim"), pd.DataFrame
data_loader.load_from_source.assert_not_called()


@pytest.mark.parametrize(
"dataset_key, source",
[
("deaths", "sim"),
],
)
def load_from_source(dataset_key, source, sim_result_dir):
"""Ensure we can sensibly load using key / source combinations"""
data_loader = DataLoader(sim_result_dir)
assert not data_loader.raw_datasets.get(source).get(dataset_key)
data_loader.load_from_source(dataset_key, source)
assert data_loader.raw_datasets.get(source).get(dataset_key)


def test_add_to_datasets(sim_result_dir):
"""Ensure that we can add data to the cache"""
df = pd.DataFrame({"baz": [1, 2, 3]})
data_loader = DataLoader(sim_result_dir)
data_loader.add_to_datasets("foo", "bar", df)
assert data_loader.raw_datasets.get("bar").get("foo").equals(df)


def test_load_from_sim(sim_result_dir):
"""Ensure that we can load data from the simulation output directory"""
data_loader = DataLoader(sim_result_dir)
person_time_cause = data_loader.load_from_sim("deaths")
assert person_time_cause.shape == (8, 1)
# check that value is column and rest are indices
assert person_time_cause.index.names == [
"measure",
"entity_type",
"entity",
"sub_entity",
"age_group",
"sex",
"input_draw",
"random_seed",
]
assert person_time_cause.columns == ["value"]