77from dataclasses import dataclass
88
99from ..config import SafeSynthesizerParameters
10+ from ..config .generate import ValidationParameters
1011from ..data_processing .record_utils import (
1112 check_if_records_are_ordered ,
1213 extract_and_validate_records ,
@@ -36,8 +37,10 @@ class Processor(ABC):
3637 schema: JSON schema as a dictionary.
3738 """
3839
39- def __init__ (self , schema : dict ):
40+ def __init__ (self , schema : dict , config : ValidationParameters ):
4041 self .schema = schema
42+ self .config = config
43+ logger .debug (f"Initialized processor with schema={ schema } and config={ config } " )
4144
4245 @property
4346 def name (self ):
@@ -102,8 +105,15 @@ def _process_text_generation(self, text: str) -> ParsedResponse:
102105class TimeSeriesDataProcessor (Processor ):
103106 """Processor for time-series data generation tasks."""
104107
105- def __init__ (self , schema : dict , time_column : str | None , interval_seconds : int | None , time_format : str | None ):
106- super ().__init__ (schema = schema )
108+ def __init__ (
109+ self ,
110+ schema : dict ,
111+ config : ValidationParameters ,
112+ time_column : str | None ,
113+ interval_seconds : int | None ,
114+ time_format : str | None ,
115+ ):
116+ super ().__init__ (schema = schema , config = config )
107117 if time_column is None :
108118 raise ValueError (
109119 "time_column is required for TimeSeriesDataProcessor but was None. "
@@ -142,12 +152,13 @@ class GroupedDataProcessor(Processor):
142152 def __init__ (
143153 self ,
144154 schema : dict ,
155+ config : ValidationParameters ,
145156 bos_token : str ,
146157 eos_token : str ,
147158 group_by : str | list [str ],
148159 order_by : str | None = None ,
149160 ):
150- super ().__init__ (schema = schema )
161+ super ().__init__ (schema = schema , config = config )
151162 if isinstance (group_by , str ):
152163 group_by = [group_by ]
153164 self .group_by = group_by
@@ -158,12 +169,15 @@ def __init__(
158169 def _process_text_generation (self , text : str ) -> ParsedResponse :
159170 """Process the output from the fine-tuned model.
160171
161- For records to be valid, they must :
172+ For records to be valid, they should :
162173 - Be in a group that is bound by BOS and EOS tokens.
163174 - Respect the known JSONL schema.
164175 - Have a unique value for the `group_by` field(s).
165176 - Be ordered by the `order_by` field if specified.
166177
178+ These requirements may be relaxed and automatically fixed depending on
179+ the settings in self.config.
180+
167181 Args:
168182 text: Text generated by the fine-tuned model.
169183
@@ -173,6 +187,9 @@ def _process_text_generation(self, text: str) -> ParsedResponse:
173187 groups = extract_groups_from_jsonl_string (text , self .bos_token , self .eos_token )
174188 groupby_validator = "groupby"
175189
190+ if len (groups ) == 0 and self .config .group_by_accept_no_delineator :
191+ groups = [text ]
192+
176193 if len (groups ) == 0 :
177194 return ParsedResponse (
178195 valid_records = [],
@@ -186,21 +203,53 @@ def _process_text_generation(self, text: str) -> ParsedResponse:
186203 valid , invalid , errors = extract_and_validate_records (group , self .schema )
187204 valid_with_str_members = [str (item ) for item in valid ]
188205
189- # If there are any invalid records, the entire group is invalid.
190- if len (invalid ) > 0 :
191- invalid = valid_with_str_members + invalid
192- errors = errors + [("Invalid JSON in other groupby records" , groupby_validator )] * len (valid )
193- valid = []
194-
195- # The group is invalid if the set of group_by fields is not unique.
196- elif len (set ([tuple (record [group_by ] for group_by in self .group_by ) for record in valid ])) != 1 :
197- valid , invalid = [], valid_with_str_members + invalid
198- errors = [("Groupby value is not unique" , groupby_validator )] * len (invalid )
206+ if len (valid ) == 0 :
207+ invalid_groups .extend (invalid )
208+ errors_groups .extend (errors )
209+ continue
199210
200- # If order_by is specified, the group is invalid if the records are not ordered.
201- elif self .order_by is not None and not check_if_records_are_ordered (valid , self .order_by ):
202- valid , invalid = [], valid_with_str_members + invalid
203- errors = [("Group not ordered" , groupby_validator )] * len (invalid )
211+ # Handle invalid records in the group (optionally ignore and proceed).
212+ if len (invalid ) > 0 :
213+ if self .config .group_by_ignore_invalid_records :
214+ invalid = []
215+ errors = []
216+ else :
217+ # If there are any invalid records, the entire group is invalid.
218+ invalid = valid_with_str_members + invalid
219+ errors = errors + [("Invalid JSON in other groupby records" , groupby_validator )] * len (valid )
220+ valid = []
221+ valid_groups .extend (valid )
222+ invalid_groups .extend (invalid )
223+ errors_groups .extend (errors )
224+ continue
225+
226+ # Handle non-unique group_by values (optionally fix by using first record's values).
227+ if len (set (tuple (record [gb ] for gb in self .group_by ) for record in valid )) != 1 :
228+ if self .config .group_by_fix_non_unique_value :
229+ for group_by in self .group_by :
230+ for record in valid [1 :]:
231+ record [group_by ] = valid [0 ][group_by ]
232+ else :
233+ # The group is invalid if the set of group_by fields is not unique.
234+ valid , invalid = [], valid_with_str_members + invalid
235+ errors = [("Groupby value is not unique" , groupby_validator )] * len (invalid )
236+ valid_groups .extend (valid )
237+ invalid_groups .extend (invalid )
238+ errors_groups .extend (errors )
239+ continue
240+
241+ # Handle unordered records when order_by is set (optionally fix by sorting).
242+ if self .order_by is not None and not check_if_records_are_ordered (valid , self .order_by ):
243+ if self .config .group_by_fix_unordered_records :
244+ valid .sort (key = lambda x : x [self .order_by ])
245+ else :
246+ # If order_by is specified, the group is invalid if the records are not ordered.
247+ valid , invalid = [], valid_with_str_members + invalid
248+ errors = [("Group not ordered" , groupby_validator )] * len (invalid )
249+ valid_groups .extend (valid )
250+ invalid_groups .extend (invalid )
251+ errors_groups .extend (errors )
252+ continue
204253
205254 valid_groups .extend (valid )
206255 invalid_groups .extend (invalid )
@@ -227,20 +276,22 @@ def create_processor(schema: dict, metadata: ModelMetadata, config: SafeSynthesi
227276 if config .time_series .is_timeseries :
228277 processor = TimeSeriesDataProcessor (
229278 schema ,
279+ config = config .generation .validation ,
230280 time_column = config .time_series .timestamp_column ,
231281 interval_seconds = config .time_series .timestamp_interval_seconds ,
232282 time_format = config .time_series .timestamp_format ,
233283 )
234284 elif config .data .group_training_examples_by :
235285 processor = GroupedDataProcessor (
236286 schema ,
287+ config = config .generation .validation ,
237288 group_by = config .data .group_training_examples_by ,
238289 order_by = config .data .order_training_examples_by ,
239290 bos_token = metadata .prompt_config .bos_token ,
240291 eos_token = metadata .prompt_config .eos_token ,
241292 )
242293 else :
243- processor = TabularDataProcessor (schema )
294+ processor = TabularDataProcessor (schema , config = config . generation . validation )
244295
245296 logger .info (f"Initialized the { processor .name } " )
246297 return processor
0 commit comments