Skip to content

Commit 2a62d0a

Browse files
authored
Merge pull request #22 from vertti/fix-typings
Fix typings
2 parents 39db742 + 670d36e commit 2a62d0a

File tree

6 files changed

+203
-25
lines changed

6 files changed

+203
-25
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## 0.13.0
6+
7+
- Fix type annotation issues with decorator parameters that could cause type errors in strict type checking
8+
- Use `Sequence` instead of `List` for better type variance compatibility
9+
- Add test case that validates type compatibility
10+
511
## 0.12.0
612

713
- Add support for regex patterns used with column dtype validation

daffy/decorators.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
from functools import wraps
77
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, TypeVar, Union
8+
from typing import Sequence as Seq # Renamed to avoid collision
89

910
import pandas as pd
1011
import polars as pl
@@ -16,30 +17,30 @@
1617
from daffy.config import get_strict
1718

1819
# Type variables for preserving return types
19-
T = TypeVar("T")
20-
R = TypeVar("R")
20+
T = TypeVar("T") # Generic type var for df_log
21+
DF = TypeVar("DF", bound=Union[PandasDataFrame, PolarsDataFrame])
22+
R = TypeVar("R") # Return type for df_in
2123

24+
RegexColumnDef = Tuple[str, Pattern[str]]
2225

23-
# Improved type definitions to support regex patterns
24-
RegexColumnDef = Tuple[str, Pattern[str]] # Tuple of (pattern_str, compiled_pattern)
25-
ColumnsDef = Union[List[Union[str, RegexColumnDef]], Dict[Union[str, RegexColumnDef], Any]]
26+
ColumnsList = Seq[Union[str, RegexColumnDef]]
27+
ColumnsDict = Dict[Union[str, RegexColumnDef], Any]
28+
ColumnsDef = Union[ColumnsList, ColumnsDict, None]
2629
DataFrameType = Union[PandasDataFrame, PolarsDataFrame]
2730

2831

2932
def _is_regex_pattern(column: Any) -> bool:
30-
"""Check if the column definition is a regex pattern tuple."""
3133
return (
3234
isinstance(column, tuple) and len(column) == 2 and isinstance(column[0], str) and isinstance(column[1], Pattern)
3335
)
3436

3537

3638
def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[str]) -> List[str]:
37-
"""Find all column names that match the regex pattern."""
3839
_, pattern = column_pattern
3940
return [col for col in df_columns if pattern.match(col)]
4041

4142

42-
def _compile_regex_patterns(columns: List[Any]) -> List[Union[str, RegexColumnDef]]:
43+
def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef]]:
4344
"""Compile regex patterns in the column list."""
4445
result: List[Union[str, RegexColumnDef]] = []
4546
for col in columns:
@@ -53,7 +54,7 @@ def _compile_regex_patterns(columns: List[Any]) -> List[Union[str, RegexColumnDe
5354
return result
5455

5556

56-
def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None:
57+
def _check_columns(df: DataFrameType, columns: Union[ColumnsList, ColumnsDict], strict: bool) -> None:
5758
missing_columns = []
5859
dtype_mismatches = []
5960
matched_by_regex = set()
@@ -137,15 +138,16 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
137138

138139

139140
def df_out(
140-
columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None
141-
) -> Callable[[Callable[..., DataFrameType]], Callable[..., DataFrameType]]:
141+
columns: Union[ColumnsList, ColumnsDict, None] = None, strict: Optional[bool] = None
142+
) -> Callable[[Callable[..., DF]], Callable[..., DF]]:
142143
"""Decorate a function that returns a Pandas or Polars DataFrame.
143144
144145
Document the return value of a function. The return value will be validated in runtime.
145146
146147
Args:
147-
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
148-
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
148+
columns (Union[Sequence[str], Dict[str, Any]], optional): Sequence or dict that describes expected columns
149+
of the DataFrame.
150+
Sequence can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
149151
Dict can use regex patterns as keys in format "r/pattern/" to validate dtypes for matching columns.
150152
Defaults to None.
151153
strict (bool, optional): If True, columns must match exactly with no extra columns.
@@ -155,9 +157,9 @@ def df_out(
155157
Callable: Decorated function with preserved DataFrame return type
156158
"""
157159

158-
def wrapper_df_out(func: Callable[..., DataFrameType]) -> Callable[..., DataFrameType]:
160+
def wrapper_df_out(func: Callable[..., DF]) -> Callable[..., DF]:
159161
@wraps(func)
160-
def wrapper(*args: Any, **kwargs: Any) -> DataFrameType:
162+
def wrapper(*args: Any, **kwargs: Any) -> DF:
161163
result = func(*args, **kwargs)
162164
assert isinstance(result, pd.DataFrame) or isinstance(result, pl.DataFrame), (
163165
f"Wrong return type. Expected DataFrame, got {type(result)}"
@@ -188,16 +190,17 @@ def _get_parameter(func: Callable[..., Any], name: Optional[str] = None, *args:
188190

189191

190192
def df_in(
191-
name: Optional[str] = None, columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None
193+
name: Optional[str] = None, columns: Union[ColumnsList, ColumnsDict, None] = None, strict: Optional[bool] = None
192194
) -> Callable[[Callable[..., R]], Callable[..., R]]:
193195
"""Decorate a function parameter that is a Pandas or Polars DataFrame.
194196
195197
Document the contents of an input parameter. The parameter will be validated in runtime.
196198
197199
Args:
198200
name (Optional[str], optional): Name of the parameter that contains a DataFrame. Defaults to None.
199-
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
200-
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
201+
columns (Union[Sequence[str], Dict[str, Any]], optional): Sequence or dict that describes expected columns
202+
of the DataFrame.
203+
Sequence can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
201204
Dict can use regex patterns as keys in format "r/pattern/" to validate dtypes for matching columns.
202205
Defaults to None.
203206
strict (bool, optional): If True, columns must match exactly with no extra columns.

mypy.ini

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,3 @@ warn_unused_ignores = True
2323
implicit_reexport = False
2424
strict_optional = True
2525
strict_equality = True
26-
27-
# Relax rules for tests
28-
[mypy-tests.*]
29-
disallow_any_unimported = False
30-
disallow_any_decorated = False

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "daffy"
3-
version = "0.12.0"
3+
version = "0.13.0"
44
description = "Function decorators for Pandas and Polars Dataframe column name and data type validation"
55
authors = [
66
{ name="Janne Sinivirta", email="[email protected]" },

tests/test_df_out.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def test_wrong_return_type() -> None:
12-
@df_out() # type: ignore[arg-type]
12+
@df_out() # type: ignore[type-var]
1313
def test_fn() -> int:
1414
return 1
1515

tests/test_type_compatibility.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Test type compatibility issues that might occur in client code."""
2+
3+
from typing import Sequence
4+
5+
import pandas as pd
6+
import polars as pl
7+
8+
from daffy import df_in, df_out
9+
10+
11+
# Pass-through function for testing
12+
@df_in(columns=["Brand", "Price"])
13+
def simple_list_columns(df: pd.DataFrame) -> pd.DataFrame:
14+
return df
15+
16+
17+
def test_simple_list_columns() -> None:
18+
"""Test with a simple list of string columns."""
19+
df = pd.DataFrame({"Brand": ["Toyota"], "Price": [25000]})
20+
result = simple_list_columns(df)
21+
assert isinstance(result, pd.DataFrame)
22+
23+
24+
# This would test the Union type DataFrameType compatibility
25+
@df_out(columns=["Brand", "Price"])
26+
def return_dataframe() -> pd.DataFrame:
27+
return pd.DataFrame({"Brand": ["Toyota"], "Price": [25000]})
28+
29+
30+
def function_with_explicit_type_annotations(columns: Sequence[str]) -> None:
31+
@df_in(columns=columns)
32+
def inner_function(df: pd.DataFrame) -> pd.DataFrame:
33+
return df
34+
35+
df = pd.DataFrame({"Brand": ["Toyota"], "Price": [25000]})
36+
inner_function(df)
37+
38+
39+
def test_with_polars() -> None:
40+
df = pl.DataFrame({"Brand": ["Toyota"], "Price": [25000]})
41+
42+
@df_in(columns=["Brand", "Price"])
43+
def inner_function(df_param: pl.DataFrame) -> pl.DataFrame:
44+
return df_param
45+
46+
inner_function(df)
47+
48+
49+
def test_function_with_explicit_type_annotations() -> None:
50+
columns = ["Brand", "Price"]
51+
function_with_explicit_type_annotations(columns)
52+
53+
54+
def test_simple_list_columns_function() -> None:
55+
df = pd.DataFrame({"Brand": ["Toyota"], "Price": [25000]})
56+
simple_list_columns(df)
57+
58+
59+
def test_return_dataframe_function() -> None:
60+
result = return_dataframe()
61+
assert isinstance(result, pd.DataFrame)
62+
63+
64+
def test_dtype_with_regex_pandas() -> None:
65+
"""Test using both dtype validation and regex patterns with pandas."""
66+
# Create a DataFrame with numeric columns following a pattern
67+
df = pd.DataFrame(
68+
{
69+
"measure_2020": [10, 20, 30],
70+
"measure_2021": [15, 25, 35],
71+
"measure_2022": [18, 28, 38],
72+
"category": ["A", "B", "C"],
73+
}
74+
)
75+
76+
# Define a function using both regex patterns and dtype validation
77+
@df_in(
78+
columns={
79+
"category": "object",
80+
"r/measure_\\d{4}/": "int64", # All measure_YYYY columns should be int64
81+
}
82+
)
83+
def process_measures(data: pd.DataFrame) -> pd.DataFrame:
84+
return data
85+
86+
# This should pass type checking and runtime validation
87+
result = process_measures(df)
88+
assert "measure_2020" in result.columns
89+
assert "measure_2021" in result.columns
90+
assert "measure_2022" in result.columns
91+
92+
93+
def test_dtype_with_regex_polars() -> None:
94+
"""Test using both dtype validation and regex patterns with polars."""
95+
# Create a Polars DataFrame with numeric columns following a pattern
96+
df = pl.DataFrame(
97+
{
98+
"measure_2020": [10, 20, 30],
99+
"measure_2021": [15, 25, 35],
100+
"measure_2022": [18, 28, 38],
101+
"category": ["A", "B", "C"],
102+
}
103+
)
104+
105+
# Define a function using both regex patterns and dtype validation
106+
@df_in(
107+
columns={
108+
"category": pl.String,
109+
"r/measure_\\d{4}/": pl.Int64, # All measure_YYYY columns should be Int64
110+
}
111+
)
112+
def process_measures(data: pl.DataFrame) -> pl.DataFrame:
113+
return data
114+
115+
# This should pass type checking and runtime validation
116+
result = process_measures(df)
117+
assert "measure_2020" in result.columns
118+
assert "measure_2021" in result.columns
119+
assert "measure_2022" in result.columns
120+
121+
122+
def test_type_narrowing_with_df_out_pandas() -> None:
123+
"""Test assigning df_out decorated function result to a specific Pandas DataFrame type."""
124+
125+
# Define a function that returns a DataFrame with df_out decoration
126+
@df_out(columns=["name", "value"])
127+
def get_data() -> pd.DataFrame:
128+
return pd.DataFrame({"name": ["A", "B", "C"], "value": [1, 2, 3]})
129+
130+
# The critical test: we should be able to assign the result to a variable
131+
# explicitly typed as pd.DataFrame without mypy errors
132+
result: pd.DataFrame = get_data()
133+
assert "name" in result.columns
134+
assert "value" in result.columns
135+
136+
137+
def test_type_narrowing_with_df_out_polars() -> None:
138+
"""Test assigning df_out decorated function result to a specific Polars DataFrame type."""
139+
140+
# Define a function that returns a DataFrame with df_out decoration
141+
@df_out(columns=["name", "value"])
142+
def get_data() -> pl.DataFrame:
143+
return pl.DataFrame({"name": ["A", "B", "C"], "value": [1, 2, 3]})
144+
145+
# The critical test: we should be able to assign the result to a variable
146+
# explicitly typed as pl.DataFrame without mypy errors
147+
result: pl.DataFrame = get_data()
148+
assert "name" in result.columns
149+
assert "value" in result.columns
150+
151+
152+
def test_df_out_preserves_specific_return_type() -> None:
153+
"""Test that df_out preserves the specific DataFrame return type annotation."""
154+
155+
# Function that specifically returns pandas DataFrame with df_out
156+
@df_out(columns=["col1", "col2"])
157+
def function_with_pandas_df() -> pd.DataFrame:
158+
return pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
159+
160+
# We should be able to assign to a variable typed as pandas DataFrame
161+
# without having to cast or getting type errors
162+
result: pd.DataFrame = function_with_pandas_df()
163+
164+
# Same with a function returning polars DataFrame
165+
@df_out(columns=["col1", "col2"])
166+
def function_with_polars_df() -> pl.DataFrame:
167+
return pl.DataFrame({"col1": [1, 2], "col2": [3, 4]})
168+
169+
# Should be assignable to a variable typed as polars DataFrame
170+
polars_result: pl.DataFrame = function_with_polars_df()
171+
172+
# Both should work at runtime too
173+
assert isinstance(result, pd.DataFrame)
174+
assert isinstance(polars_result, pl.DataFrame)

0 commit comments

Comments
 (0)