Skip to content

Commit ba6bc9a

Browse files
authored
Merge pull request #18 from vertti/support-regex-patters
Support regex patters
2 parents b65b4c3 + 5651936 commit ba6bc9a

File tree

7 files changed

+188
-13
lines changed

7 files changed

+188
-13
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## 0.10.0
6+
7+
- Add support for regex patterns in column name validation
8+
59
## 0.9.4
610

711
- Fix to strict flag loading when tool config was missing

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Like type hints for DataFrames, Daffy helps you catch structural mismatches earl
2222
## Key Features
2323

2424
- Validate DataFrame columns at function entry and exit points
25+
- Support regex patterns for matching column names (e.g., `"r/column_\d+/"`)
2526
- Check data types of columns
2627
- Control strictness of validation (allow or disallow extra columns)
2728
- Works with both Pandas and Polars DataFrames

daffy/decorators.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,73 @@
22

33
import inspect
44
import logging
5+
import re
56
from functools import wraps
6-
from typing import Any, Callable, Dict, List, Optional, Union
7+
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
78

89
import pandas as pd
910
import polars as pl
1011

1112
from daffy.config import get_strict
1213

13-
ColumnsDef = Union[List, Dict]
14+
# New type definition to support regex patterns
15+
RegexColumnDef = Tuple[str, Pattern] # Tuple of (pattern_str, compiled_pattern)
16+
ColumnsDef = Union[List, Dict, List[Union[str, RegexColumnDef]]]
1417
DataFrameType = Union[pd.DataFrame, pl.DataFrame]
1518

1619

20+
def _is_regex_pattern(column: Any) -> bool:
21+
"""Check if the column definition is a regex pattern tuple."""
22+
return (
23+
isinstance(column, tuple) and len(column) == 2 and isinstance(column[0], str) and isinstance(column[1], Pattern)
24+
)
25+
26+
27+
def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[str]) -> List[str]:
28+
"""Find all column names that match the regex pattern."""
29+
_, pattern = column_pattern
30+
return [col for col in df_columns if pattern.match(col)]
31+
32+
33+
def _compile_regex_patterns(columns: List) -> List:
34+
"""Compile regex patterns in the column list."""
35+
result = []
36+
for col in columns:
37+
if isinstance(col, str) and col.startswith("r/") and col.endswith("/"):
38+
# Pattern is in the format "r/pattern/"
39+
pattern_str = col[2:-1] # Remove "r/" prefix and "/" suffix
40+
compiled_pattern = re.compile(pattern_str)
41+
result.append((col, compiled_pattern))
42+
else:
43+
result.append(col)
44+
return result
45+
46+
1747
def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None:
1848
missing_columns = []
1949
dtype_mismatches = []
50+
matched_by_regex = set()
2051

52+
# Handle list of column names/patterns
2153
if isinstance(columns, list):
22-
for column in columns:
23-
if column not in df.columns:
24-
missing_columns.append(column)
25-
if isinstance(columns, dict):
54+
# First, compile any regex patterns
55+
processed_columns = _compile_regex_patterns(columns)
56+
57+
for column in processed_columns:
58+
if isinstance(column, str):
59+
# Direct column name match
60+
if column not in df.columns:
61+
missing_columns.append(column)
62+
elif _is_regex_pattern(column):
63+
# Regex pattern match
64+
matches = _match_column_with_regex(column, list(df.columns))
65+
if not matches:
66+
missing_columns.append(column[0]) # Add the original pattern string
67+
else:
68+
matched_by_regex.update(matches)
69+
70+
# Handle dictionary of column names/types
71+
elif isinstance(columns, dict):
2672
for column, dtype in columns.items():
2773
if column not in df.columns:
2874
missing_columns.append(column)
@@ -39,18 +85,26 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
3985
raise AssertionError(mismatches)
4086

4187
if strict:
42-
extra_columns = set(df.columns) - set(columns)
88+
if isinstance(columns, list):
89+
# For regex matches, we need to consider all matched columns
90+
explicit_columns = {col for col in columns if isinstance(col, str)}
91+
allowed_columns = explicit_columns.union(matched_by_regex)
92+
extra_columns = set(df.columns) - allowed_columns
93+
else:
94+
extra_columns = set(df.columns) - set(columns)
95+
4396
if extra_columns:
4497
raise AssertionError(f"DataFrame contained unexpected column(s): {', '.join(extra_columns)}")
4598

4699

47100
def df_out(columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None) -> Callable:
48-
"""Decorate a function that returns a Pandas DataFrame.
101+
"""Decorate a function that returns a Pandas or Polars DataFrame.
49102
50103
Document the return value of a function. The return value will be validated in runtime.
51104
52105
Args:
53-
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame. Defaults to None.
106+
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
107+
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
54108
strict (bool, optional): If True, columns must match exactly with no extra columns.
55109
If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
56110
@@ -91,13 +145,14 @@ def _get_parameter(func: Callable, name: Optional[str] = None, *args: str, **kwa
91145

92146

93147
def df_in(name: Optional[str] = None, columns: Optional[ColumnsDef] = None, strict: Optional[bool] = None) -> Callable:
94-
"""Decorate a function parameter that is a Pandas DataFrame.
148+
"""Decorate a function parameter that is a Pandas or Polars DataFrame.
95149
96-
Document the contents of an inpute parameter. The parameter will be validated in runtime.
150+
Document the contents of an input parameter. The parameter will be validated in runtime.
97151
98152
Args:
99153
name (Optional[str], optional): Name of the parameter that contains a DataFrame. Defaults to None.
100-
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame. Defaults to None.
154+
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
155+
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
101156
strict (bool, optional): If True, columns must match exactly with no extra columns.
102157
If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
103158

docs/usage.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,31 @@ def filter_cars(car_df):
6363
return filtered_cars_df
6464
```
6565

66+
## Column Pattern Matching with Regex
67+
68+
You can use regex patterns to match column names that follow a specific pattern. This is useful when working with dynamic column names or when dealing with many similar columns.
69+
70+
Define a regex pattern by using the format `"r/pattern/"`:
71+
72+
```python
73+
@df_in(columns=["Brand", "r/Price_\d+/"])
74+
def process_data(df):
75+
# This will accept DataFrames with columns like "Brand", "Price_1", "Price_2", etc.
76+
...
77+
```
78+
79+
In this example:
80+
- The DataFrame must have a column named exactly "Brand"
81+
- The DataFrame must have at least one column matching the pattern "Price_\d+" (e.g., "Price_1", "Price_2", etc.)
82+
83+
If no columns match a regex pattern, an error is raised:
84+
85+
```
86+
AssertionError: Missing columns: ['r/Price_\d+/']. Got columns: ['Brand', 'Model']
87+
```
88+
89+
Regex patterns are also considered in strict mode. Any column matching a regex pattern is considered valid.
90+
6691
## Data Type Validation
6792

6893
If you want to also check the data types of each column, you can replace the column array:
@@ -83,6 +108,8 @@ This will not only check that the specified columns are found from the DataFrame
83108
AssertionError("Column Price has wrong dtype. Was int64, expected float64")
84109
```
85110

111+
> Note: Regex pattern matching is only available for column name lists, not for dictionaries specifying data types.
112+
86113
## Strict Mode
87114

88115
You can enable strict-mode for both `@df_in` and `@df_out`. This will raise an error if the DataFrame contains columns not defined in the annotation:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "daffy"
3-
version = "0.9.4"
3+
version = "0.10.0"
44
description = "Function decorators for Pandas and Polars Dataframe column name and data type validation"
55
authors = [
66
{ name="Janne Sinivirta", email="[email protected]" },

tests/test_df_in.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,50 @@ def test_fn(cars: DataFrameType, ext_cars: DataFrameType) -> int:
208208
return len(cars) + len(ext_cars)
209209

210210
test_fn(basic_df, ext_cars=extended_df)
211+
212+
213+
def test_regex_column_pattern(basic_pandas_df: pd.DataFrame) -> None:
214+
# Create a DataFrame with numbered price columns
215+
df = basic_pandas_df.copy()
216+
df["Price_1"] = df["Price"] * 1
217+
df["Price_2"] = df["Price"] * 2
218+
df["Price_3"] = df["Price"] * 3
219+
220+
@df_in(columns=["Brand", "r/Price_[0-9]/"])
221+
def test_fn(my_input: Any) -> Any:
222+
return my_input
223+
224+
# This should pass since we have Price_1, Price_2, and Price_3 columns
225+
result = test_fn(df)
226+
assert "Price_1" in result.columns
227+
assert "Price_2" in result.columns
228+
assert "Price_3" in result.columns
229+
230+
231+
def test_regex_column_pattern_missing(basic_pandas_df: pd.DataFrame) -> None:
232+
@df_in(columns=["Brand", "r/NonExistent_[0-9]/"])
233+
def test_fn(my_input: Any) -> Any:
234+
return my_input
235+
236+
# This should fail since we don't have any columns matching the pattern
237+
with pytest.raises(AssertionError) as excinfo:
238+
test_fn(basic_pandas_df)
239+
240+
assert "Missing columns: ['r/NonExistent_[0-9]/']" in str(excinfo.value)
241+
242+
243+
def test_regex_column_pattern_with_strict(basic_pandas_df: pd.DataFrame) -> None:
244+
# Create a DataFrame with numbered price columns
245+
df = basic_pandas_df.copy()
246+
df["Price_1"] = df["Price"] * 1
247+
df["Price_2"] = df["Price"] * 2
248+
249+
@df_in(columns=["Brand", "r/Price_[0-9]/"], strict=True)
250+
def test_fn(my_input: Any) -> Any:
251+
return my_input
252+
253+
# This should pass, because "Price" is unexpected but "Price_1" and "Price_2" match the regex
254+
with pytest.raises(AssertionError) as excinfo:
255+
test_fn(df)
256+
257+
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)

tests/test_df_out.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,44 @@ def test_fn(my_input: Any) -> Any:
8787

8888
assert list(basic_pandas_df.columns) == ["Brand", "Price"] # For sanity
8989
pd.testing.assert_frame_equal(extended_pandas_df, test_fn(basic_pandas_df.copy()))
90+
91+
92+
def test_regex_column_pattern_in_output(basic_pandas_df: pd.DataFrame) -> None:
93+
# Create a function that adds numbered price columns
94+
@df_out(columns=["Brand", "r/Price_[0-9]/"])
95+
def test_fn() -> pd.DataFrame:
96+
df = basic_pandas_df.copy()
97+
df["Price_1"] = df["Price"] * 1
98+
df["Price_2"] = df["Price"] * 2
99+
return df
100+
101+
# This should pass since the output has Brand and Price_1, Price_2 columns
102+
result = test_fn()
103+
assert "Price_1" in result.columns
104+
assert "Price_2" in result.columns
105+
106+
107+
def test_regex_column_pattern_missing_in_output(basic_pandas_df: pd.DataFrame) -> None:
108+
@df_out(columns=["Brand", "r/NonExistent_[0-9]/"])
109+
def test_fn() -> pd.DataFrame:
110+
return basic_pandas_df.copy()
111+
112+
# This should fail since the output doesn't have columns matching the pattern
113+
with pytest.raises(AssertionError) as excinfo:
114+
test_fn()
115+
116+
assert "Missing columns: ['r/NonExistent_[0-9]/']" in str(excinfo.value)
117+
118+
119+
def test_regex_column_pattern_with_strict_in_output(basic_pandas_df: pd.DataFrame) -> None:
120+
@df_out(columns=["Brand", "r/Price_[0-9]/"], strict=True)
121+
def test_fn() -> pd.DataFrame:
122+
df = basic_pandas_df.copy()
123+
df["Price_1"] = df["Price"] * 1
124+
return df
125+
126+
# This should raise an error because Price is unexpected
127+
with pytest.raises(AssertionError) as excinfo:
128+
test_fn()
129+
130+
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)

0 commit comments

Comments
 (0)