Skip to content

Commit 78c9f82

Browse files
committed
Noise shards together
1 parent 71e1c00 commit 78c9f82

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

src/pseudopeople/interface.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -67,42 +67,39 @@ def _generate_dataset(
6767
"Please provide the path to the unmodified root data directory."
6868
)
6969
validate_data_path_suffix(data_paths)
70-
noised_dataset = []
71-
iterator = (
72-
tqdm(data_paths, desc="Noising data", leave=False)
73-
if len(data_paths) > 1
74-
else data_paths
75-
)
70+
all_data = []
71+
iterator = tqdm(data_paths, desc="Loading data") if len(data_paths) > 1 else data_paths
7672

77-
for data_path_index, data_path in enumerate(iterator):
73+
for data_path in iterator:
7874
logger.debug(f"Loading data from {data_path}.")
7975
data = _load_data_from_path(data_path, user_filters)
8076
if data.empty:
8177
continue
82-
data = _reformat_dates_for_noising(data, dataset)
83-
data = _coerce_dtypes(data, dataset)
84-
# Use a different seed for each data file/shard, otherwise the randomness will duplicate
85-
# and the Nth row in each shard will get the same noise
86-
data_path_seed = f"{seed}_{data_path_index}"
87-
noised_data = noise_dataset(dataset, data, configuration_tree, data_path_seed)
88-
noised_data = _extract_columns(dataset.columns, noised_data)
89-
noised_dataset.append(noised_data)
78+
# FIXME: Right now, Categorical columns in the Rhode Island data
79+
# contain a very large number of unnecessary categories. We want
80+
# to get rid of these during this loop so that they are never all
81+
# in memory at the same time.
82+
# TODO: Remove this when we stop Categorical encoding.
83+
data = _remove_unused_categories(data, dataset)
84+
all_data.append(data)
9085

9186
# Check if all shards for the dataset are empty
92-
if len(noised_dataset) == 0:
87+
if len(all_data) == 0:
9388
raise ValueError(
9489
"Invalid value provided for 'state' or 'year'. No data found with "
95-
f"the user provided 'state' or 'year' filters at {data_path}."
90+
f"the user provided 'state' or 'year' filters at {source}."
9691
)
97-
noised_dataset = pd.concat(noised_dataset, ignore_index=True)
9892

99-
# Known pandas bug: pd.concat does not preserve category dtypes so we coerce
100-
# again after concat (https://github.com/pandas-dev/pandas/issues/51362)
101-
noised_dataset = _coerce_dtypes(noised_dataset, dataset, cleanse_int_cols=True)
93+
all_data = pd.concat(all_data, ignore_index=True)
94+
_reformat_dates_for_noising(all_data, dataset)
95+
all_data = _coerce_dtypes(all_data, dataset)
96+
all_data = noise_dataset(dataset, all_data, configuration_tree, seed)
97+
all_data = _extract_columns(dataset.columns, all_data)
98+
all_data = _coerce_dtypes(all_data, dataset, cleanse_int_cols=True)
10299

103100
logger.debug("*** Finished ***")
104101

105-
return noised_dataset
102+
return all_data
106103

107104

108105
def validate_source_compatibility(source: Path):
@@ -151,15 +148,27 @@ def _coerce_dtypes(
151148
return data
152149

153150

151+
def _remove_unused_categories(data: pd.DataFrame, dataset: Dataset) -> pd.DataFrame:
152+
for col in data.columns:
153+
if data[col].dtype.name == "category" and (
154+
# NOTE: We want to avoid dropping categories that just happen not to be used
155+
# in columns that are returned as Categorical to the user such as event_type
156+
col not in dataset.columns
157+
or dataset.columns[col].dtype_name != "category"
158+
):
159+
data[col] = data[col].cat.remove_unused_categories()
160+
161+
return data
162+
163+
154164
def _load_data_from_path(data_path: Path, user_filters: List[Tuple]) -> pd.DataFrame:
155165
"""Load data from a data file given a data_path and a year_filter."""
156166
data = load_standard_dataset_file(data_path, user_filters)
157167
return data
158168

159169

160-
def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset):
170+
def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset) -> None:
161171
"""Formats date columns so they can be noised as strings."""
162-
data = data.copy()
163172

164173
for date_column in [COLUMNS.dob.name, COLUMNS.ssa_event_date.name]:
165174
# Format both the actual column, and the shadow version that will be used
@@ -170,19 +179,20 @@ def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset):
170179
# re-parse the format string for each row
171180
# https://github.com/pandas-dev/pandas/issues/44764
172181
# Year is already guaranteed to be 4-digit: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-timestamp-limits
173-
year_string = data[column].dt.year.astype(str)
174-
month_string = _zfill_fast(data[column].dt.month.astype(str), 2)
175-
day_string = _zfill_fast(data[column].dt.day.astype(str), 2)
182+
data_column = data[column]
183+
year_string = data_column.dt.year.astype(str)
184+
month_string = _zfill_fast(data_column.dt.month.astype(str), 2)
185+
day_string = _zfill_fast(data_column.dt.day.astype(str), 2)
176186
if dataset.date_format == DATEFORMATS.YYYYMMDD:
177-
data[column] = year_string + month_string + day_string
187+
result = year_string + month_string + day_string
178188
elif dataset.date_format == DATEFORMATS.MM_DD_YYYY:
179-
data[column] = month_string + "/" + day_string + "/" + year_string
189+
result = month_string + "/" + day_string + "/" + year_string
180190
elif dataset.date_format == DATEFORMATS.MMDDYYYY:
181-
data[column] = month_string + day_string + year_string
191+
result = month_string + day_string + year_string
182192
else:
183193
raise ValueError(f"Invalid date format in {dataset.name}.")
184194

185-
return data
195+
data[column] = result
186196

187197

188198
def _zfill_fast(col: pd.Series, desired_length: int) -> pd.Series:

src/pseudopeople/noise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def noise_dataset(
5656
# except for the leave_blank kind which is special-cased below
5757
missingness = (dataset_data == "") | (dataset_data.isna())
5858

59-
for noise_type in tqdm(NOISE_TYPES, desc="Applying noise", unit="type", leave=False):
59+
for noise_type in tqdm(NOISE_TYPES, desc="Applying noise", unit="type"):
6060
if isinstance(noise_type, RowNoiseType):
6161
if (
6262
Keys.ROW_NOISE in noise_configuration

0 commit comments

Comments
 (0)