1212from pytest_mock import MockerFixture
1313from vivarium_testing_utils import FuzzyChecker
1414
15- from pseudopeople .configuration import get_configuration
1615from pseudopeople .schema_entities import COLUMNS , DATASET_SCHEMAS , Column
16+ from pseudopeople .utilities import coerce_dtypes
1717from tests .constants import DATASET_GENERATION_FUNCS
1818from tests .integration .conftest import (
1919 IDX_COLS ,
2626 initialize_dataset_with_sample ,
2727 run_column_noising_tests ,
2828 run_omit_row_or_do_not_respond_tests ,
29- validate_column_noise_level ,
3029)
3130
3231
4948 "dask" ,
5049 ],
5150)
52- def test_generate_dataset_from_multiple_shards (
51+ def test_noising_sharded_vs_unsharded_data (
5352 dataset_name : str ,
5453 engine : str ,
5554 config : dict [str , Any ],
@@ -65,54 +64,54 @@ def test_generate_dataset_from_multiple_shards(
6564 pytest .skip (reason = dataset_name )
6665 mocker .patch ("pseudopeople.interface.validate_source_compatibility" )
6766 generation_function = DATASET_GENERATION_FUNCS [dataset_name ]
68- original = initialize_dataset_with_sample (dataset_name )
69- noised_sample = request .getfixturevalue (f"noised_sample_data_{ dataset_name } " )
7067
71- noised_dataset = generation_function (
68+ unnoised_dataset = initialize_dataset_with_sample (dataset_name )
69+ single_shard_noised_data = request .getfixturevalue (f"noised_sample_data_{ dataset_name } " )
70+ multi_shard_noised_data = generation_function (
7271 seed = SEED ,
7372 year = None ,
7473 source = split_sample_data_dir ,
7574 engine = engine ,
7675 config = config ,
7776 )
78-
7977 if engine == "dask" :
80- noised_dataset = noised_dataset .compute ()
81-
82- # Check same order of magnitude of rows was removed -- we don't know the
83- # full data size (we would need unnoised data for that), so we just check
84- # for similar lengths
85- assert 0.9 <= (len (noised_dataset ) / len (noised_sample )) <= 1.1
86- # Check that columns are identical
87- assert noised_dataset .columns .equals (noised_sample .columns )
88-
89- # Check that each columns level of noising are similar
90- check_noised_dataset , check_original_dataset , shared_dataset_idx = _get_common_datasets (
91- original , noised_dataset
92- )
78+ multi_shard_noised_data = multi_shard_noised_data .compute ()
9379
94- config_tree = get_configuration (config )
95- for col_name in check_noised_dataset .columns :
96- col = COLUMNS .get_column (col_name )
97- if col .noise_types :
98- noise_level_dataset , to_compare_dataset_idx = _get_column_noise_level (
99- column = col ,
100- noised_data = check_noised_dataset ,
101- unnoised_data = check_original_dataset ,
102- common_idx = shared_dataset_idx ,
103- )
104-
105- # Validate noise for each data object
106- validate_column_noise_level (
107- dataset_name = dataset_name ,
108- check_data = check_original_dataset ,
109- check_idx = to_compare_dataset_idx ,
110- noise_level = noise_level_dataset ,
111- col = col ,
112- config = config_tree ,
113- fuzzy_name = "test_generate_dataset_from_sample_and_source_dataset" ,
114- validator = fuzzy_checker ,
115- )
80+ assert multi_shard_noised_data .columns .equals (single_shard_noised_data .columns )
81+
82+ # This index handling is adapted from _get_common_datasets
83+ # in integration/conftest.py
84+ # Define indexes
85+ idx_cols = IDX_COLS .get (unnoised_dataset .dataset_schema .name )
86+ unnoised_dataset ._reformat_dates_for_noising ()
87+ unnoised_dataset .data = coerce_dtypes (
88+ unnoised_dataset .data , unnoised_dataset .dataset_schema
89+ )
90+ check_original = unnoised_dataset .data .set_index (idx_cols )
91+ check_single_noised = single_shard_noised_data .set_index (idx_cols )
92+ check_multi_noised = multi_shard_noised_data .set_index (idx_cols )
93+
94+ # Ensure the idx_cols are unique
95+ assert check_original .index .duplicated ().sum () == 0
96+ assert check_single_noised .index .duplicated ().sum () == 0
97+ assert check_multi_noised .index .duplicated ().sum () == 0
98+
99+ # Get shared indexes
100+ shared_idx = pd .Index (
101+ set (check_original .index )
102+ .intersection (set (check_single_noised .index ))
103+ .intersection (set (check_multi_noised .index ))
104+ )
105+ check_original = check_original .loc [shared_idx ]
106+ check_single_noised = check_single_noised .loc [shared_idx ]
107+ check_multi_noised = check_multi_noised .loc [shared_idx ]
108+
109+ for col in check_single_noised .columns :
110+ fuzzy_checker .fuzzy_assert_proportion (
111+ target_proportion = (check_single_noised [col ] != check_original [col ]).mean (),
112+ observed_numerator = (check_multi_noised [col ] != check_original [col ]).sum (),
113+ observed_denominator = len (check_original ),
114+ )
116115
117116
118117@pytest .mark .parametrize (
@@ -333,38 +332,6 @@ def test_row_noising_duplication(dataset_name: str) -> None:
333332 ...
334333
335334
336- @pytest .mark .parametrize (
337- "dataset_name" ,
338- [
339- DATASET_SCHEMAS .census .name ,
340- DATASET_SCHEMAS .acs .name ,
341- DATASET_SCHEMAS .cps .name ,
342- DATASET_SCHEMAS .ssa .name ,
343- DATASET_SCHEMAS .tax_w2_1099 .name ,
344- DATASET_SCHEMAS .wic .name ,
345- DATASET_SCHEMAS .tax_1040 .name ,
346- ],
347- )
348- @pytest .mark .parametrize (
349- "engine" ,
350- [
351- "pandas" ,
352- "dask" ,
353- ],
354- )
355- def test_generate_dataset_with_year (dataset_name : str , engine : str ) -> None :
356- if "TODO" in dataset_name :
357- pytest .skip (reason = dataset_name )
358- year = 2030 # not default 2020
359- generation_function = DATASET_GENERATION_FUNCS [dataset_name ]
360- original = get_unnoised_data (dataset_name )
361- # Generate a new (non-fixture) noised dataset for a single year
362- noised_data = generation_function (year = year , engine = engine )
363- if engine == "dask" :
364- noised_data = noised_data .compute ()
365- assert not original .data .equals (noised_data )
366-
367-
368335@pytest .mark .parametrize (
369336 "dataset_name" ,
370337 [
0 commit comments