Skip to content

Commit 4fe5c42

Browse files
wesselhuisingWessel Huisingxaviernogueira
authored
Implement strict parameter for strict mode (#33)
* add strict mode * add test for strict mode * black and mypy fixes * 📝 adding to docstring * ✏️ refining logic for strict mode * 📝 adding note for future dev * ✏️ failing on first error, not after all errors --------- Co-authored-by: Wessel Huising <wessel@MacBook-Pro-van-Wessel.local> Co-authored-by: Xavier Nogueira <xavier.rojas.nogueira@gmail.com>
1 parent fe85dc9 commit 4fe5c42

2 files changed

Lines changed: 50 additions & 7 deletions

File tree

src/pandantic/validators/pandas.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def validate(
2323
self,
2424
dataframe: pd.DataFrame,
2525
errors: Literal["skip", "raise", "log"] = "raise",
26+
strict: bool = False,
2627
context: Optional[
2728
dict[str, Any]
2829
] = None, # pylint: disable=consider-alternative-union-syntax,useless-suppression
@@ -35,13 +36,26 @@ def validate(
3536
dataframe (pd.DataFrame): The DataFrame to validate.
3637
errors (Literal["skip", "raise", "log"], optional): How to handle validation errors. Defaults to "raise".
3738
NOTE: "skip" and "log" effectively filter the dataframe, excluding invalid rows.
39+
strict (bool, default=False): whether to fail validation if extra fields/columns are present.
3840
context (Optional[dict[str, Any]], optional): The context to use for validation. Defaults to None.
3941
n_jobs (int, optional): The number of processes to use for validation. Defaults to 1.
4042
queue (Optional[Queue], optional): A custom Queue object for multiprocessing. Defaults to None.
4143
4244
Returns:
4345
pd.DataFrame: The original DataFrame if errors="raise" or "log", or a filtered DataFrame with valid rows if errors="skip".
4446
"""
47+
# check for extra columns and handle strict mode
48+
# NOTE: this will need to be abstracted to handle different types of schema objects
49+
if strict:
50+
extras = {
51+
col for col in dataframe.columns if col not in self.schema.model_fields.keys()
52+
}
53+
if extras:
54+
raise ValueError(
55+
f"Strict mode is enabled but the following extra columns were found in the schema: {extras}."
56+
)
57+
del extras
58+
4559
if errors not in ["skip", "raise", "log"]:
4660
raise ValueError("errors must be one of 'skip', 'raise', or 'log'")
4761

@@ -101,13 +115,12 @@ def validate(
101115
except ValidationError as exc: # pylint: disable=broad-exception-caught
102116
if errors == "log":
103117
logging.info("Validation error found at index %s\n%s", index, exc)
104-
118+
if errors == "raise":
119+
raise exc
105120
errors_index.append(index)
106121

107122
logging.debug("# invalid rows: %s", len(errors_index))
108123

109-
if len(errors_index) > 0 and errors == "raise":
110-
raise ValueError(f"{len(errors_index)} validation errors found in dataframe.")
111124
if len(errors_index) > 0 and errors in ["skip", "log"]:
112125
return dataframe[~dataframe.index.isin(list(errors_index))]
113126
return dataframe
@@ -141,7 +154,8 @@ def _validate_chunk(
141154
except ValidationError as exc: # pylint: disable=broad-exception-caught
142155
if errors == "log":
143156
logging.info("Validation error found at index %s\n%s", index, exc)
144-
157+
if errors == "raise":
158+
raise exc
145159
queue.put(index)
146160

147161
logging.debug("Process ended.")
@@ -159,9 +173,12 @@ def iterate(
159173
"""Iterate over a DataFrame and yield validated schema models."""
160174
for i, row in dataframe.iterrows():
161175
try:
162-
yield i, self.schema.model_validate(
163-
obj=row.to_dict(),
164-
context=context,
176+
yield (
177+
i,
178+
self.schema.model_validate(
179+
obj=row.to_dict(),
180+
context=context,
181+
),
165182
)
166183
except Exception as e:
167184
if verbose:

tests/test_pandas_validator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
* validate() function (full table).
44
* validate() function (to skip table).
55
"""
6+
67
import logging
78
from typing import Optional
89

@@ -155,3 +156,28 @@ class Model(BaseModel):
155156

156157
# THEN
157158
assert df_skiped.equals(df_example)
159+
160+
161+
def test_strict_mode():
162+
# GIVEN
163+
class Model(BaseModel):
164+
a: Optional[int] = None
165+
b: str
166+
167+
df_example = pd.DataFrame({"a": [None, None, None], "b": ["str", "str", "str"], "c": [1, 2, 3]})
168+
validator = PandasValidator(schema=Model)
169+
170+
# WHEN
171+
df_skiped = validator.validate(df_example, errors="skip")
172+
173+
# THEN
174+
assert df_skiped.equals(df_example)
175+
176+
# THEN
177+
with pytest.raises(ValueError):
178+
# WHEN
179+
validator.validate(
180+
dataframe=df_example,
181+
strict=True,
182+
errors="raise",
183+
)

0 commit comments

Comments
 (0)