44import logging
55import re
66from functools import wraps
7- from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , TypeVar , Union
7+ from typing import Any , Callable , Dict , List , Optional , Pattern , Set , Tuple , TypeVar , Union
88from typing import Sequence as Seq # Renamed to avoid collision
99
1010import pandas as pd
@@ -40,6 +40,38 @@ def _assert_is_dataframe(obj: Any, context: str) -> None:
4040 raise AssertionError (f"Wrong { context } . Expected DataFrame, got { type (obj ).__name__ } instead." )
4141
4242
43+ def _make_param_info (param_name : Optional [str ]) -> str :
44+ return f" in parameter '{ param_name } '" if param_name else ""
45+
46+
47+ def _validate_column (
48+ column_spec : Union [str , RegexColumnDef ], df : DataFrameType , expected_dtype : Any = None
49+ ) -> Tuple [List [str ], List [Tuple [str , Any , Any ]], Set [str ]]:
50+ """Validate a single column specification against a DataFrame."""
51+ missing_columns = []
52+ dtype_mismatches = []
53+ matched_by_regex = set ()
54+
55+ if isinstance (column_spec , str ):
56+ if column_spec not in df .columns :
57+ missing_columns .append (column_spec )
58+ elif expected_dtype is not None and df [column_spec ].dtype != expected_dtype :
59+ dtype_mismatches .append ((column_spec , df [column_spec ].dtype , expected_dtype ))
60+ elif _is_regex_pattern (column_spec ):
61+ pattern_str , _ = column_spec
62+ matches = _match_column_with_regex (column_spec , list (df .columns ))
63+ if not matches :
64+ missing_columns .append (pattern_str )
65+ else :
66+ matched_by_regex .update (matches )
67+ if expected_dtype is not None :
68+ for matched_col in matches :
69+ if df [matched_col ].dtype != expected_dtype :
70+ dtype_mismatches .append ((matched_col , df [matched_col ].dtype , expected_dtype ))
71+
72+ return missing_columns , dtype_mismatches , matched_by_regex
73+
74+
4375def _match_column_with_regex (column_pattern : RegexColumnDef , df_columns : List [str ]) -> List [str ]:
4476 _ , pattern = column_pattern
4577 return [col for col in df_columns if pattern .match (col )]
@@ -62,70 +94,45 @@ def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef
6294def _check_columns (
6395 df : DataFrameType , columns : Union [ColumnsList , ColumnsDict ], strict : bool , param_name : Optional [str ] = None
6496) -> None :
65- missing_columns = []
66- dtype_mismatches = []
67- matched_by_regex = set ()
97+ all_missing_columns = []
98+ all_dtype_mismatches = []
99+ all_matched_by_regex = set ()
68100
69101 if isinstance (columns , list ):
70102 processed_columns = _compile_regex_patterns (columns )
71- for column in processed_columns :
72- if isinstance (column , str ):
73- if column not in df .columns :
74- missing_columns .append (column )
75- elif _is_regex_pattern (column ):
76- matches = _match_column_with_regex (column , list (df .columns ))
77- if not matches :
78- missing_columns .append (column [0 ])
79- else :
80- matched_by_regex .update (matches )
81-
103+ for column_spec in processed_columns :
104+ missing , mismatches , matched = _validate_column (column_spec , df )
105+ all_missing_columns .extend (missing )
106+ all_dtype_mismatches .extend (mismatches )
107+ all_matched_by_regex .update (matched )
82108 else : # isinstance(columns, dict)
83109 assert isinstance (columns , dict )
84- processed_dict : Dict [Union [str , RegexColumnDef ], Any ] = {}
85- for column , dtype in columns .items ():
86- if isinstance (column , str ) and _is_regex_string (column ):
87- processed_dict [_compile_regex_pattern (column )] = dtype
88- else :
89- processed_dict [column ] = dtype
90-
91- for column_key , dtype in processed_dict .items ():
92- if isinstance (column_key , str ):
93- if column_key not in df .columns :
94- missing_columns .append (column_key )
95- elif df [column_key ].dtype != dtype :
96- dtype_mismatches .append ((column_key , df [column_key ].dtype , dtype ))
97- elif _is_regex_pattern (column_key ):
98- pattern_str , compiled_pattern = column_key
99- matches = _match_column_with_regex (column_key , list (df .columns ))
100- if not matches :
101- missing_columns .append (pattern_str )
102- else :
103- for matched_col in matches :
104- matched_by_regex .add (matched_col )
105- if df [matched_col ].dtype != dtype :
106- dtype_mismatches .append ((matched_col , df [matched_col ].dtype , dtype ))
107-
108- if missing_columns :
109- param_info = f" in parameter '{ param_name } '" if param_name else ""
110- raise AssertionError (f"Missing columns: { missing_columns } { param_info } . Got { _describe_pd (df )} " )
111-
112- if dtype_mismatches :
113- param_info = f" in parameter '{ param_name } '" if param_name else ""
114- mismatches = ", " .join (
115- [
116- f"Column { col } { param_info } has wrong dtype. Was { was } , expected { expected } "
117- for col , was , expected in dtype_mismatches
118- ]
110+ for column , expected_dtype in columns .items ():
111+ column_spec = (
112+ _compile_regex_pattern (column ) if isinstance (column , str ) and _is_regex_string (column ) else column
113+ )
114+ missing , mismatches , matched = _validate_column (column_spec , df , expected_dtype )
115+ all_missing_columns .extend (missing )
116+ all_dtype_mismatches .extend (mismatches )
117+ all_matched_by_regex .update (matched )
118+
119+ param_info = _make_param_info (param_name )
120+
121+ if all_missing_columns :
122+ raise AssertionError (f"Missing columns: { all_missing_columns } { param_info } . Got { _describe_pd (df )} " )
123+
124+ if all_dtype_mismatches :
125+ mismatch_descriptions = ", " .join (
126+ f"Column { col } { param_info } has wrong dtype. Was { was } , expected { expected } "
127+ for col , was , expected in all_dtype_mismatches
119128 )
120- raise AssertionError (mismatches )
129+ raise AssertionError (mismatch_descriptions )
121130
122131 if strict :
123132 explicit_columns = {col for col in columns if isinstance (col , str )}
124- allowed_columns = explicit_columns .union (matched_by_regex )
133+ allowed_columns = explicit_columns .union (all_matched_by_regex )
125134 extra_columns = set (df .columns ) - allowed_columns
126-
127135 if extra_columns :
128- param_info = f" in parameter '{ param_name } '" if param_name else ""
129136 raise AssertionError (f"DataFrame{ param_info } contained unexpected column(s): { ', ' .join (extra_columns )} " )
130137
131138
@@ -165,37 +172,27 @@ def wrapper(*args: Any, **kwargs: Any) -> DF:
165172
166173def _get_parameter (func : Callable [..., Any ], name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> Any :
167174 if not name :
168- if len (args ) > 0 :
169- return args [0 ]
170- if kwargs :
171- return next (iter (kwargs .values ()))
172- return None
175+ return args [0 ] if args else next (iter (kwargs .values ()), None )
173176
174- if name not in kwargs :
175- func_params_in_order = list (inspect .signature (func ).parameters .keys ())
176- parameter_location = func_params_in_order .index (name )
177- return args [parameter_location ]
177+ if name in kwargs :
178+ return kwargs [name ]
178179
179- return kwargs [name ]
180+ func_params_in_order = list (inspect .signature (func ).parameters .keys ())
181+ parameter_location = func_params_in_order .index (name )
182+ return args [parameter_location ]
180183
181184
182185def _get_parameter_name (
183186 func : Callable [..., Any ], name : Optional [str ] = None , * args : Any , ** kwargs : Any
184187) -> Optional [str ]:
185- """Get the actual parameter name being validated."""
186188 if name :
187189 return name
188190
189- # If no name specified, try to get the first parameter name
190- if len (args ) > 0 :
191- # Get the first parameter name from the function signature
191+ if args :
192192 func_params_in_order = list (inspect .signature (func ).parameters .keys ())
193193 return func_params_in_order [0 ]
194- elif kwargs :
195- # Return the first keyword argument name
196- return next (iter (kwargs .keys ()))
197194
198- return None
195+ return next ( iter ( kwargs . keys ()), None )
199196
200197
201198def df_in (
@@ -246,19 +243,13 @@ def _describe_pd(df: DataFrameType, include_dtypes: bool = False) -> str:
246243
247244
248245def _log_input (level : int , func_name : str , df : Any , include_dtypes : bool ) -> None :
249- if isinstance (df , pd .DataFrame ) or isinstance (df , pl .DataFrame ):
250- logging .log (
251- level ,
252- f"Function { func_name } parameters contained a DataFrame: { _describe_pd (df , include_dtypes )} " ,
253- )
246+ if isinstance (df , (pd .DataFrame , pl .DataFrame )):
247+ logging .log (level , f"Function { func_name } parameters contained a DataFrame: { _describe_pd (df , include_dtypes )} " )
254248
255249
256250def _log_output (level : int , func_name : str , df : Any , include_dtypes : bool ) -> None :
257- if isinstance (df , pd .DataFrame ) or isinstance (df , pl .DataFrame ):
258- logging .log (
259- level ,
260- f"Function { func_name } returned a DataFrame: { _describe_pd (df , include_dtypes )} " ,
261- )
251+ if isinstance (df , (pd .DataFrame , pl .DataFrame )):
252+ logging .log (level , f"Function { func_name } returned a DataFrame: { _describe_pd (df , include_dtypes )} " )
262253
263254
264255def df_log (level : int = logging .DEBUG , include_dtypes : bool = False ) -> Callable [[Callable [..., T ]], Callable [..., T ]]:
0 commit comments