44import logging
55import re
66from functools import wraps
7- from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , Union
7+ from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , TypeVar , Union
88
99import pandas as pd
1010import polars as pl
1111
12+ # Import fully qualified types to satisfy disallow_any_unimported
13+ from pandas import DataFrame as PandasDataFrame
14+ from polars import DataFrame as PolarsDataFrame
15+
1216from daffy .config import get_strict
1317
14- # New type definition to support regex patterns
15- RegexColumnDef = Tuple [str , Pattern ] # Tuple of (pattern_str, compiled_pattern)
16- ColumnsDef = Union [List , Dict , List [Union [str , RegexColumnDef ]]]
17- DataFrameType = Union [pd .DataFrame , pl .DataFrame ]
18+ # Type variables for preserving return types
19+ T = TypeVar ("T" )
20+ R = TypeVar ("R" )
21+
22+
23+ # Improved type definitions to support regex patterns
24+ RegexColumnDef = Tuple [str , Pattern [str ]] # Tuple of (pattern_str, compiled_pattern)
25+ ColumnsDef = Union [List [Union [str , RegexColumnDef ]], Dict [str , Any ]]
26+ DataFrameType = Union [PandasDataFrame , PolarsDataFrame ]
1827
1928
2029def _is_regex_pattern (column : Any ) -> bool :
@@ -30,9 +39,9 @@ def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[st
3039 return [col for col in df_columns if pattern .match (col )]
3140
3241
33- def _compile_regex_patterns (columns : List ) -> List :
42+ def _compile_regex_patterns (columns : List [ Any ] ) -> List [ Union [ str , RegexColumnDef ]] :
3443 """Compile regex patterns in the column list."""
35- result = []
44+ result : List [ Union [ str , RegexColumnDef ]] = []
3645 for col in columns :
3746 if isinstance (col , str ) and col .startswith ("r/" ) and col .endswith ("/" ):
3847 # Pattern is in the format "r/pattern/"
@@ -97,7 +106,9 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
97106 raise AssertionError (f"DataFrame contained unexpected column(s): { ', ' .join (extra_columns )} " )
98107
99108
100- def df_out (columns : Optional [ColumnsDef ] = None , strict : Optional [bool ] = None ) -> Callable :
109+ def df_out (
110+ columns : Optional [ColumnsDef ] = None , strict : Optional [bool ] = None
111+ ) -> Callable [[Callable [..., DataFrameType ]], Callable [..., DataFrameType ]]:
101112 """Decorate a function that returns a Pandas or Polars DataFrame.
102113
103114 Document the return value of a function. The return value will be validated in runtime.
@@ -109,12 +120,12 @@ def df_out(columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None)
109120 If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
110121
111122 Returns:
112- Callable: Decorated function
123+ Callable: Decorated function with preserved DataFrame return type
113124 """
114125
115- def wrapper_df_out (func : Callable ) -> Callable :
126+ def wrapper_df_out (func : Callable [..., DataFrameType ] ) -> Callable [..., DataFrameType ] :
116127 @wraps (func )
117- def wrapper (* args : str , ** kwargs : Any ) -> Any :
128+ def wrapper (* args : Any , ** kwargs : Any ) -> DataFrameType :
118129 result = func (* args , ** kwargs )
119130 assert isinstance (result , pd .DataFrame ) or isinstance (result , pl .DataFrame ), (
120131 f"Wrong return type. Expected DataFrame, got { type (result )} "
@@ -128,7 +139,7 @@ def wrapper(*args: str, **kwargs: Any) -> Any:
128139 return wrapper_df_out
129140
130141
131- def _get_parameter (func : Callable , name : Optional [str ] = None , * args : str , ** kwargs : Any ) -> DataFrameType :
142+ def _get_parameter (func : Callable [..., Any ], name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> Any :
132143 if not name :
133144 if len (args ) > 0 :
134145 return args [0 ]
@@ -144,7 +155,9 @@ def _get_parameter(func: Callable, name: Optional[str] = None, *args: str, **kwa
144155 return kwargs [name ]
145156
146157
147- def df_in (name : Optional [str ] = None , columns : Optional [ColumnsDef ] = None , strict : Optional [bool ] = None ) -> Callable :
158+ def df_in (
159+ name : Optional [str ] = None , columns : Optional [ColumnsDef ] = None , strict : Optional [bool ] = None
160+ ) -> Callable [[Callable [..., R ]], Callable [..., R ]]:
148161 """Decorate a function parameter that is a Pandas or Polars DataFrame.
149162
150163 Document the contents of an input parameter. The parameter will be validated in runtime.
@@ -157,12 +170,12 @@ def df_in(name: Optional[str] = None, columns: Optional[ColumnsDef] = None, stri
157170 If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
158171
159172 Returns:
160- Callable: Decorated function
173+ Callable: Decorated function with preserved return type
161174 """
162175
163- def wrapper_df_in (func : Callable ) -> Callable :
176+ def wrapper_df_in (func : Callable [..., R ] ) -> Callable [..., R ] :
164177 @wraps (func )
165- def wrapper (* args : str , ** kwargs : Any ) -> Any :
178+ def wrapper (* args : Any , ** kwargs : Any ) -> R :
166179 df = _get_parameter (func , name , * args , ** kwargs )
167180 assert isinstance (df , pd .DataFrame ) or isinstance (df , pl .DataFrame ), (
168181 f"Wrong parameter type. Expected DataFrame, got { type (df ).__name__ } instead."
@@ -203,7 +216,7 @@ def _log_output(level: int, func_name: str, df: Any, include_dtypes: bool) -> No
203216 )
204217
205218
206- def df_log (level : int = logging .DEBUG , include_dtypes : bool = False ) -> Callable :
219+ def df_log (level : int = logging .DEBUG , include_dtypes : bool = False ) -> Callable [[ Callable [..., T ]], Callable [..., T ]] :
207220 """Decorate a function that consumes or produces a Pandas DataFrame or both.
208221
209222 Logs the columns of the consumed and/or produced DataFrame.
@@ -213,15 +226,16 @@ def df_log(level: int = logging.DEBUG, include_dtypes: bool = False) -> Callable
213226 include_dtypes (bool, optional): When set to True, will log also the dtypes of each column. Defaults to False.
214227
215228 Returns:
216- Callable: Decorated function.
229+ Callable: Decorated function with preserved return type .
217230 """
218231
219- def wrapper_df_log (func : Callable ) -> Callable :
232+ def wrapper_df_log (func : Callable [..., T ] ) -> Callable [..., T ] :
220233 @wraps (func )
221- def wrapper (* args : str , ** kwargs : Any ) -> Any :
234+ def wrapper (* args : Any , ** kwargs : Any ) -> T :
222235 _log_input (level , func .__name__ , _get_parameter (func , None , * args , ** kwargs ), include_dtypes )
223236 result = func (* args , ** kwargs )
224237 _log_output (level , func .__name__ , result , include_dtypes )
238+ return result # Added missing return statement
225239
226240 return wrapper
227241
0 commit comments