Skip to content

Commit 026cbe4

Browse files
committed
Apply pythonic refactoring to improve code readability and maintainability
1 parent ecf8095 commit 026cbe4

File tree

2 files changed

+78
-98
lines changed

2 files changed

+78
-98
lines changed

daffy/decorators.py

Lines changed: 73 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
from functools import wraps
7-
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, TypeVar, Union
7+
from typing import Any, Callable, Dict, List, Optional, Pattern, Set, Tuple, TypeVar, Union
88
from typing import Sequence as Seq # Renamed to avoid collision
99

1010
import pandas as pd
@@ -40,6 +40,38 @@ def _assert_is_dataframe(obj: Any, context: str) -> None:
4040
raise AssertionError(f"Wrong {context}. Expected DataFrame, got {type(obj).__name__} instead.")
4141

4242

43+
def _make_param_info(param_name: Optional[str]) -> str:
44+
return f" in parameter '{param_name}'" if param_name else ""
45+
46+
47+
def _validate_column(
48+
column_spec: Union[str, RegexColumnDef], df: DataFrameType, expected_dtype: Any = None
49+
) -> Tuple[List[str], List[Tuple[str, Any, Any]], Set[str]]:
50+
"""Validate a single column specification against a DataFrame."""
51+
missing_columns = []
52+
dtype_mismatches = []
53+
matched_by_regex = set()
54+
55+
if isinstance(column_spec, str):
56+
if column_spec not in df.columns:
57+
missing_columns.append(column_spec)
58+
elif expected_dtype is not None and df[column_spec].dtype != expected_dtype:
59+
dtype_mismatches.append((column_spec, df[column_spec].dtype, expected_dtype))
60+
elif _is_regex_pattern(column_spec):
61+
pattern_str, _ = column_spec
62+
matches = _match_column_with_regex(column_spec, list(df.columns))
63+
if not matches:
64+
missing_columns.append(pattern_str)
65+
else:
66+
matched_by_regex.update(matches)
67+
if expected_dtype is not None:
68+
for matched_col in matches:
69+
if df[matched_col].dtype != expected_dtype:
70+
dtype_mismatches.append((matched_col, df[matched_col].dtype, expected_dtype))
71+
72+
return missing_columns, dtype_mismatches, matched_by_regex
73+
74+
4375
def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[str]) -> List[str]:
4476
_, pattern = column_pattern
4577
return [col for col in df_columns if pattern.match(col)]
@@ -62,70 +94,45 @@ def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef
6294
def _check_columns(
6395
df: DataFrameType, columns: Union[ColumnsList, ColumnsDict], strict: bool, param_name: Optional[str] = None
6496
) -> None:
65-
missing_columns = []
66-
dtype_mismatches = []
67-
matched_by_regex = set()
97+
all_missing_columns = []
98+
all_dtype_mismatches = []
99+
all_matched_by_regex = set()
68100

69101
if isinstance(columns, list):
70102
processed_columns = _compile_regex_patterns(columns)
71-
for column in processed_columns:
72-
if isinstance(column, str):
73-
if column not in df.columns:
74-
missing_columns.append(column)
75-
elif _is_regex_pattern(column):
76-
matches = _match_column_with_regex(column, list(df.columns))
77-
if not matches:
78-
missing_columns.append(column[0])
79-
else:
80-
matched_by_regex.update(matches)
81-
103+
for column_spec in processed_columns:
104+
missing, mismatches, matched = _validate_column(column_spec, df)
105+
all_missing_columns.extend(missing)
106+
all_dtype_mismatches.extend(mismatches)
107+
all_matched_by_regex.update(matched)
82108
else: # isinstance(columns, dict)
83109
assert isinstance(columns, dict)
84-
processed_dict: Dict[Union[str, RegexColumnDef], Any] = {}
85-
for column, dtype in columns.items():
86-
if isinstance(column, str) and _is_regex_string(column):
87-
processed_dict[_compile_regex_pattern(column)] = dtype
88-
else:
89-
processed_dict[column] = dtype
90-
91-
for column_key, dtype in processed_dict.items():
92-
if isinstance(column_key, str):
93-
if column_key not in df.columns:
94-
missing_columns.append(column_key)
95-
elif df[column_key].dtype != dtype:
96-
dtype_mismatches.append((column_key, df[column_key].dtype, dtype))
97-
elif _is_regex_pattern(column_key):
98-
pattern_str, compiled_pattern = column_key
99-
matches = _match_column_with_regex(column_key, list(df.columns))
100-
if not matches:
101-
missing_columns.append(pattern_str)
102-
else:
103-
for matched_col in matches:
104-
matched_by_regex.add(matched_col)
105-
if df[matched_col].dtype != dtype:
106-
dtype_mismatches.append((matched_col, df[matched_col].dtype, dtype))
107-
108-
if missing_columns:
109-
param_info = f" in parameter '{param_name}'" if param_name else ""
110-
raise AssertionError(f"Missing columns: {missing_columns}{param_info}. Got {_describe_pd(df)}")
111-
112-
if dtype_mismatches:
113-
param_info = f" in parameter '{param_name}'" if param_name else ""
114-
mismatches = ", ".join(
115-
[
116-
f"Column {col}{param_info} has wrong dtype. Was {was}, expected {expected}"
117-
for col, was, expected in dtype_mismatches
118-
]
110+
for column, expected_dtype in columns.items():
111+
column_spec = (
112+
_compile_regex_pattern(column) if isinstance(column, str) and _is_regex_string(column) else column
113+
)
114+
missing, mismatches, matched = _validate_column(column_spec, df, expected_dtype)
115+
all_missing_columns.extend(missing)
116+
all_dtype_mismatches.extend(mismatches)
117+
all_matched_by_regex.update(matched)
118+
119+
param_info = _make_param_info(param_name)
120+
121+
if all_missing_columns:
122+
raise AssertionError(f"Missing columns: {all_missing_columns}{param_info}. Got {_describe_pd(df)}")
123+
124+
if all_dtype_mismatches:
125+
mismatch_descriptions = ", ".join(
126+
f"Column {col}{param_info} has wrong dtype. Was {was}, expected {expected}"
127+
for col, was, expected in all_dtype_mismatches
119128
)
120-
raise AssertionError(mismatches)
129+
raise AssertionError(mismatch_descriptions)
121130

122131
if strict:
123132
explicit_columns = {col for col in columns if isinstance(col, str)}
124-
allowed_columns = explicit_columns.union(matched_by_regex)
133+
allowed_columns = explicit_columns.union(all_matched_by_regex)
125134
extra_columns = set(df.columns) - allowed_columns
126-
127135
if extra_columns:
128-
param_info = f" in parameter '{param_name}'" if param_name else ""
129136
raise AssertionError(f"DataFrame{param_info} contained unexpected column(s): {', '.join(extra_columns)}")
130137

131138

@@ -165,37 +172,27 @@ def wrapper(*args: Any, **kwargs: Any) -> DF:
165172

166173
def _get_parameter(func: Callable[..., Any], name: Optional[str] = None, *args: Any, **kwargs: Any) -> Any:
167174
if not name:
168-
if len(args) > 0:
169-
return args[0]
170-
if kwargs:
171-
return next(iter(kwargs.values()))
172-
return None
175+
return args[0] if args else next(iter(kwargs.values()), None)
173176

174-
if name not in kwargs:
175-
func_params_in_order = list(inspect.signature(func).parameters.keys())
176-
parameter_location = func_params_in_order.index(name)
177-
return args[parameter_location]
177+
if name in kwargs:
178+
return kwargs[name]
178179

179-
return kwargs[name]
180+
func_params_in_order = list(inspect.signature(func).parameters.keys())
181+
parameter_location = func_params_in_order.index(name)
182+
return args[parameter_location]
180183

181184

182185
def _get_parameter_name(
183186
func: Callable[..., Any], name: Optional[str] = None, *args: Any, **kwargs: Any
184187
) -> Optional[str]:
185-
"""Get the actual parameter name being validated."""
186188
if name:
187189
return name
188190

189-
# If no name specified, try to get the first parameter name
190-
if len(args) > 0:
191-
# Get the first parameter name from the function signature
191+
if args:
192192
func_params_in_order = list(inspect.signature(func).parameters.keys())
193193
return func_params_in_order[0]
194-
elif kwargs:
195-
# Return the first keyword argument name
196-
return next(iter(kwargs.keys()))
197194

198-
return None
195+
return next(iter(kwargs.keys()), None)
199196

200197

201198
def df_in(
@@ -246,19 +243,13 @@ def _describe_pd(df: DataFrameType, include_dtypes: bool = False) -> str:
246243

247244

248245
def _log_input(level: int, func_name: str, df: Any, include_dtypes: bool) -> None:
249-
if isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame):
250-
logging.log(
251-
level,
252-
f"Function {func_name} parameters contained a DataFrame: {_describe_pd(df, include_dtypes)}",
253-
)
246+
if isinstance(df, (pd.DataFrame, pl.DataFrame)):
247+
logging.log(level, f"Function {func_name} parameters contained a DataFrame: {_describe_pd(df, include_dtypes)}")
254248

255249

256250
def _log_output(level: int, func_name: str, df: Any, include_dtypes: bool) -> None:
257-
if isinstance(df, pd.DataFrame) or isinstance(df, pl.DataFrame):
258-
logging.log(
259-
level,
260-
f"Function {func_name} returned a DataFrame: {_describe_pd(df, include_dtypes)}",
261-
)
251+
if isinstance(df, (pd.DataFrame, pl.DataFrame)):
252+
logging.log(level, f"Function {func_name} returned a DataFrame: {_describe_pd(df, include_dtypes)}")
262253

263254

264255
def df_log(level: int = logging.DEBUG, include_dtypes: bool = False) -> Callable[[Callable[..., T]], Callable[..., T]]:

tests/test_df_in.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from daffy import df_in
8+
from daffy.decorators import _check_columns, _get_parameter_name
89
from tests.conftest import DataFrameType, cars, extended_cars
910

1011

@@ -325,26 +326,20 @@ def test_fn(cars: DataFrameType, ext_cars: DataFrameType) -> int:
325326

326327

327328
def test_check_columns_handles_invalid_column_type_in_list() -> None:
328-
from daffy.decorators import _check_columns
329-
330329
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
331-
columns = ["A", 123]
330+
columns: Any = ["A", 123]
332331

333332
_check_columns(df, columns, False)
334333

335334

336335
def test_check_columns_handles_invalid_column_key_in_dict() -> None:
337-
from daffy.decorators import _check_columns
338-
339336
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
340-
columns = {"A": "int64", 123: "int64"}
337+
columns: Any = {"A": "int64", 123: "int64"}
341338

342339
_check_columns(df, columns, False)
343340

344341

345342
def test_get_parameter_name_returns_none_when_no_params() -> None:
346-
from daffy.decorators import _get_parameter_name
347-
348343
def func_with_no_params() -> None:
349344
pass
350345

@@ -353,8 +348,6 @@ def func_with_no_params() -> None:
353348

354349

355350
def test_get_parameter_name_returns_none_when_no_args_or_kwargs() -> None:
356-
from daffy.decorators import _get_parameter_name
357-
358351
def some_func(param: str) -> str:
359352
return param
360353

@@ -363,10 +356,8 @@ def some_func(param: str) -> str:
363356

364357

365358
def test_missing_column_in_dict_specification() -> None:
366-
from daffy.decorators import _check_columns
367-
368359
df = pd.DataFrame({"A": [1, 2]})
369-
columns = {"A": "int64", "MissingCol": "int64"}
360+
columns: Any = {"A": "int64", "MissingCol": "int64"}
370361

371362
with pytest.raises(AssertionError) as excinfo:
372363
_check_columns(df, columns, False)
@@ -375,10 +366,8 @@ def test_missing_column_in_dict_specification() -> None:
375366

376367

377368
def test_missing_regex_pattern_in_dict_specification() -> None:
378-
from daffy.decorators import _check_columns
379-
380369
df = pd.DataFrame({"A": [1, 2]})
381-
columns = {"A": "int64", "r/Missing_[0-9]/": "int64"}
370+
columns: Any = {"A": "int64", "r/Missing_[0-9]/": "int64"}
382371

383372
with pytest.raises(AssertionError) as excinfo:
384373
_check_columns(df, columns, False)

0 commit comments

Comments
 (0)