diff --git a/adala/skills/collection/entity_extraction.py b/adala/skills/collection/entity_extraction.py index 18177882..3ca81b91 100644 --- a/adala/skills/collection/entity_extraction.py +++ b/adala/skills/collection/entity_extraction.py @@ -32,13 +32,14 @@ def validate_output_format_for_ner_tag( continue text = row[input_field_name] entities = row[output_field_name] - for entity in entities: - corrected_text = text[entity["start"] : entity["end"]] - if entity.get("text") is None: - entity["text"] = corrected_text - elif entity["text"] != corrected_text: - # this seems to happen rarely if at all in testing, but could lead to invalid predictions - logger.warning(f"text and indices disagree for a predicted entity") + if entities is not None: + for entity in entities: + corrected_text = text[entity["start"] : entity["end"]] + if entity.get("text") is None: + entity["text"] = corrected_text + elif entity["text"] != corrected_text: + # this seems to happen rarely if at all in testing, but could lead to invalid predictions + logger.error(f"text and indices disagree for a predicted entity") return df diff --git a/adala/utils/parse.py b/adala/utils/parse.py index b367c463..3a07dd3a 100644 --- a/adala/utils/parse.py +++ b/adala/utils/parse.py @@ -1,7 +1,6 @@ import re import string import logging -from string import Formatter from typing import ( List, TypedDict, @@ -21,102 +20,78 @@ class PartialStringFormatter(string.Formatter): - def get_value(self, key, args, kwds): - if isinstance(key, str): - try: - return kwds[key] - except KeyError: - return "{" + key + "}" - else: - Formatter.get_value(key, args, kwds) + def __init__(self): + super().__init__() + self._current_field_name = None - def format_field(self, value, format_spec): + def get_field(self, field_name, args, kwargs): + self._current_field_name = field_name try: - return super().format_field(value, format_spec) - except ValueError: - # HACK: the value was an unfilled variable or not a variable at all, so the format spec should be considered part of the variable name - if value.startswith("{") and value.endswith("}"): - return value[:-1] + ":" + format_spec + "}" + return super().get_field(field_name, args, kwargs) + except (KeyError, AttributeError): + # For unprovided variables, preserve the entire field name including format spec + return "{" + field_name + "}", field_name - def _vformat( - self, format_string, args, kwargs, used_args, recursion_depth, auto_arg_index=0 - ): - # copied verbatim from parent class except for the # HACK - if recursion_depth < 0: - raise ValueError("Max string recursion exceeded") - result = [] - for literal_text, field_name, format_spec, conversion in self.parse( - format_string - ): - - # output the literal text - if literal_text: - result.append(literal_text) - - # if there's a field, output it - if field_name is not None: - # this is some markup, find the object and do - # the formatting - - # handle arg indexing when empty field_names are given. - if field_name == "": - if auto_arg_index is False: - raise ValueError( - "cannot switch from manual field " - "specification to automatic field " - "numbering" - ) - field_name = str(auto_arg_index) - auto_arg_index += 1 - elif field_name.isdigit(): - if auto_arg_index: - raise ValueError( - "cannot switch from manual field " - "specification to automatic field " - "numbering" - ) - # disable auto arg incrementing, if it gets - # used later on, then an exception will be raised - auto_arg_index = False - - # given the field_name, find the object it references - # and the argument it came from - obj, arg_used = self.get_field(field_name, args, kwargs) - used_args.add(arg_used) - - # do any conversion on the resulting object - obj = self.convert_field(obj, conversion) - - # expand the format spec, if needed - format_spec, auto_arg_index = self._vformat( - format_spec, - args, - kwargs, - used_args, - recursion_depth - 1, - auto_arg_index=auto_arg_index, - ) - - # format the object and append to the result - # HACK: if the format_spec is invalid, assume this field_name was not meant to be a variable, and don't substitute anything - formatted_field = self.format_field(obj, format_spec) - if formatted_field is None: - result.append("{" + ":".join([field_name, format_spec]) + "}") - else: - result.append(formatted_field) + def format_field(self, value, format_spec): + if isinstance(value, str) and value.startswith("{"): + # This is a preserved placeholder, return as is + if format_spec: + return value[:-1] + ":" + format_spec + "}" + else: + return value - return "".join(result), auto_arg_index + try: + return super().format_field(value, format_spec) + except (ValueError, TypeError): + # If format spec is invalid, preserve the original field name and format spec + if format_spec: + return "{" + self._current_field_name + ":" + format_spec + "}" + return str(value) PartialStringFormat = PartialStringFormatter() -def partial_str_format(string, **kwargs): +def partial_str_format(format_string: str, **kwargs) -> str: """ - Formats a string with a subset of the arguments. - Analogous to str.format, but ignores missing arguments. + Format a string with the provided variables while preserving any unprovided placeholders. + Preserves format specifiers for both provided and unprovided variables. + + Args: + format_string: The string to format + **kwargs: The variables to use for formatting + + Returns: + The formatted string with preserved unprovided placeholders and format specifiers + + Examples: + >>> partial_str_format("Hello {name}!", name="World") + 'Hello World!' + >>> partial_str_format("Hello {name} {unknown}!", name="World") + 'Hello World {unknown}!' + >>> partial_str_format("Value: {x:.2f}", x="not_a_float") + 'Value: {x:.2f}' """ - return PartialStringFormat.format(string, **kwargs) + if not format_string: + return "" + + # Temporarily replace valid format strings to protect them from escaping + format_pattern = re.compile(r"\{[^{}]+\}") + markers = { + f"__MARKER_{i}__": m.group(0) + for i, m in enumerate(format_pattern.finditer(format_string)) + } + + processed = format_string + for marker, format_str in markers.items(): + processed = processed.replace(format_str, marker) + + # Escape remaining brackets and restore format strings + processed = processed.replace("{", "{{").replace("}", "}}") + for marker, format_str in markers.items(): + processed = processed.replace(marker, format_str) + + return PartialStringFormat.format(processed, **kwargs) class TemplateChunks(TypedDict): diff --git a/tests/test_parse.py b/tests/test_parse.py index 77145086..2b589c1e 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -45,6 +45,20 @@ def test_partial_string_format(): result = partial_str_format("{missing1} and {missing2}") assert result == "{missing1} and {missing2}" + # Test with unmatched brackets + result = partial_str_format("{ {text}", text="test") + assert result == "{ test" + + # Test adversarial example + result = partial_str_format( + '{"key": "value", "text": "{text}", "unused1": "{unused}", "nested": {"subkey": "{more_unexpected}"}, "break": "\\" \' \\n \\t \\b \\f \\r \\\\ \\/ {unmatched", "unicode": "\uD83D\uDE00", "null": null, "number": 123, "array": [1, 2, "{array_item}"], "weird": "\u0000\u001F"}', + text="test", + ) + assert ( + result + == '{"key": "value", "text": "test", "unused1": "{unused}", "nested": {"subkey": "{more_unexpected}"}, "break": "\\" \' \\n \\t \\b \\f \\r \\\\ \\/ {unmatched", "unicode": "\uD83D\uDE00", "null": null, "number": 123, "array": [1, 2, "{array_item}"], "weird": "\u0000\u001F"}' + ) + # Test larger prompt prompt = """ Given the following product review: