22
33import inspect
44import logging
5+ import re
56from functools import wraps
6- from typing import Any , Callable , Dict , List , Optional , Union
7+ from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , Union
78
89import pandas as pd
910import polars as pl
1011
1112from daffy .config import get_strict
1213
13- ColumnsDef = Union [List , Dict ]
14+ # New type definition to support regex patterns
15+ RegexColumnDef = Tuple [str , Pattern ] # Tuple of (pattern_str, compiled_pattern)
16+ ColumnsDef = Union [List , Dict , List [Union [str , RegexColumnDef ]]]
1417DataFrameType = Union [pd .DataFrame , pl .DataFrame ]
1518
1619
20+ def _is_regex_pattern (column : Any ) -> bool :
21+ """Check if the column definition is a regex pattern tuple."""
22+ return (
23+ isinstance (column , tuple ) and len (column ) == 2 and isinstance (column [0 ], str ) and isinstance (column [1 ], Pattern )
24+ )
25+
26+
27+ def _match_column_with_regex (column_pattern : RegexColumnDef , df_columns : List [str ]) -> List [str ]:
28+ """Find all column names that match the regex pattern."""
29+ _ , pattern = column_pattern
30+ return [col for col in df_columns if pattern .match (col )]
31+
32+
33+ def _compile_regex_patterns (columns : List ) -> List :
34+ """Compile regex patterns in the column list."""
35+ result = []
36+ for col in columns :
37+ if isinstance (col , str ) and col .startswith ("r/" ) and col .endswith ("/" ):
38+ # Pattern is in the format "r/pattern/"
39+ pattern_str = col [2 :- 1 ] # Remove "r/" prefix and "/" suffix
40+ compiled_pattern = re .compile (pattern_str )
41+ result .append ((col , compiled_pattern ))
42+ else :
43+ result .append (col )
44+ return result
45+
46+
1747def _check_columns (df : DataFrameType , columns : ColumnsDef , strict : bool ) -> None :
1848 missing_columns = []
1949 dtype_mismatches = []
50+ matched_by_regex = set ()
2051
52+ # Handle list of column names/patterns
2153 if isinstance (columns , list ):
22- for column in columns :
23- if column not in df .columns :
24- missing_columns .append (column )
25- if isinstance (columns , dict ):
54+ # First, compile any regex patterns
55+ processed_columns = _compile_regex_patterns (columns )
56+
57+ for column in processed_columns :
58+ if isinstance (column , str ):
59+ # Direct column name match
60+ if column not in df .columns :
61+ missing_columns .append (column )
62+ elif _is_regex_pattern (column ):
63+ # Regex pattern match
64+ matches = _match_column_with_regex (column , list (df .columns ))
65+ if not matches :
66+ missing_columns .append (column [0 ]) # Add the original pattern string
67+ else :
68+ matched_by_regex .update (matches )
69+
70+ # Handle dictionary of column names/types
71+ elif isinstance (columns , dict ):
2672 for column , dtype in columns .items ():
2773 if column not in df .columns :
2874 missing_columns .append (column )
@@ -39,18 +85,26 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
3985 raise AssertionError (mismatches )
4086
4187 if strict :
42- extra_columns = set (df .columns ) - set (columns )
88+ if isinstance (columns , list ):
89+ # For regex matches, we need to consider all matched columns
90+ explicit_columns = {col for col in columns if isinstance (col , str )}
91+ allowed_columns = explicit_columns .union (matched_by_regex )
92+ extra_columns = set (df .columns ) - allowed_columns
93+ else :
94+ extra_columns = set (df .columns ) - set (columns )
95+
4396 if extra_columns :
4497 raise AssertionError (f"DataFrame contained unexpected column(s): { ', ' .join (extra_columns )} " )
4598
4699
47100def df_out (columns : Optional [ColumnsDef ] = None , strict : Optional [bool ] = None ) -> Callable :
48- """Decorate a function that returns a Pandas DataFrame.
101+ """Decorate a function that returns a Pandas or Polars DataFrame.
49102
50103 Document the return value of a function. The return value will be validated in runtime.
51104
52105 Args:
53- columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame. Defaults to None.
106+ columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
107+ List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
54108 strict (bool, optional): If True, columns must match exactly with no extra columns.
55109 If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
56110
@@ -91,13 +145,14 @@ def _get_parameter(func: Callable, name: Optional[str] = None, *args: str, **kwa
91145
92146
93147def df_in (name : Optional [str ] = None , columns : Optional [ColumnsDef ] = None , strict : Optional [bool ] = None ) -> Callable :
94- """Decorate a function parameter that is a Pandas DataFrame.
148+ """Decorate a function parameter that is a Pandas or Polars DataFrame.
95149
96- Document the contents of an inpute parameter. The parameter will be validated in runtime.
150+ Document the contents of an input parameter. The parameter will be validated in runtime.
97151
98152 Args:
99153 name (Optional[str], optional): Name of the parameter that contains a DataFrame. Defaults to None.
100- columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame. Defaults to None.
154+ columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
155+ List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
101156 strict (bool, optional): If True, columns must match exactly with no extra columns.
102157 If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
103158
0 commit comments