Skip to content

Commit 18f07c4

Browse files
authored
Merge pull request #20 from vertti/improve-types
Improve types
2 parents 79bfdd8 + 29a6c92 commit 18f07c4

File tree

9 files changed

+83
-40
lines changed

9 files changed

+83
-40
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## 0.11.0
6+
7+
- Update function parameter types for better type safety
8+
- Fix missing return statement in df_log decorator
9+
- Added stricter mypy type checking settings
10+
511
## 0.10.1
612

713
- Built and published with UV. No functional changes

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Like type hints for DataFrames, Daffy helps you catch structural mismatches earl
2828
- Works with both Pandas and Polars DataFrames
2929
- Project-wide configuration via pyproject.toml
3030
- Integrated logging for DataFrame structure inspection
31+
- Enhanced type annotations for improved IDE and type checker support
3132

3233
## Documentation
3334

daffy/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Configuration handling for DAFFY."""
22

33
import os
4-
from typing import Optional
4+
from typing import Any, Dict, Optional
55

66
import tomli
77

88

9-
def load_config() -> dict:
9+
def load_config() -> Dict[str, Any]:
1010
"""
1111
Load daffy configuration from pyproject.toml.
1212
@@ -61,7 +61,7 @@ def find_config_file() -> Optional[str]:
6161
_config_cache = None
6262

6363

64-
def get_config() -> dict:
64+
def get_config() -> Dict[str, Any]:
6565
"""
6666
Get the daffy configuration, loading it if necessary.
6767

daffy/decorators.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,26 @@
44
import logging
55
import re
66
from functools import wraps
7-
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
7+
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, TypeVar, Union
88

99
import pandas as pd
1010
import polars as pl
1111

12+
# Import fully qualified types to satisfy disallow_any_unimported
13+
from pandas import DataFrame as PandasDataFrame
14+
from polars import DataFrame as PolarsDataFrame
15+
1216
from daffy.config import get_strict
1317

14-
# New type definition to support regex patterns
15-
RegexColumnDef = Tuple[str, Pattern] # Tuple of (pattern_str, compiled_pattern)
16-
ColumnsDef = Union[List, Dict, List[Union[str, RegexColumnDef]]]
17-
DataFrameType = Union[pd.DataFrame, pl.DataFrame]
18+
# Type variables for preserving return types
19+
T = TypeVar("T")
20+
R = TypeVar("R")
21+
22+
23+
# Improved type definitions to support regex patterns
24+
RegexColumnDef = Tuple[str, Pattern[str]] # Tuple of (pattern_str, compiled_pattern)
25+
ColumnsDef = Union[List[Union[str, RegexColumnDef]], Dict[str, Any]]
26+
DataFrameType = Union[PandasDataFrame, PolarsDataFrame]
1827

1928

2029
def _is_regex_pattern(column: Any) -> bool:
@@ -30,9 +39,9 @@ def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[st
3039
return [col for col in df_columns if pattern.match(col)]
3140

3241

33-
def _compile_regex_patterns(columns: List) -> List:
42+
def _compile_regex_patterns(columns: List[Any]) -> List[Union[str, RegexColumnDef]]:
3443
"""Compile regex patterns in the column list."""
35-
result = []
44+
result: List[Union[str, RegexColumnDef]] = []
3645
for col in columns:
3746
if isinstance(col, str) and col.startswith("r/") and col.endswith("/"):
3847
# Pattern is in the format "r/pattern/"
@@ -97,7 +106,9 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
97106
raise AssertionError(f"DataFrame contained unexpected column(s): {', '.join(extra_columns)}")
98107

99108

100-
def df_out(columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None) -> Callable:
109+
def df_out(
110+
columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None
111+
) -> Callable[[Callable[..., DataFrameType]], Callable[..., DataFrameType]]:
101112
"""Decorate a function that returns a Pandas or Polars DataFrame.
102113
103114
Document the return value of a function. The return value will be validated in runtime.
@@ -109,12 +120,12 @@ def df_out(columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None)
109120
If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
110121
111122
Returns:
112-
Callable: Decorated function
123+
Callable: Decorated function with preserved DataFrame return type
113124
"""
114125

115-
def wrapper_df_out(func: Callable) -> Callable:
126+
def wrapper_df_out(func: Callable[..., DataFrameType]) -> Callable[..., DataFrameType]:
116127
@wraps(func)
117-
def wrapper(*args: str, **kwargs: Any) -> Any:
128+
def wrapper(*args: Any, **kwargs: Any) -> DataFrameType:
118129
result = func(*args, **kwargs)
119130
assert isinstance(result, pd.DataFrame) or isinstance(result, pl.DataFrame), (
120131
f"Wrong return type. Expected DataFrame, got {type(result)}"
@@ -128,7 +139,7 @@ def wrapper(*args: str, **kwargs: Any) -> Any:
128139
return wrapper_df_out
129140

130141

131-
def _get_parameter(func: Callable, name: Optional[str] = None, *args: str, **kwargs: Any) -> DataFrameType:
142+
def _get_parameter(func: Callable[..., Any], name: Optional[str] = None, *args: Any, **kwargs: Any) -> Any:
132143
if not name:
133144
if len(args) > 0:
134145
return args[0]
@@ -144,7 +155,9 @@ def _get_parameter(func: Callable, name: Optional[str] = None, *args: str, **kwa
144155
return kwargs[name]
145156

146157

147-
def df_in(name: Optional[str] = None, columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None) -> Callable:
158+
def df_in(
159+
name: Optional[str] = None, columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None
160+
) -> Callable[[Callable[..., R]], Callable[..., R]]:
148161
"""Decorate a function parameter that is a Pandas or Polars DataFrame.
149162
150163
Document the contents of an input parameter. The parameter will be validated in runtime.
@@ -157,12 +170,12 @@ def df_in(name: Optional[str] = None, columns: Optional[ColumnsDef] = None, stri
157170
If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
158171
159172
Returns:
160-
Callable: Decorated function
173+
Callable: Decorated function with preserved return type
161174
"""
162175

163-
def wrapper_df_in(func: Callable) -> Callable:
176+
def wrapper_df_in(func: Callable[..., R]) -> Callable[..., R]:
164177
@wraps(func)
165-
def wrapper(*args: str, **kwargs: Any) -> Any:
178+
def wrapper(*args: Any, **kwargs: Any) -> R:
166179
df = _get_parameter(func, name, *args, **kwargs)
167180
assert isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame), (
168181
f"Wrong parameter type. Expected DataFrame, got {type(df).__name__} instead."
@@ -203,7 +216,7 @@ def _log_output(level: int, func_name: str, df: Any, include_dtypes: bool) -> No
203216
)
204217

205218

206-
def df_log(level: int = logging.DEBUG, include_dtypes: bool = False) -> Callable:
219+
def df_log(level: int = logging.DEBUG, include_dtypes: bool = False) -> Callable[[Callable[..., T]], Callable[..., T]]:
207220
"""Decorate a function that consumes or produces a Pandas DataFrame or both.
208221
209222
Logs the columns of the consumed and/or produced DataFrame.
@@ -213,15 +226,16 @@ def df_log(level: int = logging.DEBUG, include_dtypes: bool = False) -> Callable
213226
include_dtypes (bool, optional): When set to True, will log also the dtypes of each column. Defaults to False.
214227
215228
Returns:
216-
Callable: Decorated function.
229+
Callable: Decorated function with preserved return type.
217230
"""
218231

219-
def wrapper_df_log(func: Callable) -> Callable:
232+
def wrapper_df_log(func: Callable[..., T]) -> Callable[..., T]:
220233
@wraps(func)
221-
def wrapper(*args: str, **kwargs: Any) -> Any:
234+
def wrapper(*args: Any, **kwargs: Any) -> T:
222235
_log_input(level, func.__name__, _get_parameter(func, None, *args, **kwargs), include_dtypes)
223236
result = func(*args, **kwargs)
224237
_log_output(level, func.__name__, result, include_dtypes)
238+
return result # Added missing return statement
225239

226240
return wrapper
227241

mypy.ini

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
11
[mypy]
22
python_version = 3.10
3-
warn_return_any = True
3+
# Enable all strict options
4+
strict = True
5+
# The strict flag enables:
6+
# - disallow_untyped_calls
7+
# - disallow_untyped_defs
8+
# - disallow_incomplete_defs
9+
# - check_untyped_defs
10+
# - disallow_untyped_decorators
11+
# - no_implicit_optional
12+
# - warn_redundant_casts
13+
# - warn_unused_ignores
14+
# - warn_no_return
15+
# - warn_return_any
16+
# - warn_unreachable
17+
# - disallow_any_generics
18+
# Additional strict options not included in strict flag
19+
warn_incomplete_stub = True
20+
warn_redundant_casts = True
421
warn_unused_configs = True
5-
ignore_missing_imports = True
6-
disallow_untyped_calls = True
7-
disallow_untyped_defs = True
8-
disallow_incomplete_defs = True
9-
check_untyped_defs = True
1022
warn_unused_ignores = True
11-
warn_redundant_casts = True
12-
no_implicit_optional = True
23+
implicit_reexport = False
24+
strict_optional = True
25+
strict_equality = True
26+
27+
# Relax rules for tests
28+
[mypy-tests.*]
29+
disallow_any_unimported = False
30+
disallow_any_decorated = False

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "daffy"
3-
version = "0.10.1"
3+
version = "0.11.0"
44
description = "Function decorators for Pandas and Polars Dataframe column name and data type validation"
55
authors = [
66
{ name="Janne Sinivirta", email="[email protected]" },
@@ -42,6 +42,7 @@ dev = [
4242
"coverage[toml]>=7.3.2",
4343
"pydocstyle>=6.3.0",
4444
"ruff>=0.9.1",
45+
"pandas-stubs>=2.2.2.240807",
4546
]
4647

4748
[build-system]
@@ -83,4 +84,4 @@ line-length = 120
8384
target-version = "py39"
8485

8586
[tool.ruff.lint]
86-
select = ["F", "E", "W", "I", "N"]
87+
select = ["F", "E", "W", "I", "N"]

tests/test_decorators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any
2-
31
import pandas as pd
42

53
from daffy import df_in, df_out
@@ -8,8 +6,10 @@
86
def test_decorator_combinations(basic_pandas_df: pd.DataFrame, extended_pandas_df: pd.DataFrame) -> None:
97
@df_in(columns=["Brand", "Price"])
108
@df_out(columns=["Brand", "Price", "Year"])
11-
def test_fn(my_input: Any) -> Any:
9+
def test_fn(my_input: pd.DataFrame) -> pd.DataFrame:
1210
my_input["Year"] = list(extended_pandas_df["Year"])
1311
return my_input
1412

15-
pd.testing.assert_frame_equal(extended_pandas_df, test_fn(basic_pandas_df.copy()))
13+
result = test_fn(basic_pandas_df.copy())
14+
assert isinstance(result, pd.DataFrame) # Confirm the type for mypy
15+
pd.testing.assert_frame_equal(extended_pandas_df, result)

tests/test_df_log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_fn(foo_df: pd.DataFrame) -> pd.DataFrame:
6060

6161
def test_log_df_with_dtypes_polars(basic_polars_df: pl.DataFrame, mocker: MockerFixture) -> None:
6262
@df_log(include_dtypes=True)
63-
def test_fn(foo_df: pd.DataFrame) -> pd.DataFrame:
63+
def test_fn(foo_df: pl.DataFrame) -> pl.DataFrame:
6464
return basic_polars_df
6565

6666
mock_log = mocker.patch("daffy.decorators.logging.log")

tests/test_df_out.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def test_wrong_return_type() -> None:
12-
@df_out()
12+
@df_out() # type: ignore[arg-type]
1313
def test_fn() -> int:
1414
return 1
1515

@@ -86,7 +86,10 @@ def test_fn(my_input: Any) -> Any:
8686
return my_input
8787

8888
assert list(basic_pandas_df.columns) == ["Brand", "Price"] # For sanity
89-
pd.testing.assert_frame_equal(extended_pandas_df, test_fn(basic_pandas_df.copy()))
89+
result = test_fn(basic_pandas_df.copy())
90+
# Type check to ensure we get a pandas DataFrame before comparing
91+
assert isinstance(result, pd.DataFrame)
92+
pd.testing.assert_frame_equal(extended_pandas_df, result)
9093

9194

9295
def test_regex_column_pattern_in_output(basic_pandas_df: pd.DataFrame) -> None:

0 commit comments

Comments
 (0)