@@ -35,6 +35,13 @@ def _is_regex_pattern(column: Any) -> bool:
3535 )
3636
3737
38+ def _as_regex_pattern (column : Union [str , RegexColumnDef ]) -> Optional [RegexColumnDef ]:
39+ """Convert column to RegexColumnDef if it is a regex pattern, otherwise return None."""
40+ if _is_regex_pattern (column ):
41+ return column # type: ignore[return-value] # We know it's the right type after the check
42+ return None
43+
44+
3845def _assert_is_dataframe (obj : Any , context : str ) -> None :
3946 if not isinstance (obj , (pd .DataFrame , pl .DataFrame )):
4047 raise AssertionError (f"Wrong { context } . Expected DataFrame, got { type (obj ).__name__ } instead." )
@@ -44,32 +51,39 @@ def _make_param_info(param_name: Optional[str]) -> str:
4451 return f" in parameter '{ param_name } '" if param_name else ""
4552
4653
47- def _validate_column (
48- column_spec : Union [str , RegexColumnDef ], df : DataFrameType , expected_dtype : Any = None
49- ) -> Tuple [List [str ], List [Tuple [str , Any , Any ]], Set [str ]]:
50- """Validate a single column specification against a DataFrame."""
51- missing_columns = []
52- dtype_mismatches = []
53- matched_by_regex = set ()
54-
54+ def _find_missing_columns (column_spec : Union [str , RegexColumnDef ], df_columns : List [str ]) -> List [str ]:
55+ """Find missing columns for a single column specification."""
5556 if isinstance (column_spec , str ):
56- if column_spec not in df .columns :
57- missing_columns .append (column_spec )
58- elif expected_dtype is not None and df [column_spec ].dtype != expected_dtype :
59- dtype_mismatches .append ((column_spec , df [column_spec ].dtype , expected_dtype ))
57+ return [column_spec ] if column_spec not in df_columns else []
6058 elif _is_regex_pattern (column_spec ):
6159 pattern_str , _ = column_spec
62- matches = _match_column_with_regex (column_spec , list (df .columns ))
63- if not matches :
64- missing_columns .append (pattern_str )
65- else :
66- matched_by_regex .update (matches )
67- if expected_dtype is not None :
68- for matched_col in matches :
69- if df [matched_col ].dtype != expected_dtype :
70- dtype_mismatches .append ((matched_col , df [matched_col ].dtype , expected_dtype ))
60+ matches = _match_column_with_regex (column_spec , df_columns )
61+ return [pattern_str ] if not matches else []
62+ return []
63+
64+
65+ def _find_dtype_mismatches (
66+ column_spec : Union [str , RegexColumnDef ], df : DataFrameType , expected_dtype : Any , df_columns : List [str ]
67+ ) -> List [Tuple [str , Any , Any ]]:
68+ """Find dtype mismatches for a single column specification."""
69+ mismatches = []
70+ if isinstance (column_spec , str ):
71+ if column_spec in df_columns and df [column_spec ].dtype != expected_dtype :
72+ mismatches .append ((column_spec , df [column_spec ].dtype , expected_dtype ))
73+ elif _is_regex_pattern (column_spec ):
74+ matches = _match_column_with_regex (column_spec , df_columns )
75+ for matched_col in matches :
76+ if df [matched_col ].dtype != expected_dtype :
77+ mismatches .append ((matched_col , df [matched_col ].dtype , expected_dtype ))
78+ return mismatches
79+
7180
72- return missing_columns , dtype_mismatches , matched_by_regex
81+ def _find_regex_matches (column_spec : Union [str , RegexColumnDef ], df_columns : List [str ]) -> Set [str ]:
82+ """Find regex matches for a single column specification."""
83+ regex_pattern = _as_regex_pattern (column_spec )
84+ if regex_pattern :
85+ return set (_match_column_with_regex (regex_pattern , df_columns ))
86+ return set ()
7387
7488
7589def _match_column_with_regex (column_pattern : RegexColumnDef , df_columns : List [str ]) -> List [str ]:
@@ -94,27 +108,25 @@ def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef
94108def _check_columns (
95109 df : DataFrameType , columns : Union [ColumnsList , ColumnsDict ], strict : bool , param_name : Optional [str ] = None
96110) -> None :
111+ df_columns = list (df .columns ) # Cache the column list conversion
97112 all_missing_columns = []
98113 all_dtype_mismatches = []
99114 all_matched_by_regex = set ()
100115
101116 if isinstance (columns , list ):
102117 processed_columns = _compile_regex_patterns (columns )
103118 for column_spec in processed_columns :
104- missing , mismatches , matched = _validate_column (column_spec , df )
105- all_missing_columns .extend (missing )
106- all_dtype_mismatches .extend (mismatches )
107- all_matched_by_regex .update (matched )
119+ all_missing_columns .extend (_find_missing_columns (column_spec , df_columns ))
120+ all_matched_by_regex .update (_find_regex_matches (column_spec , df_columns ))
108121 else : # isinstance(columns, dict)
109122 assert isinstance (columns , dict )
110123 for column , expected_dtype in columns .items ():
111124 column_spec = (
112125 _compile_regex_pattern (column ) if isinstance (column , str ) and _is_regex_string (column ) else column
113126 )
114- missing , mismatches , matched = _validate_column (column_spec , df , expected_dtype )
115- all_missing_columns .extend (missing )
116- all_dtype_mismatches .extend (mismatches )
117- all_matched_by_regex .update (matched )
127+ all_missing_columns .extend (_find_missing_columns (column_spec , df_columns ))
128+ all_dtype_mismatches .extend (_find_dtype_mismatches (column_spec , df , expected_dtype , df_columns ))
129+ all_matched_by_regex .update (_find_regex_matches (column_spec , df_columns ))
118130
119131 param_info = _make_param_info (param_name )
120132
@@ -131,7 +143,7 @@ def _check_columns(
131143 if strict :
132144 explicit_columns = {col for col in columns if isinstance (col , str )}
133145 allowed_columns = explicit_columns .union (all_matched_by_regex )
134- extra_columns = set (df . columns ) - allowed_columns
146+ extra_columns = set (df_columns ) - allowed_columns
135147 if extra_columns :
136148 raise AssertionError (f"DataFrame{ param_info } contained unexpected column(s): { ', ' .join (extra_columns )} " )
137149
0 commit comments