@@ -23,6 +23,7 @@ def validate(
2323 self ,
2424 dataframe : pd .DataFrame ,
2525 errors : Literal ["skip" , "raise" , "log" ] = "raise" ,
26+ strict : bool = False ,
2627 context : Optional [
2728 dict [str , Any ]
2829 ] = None , # pylint: disable=consider-alternative-union-syntax,useless-suppression
@@ -35,13 +36,26 @@ def validate(
3536 dataframe (pd.DataFrame): The DataFrame to validate.
3637 errors (Literal["skip", "raise", "log"], optional): How to handle validation errors. Defaults to "raise".
3738 NOTE: "skip" and "log" effectively filter the dataframe, excluding invalid rows.
39+ strict (bool, default=False): whether to fail validation if extra fields/columns are present.
3840 context (Optional[dict[str, Any]], optional): The context to use for validation. Defaults to None.
3941 n_jobs (int, optional): The number of processes to use for validation. Defaults to 1.
4042 queue (Optional[Queue], optional): A custom Queue object for multiprocessing. Defaults to None.
4143
4244 Returns:
4345 pd.DataFrame: The original DataFrame if errors="raise" or "log", or a filtered DataFrame with valid rows if errors="skip".
4446 """
47+ # check for extra columns and handle strict mode
48+ # NOTE: this will need to be abstracted to handle different types of schema objects
49+ if strict :
50+ extras = {
51+ col for col in dataframe .columns if col not in self .schema .model_fields .keys ()
52+ }
53+ if extras :
54+ raise ValueError (
55+ f"Strict mode is enabled but the following extra columns were found in the schema: { extras } ."
56+ )
57+ del extras
58+
4559 if errors not in ["skip" , "raise" , "log" ]:
4660 raise ValueError ("errors must be one of 'skip', 'raise', or 'log'" )
4761
@@ -101,13 +115,12 @@ def validate(
101115 except ValidationError as exc : # pylint: disable=broad-exception-caught
102116 if errors == "log" :
103117 logging .info ("Validation error found at index %s\n %s" , index , exc )
104-
118+ if errors == "raise" :
119+ raise exc
105120 errors_index .append (index )
106121
107122 logging .debug ("# invalid rows: %s" , len (errors_index ))
108123
109- if len (errors_index ) > 0 and errors == "raise" :
110- raise ValueError (f"{ len (errors_index )} validation errors found in dataframe." )
111124 if len (errors_index ) > 0 and errors in ["skip" , "log" ]:
112125 return dataframe [~ dataframe .index .isin (list (errors_index ))]
113126 return dataframe
@@ -141,7 +154,8 @@ def _validate_chunk(
141154 except ValidationError as exc : # pylint: disable=broad-exception-caught
142155 if errors == "log" :
143156 logging .info ("Validation error found at index %s\n %s" , index , exc )
144-
157+ if errors == "raise" :
158+ raise exc
145159 queue .put (index )
146160
147161 logging .debug ("Process ended." )
@@ -159,9 +173,12 @@ def iterate(
159173 """Iterate over a DataFrame and yield validated schema models."""
160174 for i , row in dataframe .iterrows ():
161175 try :
162- yield i , self .schema .model_validate (
163- obj = row .to_dict (),
164- context = context ,
176+ yield (
177+ i ,
178+ self .schema .model_validate (
179+ obj = row .to_dict (),
180+ context = context ,
181+ ),
165182 )
166183 except Exception as e :
167184 if verbose :
0 commit comments