Skip to content

Commit f93923b

Browse files
add typing overloads to dataset generation functions
1 parent c42e5d5 commit f93923b

File tree

1 file changed

+218
-10
lines changed

1 file changed

+218
-10
lines changed

src/pseudopeople/interface.py

Lines changed: 218 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44
from pathlib import Path
5-
from typing import Any, Literal, cast
5+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
66

77
import pandas as pd
88
from loguru import logger
@@ -25,6 +25,35 @@
2525
get_state_abbreviation,
2626
)
2727

28+
if TYPE_CHECKING:
29+
import dask.dataframe as dd
30+
31+
32+
@overload
33+
def _generate_dataset(
34+
dataset_schema: DatasetSchema,
35+
source: Path | str | None,
36+
seed: int,
37+
config: Path | str | dict[str, Any] | None,
38+
filters: Sequence[DataFilter],
39+
verbose: bool,
40+
engine_name: Literal["pandas"],
41+
) -> pd.DataFrame:
42+
...
43+
44+
45+
@overload
46+
def _generate_dataset(
47+
dataset_schema: DatasetSchema,
48+
source: Path | str | None,
49+
seed: int,
50+
config: Path | str | dict[str, Any] | None,
51+
filters: Sequence[DataFilter],
52+
verbose: bool,
53+
engine_name: Literal["dask"],
54+
) -> dd.DataFrame:
55+
...
56+
2857

2958
def _generate_dataset(
3059
dataset_schema: DatasetSchema,
@@ -34,7 +63,7 @@ def _generate_dataset(
3463
filters: Sequence[DataFilter],
3564
verbose: bool = False,
3665
engine_name: Literal["pandas", "dask"] = "pandas",
37-
) -> pd.DataFrame:
66+
) -> pd.DataFrame | dd.DataFrame:
3867
"""
3968
Helper for generating noised datasets.
4069
@@ -67,7 +96,6 @@ def _generate_dataset(
6796

6897
engine = get_engine_from_string(engine_name)
6998

70-
noised_dataset: pd.DataFrame
7199
if engine == PANDAS_ENGINE:
72100
# We process shards serially
73101
data_file_paths = get_dataset_filepaths(source, dataset_schema.name)
@@ -205,15 +233,41 @@ def _get_data_changelog_version(changelog: Path) -> Version:
205233
return version
206234

207235

236+
@overload
208237
def generate_decennial_census(
209238
source: Path | str | None = None,
210239
seed: int = 0,
211240
config: Path | str | dict[str, Any] | None = None,
212241
year: int | None = 2020,
213242
state: str | None = None,
214243
verbose: bool = False,
215-
engine: Literal["pandas", "dask"] = "pandas",
244+
engine: Literal["pandas"] = "pandas",
216245
) -> pd.DataFrame:
246+
...
247+
248+
249+
@overload
250+
def generate_decennial_census(
251+
source: Path | str | None,
252+
seed: int,
253+
config: Path | str | dict[str, Any] | None,
254+
year: int | None,
255+
state: str | None,
256+
verbose: bool,
257+
engine: Literal["dask"],
258+
) -> dd.DataFrame:
259+
...
260+
261+
262+
def generate_decennial_census(
263+
source: Path | str | None = None,
264+
seed: int = 0,
265+
config: Path | str | dict[str, Any] | None = None,
266+
year: int | None = 2020,
267+
state: str | None = None,
268+
verbose: bool = False,
269+
engine: Literal["pandas", "dask"] = "pandas",
270+
) -> pd.DataFrame | dd.DataFrame:
217271
"""
218272
Generates a pseudopeople decennial census dataset which represents
219273
simulated responses to the US Census Bureau's Census of Population
@@ -303,15 +357,41 @@ def generate_decennial_census(
303357
)
304358

305359

360+
@overload
306361
def generate_american_community_survey(
307362
source: Path | str | None = None,
308363
seed: int = 0,
309364
config: Path | str | dict[str, Any] | None = None,
310365
year: int | None = 2020,
311366
state: str | None = None,
312367
verbose: bool = False,
313-
engine: Literal["pandas", "dask"] = "pandas",
368+
engine: Literal["pandas"] = "pandas",
314369
) -> pd.DataFrame:
370+
...
371+
372+
373+
@overload
374+
def generate_american_community_survey(
375+
source: Path | str | None,
376+
seed: int,
377+
config: Path | str | dict[str, Any] | None,
378+
year: int | None,
379+
state: str | None,
380+
verbose: bool,
381+
engine: Literal["dask"],
382+
) -> dd.DataFrame:
383+
...
384+
385+
386+
def generate_american_community_survey(
387+
source: Path | str | None = None,
388+
seed: int = 0,
389+
config: Path | str | dict[str, Any] | None = None,
390+
year: int | None = 2020,
391+
state: str | None = None,
392+
verbose: bool = False,
393+
engine: Literal["pandas", "dask"] = "pandas",
394+
) -> pd.DataFrame | dd.DataFrame:
315395
"""
316396
Generates a pseudopeople ACS dataset which represents simulated
317397
responses to the ACS survey.
@@ -416,15 +496,41 @@ def generate_american_community_survey(
416496
)
417497

418498

499+
@overload
419500
def generate_current_population_survey(
420501
source: Path | str | None = None,
421502
seed: int = 0,
422503
config: Path | str | dict[str, Any] | None = None,
423504
year: int | None = 2020,
424505
state: str | None = None,
425506
verbose: bool = False,
426-
engine: Literal["pandas", "dask"] = "pandas",
507+
engine: Literal["pandas"] = "pandas",
427508
) -> pd.DataFrame:
509+
...
510+
511+
512+
@overload
513+
def generate_current_population_survey(
514+
source: Path | str | None,
515+
seed: int,
516+
config: Path | str | dict[str, Any] | None,
517+
year: int | None,
518+
state: str | None,
519+
verbose: bool,
520+
engine: Literal["dask"],
521+
) -> dd.DataFrame:
522+
...
523+
524+
525+
def generate_current_population_survey(
526+
source: Path | str | None = None,
527+
seed: int = 0,
528+
config: Path | str | dict[str, Any] | None = None,
529+
year: int | None = 2020,
530+
state: str | None = None,
531+
verbose: bool = False,
532+
engine: Literal["pandas", "dask"] = "pandas",
533+
) -> pd.DataFrame | dd.DataFrame:
428534
"""
429535
Generates a pseudopeople CPS dataset which represents simulated
430536
responses to the CPS survey.
@@ -530,15 +636,41 @@ def generate_current_population_survey(
530636
)
531637

532638

639+
@overload
533640
def generate_taxes_w2_and_1099(
534641
source: Path | str | None = None,
535642
seed: int = 0,
536643
config: Path | str | dict[str, Any] | None = None,
537644
year: int | None = 2020,
538645
state: str | None = None,
539646
verbose: bool = False,
540-
engine: Literal["pandas", "dask"] = "pandas",
647+
engine: Literal["pandas"] = "pandas",
541648
) -> pd.DataFrame:
649+
...
650+
651+
652+
@overload
653+
def generate_taxes_w2_and_1099(
654+
source: Path | str | None,
655+
seed: int,
656+
config: Path | str | dict[str, Any] | None,
657+
year: int | None,
658+
state: str | None,
659+
verbose: bool,
660+
engine: Literal["dask"],
661+
) -> dd.DataFrame:
662+
...
663+
664+
665+
def generate_taxes_w2_and_1099(
666+
source: Path | str | None = None,
667+
seed: int = 0,
668+
config: Path | str | dict[str, Any] | None = None,
669+
year: int | None = 2020,
670+
state: str | None = None,
671+
verbose: bool = False,
672+
engine: Literal["pandas", "dask"] = "pandas",
673+
) -> pd.DataFrame | dd.DataFrame:
542674
"""
543675
Generates a pseudopeople W2 and 1099 tax dataset which represents
544676
simulated tax form data.
@@ -628,15 +760,41 @@ def generate_taxes_w2_and_1099(
628760
)
629761

630762

763+
@overload
631764
def generate_women_infants_and_children(
632765
source: Path | str | None = None,
633766
seed: int = 0,
634767
config: Path | str | dict[str, Any] | None = None,
635768
year: int | None = 2020,
636769
state: str | None = None,
637770
verbose: bool = False,
638-
engine: Literal["pandas", "dask"] = "pandas",
771+
engine: Literal["pandas"] = "pandas",
639772
) -> pd.DataFrame:
773+
...
774+
775+
776+
@overload
777+
def generate_women_infants_and_children(
778+
source: Path | str | None,
779+
seed: int,
780+
config: Path | str | dict[str, Any] | None,
781+
year: int | None,
782+
state: str | None,
783+
verbose: bool,
784+
engine: Literal["dask"],
785+
) -> dd.DataFrame:
786+
...
787+
788+
789+
def generate_women_infants_and_children(
790+
source: Path | str | None = None,
791+
seed: int = 0,
792+
config: Path | str | dict[str, Any] | None = None,
793+
year: int | None = 2020,
794+
state: str | None = None,
795+
verbose: bool = False,
796+
engine: Literal["pandas", "dask"] = "pandas",
797+
) -> pd.DataFrame | dd.DataFrame:
640798
"""
641799
Generates a pseudopeople WIC dataset which represents a simulated
642800
version of the administrative data that would be recorded by WIC.
@@ -731,14 +889,38 @@ def generate_women_infants_and_children(
731889
)
732890

733891

892+
@overload
734893
def generate_social_security(
735894
source: Path | str | None = None,
736895
seed: int = 0,
737896
config: Path | str | dict[str, Any] | None = None,
738897
year: int | None = 2020,
739898
verbose: bool = False,
740-
engine: Literal["pandas", "dask"] = "pandas",
899+
engine: Literal["pandas"] = "pandas",
741900
) -> pd.DataFrame:
901+
...
902+
903+
904+
@overload
905+
def generate_social_security(
906+
source: Path | str | None,
907+
seed: int,
908+
config: Path | str | dict[str, Any] | None,
909+
year: int | None,
910+
verbose: bool,
911+
engine: Literal["dask"],
912+
) -> dd.DataFrame:
913+
...
914+
915+
916+
def generate_social_security(
917+
source: Path | str | None = None,
918+
seed: int = 0,
919+
config: Path | str | dict[str, Any] | None = None,
920+
year: int | None = 2020,
921+
verbose: bool = False,
922+
engine: Literal["pandas", "dask"] = "pandas",
923+
) -> pd.DataFrame | dd.DataFrame:
742924
"""
743925
Generates a pseudopeople SSA dataset which represents simulated
744926
Social Security Administration (SSA) data.
@@ -819,15 +1001,41 @@ def generate_social_security(
8191001
)
8201002

8211003

1004+
@overload
8221005
def generate_taxes_1040(
8231006
source: Path | str | None = None,
8241007
seed: int = 0,
8251008
config: Path | str | dict[str, Any] | None = None,
8261009
year: int | None = 2020,
8271010
state: str | None = None,
8281011
verbose: bool = False,
829-
engine: Literal["pandas", "dask"] = "pandas",
1012+
engine: Literal["pandas"] = "pandas",
8301013
) -> pd.DataFrame:
1014+
...
1015+
1016+
1017+
@overload
1018+
def generate_taxes_1040(
1019+
source: Path | str | None,
1020+
seed: int,
1021+
config: Path | str | dict[str, Any] | None,
1022+
year: int | None,
1023+
state: str | None,
1024+
verbose: bool,
1025+
engine: Literal["dask"],
1026+
) -> dd.DataFrame:
1027+
...
1028+
1029+
1030+
def generate_taxes_1040(
1031+
source: Path | str | None = None,
1032+
seed: int = 0,
1033+
config: Path | str | dict[str, Any] | None = None,
1034+
year: int | None = 2020,
1035+
state: str | None = None,
1036+
verbose: bool = False,
1037+
engine: Literal["pandas", "dask"] = "pandas",
1038+
) -> pd.DataFrame | dd.DataFrame:
8311039
"""
8321040
Generates a pseudopeople 1040 tax dataset which represents simulated
8331041
tax form data.

0 commit comments

Comments
 (0)