Skip to content

Commit 39db742

Browse files
authored
Merge pull request #21 from vertti/regex-with-dtypes
Regex with dtypes
2 parents 18f07c4 + c92fa04 commit 39db742

File tree

6 files changed

+175
-10
lines changed

6 files changed

+175
-10
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## 0.12.0
6+
7+
- Add support for regex patterns used with column dtype validation
8+
59
## 0.11.0
610

711
- Update function parameter types for better type safety

daffy/decorators.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# Improved type definitions to support regex patterns
2424
RegexColumnDef = Tuple[str, Pattern[str]] # Tuple of (pattern_str, compiled_pattern)
25-
ColumnsDef = Union[List[Union[str, RegexColumnDef]], Dict[str, Any]]
25+
ColumnsDef = Union[List[Union[str, RegexColumnDef]], Dict[Union[str, RegexColumnDef], Any]]
2626
DataFrameType = Union[PandasDataFrame, PolarsDataFrame]
2727

2828

@@ -78,11 +78,38 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
7878

7979
# Handle dictionary of column names/types
8080
elif isinstance(columns, dict):
81+
# First, process dictionary keys for regex patterns
82+
processed_dict: Dict[Union[str, RegexColumnDef], Any] = {}
8183
for column, dtype in columns.items():
82-
if column not in df.columns:
83-
missing_columns.append(column)
84-
elif df[column].dtype != dtype:
85-
dtype_mismatches.append((column, df[column].dtype, dtype))
84+
if isinstance(column, str) and column.startswith("r/") and column.endswith("/"):
85+
# Pattern is in the format "r/pattern/"
86+
pattern_str = column[2:-1] # Remove "r/" prefix and "/" suffix
87+
compiled_pattern = re.compile(pattern_str)
88+
processed_dict[(column, compiled_pattern)] = dtype
89+
else:
90+
processed_dict[column] = dtype
91+
92+
# Check each column against dictionary keys
93+
regex_matched_columns = set()
94+
for column_key, dtype in processed_dict.items():
95+
if isinstance(column_key, str):
96+
# Direct column name match
97+
if column_key not in df.columns:
98+
missing_columns.append(column_key)
99+
elif df[column_key].dtype != dtype:
100+
dtype_mismatches.append((column_key, df[column_key].dtype, dtype))
101+
elif _is_regex_pattern(column_key):
102+
# Regex pattern match
103+
pattern_str, compiled_pattern = column_key
104+
matches = _match_column_with_regex(column_key, list(df.columns))
105+
if not matches:
106+
missing_columns.append(pattern_str) # Add the original pattern string
107+
else:
108+
for matched_col in matches:
109+
matched_by_regex.add(matched_col)
110+
regex_matched_columns.add(matched_col)
111+
if df[matched_col].dtype != dtype:
112+
dtype_mismatches.append((matched_col, df[matched_col].dtype, dtype))
86113

87114
if missing_columns:
88115
raise AssertionError(f"Missing columns: {missing_columns}. Got {_describe_pd(df)}")
@@ -100,7 +127,10 @@ def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None
100127
allowed_columns = explicit_columns.union(matched_by_regex)
101128
extra_columns = set(df.columns) - allowed_columns
102129
else:
103-
extra_columns = set(df.columns) - set(columns)
130+
# For dict with regex patterns, we need to handle both direct and regex matches
131+
explicit_columns = {col for col in columns if isinstance(col, str)}
132+
allowed_columns = explicit_columns.union(matched_by_regex)
133+
extra_columns = set(df.columns) - allowed_columns
104134

105135
if extra_columns:
106136
raise AssertionError(f"DataFrame contained unexpected column(s): {', '.join(extra_columns)}")
@@ -115,7 +145,9 @@ def df_out(
115145
116146
Args:
117147
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
118-
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
148+
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
149+
Dict can use regex patterns as keys in format "r/pattern/" to validate dtypes for matching columns.
150+
Defaults to None.
119151
strict (bool, optional): If True, columns must match exactly with no extra columns.
120152
If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
121153
@@ -165,7 +197,9 @@ def df_in(
165197
Args:
166198
name (Optional[str], optional): Name of the parameter that contains a DataFrame. Defaults to None.
167199
columns (ColumnsDef, optional): List or dict that describes expected columns of the DataFrame.
168-
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/"). Defaults to None.
200+
List can contain regex patterns in format "r/pattern/" (e.g., "r/Col[0-9]+/").
201+
Dict can use regex patterns as keys in format "r/pattern/" to validate dtypes for matching columns.
202+
Defaults to None.
169203
strict (bool, optional): If True, columns must match exactly with no extra columns.
170204
If None, uses the value from [tool.daffy] strict setting in pyproject.toml.
171205

docs/usage.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,26 @@ This will not only check that the specified columns are found from the DataFrame
108108
AssertionError("Column Price has wrong dtype. Was int64, expected float64")
109109
```
110110

111-
> Note: Regex pattern matching is only available for column name lists, not for dictionaries specifying data types.
111+
### Combining Regex Patterns with Data Type Validation
112+
113+
You can use regex patterns in dictionaries that specify data types as well:
114+
115+
```python
116+
@df_in(columns={"Brand": "object", "r/Price_\d+/": "int64"})
117+
def process_data(df):
118+
# This will check that all columns matching "Price_\d+" have int64 dtype
119+
...
120+
```
121+
122+
In this example:
123+
- The DataFrame must have a column named exactly "Brand" with dtype "object"
124+
- Any columns matching the pattern "Price_\d+" (e.g., "Price_1", "Price_2") must have dtype "int64"
125+
126+
If a column matches the regex pattern but has the wrong dtype, an error is raised:
127+
128+
```
129+
AssertionError: Column Price_2 has wrong dtype. Was float64, expected int64
130+
```
112131

113132
## Strict Mode
114133

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "daffy"
3-
version = "0.11.0"
3+
version = "0.12.0"
44
description = "Function decorators for Pandas and Polars Dataframe column name and data type validation"
55
authors = [
66
{ name="Janne Sinivirta", email="[email protected]" },

tests/test_df_in.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,51 @@ def test_fn(my_input: Any) -> Any:
255255
test_fn(df)
256256

257257
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)
258+
259+
260+
def test_regex_column_with_dtype_pandas(basic_pandas_df: pd.DataFrame) -> None:
261+
# Create a DataFrame with numbered price columns
262+
df = basic_pandas_df.copy()
263+
df["Price_1"] = df["Price"] * 1
264+
df["Price_2"] = df["Price"] * 2
265+
266+
@df_in(columns={"Brand": "object", "r/Price_[0-9]/": "int64"})
267+
def test_fn(my_input: Any) -> Any:
268+
return my_input
269+
270+
# This should pass since Price_1 and Price_2 are int64
271+
result = test_fn(df)
272+
assert "Price_1" in result.columns
273+
assert "Price_2" in result.columns
274+
275+
276+
def test_regex_column_with_dtype_mismatch_pandas(basic_pandas_df: pd.DataFrame) -> None:
277+
# Create a DataFrame with numbered price columns
278+
df = basic_pandas_df.copy()
279+
df["Price_1"] = df["Price"] * 1
280+
df["Price_2"] = df["Price"] * 2.0 # Make this a float
281+
282+
@df_in(columns={"Brand": "object", "r/Price_[0-9]/": "int64"})
283+
def test_fn(my_input: Any) -> Any:
284+
return my_input
285+
286+
# This should fail since Price_2 is float64, not int64
287+
with pytest.raises(AssertionError) as excinfo:
288+
test_fn(df)
289+
290+
assert "Column Price_2 has wrong dtype. Was float64, expected int64" in str(excinfo.value)
291+
292+
293+
def test_regex_column_with_dtype_polars(basic_polars_df: pl.DataFrame) -> None:
294+
# Create a DataFrame with numbered price columns
295+
# Polars DataFrames are immutable, so we don't need to copy
296+
df = basic_polars_df.with_columns([pl.col("Price").alias("Price_1"), pl.col("Price").alias("Price_2")])
297+
298+
@df_in(columns={"Brand": pl.datatypes.String, "r/Price_[0-9]/": pl.datatypes.Int64})
299+
def test_fn(my_input: Any) -> Any:
300+
return my_input
301+
302+
# This should pass since Price_1 and Price_2 are Int64
303+
result = test_fn(df)
304+
assert "Price_1" in result.columns
305+
assert "Price_2" in result.columns

tests/test_df_out.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,63 @@ def test_fn() -> pd.DataFrame:
131131
test_fn()
132132

133133
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)
134+
135+
136+
def test_regex_column_with_dtype_in_output_pandas(basic_pandas_df: pd.DataFrame) -> None:
137+
# Create a function that adds numbered price columns
138+
@df_out(columns={"Brand": "object", "r/Price_[0-9]/": "int64"})
139+
def test_fn() -> pd.DataFrame:
140+
df = basic_pandas_df.copy()
141+
df["Price_1"] = df["Price"] * 1
142+
df["Price_2"] = df["Price"] * 2
143+
return df
144+
145+
# This should pass since Price_1 and Price_2 are int64
146+
result = test_fn()
147+
assert "Price_1" in result.columns
148+
assert "Price_2" in result.columns
149+
150+
151+
def test_regex_column_with_dtype_mismatch_in_output_pandas(basic_pandas_df: pd.DataFrame) -> None:
152+
# Create a function that adds numbered price columns with one wrong dtype
153+
@df_out(columns={"Brand": "object", "r/Price_[0-9]/": "int64"})
154+
def test_fn() -> pd.DataFrame:
155+
df = basic_pandas_df.copy()
156+
df["Price_1"] = df["Price"] * 1
157+
df["Price_2"] = df["Price"] * 2.0 # Make this a float
158+
return df
159+
160+
# This should fail since Price_2 is float64, not int64
161+
with pytest.raises(AssertionError) as excinfo:
162+
test_fn()
163+
164+
assert "Column Price_2 has wrong dtype. Was float64, expected int64" in str(excinfo.value)
165+
166+
167+
def test_regex_column_with_dtype_in_output_polars(basic_polars_df: pl.DataFrame) -> None:
168+
# Create a function that adds numbered price columns
169+
@df_out(columns={"Brand": pl.datatypes.String, "r/Price_[0-9]/": pl.datatypes.Int64})
170+
def test_fn() -> pl.DataFrame:
171+
# Polars DataFrames are immutable, so we build a new one
172+
return basic_polars_df.with_columns([pl.col("Price").alias("Price_1"), pl.col("Price").alias("Price_2")])
173+
174+
# This should pass since Price_1 and Price_2 are Int64
175+
result = test_fn()
176+
assert "Price_1" in result.columns
177+
assert "Price_2" in result.columns
178+
179+
180+
def test_regex_column_with_dtype_strict_in_output_pandas(basic_pandas_df: pd.DataFrame) -> None:
181+
# Create a function that adds numbered price columns
182+
@df_out(columns={"Brand": "object", "r/Price_[0-9]/": "int64"}, strict=True)
183+
def test_fn() -> pd.DataFrame:
184+
df = basic_pandas_df.copy()
185+
df["Price_1"] = df["Price"] * 1
186+
df["Price_2"] = df["Price"] * 2
187+
return df
188+
189+
# This should fail because Price is unexpected
190+
with pytest.raises(AssertionError) as excinfo:
191+
test_fn()
192+
193+
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)

0 commit comments

Comments
 (0)