2222
2323# Improved type definitions to support regex patterns
2424RegexColumnDef = Tuple [str , Pattern [str ]] # Tuple of (pattern_str, compiled_pattern)
25- ColumnsDef = Union [List [Union [str , RegexColumnDef ]], Dict [str , Any ]]
25+ ColumnsDef = Union [List [Union [str , RegexColumnDef ]], Dict [Union [ str , RegexColumnDef ] , Any ]]
2626DataFrameType = Union [PandasDataFrame , PolarsDataFrame ]
2727
2828
@@ -78,11 +78,38 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
7878
7979 # Handle dictionary of column names/types
8080 elif isinstance (columns , dict ):
81+ # First, process dictionary keys for regex patterns
82+ processed_dict : Dict [Union [str , RegexColumnDef ], Any ] = {}
8183 for column , dtype in columns .items ():
82- if column not in df .columns :
83- missing_columns .append (column )
84- elif df [column ].dtype != dtype :
85- dtype_mismatches .append ((column , df [column ].dtype , dtype ))
84+ if isinstance (column , str ) and column .startswith ("r/" ) and column .endswith ("/" ):
85+ # Pattern is in the format "r/pattern/"
86+ pattern_str = column [2 :- 1 ] # Remove "r/" prefix and "/" suffix
87+ compiled_pattern = re .compile (pattern_str )
88+ processed_dict [(column , compiled_pattern )] = dtype
89+ else :
90+ processed_dict [column ] = dtype
91+
92+ # Check each column against dictionary keys
93+ regex_matched_columns = set ()
94+ for column_key , dtype in processed_dict .items ():
95+ if isinstance (column_key , str ):
96+ # Direct column name match
97+ if column_key not in df .columns :
98+ missing_columns .append (column_key )
99+ elif df [column_key ].dtype != dtype :
100+ dtype_mismatches .append ((column_key , df [column_key ].dtype , dtype ))
101+ elif _is_regex_pattern (column_key ):
102+ # Regex pattern match
103+ pattern_str , compiled_pattern = column_key
104+ matches = _match_column_with_regex (column_key , list (df .columns ))
105+ if not matches :
106+ missing_columns .append (pattern_str ) # Add the original pattern string
107+ else :
108+ for matched_col in matches :
109+ matched_by_regex .add (matched_col )
110+ regex_matched_columns .add (matched_col )
111+ if df [matched_col ].dtype != dtype :
112+ dtype_mismatches .append ((matched_col , df [matched_col ].dtype , dtype ))
86113
87114 if missing_columns :
88115 raise AssertionError (f"Missing columns: { missing_columns } . Got { _describe_pd (df )} " )
@@ -100,7 +127,10 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
100127 allowed_columns = explicit_columns .union (matched_by_regex )
101128 extra_columns = set (df .columns ) - allowed_columns
102129 else :
103- extra_columns = set (df .columns ) - set (columns )
130+ # For dict with regex patterns, we need to handle both direct and regex matches
131+ explicit_columns = {col for col in columns if isinstance (col , str )}
132+ allowed_columns = explicit_columns .union (matched_by_regex )
133+ extra_columns = set (df .columns ) - allowed_columns
104134
105135 if extra_columns :
106136 raise AssertionError (f"DataFrame contained unexpected column(s): { ', ' .join (extra_columns )} " )
@@ -115,7 +145,9 @@ def df_out(
115145
116146 Args:
117147 columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
118- List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
148+ List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
149+ Dict can use regex patterns as keys in format "r/pattern/" to validate dtypes for matching columns.
150+ Defaults to None.
119151 strict (bool, optional): If True, columns must match exactly with no extra columns.
120152 If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
121153
@@ -165,7 +197,9 @@ def df_in(
165197 Args:
166198 name (Optional[str], optional): Name of the parameter that contains a DataFrame. Defaults to None.
167199 columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
168- List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
200+ List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
201+ Dict can use regex patterns as keys in format "r/pattern/" to validate dtypes for matching columns.
202+ Defaults to None.
169203 strict (bool, optional): If True, columns must match exactly with no extra columns.
170204 If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
171205
0 commit comments