44import logging
55import re
66from functools import wraps
7- from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , TypeVar , Union
7+ from typing import Any , Callable , Dict , List , Optional , Pattern , Set , Tuple , TypeVar , Union
88from typing import Sequence as Seq # Renamed to avoid collision
99
1010import pandas as pd
@@ -35,113 +35,104 @@ def _is_regex_pattern(column: Any) -> bool:
3535 )
3636
3737
38+ def _assert_is_dataframe (obj : Any , context : str ) -> None :
39+ if not isinstance (obj , (pd .DataFrame , pl .DataFrame )):
40+ raise AssertionError (f"Wrong { context } . Expected DataFrame, got { type (obj ).__name__ } instead." )
41+
42+
43+ def _make_param_info (param_name : Optional [str ]) -> str :
44+ return f" in parameter '{ param_name } '" if param_name else ""
45+
46+
47+ def _validate_column (
48+ column_spec : Union [str , RegexColumnDef ], df : DataFrameType , expected_dtype : Any = None
49+ ) -> Tuple [List [str ], List [Tuple [str , Any , Any ]], Set [str ]]:
50+ """Validate a single column specification against a DataFrame."""
51+ missing_columns = []
52+ dtype_mismatches = []
53+ matched_by_regex = set ()
54+
55+ if isinstance (column_spec , str ):
56+ if column_spec not in df .columns :
57+ missing_columns .append (column_spec )
58+ elif expected_dtype is not None and df [column_spec ].dtype != expected_dtype :
59+ dtype_mismatches .append ((column_spec , df [column_spec ].dtype , expected_dtype ))
60+ elif _is_regex_pattern (column_spec ):
61+ pattern_str , _ = column_spec
62+ matches = _match_column_with_regex (column_spec , list (df .columns ))
63+ if not matches :
64+ missing_columns .append (pattern_str )
65+ else :
66+ matched_by_regex .update (matches )
67+ if expected_dtype is not None :
68+ for matched_col in matches :
69+ if df [matched_col ].dtype != expected_dtype :
70+ dtype_mismatches .append ((matched_col , df [matched_col ].dtype , expected_dtype ))
71+
72+ return missing_columns , dtype_mismatches , matched_by_regex
73+
74+
3875def _match_column_with_regex (column_pattern : RegexColumnDef , df_columns : List [str ]) -> List [str ]:
3976 _ , pattern = column_pattern
4077 return [col for col in df_columns if pattern .match (col )]
4178
4279
80+ def _compile_regex_pattern (pattern_string : str ) -> RegexColumnDef :
81+ pattern_str = pattern_string [2 :- 1 ]
82+ compiled_pattern = re .compile (pattern_str )
83+ return (pattern_string , compiled_pattern )
84+
85+
86+ def _is_regex_string (column : str ) -> bool :
87+ return column .startswith ("r/" ) and column .endswith ("/" )
88+
89+
4390def _compile_regex_patterns (columns : Seq [Any ]) -> List [Union [str , RegexColumnDef ]]:
44- """Compile regex patterns in the column list."""
45- result : List [Union [str , RegexColumnDef ]] = []
46- for col in columns :
47- if isinstance (col , str ) and col .startswith ("r/" ) and col .endswith ("/" ):
48- # Pattern is in the format "r/pattern/"
49- pattern_str = col [2 :- 1 ] # Remove "r/" prefix and "/" suffix
50- compiled_pattern = re .compile (pattern_str )
51- result .append ((col , compiled_pattern ))
52- else :
53- result .append (col )
54- return result
91+ return [_compile_regex_pattern (col ) if isinstance (col , str ) and _is_regex_string (col ) else col for col in columns ]
5592
5693
5794def _check_columns (
5895 df : DataFrameType , columns : Union [ColumnsList , ColumnsDict ], strict : bool , param_name : Optional [str ] = None
5996) -> None :
60- missing_columns = []
61- dtype_mismatches = []
62- matched_by_regex = set ()
97+ all_missing_columns = []
98+ all_dtype_mismatches = []
99+ all_matched_by_regex = set ()
63100
64- # Handle list of column names/patterns
65101 if isinstance (columns , list ):
66- # First, compile any regex patterns
67102 processed_columns = _compile_regex_patterns (columns )
103+ for column_spec in processed_columns :
104+ missing , mismatches , matched = _validate_column (column_spec , df )
105+ all_missing_columns .extend (missing )
106+ all_dtype_mismatches .extend (mismatches )
107+ all_matched_by_regex .update (matched )
108+ else : # isinstance(columns, dict)
109+ assert isinstance (columns , dict )
110+ for column , expected_dtype in columns .items ():
111+ column_spec = (
112+ _compile_regex_pattern (column ) if isinstance (column , str ) and _is_regex_string (column ) else column
113+ )
114+ missing , mismatches , matched = _validate_column (column_spec , df , expected_dtype )
115+ all_missing_columns .extend (missing )
116+ all_dtype_mismatches .extend (mismatches )
117+ all_matched_by_regex .update (matched )
118+
119+ param_info = _make_param_info (param_name )
120+
121+ if all_missing_columns :
122+ raise AssertionError (f"Missing columns: { all_missing_columns } { param_info } . Got { _describe_pd (df )} " )
68123
69- for column in processed_columns :
70- if isinstance (column , str ):
71- # Direct column name match
72- if column not in df .columns :
73- missing_columns .append (column )
74- elif _is_regex_pattern (column ):
75- # Regex pattern match
76- matches = _match_column_with_regex (column , list (df .columns ))
77- if not matches :
78- missing_columns .append (column [0 ]) # Add the original pattern string
79- else :
80- matched_by_regex .update (matches )
81-
82- # Handle dictionary of column names/types
83- elif isinstance (columns , dict ):
84- # First, process dictionary keys for regex patterns
85- processed_dict : Dict [Union [str , RegexColumnDef ], Any ] = {}
86- for column , dtype in columns .items ():
87- if isinstance (column , str ) and column .startswith ("r/" ) and column .endswith ("/" ):
88- # Pattern is in the format "r/pattern/"
89- pattern_str = column [2 :- 1 ] # Remove "r/" prefix and "/" suffix
90- compiled_pattern = re .compile (pattern_str )
91- processed_dict [(column , compiled_pattern )] = dtype
92- else :
93- processed_dict [column ] = dtype
94-
95- # Check each column against dictionary keys
96- regex_matched_columns = set ()
97- for column_key , dtype in processed_dict .items ():
98- if isinstance (column_key , str ):
99- # Direct column name match
100- if column_key not in df .columns :
101- missing_columns .append (column_key )
102- elif df [column_key ].dtype != dtype :
103- dtype_mismatches .append ((column_key , df [column_key ].dtype , dtype ))
104- elif _is_regex_pattern (column_key ):
105- # Regex pattern match
106- pattern_str , compiled_pattern = column_key
107- matches = _match_column_with_regex (column_key , list (df .columns ))
108- if not matches :
109- missing_columns .append (pattern_str ) # Add the original pattern string
110- else :
111- for matched_col in matches :
112- matched_by_regex .add (matched_col )
113- regex_matched_columns .add (matched_col )
114- if df [matched_col ].dtype != dtype :
115- dtype_mismatches .append ((matched_col , df [matched_col ].dtype , dtype ))
116-
117- if missing_columns :
118- param_info = f" in parameter '{ param_name } '" if param_name else ""
119- raise AssertionError (f"Missing columns: { missing_columns } { param_info } . Got { _describe_pd (df )} " )
120-
121- if dtype_mismatches :
122- param_info = f" in parameter '{ param_name } '" if param_name else ""
123- mismatches = ", " .join (
124- [
125- f"Column { col } { param_info } has wrong dtype. Was { was } , expected { expected } "
126- for col , was , expected in dtype_mismatches
127- ]
124+ if all_dtype_mismatches :
125+ mismatch_descriptions = ", " .join (
126+ f"Column { col } { param_info } has wrong dtype. Was { was } , expected { expected } "
127+ for col , was , expected in all_dtype_mismatches
128128 )
129- raise AssertionError (mismatches )
129+ raise AssertionError (mismatch_descriptions )
130130
131131 if strict :
132- if isinstance (columns , list ):
133- # For regex matches, we need to consider all matched columns
134- explicit_columns = {col for col in columns if isinstance (col , str )}
135- allowed_columns = explicit_columns .union (matched_by_regex )
136- extra_columns = set (df .columns ) - allowed_columns
137- else :
138- # For dict with regex patterns, we need to handle both direct and regex matches
139- explicit_columns = {col for col in columns if isinstance (col , str )}
140- allowed_columns = explicit_columns .union (matched_by_regex )
141- extra_columns = set (df .columns ) - allowed_columns
142-
132+ explicit_columns = {col for col in columns if isinstance (col , str )}
133+ allowed_columns = explicit_columns .union (all_matched_by_regex )
134+ extra_columns = set (df .columns ) - allowed_columns
143135 if extra_columns :
144- param_info = f" in parameter '{ param_name } '" if param_name else ""
145136 raise AssertionError (f"DataFrame{ param_info } contained unexpected column(s): { ', ' .join (extra_columns )} " )
146137
147138
@@ -169,9 +160,7 @@ def wrapper_df_out(func: Callable[..., DF]) -> Callable[..., DF]:
169160 @wraps (func )
170161 def wrapper (* args : Any , ** kwargs : Any ) -> DF :
171162 result = func (* args , ** kwargs )
172- assert isinstance (result , pd .DataFrame ) or isinstance (result , pl .DataFrame ), (
173- f"Wrong return type. Expected DataFrame, got { type (result )} "
174- )
163+ _assert_is_dataframe (result , "return type" )
175164 if columns :
176165 _check_columns (result , columns , get_strict (strict ))
177166 return result
@@ -183,38 +172,27 @@ def wrapper(*args: Any, **kwargs: Any) -> DF:
183172
184173def _get_parameter (func : Callable [..., Any ], name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> Any :
185174 if not name :
186- if len (args ) > 0 :
187- return args [0 ]
188- if kwargs :
189- return next (iter (kwargs .values ()))
190- return None
175+ return args [0 ] if args else next (iter (kwargs .values ()), None )
191176
192- if name and (name not in kwargs ):
193- func_params_in_order = list (inspect .signature (func ).parameters .keys ())
194- parameter_location = func_params_in_order .index (name )
195- return args [parameter_location ]
177+ if name in kwargs :
178+ return kwargs [name ]
196179
197- return kwargs [name ]
180+ func_params_in_order = list (inspect .signature (func ).parameters .keys ())
181+ parameter_location = func_params_in_order .index (name )
182+ return args [parameter_location ]
198183
199184
200185def _get_parameter_name (
201186 func : Callable [..., Any ], name : Optional [str ] = None , * args : Any , ** kwargs : Any
202187) -> Optional [str ]:
203- """Get the actual parameter name being validated."""
204188 if name :
205189 return name
206190
207- # If no name specified, try to get the first parameter name
208- if len (args ) > 0 :
209- # Get the first parameter name from the function signature
191+ if args :
210192 func_params_in_order = list (inspect .signature (func ).parameters .keys ())
211- if func_params_in_order :
212- return func_params_in_order [0 ]
213- elif kwargs :
214- # Return the first keyword argument name
215- return next (iter (kwargs .keys ()))
193+ return func_params_in_order [0 ]
216194
217- return None
195+ return next ( iter ( kwargs . keys ()), None )
218196
219197
220198def df_in (
@@ -243,9 +221,7 @@ def wrapper_df_in(func: Callable[..., R]) -> Callable[..., R]:
243221 def wrapper (* args : Any , ** kwargs : Any ) -> R :
244222 df = _get_parameter (func , name , * args , ** kwargs )
245223 param_name = _get_parameter_name (func , name , * args , ** kwargs )
246- assert isinstance (df , pd .DataFrame ) or isinstance (df , pl .DataFrame ), (
247- f"Wrong parameter type. Expected DataFrame, got { type (df ).__name__ } instead."
248- )
224+ _assert_is_dataframe (df , "parameter type" )
249225 if columns :
250226 _check_columns (df , columns , get_strict (strict ), param_name )
251227 return func (* args , ** kwargs )
@@ -261,25 +237,19 @@ def _describe_pd(df: DataFrameType, include_dtypes: bool = False) -> str:
261237 if isinstance (df , pd .DataFrame ):
262238 readable_dtypes = [dtype .name for dtype in df .dtypes ]
263239 result += f" with dtypes { readable_dtypes } "
264- if isinstance ( df , pl . DataFrame ) :
240+ else :
265241 result += f" with dtypes { df .dtypes } "
266242 return result
267243
268244
269245def _log_input (level : int , func_name : str , df : Any , include_dtypes : bool ) -> None :
270- if isinstance (df , pd .DataFrame ) or isinstance (df , pl .DataFrame ):
271- logging .log (
272- level ,
273- f"Function { func_name } parameters contained a DataFrame: { _describe_pd (df , include_dtypes )} " ,
274- )
246+ if isinstance (df , (pd .DataFrame , pl .DataFrame )):
247+ logging .log (level , f"Function { func_name } parameters contained a DataFrame: { _describe_pd (df , include_dtypes )} " )
275248
276249
277250def _log_output (level : int , func_name : str , df : Any , include_dtypes : bool ) -> None :
278- if isinstance (df , pd .DataFrame ) or isinstance (df , pl .DataFrame ):
279- logging .log (
280- level ,
281- f"Function { func_name } returned a DataFrame: { _describe_pd (df , include_dtypes )} " ,
282- )
251+ if isinstance (df , (pd .DataFrame , pl .DataFrame )):
252+ logging .log (level , f"Function { func_name } returned a DataFrame: { _describe_pd (df , include_dtypes )} " )
283253
284254
285255def df_log (level : int = logging .DEBUG , include_dtypes : bool = False ) -> Callable [[Callable [..., T ]], Callable [..., T ]]:
0 commit comments