Skip to content

Commit 1ad5826

Browse files
patricktnastrmudambi
authored andcommitted
Generate stubs for Automated V&V Phase 1 (#22)
* first pass at stubbing * lint * refine * Update src/vivarium_testing_utils/automated_validation/interface.py Co-authored-by: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> * change pass to notimplementederror --------- Co-authored-by: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com>
1 parent e5e4568 commit 1ad5826

File tree

5 files changed

+165
-0
lines changed

5 files changed

+165
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pandas as pd
2+
3+
4+
def process_raw_data(
5+
input_data_type: str, raw_data: pd.DataFrame, measure: str
6+
) -> pd.DataFrame:
7+
raise NotImplementedError
8+
9+
10+
def compute_metric(
11+
input_data_type: str, intermediate_data: pd.DataFrame, measure: str
12+
) -> pd.Series:
13+
raise NotImplementedError
14+
15+
16+
def ratio(numerator, denominator):
17+
raise NotImplementedError
18+
19+
20+
def aggregate(data: pd.DataFrame, groupby_cols: list[str]) -> pd.DataFrame:
21+
raise NotImplementedError
22+
23+
24+
def linear_combination(coefficients: list[float], data: pd.DataFrame) -> pd.Series:
25+
raise NotImplementedError
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pandas as pd
2+
3+
from vivarium_testing_utils.automated_validation.calculations import compute_metric
4+
5+
6+
class Comparison:
7+
def __init__(
8+
self,
9+
measure_key: str,
10+
test_data: pd.DataFrame,
11+
reference_data: pd.DataFrame,
12+
stratifications: list[str] = [],
13+
):
14+
self.measure = measure_key
15+
self.test_data = test_data
16+
self.reference_data = reference_data
17+
self.computed_comparison = compute_metric(
18+
self.test_data, self.reference_data, self.measure
19+
)
20+
# you need to marginalize out the non-stratified columns as well
21+
22+
def verify(self, stratifications: list[str]):
23+
raise NotImplementedError
24+
25+
def summarize(self, stratifications: list[str]):
26+
raise NotImplementedError
27+
28+
def heads(self, stratifications: list[str]):
29+
raise NotImplementedError
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pandas as pd
2+
from layered_config_tree import LayeredConfigTree
3+
4+
5+
class DataLoader:
6+
def __init__(self, results_dir: str, cache_size_mb: int = 1000):
7+
self.results_dir = results_dir
8+
self.cache_size_mb = cache_size_mb
9+
self.raw_datasets = LayeredConfigTree()
10+
self.metadata = LayeredConfigTree()
11+
self.artifact = None # Just stubbing this out for now
12+
13+
def load_data(self, dataset_key: str, data_type: str) -> None:
14+
raise NotImplementedError
15+
16+
def get_dataset(self, dataset_key: str, data_type: str) -> pd.DataFrame:
17+
raise NotImplementedError
18+
19+
def sim_outputs(self) -> list[str]:
20+
raise NotImplementedError
21+
22+
def artifact_keys(self) -> list[str]:
23+
raise NotImplementedError
24+
25+
def load_from_sim(self, dataset_key: str) -> pd.DataFrame:
26+
raise NotImplementedError
27+
28+
def load_from_artifact(self, dataset_key: str) -> pd.DataFrame:
29+
raise NotImplementedError
30+
31+
def load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
32+
raise NotImplementedError
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pathlib import Path
2+
3+
import pandas as pd
4+
from layered_config_tree import LayeredConfigTree
5+
6+
from vivarium_testing_utils.automated_validation import plot_utils
7+
from vivarium_testing_utils.automated_validation.comparison import Comparison
8+
from vivarium_testing_utils.automated_validation.data_loader import DataLoader
9+
10+
11+
class ValidationContext:
12+
def __init__(self, results_dir: str | Path, age_groups: pd.DataFrame | None):
13+
self.data_loader = DataLoader(results_dir)
14+
self.comparisons = LayeredConfigTree()
15+
16+
def get_sim_outputs(self):
17+
return self.data_loader.sim_outputs()
18+
19+
def get_artifact_keys(self):
20+
return self.data_loader.artifact_keys()
21+
22+
def add_comparison(
23+
self, measure_key: str, test_source: str, ref_source: str, stratifications: list[str]
24+
) -> None:
25+
test_data = self.data_loader.get_dataset(measure_key, test_source)
26+
ref_data = self.data_loader.get_dataset(measure_key, ref_source)
27+
self.comparisons.update(
28+
[measure_key], Comparison(measure_key, test_data, ref_data, stratifications)
29+
)
30+
31+
def verify(self, comparison_key: str, stratifications: list[str] = []):
32+
self.comparisons[comparison_key].verify(stratifications)
33+
34+
def plot_comparison(self, comparison_key: str, type: str, **kwargs):
35+
return plot_utils.plot_comparison(self.comparisons[comparison_key], type, kwargs)
36+
37+
def generate_comparisons(self):
38+
raise NotImplementedError
39+
40+
def verify_all(self):
41+
for comparison in self.comparisons.values():
42+
comparison.verify()
43+
44+
def plot_all(self):
45+
raise NotImplementedError
46+
47+
def get_results(self, verbose: bool = False):
48+
raise NotImplementedError
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pandas as pd
2+
3+
from vivarium_testing_utils.automated_validation.comparison import Comparison
4+
5+
6+
def plot_comparison(comparison: Comparison, type: str, kwargs):
7+
raise NotImplementedError
8+
9+
10+
def plot_data(dataset: pd.DataFrame, type: str, kwargs):
11+
raise NotImplementedError
12+
13+
14+
def line_plot(comparison: Comparison, x_axis: str, stratifications: list[str]):
15+
raise NotImplementedError
16+
17+
18+
def bar_plot(comparison: Comparison, x_axis: str, stratifications: list[str]):
19+
raise NotImplementedError
20+
21+
22+
def box_plot(comparison: Comparison, cat: str, stratifications: list[str]):
23+
raise NotImplementedError
24+
25+
26+
def heatmap(comparison: Comparison, row: str, col: str):
27+
raise NotImplementedError
28+
29+
30+
def save_plot(fig, name, format):
31+
raise NotImplementedError

0 commit comments

Comments
 (0)