@@ -67,42 +67,39 @@ def _generate_dataset(
6767 "Please provide the path to the unmodified root data directory."
6868 )
6969 validate_data_path_suffix (data_paths )
70- noised_dataset = []
71- iterator = (
72- tqdm (data_paths , desc = "Noising data" , leave = False )
73- if len (data_paths ) > 1
74- else data_paths
75- )
70+ all_data = []
71+ iterator = tqdm (data_paths , desc = "Loading data" ) if len (data_paths ) > 1 else data_paths
7672
77- for data_path_index , data_path in enumerate ( iterator ) :
73+ for data_path in iterator :
7874 logger .debug (f"Loading data from { data_path } ." )
7975 data = _load_data_from_path (data_path , user_filters )
8076 if data .empty :
8177 continue
82- data = _reformat_dates_for_noising (data , dataset )
83- data = _coerce_dtypes (data , dataset )
84- # Use a different seed for each data file/shard, otherwise the randomness will duplicate
85- # and the Nth row in each shard will get the same noise
86- data_path_seed = f"{ seed } _{ data_path_index } "
87- noised_data = noise_dataset (dataset , data , configuration_tree , data_path_seed )
88- noised_data = _extract_columns (dataset .columns , noised_data )
89- noised_dataset .append (noised_data )
78+ # FIXME: Right now, Categorical columns in the Rhode Island data
79+ # contain a very large number of unnecessary categories. We want
80+ # to get rid of these during this loop so that they are never all
81+ # in memory at the same time.
82+ # TODO: Remove this when we stop Categorical encoding.
83+ data = _remove_unused_categories (data , dataset )
84+ all_data .append (data )
9085
9186 # Check if all shards for the dataset are empty
92- if len (noised_dataset ) == 0 :
87+ if len (all_data ) == 0 :
9388 raise ValueError (
9489 "Invalid value provided for 'state' or 'year'. No data found with "
95- f"the user provided 'state' or 'year' filters at { data_path } ."
90+ f"the user provided 'state' or 'year' filters at { source } ."
9691 )
97- noised_dataset = pd .concat (noised_dataset , ignore_index = True )
9892
99- # Known pandas bug: pd.concat does not preserve category dtypes so we coerce
100- # again after concat (https://github.com/pandas-dev/pandas/issues/51362)
101- noised_dataset = _coerce_dtypes (noised_dataset , dataset , cleanse_int_cols = True )
93+ all_data = pd .concat (all_data , ignore_index = True )
94+ _reformat_dates_for_noising (all_data , dataset )
95+ all_data = _coerce_dtypes (all_data , dataset )
96+ all_data = noise_dataset (dataset , all_data , configuration_tree , seed )
97+ all_data = _extract_columns (dataset .columns , all_data )
98+ all_data = _coerce_dtypes (all_data , dataset , cleanse_int_cols = True )
10299
103100 logger .debug ("*** Finished ***" )
104101
105- return noised_dataset
102+ return all_data
106103
107104
108105def validate_source_compatibility (source : Path ):
@@ -151,15 +148,27 @@ def _coerce_dtypes(
151148 return data
152149
153150
151+ def _remove_unused_categories (data : pd .DataFrame , dataset : Dataset ) -> pd .DataFrame :
152+ for col in data .columns :
153+ if data [col ].dtype .name == "category" and (
154+ # NOTE: We want to avoid dropping categories that just happen not to be used
155+ # in columns that are returned as Categorical to the user such as event_type
156+ col not in dataset .columns
157+ or dataset .columns [col ].dtype_name != "category"
158+ ):
159+ data [col ] = data [col ].cat .remove_unused_categories ()
160+
161+ return data
162+
163+
154164def _load_data_from_path (data_path : Path , user_filters : List [Tuple ]) -> pd .DataFrame :
155165 """Load data from a data file given a data_path and a year_filter."""
156166 data = load_standard_dataset_file (data_path , user_filters )
157167 return data
158168
159169
160- def _reformat_dates_for_noising (data : pd .DataFrame , dataset : Dataset ):
170+ def _reformat_dates_for_noising (data : pd .DataFrame , dataset : Dataset ) -> None :
161171 """Formats date columns so they can be noised as strings."""
162- data = data .copy ()
163172
164173 for date_column in [COLUMNS .dob .name , COLUMNS .ssa_event_date .name ]:
165174 # Format both the actual column, and the shadow version that will be used
@@ -170,19 +179,20 @@ def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset):
170179 # re-parse the format string for each row
171180 # https://github.com/pandas-dev/pandas/issues/44764
172181 # Year is already guaranteed to be 4-digit: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-timestamp-limits
173- year_string = data [column ].dt .year .astype (str )
174- month_string = _zfill_fast (data [column ].dt .month .astype (str ), 2 )
175- day_string = _zfill_fast (data [column ].dt .day .astype (str ), 2 )
182+ data_column = data [column ]
183+ year_string = data_column .dt .year .astype (str )
184+ month_string = _zfill_fast (data_column .dt .month .astype (str ), 2 )
185+ day_string = _zfill_fast (data_column .dt .day .astype (str ), 2 )
176186 if dataset .date_format == DATEFORMATS .YYYYMMDD :
177- data [ column ] = year_string + month_string + day_string
187+ result = year_string + month_string + day_string
178188 elif dataset .date_format == DATEFORMATS .MM_DD_YYYY :
179- data [ column ] = month_string + "/" + day_string + "/" + year_string
189+ result = month_string + "/" + day_string + "/" + year_string
180190 elif dataset .date_format == DATEFORMATS .MMDDYYYY :
181- data [ column ] = month_string + day_string + year_string
191+ result = month_string + day_string + year_string
182192 else :
183193 raise ValueError (f"Invalid date format in { dataset .name } ." )
184194
185- return data
195+ data [ column ] = result
186196
187197
188198def _zfill_fast (col : pd .Series , desired_length : int ) -> pd .Series :
0 commit comments