Skip to content

Commit 8a02afe

Browse files
support only a path to data root directory (#57)
1 parent 0ba7e0e commit 8a02afe

File tree

3 files changed

+152
-70
lines changed

3 files changed

+152
-70
lines changed

src/pseudopeople/entity_types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from vivarium import ConfigTree
77
from vivarium.framework.randomness import RandomnessStream
88

9-
from pseudopeople import schema_entities
109
from pseudopeople.utilities import get_index_to_noise
1110

1211

@@ -82,7 +81,7 @@ def __call__(
8281
column.loc[to_noise_idx], configuration, randomness_stream, additional_key
8382
)
8483

85-
# Coerce noised column dtype back to original column's if it's changed
84+
# Coerce noised column dtype back to original column's if it has changed
8685
if noised_data.dtype.name != column.dtype.name:
8786
noised_data = noised_data.astype(column.dtype)
8887

src/pseudopeople/interface.py

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from pathlib import Path
2-
from typing import List, Union
2+
from typing import Union
33

44
import pandas as pd
55
import pyarrow.parquet as pq
6+
from loguru import logger
67

78
from pseudopeople.configuration import get_configuration
89
from pseudopeople.constants import paths
@@ -12,7 +13,7 @@
1213

1314
def _generate_form(
1415
form: Form,
15-
source: Union[Path, str, pd.DataFrame],
16+
source: Union[Path, str],
1617
seed: int,
1718
configuration: Union[Path, str, dict],
1819
year_filter: dict,
@@ -23,7 +24,7 @@ def _generate_form(
2324
:param form:
2425
Form needing to be noised
2526
:param source:
26-
Clean data input which needs to be noised
27+
Root directory of clean data input which needs to be noised
2728
:param seed:
2829
Seed for controlling randomness
2930
:param configuration:
@@ -32,73 +33,88 @@ def _generate_form(
3233
Noised form data in a pd.DataFrame
3334
"""
3435
configuration_tree = get_configuration(configuration)
36+
# TODO: we should save outputs of the simulation with filenames that are
37+
# consistent with the names of the forms if possible.
38+
form_file_name = {
39+
FORMS.acs.name: "household_survey_observer_acs",
40+
FORMS.cps.name: "household_survey_observer_cps",
41+
FORMS.tax_w2_1099.name: "tax_w2_observer",
42+
FORMS.wic.name: "wic_observer",
43+
}.get(form.name, f"{form.name}_observer")
3544
if source is None:
36-
# TODO: hard-coding the .parquet extension for now. This will go away
37-
# once we only support passing the root directory of the data.
38-
# TODO: we should save outputs of the simulation with filenames that are
39-
# consistent with the names of the forms if possible.
40-
form_file_name = {
41-
FORMS.acs.name: "household_survey_observer_acs",
42-
FORMS.cps.name: "household_survey_observer_cps",
43-
FORMS.tax_w2_1099.name: "tax_w2_observer",
44-
FORMS.wic.name: "wic_observer",
45-
}.get(form.name, f"{form.name}_observer")
46-
47-
source = paths.SAMPLE_DATA_ROOT / form_file_name / f"{form_file_name}.parquet"
48-
if isinstance(source, str):
49-
source = Path(source)
50-
if isinstance(source, pd.DataFrame):
51-
data = source
52-
elif isinstance(source, Path):
53-
if source.suffix == ".hdf":
54-
with pd.HDFStore(str(source), mode="r") as hdf_store:
45+
source = paths.SAMPLE_DATA_ROOT
46+
source = Path(source) / form_file_name
47+
data_paths = [x for x in source.glob(f"{form_file_name}*")]
48+
if not data_paths:
49+
logger.warning(
50+
f"No datasets found at directory {str(source)}. "
51+
"Please provide the path to the unmodified root data directory."
52+
)
53+
return None
54+
suffix = set(x.suffix for x in data_paths)
55+
if len(suffix) > 1:
56+
raise TypeError(
57+
f"Only one type of file extension expected but more than one found: {suffix}. "
58+
"Please provide the path to the unmodified root data directory."
59+
)
60+
noised_form = []
61+
columns_to_keep = [c for c in form.columns]
62+
for data_path in data_paths:
63+
if data_path.suffix == ".hdf":
64+
with pd.HDFStore(str(data_path), mode="r") as hdf_store:
5565
data = hdf_store.select("data", where=year_filter["hdf"])
56-
hdf_store.close()
57-
elif source.suffix == ".parquet":
58-
data = pq.read_table(source, filters=year_filter["parquet"]).to_pandas()
66+
elif data_path.suffix == ".parquet":
67+
data = pq.read_table(data_path, filters=year_filter["parquet"]).to_pandas()
5968
else:
6069
raise ValueError(
6170
"Source path must either be a .hdf or a .parquet file. Provided "
62-
f"{source.suffix}"
71+
f"{data_path.suffix}"
6372
)
6473
if not isinstance(data, pd.DataFrame):
65-
raise TypeError(f"File located at {source} must contain a pandas DataFrame.")
66-
else:
67-
raise TypeError(
68-
f"Source {source} must be either a pandas DataFrame or a path to a "
69-
"file containing a pandas DataFrame."
70-
)
74+
raise TypeError(
75+
f"File located at {data_path} must contain a pandas DataFrame. "
76+
"Please provide the path to the unmodified root data directory."
77+
)
7178

72-
columns_to_keep = [c for c in form.columns]
73-
# Coerce dtypes
79+
# Coerce dtypes prior to noising to catch issues early as well as
80+
# get most columns away from dtype 'category' and into 'object' (strings)
81+
for col in columns_to_keep:
82+
if col.dtype_name != data[col.name].dtype.name:
83+
data[col.name] = data[col.name].astype(col.dtype_name)
84+
85+
noised_data = noise_form(form, data, configuration_tree, seed)
86+
noised_data = _extract_columns(columns_to_keep, noised_data)
87+
noised_form.append(noised_data)
88+
89+
noised_form = pd.concat(noised_form, ignore_index=True)
90+
91+
# Known pandas bug: pd.concat does not preserve category dtypes so we coerce
92+
# again after concat (https://github.com/pandas-dev/pandas/issues/51362)
7493
for col in columns_to_keep:
75-
if col.dtype_name != data[col.name].dtype.name:
76-
data[col.name] = data[col.name].astype(col.dtype_name)
77-
noised_form = noise_form(form, data, configuration_tree, seed)
78-
noised_form = _extract_columns(columns_to_keep, noised_form)
94+
if col.dtype_name != noised_form[col.name].dtype.name:
95+
noised_form[col.name] = noised_form[col.name].astype(col.dtype_name)
96+
7997
return noised_form
8098

8199

82100
def _extract_columns(columns_to_keep, noised_form):
101+
"""Helper function for test mocking purposes"""
83102
if columns_to_keep:
84103
noised_form = noised_form[[c.name for c in columns_to_keep]]
85104
return noised_form
86105

87106

88107
# TODO: add year as parameter to select the year of the decennial census to generate (MIC-3909)
89-
# TODO: add default path: have the package install the small data in a known location and then
90-
# to make this parameter optional, with the default being the location of the small data that
91-
# is installed with the package (MIC-3884)
92108
def generate_decennial_census(
93-
source: Union[Path, str, pd.DataFrame] = None,
109+
source: Union[Path, str] = None,
94110
seed: int = 0,
95111
configuration: Union[Path, str, dict] = None,
96112
year: int = 2020,
97113
) -> pd.DataFrame:
98114
"""
99115
Generates noised decennial census data from un-noised data.
100116
101-
:param source: A path to or pd.DataFrame of the un-noised source census data
117+
:param source: A path to un-noised source census data
102118
:param seed: An integer seed for randomness
103119
:param configuration: (optional) A path to a configuration YAML file or a dictionary to override the default configuration
104120
:param year: The year from the data to noise
@@ -112,15 +128,15 @@ def generate_decennial_census(
112128

113129

114130
def generate_american_communities_survey(
115-
source: Union[Path, str, pd.DataFrame] = None,
131+
source: Union[Path, str] = None,
116132
seed: int = 0,
117133
configuration: Union[Path, str, dict] = None,
118134
year: int = 2020,
119135
) -> pd.DataFrame:
120136
"""
121137
Generates noised American Communities Survey (ACS) data from un-noised data.
122138
123-
:param source: A path to or pd.DataFrame of the un-noised source ACS data
139+
:param source: A path to un-noised source ACS data
124140
:param seed: An integer seed for randomness
125141
:param configuration: (optional) A path to a configuration YAML file or a dictionary to override the default configuration
126142
:param year: The year from the data to noise
@@ -140,15 +156,15 @@ def generate_american_communities_survey(
140156

141157

142158
def generate_current_population_survey(
143-
source: Union[Path, str, pd.DataFrame] = None,
159+
source: Union[Path, str] = None,
144160
seed: int = 0,
145161
configuration: Union[Path, str, dict] = None,
146162
year: int = 2020,
147163
) -> pd.DataFrame:
148164
"""
149165
Generates noised Current Population Survey (CPS) data from un-noised data.
150166
151-
:param source: A path to or pd.DataFrame of the un-noised source CPS data
167+
:param source: A path to un-noised source CPS data
152168
:param seed: An integer seed for randomness
153169
:param configuration: (optional) A path to a configuration YAML file or a dictionary to override the default configuration
154170
:param year: The year from the data to noise
@@ -168,15 +184,15 @@ def generate_current_population_survey(
168184

169185

170186
def generate_taxes_w2_and_1099(
171-
source: Union[Path, str, pd.DataFrame] = None,
187+
source: Union[Path, str] = None,
172188
seed: int = 0,
173189
configuration: Union[Path, str, dict] = None,
174190
year: int = 2020,
175191
) -> pd.DataFrame:
176192
"""
177193
Generates noised W2 and 1099 data from un-noised data.
178194
179-
:param source: A path to or pd.DataFrame of the un-noised source W2 and 1099 data
195+
:param source: A path to un-noised source W2 and 1099 data
180196
:param seed: An integer seed for randomness
181197
:param configuration: (optional) A path to a configuration YAML file or a dictionary to override the default configuration
182198
:param year: The year from the data to noise
@@ -191,15 +207,15 @@ def generate_taxes_w2_and_1099(
191207

192208

193209
def generate_women_infants_and_children(
194-
source: Union[Path, str, pd.DataFrame] = None,
210+
source: Union[Path, str] = None,
195211
seed: int = 0,
196212
configuration: Union[Path, str, dict] = None,
197213
year: int = 2020,
198214
) -> pd.DataFrame:
199215
"""
200216
Generates noised Women Infants and Children (WIC) data from un-noised data.
201217
202-
:param source: A path to or pd.DataFrame of the un-noised source WIC data
218+
:param source: A path to un-noised source WIC data
203219
:param seed: An integer seed for randomness
204220
:param configuration: (optional) A path to a configuration YAML file or a dictionary to override the default configuration
205221
:param year: The year from the data to noise
@@ -214,15 +230,15 @@ def generate_women_infants_and_children(
214230

215231

216232
def generate_social_security(
217-
source: Union[Path, str, pd.DataFrame] = None,
233+
source: Union[Path, str] = None,
218234
seed: int = 0,
219235
configuration: Union[Path, str, dict] = None,
220236
year: int = 2020,
221237
) -> pd.DataFrame:
222238
"""
223239
Generates noised Social Security (SSA) data from un-noised data.
224240
225-
:param source: A path to or pd.DataFrame of the un-noised source SSA data
241+
:param source: A path to un-noised source SSA data
226242
:param seed: An integer seed for randomness
227243
:param configuration: (optional) A path to a configuration YAML file or a dictionary to override the default configuration
228244
:param year: The year up to which to noise from the data

tests/integration/test_interface.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,64 @@
1515
)
1616
from pseudopeople.schema_entities import COLUMNS, FORMS
1717

18+
# TODO: Move into a metadata file and import metadata into prl
19+
DATA_COLUMNS = ["year", "event_date", "survey_date", "tax_year"]
20+
1821

1922
@pytest.mark.parametrize(
20-
"data_dir_name, noising_function",
23+
"data_dir_name, noising_function, use_sample_data",
2124
[
22-
("decennial_census_observer", generate_decennial_census),
23-
("household_survey_observer_acs", generate_american_communities_survey),
24-
("household_survey_observer_cps", generate_current_population_survey),
25-
("social_security_observer", generate_social_security),
26-
("tax_w2_observer", generate_taxes_w2_and_1099),
27-
("wic_observer", generate_women_infants_and_children),
28-
("tax 1040", "todo"),
25+
("decennial_census_observer", generate_decennial_census, True),
26+
("decennial_census_observer", generate_decennial_census, False),
27+
("household_survey_observer_acs", generate_american_communities_survey, True),
28+
("household_survey_observer_acs", generate_american_communities_survey, False),
29+
("household_survey_observer_cps", generate_current_population_survey, True),
30+
("household_survey_observer_cps", generate_current_population_survey, False),
31+
("social_security_observer", generate_social_security, True),
32+
("social_security_observer", generate_social_security, False),
33+
("tax_w2_observer", generate_taxes_w2_and_1099, True),
34+
("tax_w2_observer", generate_taxes_w2_and_1099, False),
35+
("wic_observer", generate_women_infants_and_children, True),
36+
("wic_observer", generate_women_infants_and_children, False),
37+
("tax 1040", "todo", True),
38+
("tax 1040", "todo", False),
2939
],
3040
)
31-
def test_generate_form(data_dir_name: str, noising_function: Callable):
41+
def test_generate_form(
42+
data_dir_name: str, noising_function: Callable, use_sample_data: bool, tmpdir
43+
):
44+
"""Tests that noised forms are generated and as expected. The 'use_sample_data'
45+
parameter determines whether or not to use the sample data (if True) or
46+
a non-default root directory with multiple datasets to compile (if False)
47+
"""
3248
if noising_function == "todo":
3349
pytest.skip(reason=f"TODO: implement form {data_dir_name}")
34-
# todo fix hard-coding in MIC-3960
35-
data_path = paths.SAMPLE_DATA_ROOT / data_dir_name / f"{data_dir_name}.parquet"
36-
data = pd.read_parquet(data_path)
3750

38-
noised_data = noising_function(seed=0)
39-
noised_data_same_seed = noising_function(seed=0)
40-
noised_data_different_seed = noising_function(seed=1)
51+
sample_data_path = list(
52+
(paths.SAMPLE_DATA_ROOT / data_dir_name).glob(f"{data_dir_name}*")
53+
)[0]
54+
55+
# Load the unnoised sample data
56+
if sample_data_path.suffix == ".parquet":
57+
data = pd.read_parquet(sample_data_path)
58+
elif sample_data_path.suffix == ".hdf":
59+
data = pd.read_hdf(sample_data_path)
60+
else:
61+
raise NotImplementedError(
62+
f"Expected hdf or parquet but got {sample_data_path.suffix}"
63+
)
64+
65+
# Configure if default (sample data) is used or a different root directory
66+
if use_sample_data:
67+
source = None # will default to using sample data
68+
else:
69+
source = _generate_non_default_data_root(
70+
data_dir_name, tmpdir, sample_data_path, data
71+
)
72+
73+
noised_data = noising_function(seed=0, source=source)
74+
noised_data_same_seed = noising_function(seed=0, source=source)
75+
noised_data_different_seed = noising_function(seed=1, source=source)
4176

4277
assert not data.equals(noised_data)
4378
assert noised_data.equals(noised_data_same_seed)
@@ -52,6 +87,38 @@ def test_generate_form(data_dir_name: str, noising_function: Callable):
5287
assert noised_data[col].dtype == expected_dtype
5388

5489

90+
def _generate_non_default_data_root(data_dir_name, tmpdir, sample_data_path, data):
91+
"""Helper function to break the single sample dataset into two and save
92+
out to tmpdir to be used as a non-default 'source' argument
93+
"""
94+
outdir = tmpdir.mkdir(data_dir_name)
95+
suffix = sample_data_path.suffix
96+
split_idx = int(len(data) / 2)
97+
if suffix == ".parquet":
98+
data[:split_idx].to_parquet(outdir / f"{data_dir_name}_1{suffix}")
99+
data[split_idx:].to_parquet(outdir / f"{data_dir_name}_2{suffix}")
100+
elif suffix == ".hdf":
101+
data[:split_idx].to_hdf(
102+
outdir / f"{data_dir_name}_1{suffix}",
103+
"data",
104+
format="table",
105+
complib="bzip2",
106+
complevel=9,
107+
data_columns=DATA_COLUMNS,
108+
)
109+
data[split_idx:].to_hdf(
110+
outdir / f"{data_dir_name}_2{suffix}",
111+
"data",
112+
format="table",
113+
complib="bzip2",
114+
complevel=9,
115+
data_columns=DATA_COLUMNS,
116+
)
117+
else:
118+
raise NotImplementedError(f"Requires hdf or parquet, got {suffix}")
119+
return tmpdir
120+
121+
55122
# TODO [MIC-4000]: add test that each col to get noised actually does get noised
56123

57124

0 commit comments

Comments
 (0)