Skip to content

Commit 37d2dac

Browse files
authored
Merge pull request #25 from vertti/parameter-names-logging
Parameter names logging
2 parents 39d3232 + 6b734cc commit 37d2dac

File tree

4 files changed

+65
-13
lines changed

4 files changed

+65
-13
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## 0.14.0
6+
7+
- Improve df_in error messages to include parameter names
8+
59
## 0.13.2
610

711
- Updated urls for Pypi site compatibility

daffy/decorators.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef
5454
return result
5555

5656

57-
def _check_columns(df: DataFrameType, columns: Union[ColumnsList, ColumnsDict], strict: bool) -> None:
57+
def _check_columns(
58+
df: DataFrameType, columns: Union[ColumnsList, ColumnsDict], strict: bool, param_name: Optional[str] = None
59+
) -> None:
5860
missing_columns = []
5961
dtype_mismatches = []
6062
matched_by_regex = set()
@@ -113,11 +115,16 @@ def _check_columns(df: DataFrameType, columns: Union[ColumnsList, ColumnsDict],
113115
dtype_mismatches.append((matched_col, df[matched_col].dtype, dtype))
114116

115117
if missing_columns:
116-
raise AssertionError(f"Missing columns: {missing_columns}. Got {_describe_pd(df)}")
118+
param_info = f" in parameter '{param_name}'" if param_name else ""
119+
raise AssertionError(f"Missing columns: {missing_columns}{param_info}. Got {_describe_pd(df)}")
117120

118121
if dtype_mismatches:
122+
param_info = f" in parameter '{param_name}'" if param_name else ""
119123
mismatches = ", ".join(
120-
[f"Column {col} has wrong dtype. Was {was}, expected {expected}" for col, was, expected in dtype_mismatches]
124+
[
125+
f"Column {col}{param_info} has wrong dtype. Was {was}, expected {expected}"
126+
for col, was, expected in dtype_mismatches
127+
]
121128
)
122129
raise AssertionError(mismatches)
123130

@@ -134,7 +141,8 @@ def _check_columns(df: DataFrameType, columns: Union[ColumnsList, ColumnsDict],
134141
extra_columns = set(df.columns) - allowed_columns
135142

136143
if extra_columns:
137-
raise AssertionError(f"DataFrame contained unexpected column(s): {', '.join(extra_columns)}")
144+
param_info = f" in parameter '{param_name}'" if param_name else ""
145+
raise AssertionError(f"DataFrame{param_info} contained unexpected column(s): {', '.join(extra_columns)}")
138146

139147

140148
def df_out(
@@ -189,6 +197,26 @@ def _get_parameter(func: Callable[..., Any], name: Optional[str] = None, *args:
189197
return kwargs[name]
190198

191199

200+
def _get_parameter_name(
201+
func: Callable[..., Any], name: Optional[str] = None, *args: Any, **kwargs: Any
202+
) -> Optional[str]:
203+
"""Get the actual parameter name being validated."""
204+
if name:
205+
return name
206+
207+
# If no name specified, try to get the first parameter name
208+
if len(args) > 0:
209+
# Get the first parameter name from the function signature
210+
func_params_in_order = list(inspect.signature(func).parameters.keys())
211+
if func_params_in_order:
212+
return func_params_in_order[0]
213+
elif kwargs:
214+
# Return the first keyword argument name
215+
return next(iter(kwargs.keys()))
216+
217+
return None
218+
219+
192220
def df_in(
193221
name: Optional[str] = None, columns: Union[ColumnsList, ColumnsDict, None] = None, strict: Optional[bool] = None
194222
) -> Callable[[Callable[..., R]], Callable[..., R]]:
@@ -214,11 +242,12 @@ def wrapper_df_in(func: Callable[..., R]) -> Callable[..., R]:
214242
@wraps(func)
215243
def wrapper(*args: Any, **kwargs: Any) -> R:
216244
df = _get_parameter(func, name, *args, **kwargs)
245+
param_name = _get_parameter_name(func, name, *args, **kwargs)
217246
assert isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame), (
218247
f"Wrong parameter type. Expected DataFrame, got {type(df).__name__} instead."
219248
)
220249
if columns:
221-
_check_columns(df, columns, get_strict(strict))
250+
_check_columns(df, columns, get_strict(strict), param_name)
222251
return func(*args, **kwargs)
223252

224253
return wrapper

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "daffy"
3-
version = "0.13.2"
3+
version = "0.14.0"
44
description = "Function decorators for Pandas and Polars Dataframe column name and data type validation"
55
authors = [
66
{ name="Janne Sinivirta", email="[email protected]" },

tests/test_df_in.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_fn(my_input: Any, _df: DataFrameType) -> DataFrameType:
104104
with pytest.raises(AssertionError) as excinfo:
105105
test_fn("foo", _df=df)
106106

107-
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)
107+
assert "DataFrame in parameter '_df' contained unexpected column(s): Price" in str(excinfo.value)
108108

109109

110110
def test_correct_input_with_columns_and_dtypes_pandas(basic_pandas_df: pd.DataFrame) -> None:
@@ -131,7 +131,7 @@ def test_fn(my_input: Any) -> Any:
131131
with pytest.raises(AssertionError) as excinfo:
132132
test_fn(basic_pandas_df)
133133

134-
assert "Column Price has wrong dtype. Was int64, expected float64" in str(excinfo.value)
134+
assert "Column Price in parameter 'my_input' has wrong dtype. Was int64, expected float64" in str(excinfo.value)
135135

136136

137137
def test_dtype_mismatch_polars(basic_polars_df: pl.DataFrame) -> None:
@@ -142,7 +142,7 @@ def test_fn(my_input: Any) -> Any:
142142
with pytest.raises(AssertionError) as excinfo:
143143
test_fn(basic_polars_df)
144144

145-
assert "Column Price has wrong dtype. Was Int64, expected Float64" in str(excinfo.value)
145+
assert "Column Price in parameter 'my_input' has wrong dtype. Was Int64, expected Float64" in str(excinfo.value)
146146

147147

148148
@pytest.mark.parametrize(("df"), [pd.DataFrame(cars), pl.DataFrame(cars)])
@@ -153,7 +153,7 @@ def test_fn(my_input: Any) -> Any:
153153

154154
with pytest.raises(AssertionError) as excinfo:
155155
test_fn(df[["Brand"]])
156-
assert "Missing columns: ['Price']. Got columns: ['Brand']" in str(excinfo.value)
156+
assert "Missing columns: ['Price'] in parameter 'my_input'. Got columns: ['Brand']" in str(excinfo.value)
157157

158158

159159
@pytest.mark.parametrize(("df"), [pd.DataFrame(cars), pl.DataFrame(cars)])
@@ -164,7 +164,7 @@ def test_fn(my_input: Any) -> Any:
164164

165165
with pytest.raises(AssertionError) as excinfo:
166166
test_fn(df[["Brand"]])
167-
assert "Missing columns: ['Price', 'Extra']. Got columns: ['Brand']" in str(excinfo.value)
167+
assert "Missing columns: ['Price', 'Extra'] in parameter 'my_input'. Got columns: ['Brand']" in str(excinfo.value)
168168

169169

170170
@pytest.mark.parametrize(
@@ -254,7 +254,7 @@ def test_fn(my_input: Any) -> Any:
254254
with pytest.raises(AssertionError) as excinfo:
255255
test_fn(df)
256256

257-
assert "DataFrame contained unexpected column(s): Price" in str(excinfo.value)
257+
assert "DataFrame in parameter 'my_input' contained unexpected column(s): Price" in str(excinfo.value)
258258

259259

260260
def test_regex_column_with_dtype_pandas(basic_pandas_df: pd.DataFrame) -> None:
@@ -287,7 +287,7 @@ def test_fn(my_input: Any) -> Any:
287287
with pytest.raises(AssertionError) as excinfo:
288288
test_fn(df)
289289

290-
assert "Column Price_2 has wrong dtype. Was float64, expected int64" in str(excinfo.value)
290+
assert "Column Price_2 in parameter 'my_input' has wrong dtype. Was float64, expected int64" in str(excinfo.value)
291291

292292

293293
def test_regex_column_with_dtype_polars(basic_polars_df: pl.DataFrame) -> None:
@@ -303,3 +303,22 @@ def test_fn(my_input: Any) -> Any:
303303
result = test_fn(df)
304304
assert "Price_1" in result.columns
305305
assert "Price_2" in result.columns
306+
307+
308+
@pytest.mark.parametrize(
309+
("basic_df,extended_df"),
310+
[(pd.DataFrame(cars), pd.DataFrame(extended_cars)), (pl.DataFrame(cars), pl.DataFrame(extended_cars))],
311+
)
312+
def test_multiple_parameters_error_identification(basic_df: DataFrameType, extended_df: DataFrameType) -> None:
313+
"""Test that we can identify which parameter has the issue when multiple dataframes are used."""
314+
315+
@df_in(name="cars", columns=["Brand", "Price"], strict=True)
316+
@df_in(name="ext_cars", columns=["Brand", "Price", "Year", "NonExistent"], strict=True)
317+
def test_fn(cars: DataFrameType, ext_cars: DataFrameType) -> int:
318+
return len(cars) + len(ext_cars)
319+
320+
# Test missing column in second parameter
321+
with pytest.raises(AssertionError) as excinfo:
322+
test_fn(cars=basic_df, ext_cars=extended_df)
323+
324+
assert "Missing columns: ['NonExistent'] in parameter 'ext_cars'" in str(excinfo.value)

0 commit comments

Comments
 (0)