11"""Decorators for DAFFY DataFrame Column Validator."""
22
3+ from __future__ import annotations
4+
35import logging
6+ from collections .abc import Callable
47from functools import wraps
5- from typing import TYPE_CHECKING , Any , Callable , Optional , TypeVar , Union
8+ from typing import TYPE_CHECKING , Any , TypeVar
69
710if TYPE_CHECKING :
811 # For static type checking, assume both are available
1417 PandasDataFrame = None
1518 PolarsDataFrame = None
1619
17- from daffy .config import get_strict
20+ from daffy .config import get_row_validation_config , get_strict
1821from daffy .dataframe_types import DataFrameType
22+ from daffy .row_validation import validate_dataframe_rows
1923from daffy .utils import (
2024 assert_is_dataframe ,
25+ format_param_context ,
2126 get_parameter ,
2227 get_parameter_name ,
2328 log_dataframe_input ,
2833# Type variables for preserving return types
2934T = TypeVar ("T" ) # Generic type var for df_log
3035if TYPE_CHECKING :
31- DF = TypeVar ("DF" , bound = Union [ PandasDataFrame , PolarsDataFrame ] )
36+ DF = TypeVar ("DF" , bound = PandasDataFrame | PolarsDataFrame )
3237else :
3338 DF = TypeVar ("DF" , bound = DataFrameType )
3439R = TypeVar ("R" ) # Return type for df_in
3540
3641
42+ def _validate_rows_with_context (
43+ df : Any ,
44+ row_validator : "type[BaseModel]" ,
45+ func_name : str ,
46+ param_name : str | None ,
47+ is_return_value : bool ,
48+ ) -> None :
49+ """Validate DataFrame rows with Pydantic model and add context to errors.
50+
51+ Args:
52+ df: DataFrame to validate
53+ row_validator: Pydantic model class for row validation
54+ func_name: Name of the decorated function
55+ param_name: Name of the parameter being validated (None for return values)
56+ is_return_value: True if validating a return value
57+ """
58+ config = get_row_validation_config ()
59+
60+ try :
61+ validate_dataframe_rows (
62+ df ,
63+ row_validator ,
64+ max_errors = config ["max_errors" ],
65+ convert_nans = config ["convert_nans" ],
66+ )
67+ except AssertionError as e :
68+ context = format_param_context (param_name , func_name , is_return_value )
69+ raise AssertionError (f"{ str (e )} { context } " ) from e
70+
71+
3772def df_out (
3873 columns : ColumnsDef = None ,
39- strict : Optional [ bool ] = None ,
40- row_validator : Optional [ "type[BaseModel]" ] = None ,
74+ strict : bool | None = None ,
75+ row_validator : "type[BaseModel] | None" = None ,
4176) -> Callable [[Callable [..., DF ]], Callable [..., DF ]]:
4277 """Decorate a function that returns a Pandas or Polars DataFrame.
4378
@@ -67,22 +102,7 @@ def wrapper(*args: Any, **kwargs: Any) -> DF:
67102 validate_dataframe (result , columns , get_strict (strict ), None , func .__name__ , True )
68103
69104 if row_validator is not None :
70- from daffy .config import get_row_validation_config
71- from daffy .row_validation import validate_dataframe_rows
72- from daffy .utils import format_param_context
73-
74- config = get_row_validation_config ()
75-
76- try :
77- validate_dataframe_rows (
78- result ,
79- row_validator ,
80- max_errors = config ["max_errors" ],
81- convert_nans = config ["convert_nans" ],
82- )
83- except AssertionError as e :
84- context = format_param_context (None , func .__name__ , True )
85- raise AssertionError (f"{ str (e )} { context } " ) from e
105+ _validate_rows_with_context (result , row_validator , func .__name__ , None , True )
86106
87107 return result
88108
@@ -92,10 +112,10 @@ def wrapper(*args: Any, **kwargs: Any) -> DF:
92112
93113
94114def df_in (
95- name : Optional [ str ] = None ,
115+ name : str | None = None ,
96116 columns : ColumnsDef = None ,
97- strict : Optional [ bool ] = None ,
98- row_validator : Optional [ "type[BaseModel]" ] = None ,
117+ strict : bool | None = None ,
118+ row_validator : "type[BaseModel] | None" = None ,
99119) -> Callable [[Callable [..., R ]], Callable [..., R ]]:
100120 """Decorate a function parameter that is a Pandas or Polars DataFrame.
101121
@@ -127,22 +147,7 @@ def wrapper(*args: Any, **kwargs: Any) -> R:
127147 validate_dataframe (df , columns , get_strict (strict ), param_name , func .__name__ )
128148
129149 if row_validator is not None :
130- from daffy .config import get_row_validation_config
131- from daffy .row_validation import validate_dataframe_rows
132- from daffy .utils import format_param_context
133-
134- config = get_row_validation_config ()
135-
136- try :
137- validate_dataframe_rows (
138- df ,
139- row_validator ,
140- max_errors = config ["max_errors" ],
141- convert_nans = config ["convert_nans" ],
142- )
143- except AssertionError as e :
144- context = format_param_context (param_name , func .__name__ , False )
145- raise AssertionError (f"{ str (e )} { context } " ) from e
150+ _validate_rows_with_context (df , row_validator , func .__name__ , param_name , False )
146151
147152 return func (* args , ** kwargs )
148153
0 commit comments