Skip to content

Commit 0bcf0d2

Browse files
committed
Adds simple str contains and does not contain validator
This will enable someone to do: ``` @check_output(contains=["this is a sentence"], does_not_contain=["foo", "bar"]) def llm_call(...) -> str: return response ```
1 parent 4bdb14c commit 0bcf0d2

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

hamilton/data_quality/default_validators.py

+69
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,73 @@ def arg(cls) -> str:
407407
return "allow_none"
408408

409409

410+
class StrContainsValidator(base.BaseDefaultValidator):
411+
def __init__(self, contains: Union[str, List[str]], importance: str):
412+
super(StrContainsValidator, self).__init__(importance)
413+
if isinstance(contains, str):
414+
self.contains = [contains]
415+
else:
416+
self.contains = contains
417+
418+
@classmethod
419+
def applies_to(cls, datatype: Type[Type]) -> bool:
420+
return datatype == str
421+
422+
def description(self) -> str:
423+
return f"Validates that a string contains [{self.contains}] within it."
424+
425+
def validate(self, data: str) -> base.ValidationResult:
426+
passes = all([c in data for c in self.contains])
427+
return base.ValidationResult(
428+
passes=passes,
429+
message=(f"String did not contain {self.contains}" if not passes else "All good."),
430+
diagnostics=(
431+
{"contains": self.contains, "data": data if len(data) < 100 else data[:100]}
432+
if not passes
433+
else {}
434+
),
435+
)
436+
437+
@classmethod
438+
def arg(cls) -> str:
439+
return "contains"
440+
441+
442+
class StrDoesNotContainValidator(base.BaseDefaultValidator):
443+
def __init__(self, does_not_contain: Union[str, List[str]], importance: str):
444+
super(StrDoesNotContainValidator, self).__init__(importance)
445+
if isinstance(does_not_contain, str):
446+
self.does_not_contain = [does_not_contain]
447+
else:
448+
self.does_not_contain = does_not_contain
449+
450+
@classmethod
451+
def applies_to(cls, datatype: Type[Type]) -> bool:
452+
return datatype == str
453+
454+
def description(self) -> str:
455+
return f"Validates that a string does not contain [{self.does_not_contain}] within it."
456+
457+
def validate(self, data: str) -> base.ValidationResult:
458+
passes = all([c not in data for c in self.does_not_contain])
459+
return base.ValidationResult(
460+
passes=passes,
461+
message=(f"String did contain {self.does_not_contain}" if not passes else "All good."),
462+
diagnostics=(
463+
{
464+
"does_not_contain": self.does_not_contain,
465+
"data": data if len(data) < 100 else data[:100],
466+
}
467+
if not passes
468+
else {}
469+
),
470+
)
471+
472+
@classmethod
473+
def arg(cls) -> str:
474+
return "does_not_contain"
475+
476+
410477
AVAILABLE_DEFAULT_VALIDATORS = [
411478
AllowNaNsValidatorPandasSeries,
412479
DataInRangeValidatorPandasSeries,
@@ -419,6 +486,8 @@ def arg(cls) -> str:
419486
MaxStandardDevValidatorPandasSeries,
420487
MeanInRangeValidatorPandasSeries,
421488
AllowNoneValidator,
489+
StrContainsValidator,
490+
StrDoesNotContainValidator,
422491
]
423492

424493

tests/test_default_data_quality.py

+8
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ def test_resolve_default_validators_error(output_type, kwargs, importance):
222222
(default_validators.AllowNoneValidator, False, 1, True),
223223
(default_validators.AllowNoneValidator, True, None, True),
224224
(default_validators.AllowNoneValidator, True, 1, True),
225+
(default_validators.StrContainsValidator, "o b", "foo bar baz", True),
226+
(default_validators.StrContainsValidator, "oof", "foo bar baz", False),
227+
(default_validators.StrContainsValidator, ["o b", "baz"], "foo bar baz", True),
228+
(default_validators.StrContainsValidator, ["oof", "bar"], "foo bar baz", False),
229+
(default_validators.StrDoesNotContainValidator, "o b", "foo bar baz", False),
230+
(default_validators.StrDoesNotContainValidator, "oof", "foo bar baz", True),
231+
(default_validators.StrDoesNotContainValidator, ["o b", "boo"], "foo bar baz", False),
232+
(default_validators.StrDoesNotContainValidator, ["oof", "boo"], "foo bar baz", True),
225233
],
226234
)
227235
def test_default_data_validators(

0 commit comments

Comments
 (0)