Skip to content

Commit 17128ee

Browse files
authored
Merge pull request #26 from vertti/improve-coverage
Improve code coverage
2 parents 37d2dac + 026cbe4 commit 17128ee

File tree

3 files changed

+220
-126
lines changed

3 files changed

+220
-126
lines changed

daffy/decorators.py

Lines changed: 96 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
from functools import wraps
7-
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, TypeVar, Union
7+
from typing import Any, Callable, Dict, List, Optional, Pattern, Set, Tuple, TypeVar, Union
88
from typing import Sequence as Seq # Renamed to avoid collision
99

1010
import pandas as pd
@@ -35,113 +35,104 @@ def _is_regex_pattern(column: Any) -> bool:
3535
)
3636

3737

38+
def _assert_is_dataframe(obj: Any, context: str) -> None:
39+
if not isinstance(obj, (pd.DataFrame, pl.DataFrame)):
40+
raise AssertionError(f"Wrong {context}. Expected DataFrame, got {type(obj).__name__} instead.")
41+
42+
43+
def _make_param_info(param_name: Optional[str]) -> str:
44+
return f" in parameter '{param_name}'" if param_name else ""
45+
46+
47+
def _validate_column(
48+
column_spec: Union[str, RegexColumnDef], df: DataFrameType, expected_dtype: Any = None
49+
) -> Tuple[List[str], List[Tuple[str, Any, Any]], Set[str]]:
50+
"""Validate a single column specification against a DataFrame."""
51+
missing_columns = []
52+
dtype_mismatches = []
53+
matched_by_regex = set()
54+
55+
if isinstance(column_spec, str):
56+
if column_spec not in df.columns:
57+
missing_columns.append(column_spec)
58+
elif expected_dtype is not None and df[column_spec].dtype != expected_dtype:
59+
dtype_mismatches.append((column_spec, df[column_spec].dtype, expected_dtype))
60+
elif _is_regex_pattern(column_spec):
61+
pattern_str, _ = column_spec
62+
matches = _match_column_with_regex(column_spec, list(df.columns))
63+
if not matches:
64+
missing_columns.append(pattern_str)
65+
else:
66+
matched_by_regex.update(matches)
67+
if expected_dtype is not None:
68+
for matched_col in matches:
69+
if df[matched_col].dtype != expected_dtype:
70+
dtype_mismatches.append((matched_col, df[matched_col].dtype, expected_dtype))
71+
72+
return missing_columns, dtype_mismatches, matched_by_regex
73+
74+
3875
def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[str]) -> List[str]:
3976
_, pattern = column_pattern
4077
return [col for col in df_columns if pattern.match(col)]
4178

4279

80+
def _compile_regex_pattern(pattern_string: str) -> RegexColumnDef:
81+
pattern_str = pattern_string[2:-1]
82+
compiled_pattern = re.compile(pattern_str)
83+
return (pattern_string, compiled_pattern)
84+
85+
86+
def _is_regex_string(column: str) -> bool:
87+
return column.startswith("r/") and column.endswith("/")
88+
89+
4390
def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef]]:
44-
"""Compile regex patterns in the column list."""
45-
result: List[Union[str, RegexColumnDef]] = []
46-
for col in columns:
47-
if isinstance(col, str) and col.startswith("r/") and col.endswith("/"):
48-
# Pattern is in the format "r/pattern/"
49-
pattern_str = col[2:-1] # Remove "r/" prefix and "/" suffix
50-
compiled_pattern = re.compile(pattern_str)
51-
result.append((col, compiled_pattern))
52-
else:
53-
result.append(col)
54-
return result
91+
return [_compile_regex_pattern(col) if isinstance(col, str) and _is_regex_string(col) else col for col in columns]
5592

5693

5794
def _check_columns(
5895
df: DataFrameType, columns: Union[ColumnsList, ColumnsDict], strict: bool, param_name: Optional[str] = None
5996
) -> None:
60-
missing_columns = []
61-
dtype_mismatches = []
62-
matched_by_regex = set()
97+
all_missing_columns = []
98+
all_dtype_mismatches = []
99+
all_matched_by_regex = set()
63100

64-
# Handle list of column names/patterns
65101
if isinstance(columns, list):
66-
# First, compile any regex patterns
67102
processed_columns = _compile_regex_patterns(columns)
103+
for column_spec in processed_columns:
104+
missing, mismatches, matched = _validate_column(column_spec, df)
105+
all_missing_columns.extend(missing)
106+
all_dtype_mismatches.extend(mismatches)
107+
all_matched_by_regex.update(matched)
108+
else: # isinstance(columns, dict)
109+
assert isinstance(columns, dict)
110+
for column, expected_dtype in columns.items():
111+
column_spec = (
112+
_compile_regex_pattern(column) if isinstance(column, str) and _is_regex_string(column) else column
113+
)
114+
missing, mismatches, matched = _validate_column(column_spec, df, expected_dtype)
115+
all_missing_columns.extend(missing)
116+
all_dtype_mismatches.extend(mismatches)
117+
all_matched_by_regex.update(matched)
118+
119+
param_info = _make_param_info(param_name)
120+
121+
if all_missing_columns:
122+
raise AssertionError(f"Missing columns: {all_missing_columns}{param_info}. Got {_describe_pd(df)}")
68123

69-
for column in processed_columns:
70-
if isinstance(column, str):
71-
# Direct column name match
72-
if column not in df.columns:
73-
missing_columns.append(column)
74-
elif _is_regex_pattern(column):
75-
# Regex pattern match
76-
matches = _match_column_with_regex(column, list(df.columns))
77-
if not matches:
78-
missing_columns.append(column[0]) # Add the original pattern string
79-
else:
80-
matched_by_regex.update(matches)
81-
82-
# Handle dictionary of column names/types
83-
elif isinstance(columns, dict):
84-
# First, process dictionary keys for regex patterns
85-
processed_dict: Dict[Union[str, RegexColumnDef], Any] = {}
86-
for column, dtype in columns.items():
87-
if isinstance(column, str) and column.startswith("r/") and column.endswith("/"):
88-
# Pattern is in the format "r/pattern/"
89-
pattern_str = column[2:-1] # Remove "r/" prefix and "/" suffix
90-
compiled_pattern = re.compile(pattern_str)
91-
processed_dict[(column, compiled_pattern)] = dtype
92-
else:
93-
processed_dict[column] = dtype
94-
95-
# Check each column against dictionary keys
96-
regex_matched_columns = set()
97-
for column_key, dtype in processed_dict.items():
98-
if isinstance(column_key, str):
99-
# Direct column name match
100-
if column_key not in df.columns:
101-
missing_columns.append(column_key)
102-
elif df[column_key].dtype != dtype:
103-
dtype_mismatches.append((column_key, df[column_key].dtype, dtype))
104-
elif _is_regex_pattern(column_key):
105-
# Regex pattern match
106-
pattern_str, compiled_pattern = column_key
107-
matches = _match_column_with_regex(column_key, list(df.columns))
108-
if not matches:
109-
missing_columns.append(pattern_str) # Add the original pattern string
110-
else:
111-
for matched_col in matches:
112-
matched_by_regex.add(matched_col)
113-
regex_matched_columns.add(matched_col)
114-
if df[matched_col].dtype != dtype:
115-
dtype_mismatches.append((matched_col, df[matched_col].dtype, dtype))
116-
117-
if missing_columns:
118-
param_info = f" in parameter '{param_name}'" if param_name else ""
119-
raise AssertionError(f"Missing columns: {missing_columns}{param_info}. Got {_describe_pd(df)}")
120-
121-
if dtype_mismatches:
122-
param_info = f" in parameter '{param_name}'" if param_name else ""
123-
mismatches = ", ".join(
124-
[
125-
f"Column {col}{param_info} has wrong dtype. Was {was}, expected {expected}"
126-
for col, was, expected in dtype_mismatches
127-
]
124+
if all_dtype_mismatches:
125+
mismatch_descriptions = ", ".join(
126+
f"Column {col}{param_info} has wrong dtype. Was {was}, expected {expected}"
127+
for col, was, expected in all_dtype_mismatches
128128
)
129-
raise AssertionError(mismatches)
129+
raise AssertionError(mismatch_descriptions)
130130

131131
if strict:
132-
if isinstance(columns, list):
133-
# For regex matches, we need to consider all matched columns
134-
explicit_columns = {col for col in columns if isinstance(col, str)}
135-
allowed_columns = explicit_columns.union(matched_by_regex)
136-
extra_columns = set(df.columns) - allowed_columns
137-
else:
138-
# For dict with regex patterns, we need to handle both direct and regex matches
139-
explicit_columns = {col for col in columns if isinstance(col, str)}
140-
allowed_columns = explicit_columns.union(matched_by_regex)
141-
extra_columns = set(df.columns) - allowed_columns
142-
132+
explicit_columns = {col for col in columns if isinstance(col, str)}
133+
allowed_columns = explicit_columns.union(all_matched_by_regex)
134+
extra_columns = set(df.columns) - allowed_columns
143135
if extra_columns:
144-
param_info = f" in parameter '{param_name}'" if param_name else ""
145136
raise AssertionError(f"DataFrame{param_info} contained unexpected column(s): {', '.join(extra_columns)}")
146137

147138

@@ -169,9 +160,7 @@ def wrapper_df_out(func: Callable[..., DF]) -> Callable[..., DF]:
169160
@wraps(func)
170161
def wrapper(*args: Any, **kwargs: Any) -> DF:
171162
result = func(*args, **kwargs)
172-
assert isinstance(result, pd.DataFrame) or isinstance(result, pl.DataFrame), (
173-
f"Wrong return type. Expected DataFrame, got {type(result)}"
174-
)
163+
_assert_is_dataframe(result, "return type")
175164
if columns:
176165
_check_columns(result, columns, get_strict(strict))
177166
return result
@@ -183,38 +172,27 @@ def wrapper(*args: Any, **kwargs: Any) -> DF:
183172

184173
def _get_parameter(func: Callable[..., Any], name: Optional[str] = None, *args: Any, **kwargs: Any) -> Any:
185174
if not name:
186-
if len(args) > 0:
187-
return args[0]
188-
if kwargs:
189-
return next(iter(kwargs.values()))
190-
return None
175+
return args[0] if args else next(iter(kwargs.values()), None)
191176

192-
if name and (name not in kwargs):
193-
func_params_in_order = list(inspect.signature(func).parameters.keys())
194-
parameter_location = func_params_in_order.index(name)
195-
return args[parameter_location]
177+
if name in kwargs:
178+
return kwargs[name]
196179

197-
return kwargs[name]
180+
func_params_in_order = list(inspect.signature(func).parameters.keys())
181+
parameter_location = func_params_in_order.index(name)
182+
return args[parameter_location]
198183

199184

200185
def _get_parameter_name(
201186
func: Callable[..., Any], name: Optional[str] = None, *args: Any, **kwargs: Any
202187
) -> Optional[str]:
203-
"""Get the actual parameter name being validated."""
204188
if name:
205189
return name
206190

207-
# If no name specified, try to get the first parameter name
208-
if len(args) > 0:
209-
# Get the first parameter name from the function signature
191+
if args:
210192
func_params_in_order = list(inspect.signature(func).parameters.keys())
211-
if func_params_in_order:
212-
return func_params_in_order[0]
213-
elif kwargs:
214-
# Return the first keyword argument name
215-
return next(iter(kwargs.keys()))
193+
return func_params_in_order[0]
216194

217-
return None
195+
return next(iter(kwargs.keys()), None)
218196

219197

220198
def df_in(
@@ -243,9 +221,7 @@ def wrapper_df_in(func: Callable[..., R]) -> Callable[..., R]:
243221
def wrapper(*args: Any, **kwargs: Any) -> R:
244222
df = _get_parameter(func, name, *args, **kwargs)
245223
param_name = _get_parameter_name(func, name, *args, **kwargs)
246-
assert isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame), (
247-
f"Wrong parameter type. Expected DataFrame, got {type(df).__name__} instead."
248-
)
224+
_assert_is_dataframe(df, "parameter type")
249225
if columns:
250226
_check_columns(df, columns, get_strict(strict), param_name)
251227
return func(*args, **kwargs)
@@ -261,25 +237,19 @@ def _describe_pd(df: DataFrameType, include_dtypes: bool = False) -> str:
261237
if isinstance(df, pd.DataFrame):
262238
readable_dtypes = [dtype.name for dtype in df.dtypes]
263239
result += f" with dtypes {readable_dtypes}"
264-
if isinstance(df, pl.DataFrame):
240+
else:
265241
result += f" with dtypes {df.dtypes}"
266242
return result
267243

268244

269245
def _log_input(level: int, func_name: str, df: Any, include_dtypes: bool) -> None:
270-
if isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame):
271-
logging.log(
272-
level,
273-
f"Function {func_name} parameters contained a DataFrame: {_describe_pd(df, include_dtypes)}",
274-
)
246+
if isinstance(df, (pd.DataFrame, pl.DataFrame)):
247+
logging.log(level, f"Function {func_name} parameters contained a DataFrame: {_describe_pd(df, include_dtypes)}")
275248

276249

277250
def _log_output(level: int, func_name: str, df: Any, include_dtypes: bool) -> None:
278-
if isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame):
279-
logging.log(
280-
level,
281-
f"Function {func_name} returned a DataFrame: {_describe_pd(df, include_dtypes)}",
282-
)
251+
if isinstance(df, (pd.DataFrame, pl.DataFrame)):
252+
logging.log(level, f"Function {func_name} returned a DataFrame: {_describe_pd(df, include_dtypes)}")
283253

284254

285255
def df_log(level: int = logging.DEBUG, include_dtypes: bool = False) -> Callable[[Callable[..., T]], Callable[..., T]]:

tests/test_config.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,66 @@ def test_config_from_pyproject() -> None:
5151

5252
config = load_config()
5353
assert config["strict"] is True
54+
55+
56+
def test_load_config_returns_default_when_file_not_found() -> None:
57+
with patch("daffy.config.find_config_file", return_value="/nonexistent/pyproject.toml"):
58+
config = get_config()
59+
assert config["strict"] is False
60+
61+
62+
def test_load_config_returns_default_when_toml_malformed() -> None:
63+
with tempfile.TemporaryDirectory() as tmpdir:
64+
with open(os.path.join(tmpdir, "pyproject.toml"), "w") as f:
65+
f.write("invalid toml [[[")
66+
67+
with patch("daffy.config.os.getcwd", return_value=tmpdir):
68+
import daffy.config
69+
70+
daffy.config._config_cache = None
71+
72+
config = get_config()
73+
assert config["strict"] is False
74+
75+
76+
def test_find_config_file_returns_none_when_no_pyproject_exists() -> None:
77+
with tempfile.TemporaryDirectory() as tmpdir:
78+
with patch("daffy.config.os.getcwd", return_value=tmpdir):
79+
from daffy.config import find_config_file
80+
81+
result = find_config_file()
82+
assert result is None
83+
84+
85+
def test_load_config_without_strict_setting() -> None:
86+
with tempfile.TemporaryDirectory() as tmpdir:
87+
with open(os.path.join(tmpdir, "pyproject.toml"), "w") as f:
88+
f.write("""
89+
[tool.daffy]
90+
other_setting = "value"
91+
""")
92+
93+
with patch("daffy.config.os.getcwd", return_value=tmpdir):
94+
import daffy.config
95+
96+
daffy.config._config_cache = None
97+
98+
config = get_config()
99+
assert config["strict"] is False
100+
101+
102+
def test_load_config_daffy_section_without_strict() -> None:
103+
with tempfile.TemporaryDirectory() as tmpdir:
104+
with open(os.path.join(tmpdir, "pyproject.toml"), "w") as f:
105+
f.write("""
106+
[tool.daffy]
107+
some_other_setting = "value"
108+
""")
109+
110+
with patch("daffy.config.os.getcwd", return_value=tmpdir):
111+
import daffy.config
112+
113+
daffy.config._config_cache = None
114+
115+
config = get_config()
116+
assert config["strict"] is False

0 commit comments

Comments
 (0)