Skip to content

Commit ab5f334

Browse files
feat: Add options for validation and auto fixing in GroupedDataProcessor (#31)
Signed-off-by: Kendrick Boyd <kendrickb@nvidia.com>
1 parent 18a8351 commit ab5f334

4 files changed

Lines changed: 314 additions & 33 deletions

File tree

src/nemo_safe_synthesizer/config/generate.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,47 @@
1818
range_validator,
1919
)
2020

21-
__all__ = ["GenerateParameters"]
21+
__all__ = ["GenerateParameters", "ValidationParameters"]
22+
23+
24+
class ValidationParameters(Parameters, BaseModel):
25+
"""Configuration for record and sequence validation.
26+
27+
These parameters control the validation and automatic fixes when going
28+
from LLM output to tabular data.
29+
"""
30+
31+
group_by_accept_no_delineator: Annotated[
32+
bool,
33+
Field(
34+
title="group_by_accept_no_delineator",
35+
description="Whether to accept completions without both beginning and end of sequence delineators as a single sequence.",
36+
),
37+
] = False
38+
39+
group_by_ignore_invalid_records: Annotated[
40+
bool,
41+
Field(
42+
title="group_by_ignore_invalid_records",
43+
description="Whether to ignore invalid records in a sequence and proceed with the valid records.",
44+
),
45+
] = False
46+
47+
group_by_fix_non_unique_value: Annotated[
48+
bool,
49+
Field(
50+
title="group_by_fix_non_unique_value",
51+
description="Whether to automatically fix non-unique group by values in a sequence by using the first unique value for all records.",
52+
),
53+
] = False
54+
55+
group_by_fix_unordered_records: Annotated[
56+
bool,
57+
Field(
58+
title="group_by_fix_unordered_records",
59+
description="Whether to automatically fix unordered records in a sequence by sorting the records.",
60+
),
61+
] = False
2262

2363

2464
class GenerateParameters(Parameters, BaseModel):
@@ -134,3 +174,8 @@ class GenerateParameters(Parameters, BaseModel):
134174
description="Enforce timeseries fidelity by enforcing the time series order, intervals, start and end times of the records.",
135175
),
136176
] = False
177+
178+
validation: ValidationParameters = Field(
179+
description="Validation parameters controlling validation logic and automatic fixes when parsing LLM output and converting to tabular data.",
180+
default_factory=ValidationParameters,
181+
)

src/nemo_safe_synthesizer/generation/processors.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass
88

99
from ..config import SafeSynthesizerParameters
10+
from ..config.generate import ValidationParameters
1011
from ..data_processing.record_utils import (
1112
check_if_records_are_ordered,
1213
extract_and_validate_records,
@@ -36,8 +37,10 @@ class Processor(ABC):
3637
schema: JSON schema as a dictionary.
3738
"""
3839

39-
def __init__(self, schema: dict):
40+
def __init__(self, schema: dict, config: ValidationParameters):
4041
self.schema = schema
42+
self.config = config
43+
logger.debug(f"Initialized processor with schema={schema} and config={config}")
4144

4245
@property
4346
def name(self):
@@ -102,8 +105,15 @@ def _process_text_generation(self, text: str) -> ParsedResponse:
102105
class TimeSeriesDataProcessor(Processor):
103106
"""Processor for time-series data generation tasks."""
104107

105-
def __init__(self, schema: dict, time_column: str | None, interval_seconds: int | None, time_format: str | None):
106-
super().__init__(schema=schema)
108+
def __init__(
109+
self,
110+
schema: dict,
111+
config: ValidationParameters,
112+
time_column: str | None,
113+
interval_seconds: int | None,
114+
time_format: str | None,
115+
):
116+
super().__init__(schema=schema, config=config)
107117
if time_column is None:
108118
raise ValueError(
109119
"time_column is required for TimeSeriesDataProcessor but was None. "
@@ -142,12 +152,13 @@ class GroupedDataProcessor(Processor):
142152
def __init__(
143153
self,
144154
schema: dict,
155+
config: ValidationParameters,
145156
bos_token: str,
146157
eos_token: str,
147158
group_by: str | list[str],
148159
order_by: str | None = None,
149160
):
150-
super().__init__(schema=schema)
161+
super().__init__(schema=schema, config=config)
151162
if isinstance(group_by, str):
152163
group_by = [group_by]
153164
self.group_by = group_by
@@ -158,12 +169,15 @@ def __init__(
158169
def _process_text_generation(self, text: str) -> ParsedResponse:
159170
"""Process the output from the fine-tuned model.
160171
161-
For records to be valid, they must:
172+
For records to be valid, they should:
162173
- Be in a group that is bound by BOS and EOS tokens.
163174
- Respect the known JSONL schema.
164175
- Have a unique value for the `group_by` field(s).
165176
- Be ordered by the `order_by` field if specified.
166177
178+
These requirements may be relaxed and automatically fixed depending on
179+
the settings in self.config.
180+
167181
Args:
168182
text: Text generated by the fine-tuned model.
169183
@@ -173,6 +187,9 @@ def _process_text_generation(self, text: str) -> ParsedResponse:
173187
groups = extract_groups_from_jsonl_string(text, self.bos_token, self.eos_token)
174188
groupby_validator = "groupby"
175189

190+
if len(groups) == 0 and self.config.group_by_accept_no_delineator:
191+
groups = [text]
192+
176193
if len(groups) == 0:
177194
return ParsedResponse(
178195
valid_records=[],
@@ -186,21 +203,53 @@ def _process_text_generation(self, text: str) -> ParsedResponse:
186203
valid, invalid, errors = extract_and_validate_records(group, self.schema)
187204
valid_with_str_members = [str(item) for item in valid]
188205

189-
# If there are any invalid records, the entire group is invalid.
190-
if len(invalid) > 0:
191-
invalid = valid_with_str_members + invalid
192-
errors = errors + [("Invalid JSON in other groupby records", groupby_validator)] * len(valid)
193-
valid = []
194-
195-
# The group is invalid if the set of group_by fields is not unique.
196-
elif len(set([tuple(record[group_by] for group_by in self.group_by) for record in valid])) != 1:
197-
valid, invalid = [], valid_with_str_members + invalid
198-
errors = [("Groupby value is not unique", groupby_validator)] * len(invalid)
206+
if len(valid) == 0:
207+
invalid_groups.extend(invalid)
208+
errors_groups.extend(errors)
209+
continue
199210

200-
# If order_by is specified, the group is invalid if the records are not ordered.
201-
elif self.order_by is not None and not check_if_records_are_ordered(valid, self.order_by):
202-
valid, invalid = [], valid_with_str_members + invalid
203-
errors = [("Group not ordered", groupby_validator)] * len(invalid)
211+
# Handle invalid records in the group (optionally ignore and proceed).
212+
if len(invalid) > 0:
213+
if self.config.group_by_ignore_invalid_records:
214+
invalid = []
215+
errors = []
216+
else:
217+
# If there are any invalid records, the entire group is invalid.
218+
invalid = valid_with_str_members + invalid
219+
errors = errors + [("Invalid JSON in other groupby records", groupby_validator)] * len(valid)
220+
valid = []
221+
valid_groups.extend(valid)
222+
invalid_groups.extend(invalid)
223+
errors_groups.extend(errors)
224+
continue
225+
226+
# Handle non-unique group_by values (optionally fix by using first record's values).
227+
if len(set(tuple(record[gb] for gb in self.group_by) for record in valid)) != 1:
228+
if self.config.group_by_fix_non_unique_value:
229+
for group_by in self.group_by:
230+
for record in valid[1:]:
231+
record[group_by] = valid[0][group_by]
232+
else:
233+
# The group is invalid if the set of group_by fields is not unique.
234+
valid, invalid = [], valid_with_str_members + invalid
235+
errors = [("Groupby value is not unique", groupby_validator)] * len(invalid)
236+
valid_groups.extend(valid)
237+
invalid_groups.extend(invalid)
238+
errors_groups.extend(errors)
239+
continue
240+
241+
# Handle unordered records when order_by is set (optionally fix by sorting).
242+
if self.order_by is not None and not check_if_records_are_ordered(valid, self.order_by):
243+
if self.config.group_by_fix_unordered_records:
244+
valid.sort(key=lambda x: x[self.order_by])
245+
else:
246+
# If order_by is specified, the group is invalid if the records are not ordered.
247+
valid, invalid = [], valid_with_str_members + invalid
248+
errors = [("Group not ordered", groupby_validator)] * len(invalid)
249+
valid_groups.extend(valid)
250+
invalid_groups.extend(invalid)
251+
errors_groups.extend(errors)
252+
continue
204253

205254
valid_groups.extend(valid)
206255
invalid_groups.extend(invalid)
@@ -227,20 +276,22 @@ def create_processor(schema: dict, metadata: ModelMetadata, config: SafeSynthesi
227276
if config.time_series.is_timeseries:
228277
processor = TimeSeriesDataProcessor(
229278
schema,
279+
config=config.generation.validation,
230280
time_column=config.time_series.timestamp_column,
231281
interval_seconds=config.time_series.timestamp_interval_seconds,
232282
time_format=config.time_series.timestamp_format,
233283
)
234284
elif config.data.group_training_examples_by:
235285
processor = GroupedDataProcessor(
236286
schema,
287+
config=config.generation.validation,
237288
group_by=config.data.group_training_examples_by,
238289
order_by=config.data.order_training_examples_by,
239290
bos_token=metadata.prompt_config.bos_token,
240291
eos_token=metadata.prompt_config.eos_token,
241292
)
242293
else:
243-
processor = TabularDataProcessor(schema)
294+
processor = TabularDataProcessor(schema, config=config.generation.validation)
244295

245296
logger.info(f"Initialized the {processor.name}")
246297
return processor

0 commit comments

Comments
 (0)