Skip to content

Commit c52694d

Browse files
authored
Merge pull request #27 from vertti/more-refactorings
More refactorings
2 parents 17128ee + 15cd40b commit c52694d

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## 0.14.1
6+
7+
- Internal code quality improvements
8+
59
## 0.14.0
610

711
- Improve df_in error messages to include parameter names

daffy/decorators.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def _is_regex_pattern(column: Any) -> bool:
3535
)
3636

3737

38+
def _as_regex_pattern(column: Union[str, RegexColumnDef]) -> Optional[RegexColumnDef]:
39+
"""Convert column to RegexColumnDef if it is a regex pattern, otherwise return None."""
40+
if _is_regex_pattern(column):
41+
return column # type: ignore[return-value] # We know it's the right type after the check
42+
return None
43+
44+
3845
def _assert_is_dataframe(obj: Any, context: str) -> None:
3946
if not isinstance(obj, (pd.DataFrame, pl.DataFrame)):
4047
raise AssertionError(f"Wrong {context}. Expected DataFrame, got {type(obj).__name__} instead.")
@@ -44,32 +51,39 @@ def _make_param_info(param_name: Optional[str]) -> str:
4451
return f" in parameter '{param_name}'" if param_name else ""
4552

4653

47-
def _validate_column(
48-
column_spec: Union[str, RegexColumnDef], df: DataFrameType, expected_dtype: Any = None
49-
) -> Tuple[List[str], List[Tuple[str, Any, Any]], Set[str]]:
50-
"""Validate a single column specification against a DataFrame."""
51-
missing_columns = []
52-
dtype_mismatches = []
53-
matched_by_regex = set()
54-
54+
def _find_missing_columns(column_spec: Union[str, RegexColumnDef], df_columns: List[str]) -> List[str]:
55+
"""Find missing columns for a single column specification."""
5556
if isinstance(column_spec, str):
56-
if column_spec not in df.columns:
57-
missing_columns.append(column_spec)
58-
elif expected_dtype is not None and df[column_spec].dtype != expected_dtype:
59-
dtype_mismatches.append((column_spec, df[column_spec].dtype, expected_dtype))
57+
return [column_spec] if column_spec not in df_columns else []
6058
elif _is_regex_pattern(column_spec):
6159
pattern_str, _ = column_spec
62-
matches = _match_column_with_regex(column_spec, list(df.columns))
63-
if not matches:
64-
missing_columns.append(pattern_str)
65-
else:
66-
matched_by_regex.update(matches)
67-
if expected_dtype is not None:
68-
for matched_col in matches:
69-
if df[matched_col].dtype != expected_dtype:
70-
dtype_mismatches.append((matched_col, df[matched_col].dtype, expected_dtype))
60+
matches = _match_column_with_regex(column_spec, df_columns)
61+
return [pattern_str] if not matches else []
62+
return []
63+
64+
65+
def _find_dtype_mismatches(
66+
column_spec: Union[str, RegexColumnDef], df: DataFrameType, expected_dtype: Any, df_columns: List[str]
67+
) -> List[Tuple[str, Any, Any]]:
68+
"""Find dtype mismatches for a single column specification."""
69+
mismatches = []
70+
if isinstance(column_spec, str):
71+
if column_spec in df_columns and df[column_spec].dtype != expected_dtype:
72+
mismatches.append((column_spec, df[column_spec].dtype, expected_dtype))
73+
elif _is_regex_pattern(column_spec):
74+
matches = _match_column_with_regex(column_spec, df_columns)
75+
for matched_col in matches:
76+
if df[matched_col].dtype != expected_dtype:
77+
mismatches.append((matched_col, df[matched_col].dtype, expected_dtype))
78+
return mismatches
79+
7180

72-
return missing_columns, dtype_mismatches, matched_by_regex
81+
def _find_regex_matches(column_spec: Union[str, RegexColumnDef], df_columns: List[str]) -> Set[str]:
82+
"""Find regex matches for a single column specification."""
83+
regex_pattern = _as_regex_pattern(column_spec)
84+
if regex_pattern:
85+
return set(_match_column_with_regex(regex_pattern, df_columns))
86+
return set()
7387

7488

7589
def _match_column_with_regex(column_pattern: RegexColumnDef, df_columns: List[str]) -> List[str]:
@@ -94,27 +108,25 @@ def _compile_regex_patterns(columns: Seq[Any]) -> List[Union[str, RegexColumnDef
94108
def _check_columns(
95109
df: DataFrameType, columns: Union[ColumnsList, ColumnsDict], strict: bool, param_name: Optional[str] = None
96110
) -> None:
111+
df_columns = list(df.columns) # Cache the column list conversion
97112
all_missing_columns = []
98113
all_dtype_mismatches = []
99114
all_matched_by_regex = set()
100115

101116
if isinstance(columns, list):
102117
processed_columns = _compile_regex_patterns(columns)
103118
for column_spec in processed_columns:
104-
missing, mismatches, matched = _validate_column(column_spec, df)
105-
all_missing_columns.extend(missing)
106-
all_dtype_mismatches.extend(mismatches)
107-
all_matched_by_regex.update(matched)
119+
all_missing_columns.extend(_find_missing_columns(column_spec, df_columns))
120+
all_matched_by_regex.update(_find_regex_matches(column_spec, df_columns))
108121
else: # isinstance(columns, dict)
109122
assert isinstance(columns, dict)
110123
for column, expected_dtype in columns.items():
111124
column_spec = (
112125
_compile_regex_pattern(column) if isinstance(column, str) and _is_regex_string(column) else column
113126
)
114-
missing, mismatches, matched = _validate_column(column_spec, df, expected_dtype)
115-
all_missing_columns.extend(missing)
116-
all_dtype_mismatches.extend(mismatches)
117-
all_matched_by_regex.update(matched)
127+
all_missing_columns.extend(_find_missing_columns(column_spec, df_columns))
128+
all_dtype_mismatches.extend(_find_dtype_mismatches(column_spec, df, expected_dtype, df_columns))
129+
all_matched_by_regex.update(_find_regex_matches(column_spec, df_columns))
118130

119131
param_info = _make_param_info(param_name)
120132

@@ -131,7 +143,7 @@ def _check_columns(
131143
if strict:
132144
explicit_columns = {col for col in columns if isinstance(col, str)}
133145
allowed_columns = explicit_columns.union(all_matched_by_regex)
134-
extra_columns = set(df.columns) - allowed_columns
146+
extra_columns = set(df_columns) - allowed_columns
135147
if extra_columns:
136148
raise AssertionError(f"DataFrame{param_info} contained unexpected column(s): {', '.join(extra_columns)}")
137149

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "daffy"
3-
version = "0.14.0"
3+
version = "0.14.1"
44
description = "Function decorators for Pandas and Polars Dataframe column name and data type validation"
55
authors = [
66
{ name="Janne Sinivirta", email="[email protected]" },

0 commit comments

Comments
 (0)