Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def validate_output_format_for_ner_tag(
continue
text = row[input_field_name]
entities = row[output_field_name]
for entity in entities:
corrected_text = text[entity["start"] : entity["end"]]
if entity.get("text") is None:
entity["text"] = corrected_text
elif entity["text"] != corrected_text:
# this seems to happen rarely if at all in testing, but could lead to invalid predictions
logger.warning(f"text and indices disagree for a predicted entity")
if entities is not None:
for entity in entities:
corrected_text = text[entity["start"] : entity["end"]]
if entity.get("text") is None:
entity["text"] = corrected_text
elif entity["text"] != corrected_text:
# this seems to happen rarely if at all in testing, but could lead to invalid predictions
logger.warning(f"text and indices disagree for a predicted entity")
return df


Expand Down
147 changes: 61 additions & 86 deletions adala/utils/parse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
import string
import logging
from string import Formatter
from typing import (
List,
TypedDict,
Expand All @@ -21,102 +20,78 @@


class PartialStringFormatter(string.Formatter):
def get_value(self, key, args, kwds):
if isinstance(key, str):
try:
return kwds[key]
except KeyError:
return "{" + key + "}"
else:
Formatter.get_value(key, args, kwds)
def __init__(self):
super().__init__()
self._current_field_name = None

def format_field(self, value, format_spec):
def get_field(self, field_name, args, kwargs):
self._current_field_name = field_name
try:
return super().format_field(value, format_spec)
except ValueError:
# HACK: the value was an unfilled variable or not a variable at all, so the format spec should be considered part of the variable name
if value.startswith("{") and value.endswith("}"):
return value[:-1] + ":" + format_spec + "}"
return super().get_field(field_name, args, kwargs)
except (KeyError, AttributeError):
# For unprovided variables, preserve the entire field name including format spec
return "{" + field_name + "}", field_name

def _vformat(
self, format_string, args, kwargs, used_args, recursion_depth, auto_arg_index=0
):
# copied verbatim from parent class except for the # HACK
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
result = []
for literal_text, field_name, format_spec, conversion in self.parse(
format_string
):

# output the literal text
if literal_text:
result.append(literal_text)

# if there's a field, output it
if field_name is not None:
# this is some markup, find the object and do
# the formatting

# handle arg indexing when empty field_names are given.
if field_name == "":
if auto_arg_index is False:
raise ValueError(
"cannot switch from manual field "
"specification to automatic field "
"numbering"
)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError(
"cannot switch from manual field "
"specification to automatic field "
"numbering"
)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False

# given the field_name, find the object it references
# and the argument it came from
obj, arg_used = self.get_field(field_name, args, kwargs)
used_args.add(arg_used)

# do any conversion on the resulting object
obj = self.convert_field(obj, conversion)

# expand the format spec, if needed
format_spec, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index=auto_arg_index,
)

# format the object and append to the result
# HACK: if the format_spec is invalid, assume this field_name was not meant to be a variable, and don't substitute anything
formatted_field = self.format_field(obj, format_spec)
if formatted_field is None:
result.append("{" + ":".join([field_name, format_spec]) + "}")
else:
result.append(formatted_field)
def format_field(self, value, format_spec):
if isinstance(value, str) and value.startswith("{"):
# This is a preserved placeholder, return as is
if format_spec:
return value[:-1] + ":" + format_spec + "}"
else:
return value

return "".join(result), auto_arg_index
try:
return super().format_field(value, format_spec)
except (ValueError, TypeError):
# If format spec is invalid, preserve the original field name and format spec
if format_spec:
return "{" + self._current_field_name + ":" + format_spec + "}"
return str(value)


PartialStringFormat = PartialStringFormatter()


def partial_str_format(string, **kwargs):
def partial_str_format(format_string: str, **kwargs) -> str:
"""
Formats a string with a subset of the arguments.
Analogous to str.format, but ignores missing arguments.
Format a string with the provided variables while preserving any unprovided placeholders.
Preserves format specifiers for both provided and unprovided variables.
Args:
format_string: The string to format
**kwargs: The variables to use for formatting
Returns:
The formatted string with preserved unprovided placeholders and format specifiers
Examples:
>>> partial_str_format("Hello {name}!", name="World")
'Hello World!'
>>> partial_str_format("Hello {name} {unknown}!", name="World")
'Hello World {unknown}!'
>>> partial_str_format("Value: {x:.2f}", x="not_a_float")
'Value: {x:.2f}'
"""
return PartialStringFormat.format(string, **kwargs)
if not format_string:
return ""

# Temporarily replace valid format strings to protect them from escaping
format_pattern = re.compile(r"\{[^{}]+\}")
markers = {
f"__MARKER_{i}__": m.group(0)
for i, m in enumerate(format_pattern.finditer(format_string))
}

processed = format_string
for marker, format_str in markers.items():
processed = processed.replace(format_str, marker)

# Escape remaining brackets and restore format strings
processed = processed.replace("{", "{{").replace("}", "}}")
for marker, format_str in markers.items():
processed = processed.replace(marker, format_str)

return PartialStringFormat.format(processed, **kwargs)


class TemplateChunks(TypedDict):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ def test_partial_string_format():
result = partial_str_format("{missing1} and {missing2}")
assert result == "{missing1} and {missing2}"

# Test with unmatched brackets
result = partial_str_format("{ {text}", text="test")
assert result == "{ test"

# Test adversarial example
result = partial_str_format(
'{"key": "value", "text": "{text}", "unused1": "{unused}", "nested": {"subkey": "{more_unexpected}"}, "break": "\\" \' \\n \\t \\b \\f \\r \\\\ \\/ {unmatched", "unicode": "\uD83D\uDE00", "null": null, "number": 123, "array": [1, 2, "{array_item}"], "weird": "\u0000\u001F"}',
text="test",
)
assert (
result
== '{"key": "value", "text": "test", "unused1": "{unused}", "nested": {"subkey": "{more_unexpected}"}, "break": "\\" \' \\n \\t \\b \\f \\r \\\\ \\/ {unmatched", "unicode": "\uD83D\uDE00", "null": null, "number": 123, "array": [1, 2, "{array_item}"], "weird": "\u0000\u001F"}'
)

# Test larger prompt
prompt = """
Given the following product review:
Expand Down