Skip to content

feat: Update spans dsl to search for annotation existence #7406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: feat/annotations
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from uuid import uuid4

import sqlalchemy
from sqlalchemy import case, literal
from sqlalchemy.orm import Mapped, aliased
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.expression import Select
from sqlalchemy.sql.expression import ColumnElement, Select
from typing_extensions import TypeAlias, TypeGuard, assert_never

import phoenix.trace.v1 as pb
Expand All @@ -31,6 +32,8 @@
r"""\b((annotations|evals)\[(".*?"|'.*?')\][.](label|score))\b"""
)

EVAL_NAME_PATTERN = re.compile(r"""(?<!\w)((annotations|evals)\[(".*?"|'.*?')\])(?![\w\.])""")


@dataclass(frozen=True)
class AliasedAnnotationRelation:
Expand All @@ -46,26 +49,33 @@ class AliasedAnnotationRelation:
table: AliasedClass[models.SpanAnnotation] = field(init=False, repr=False)
_label_attribute_alias: str = field(init=False, repr=False)
_score_attribute_alias: str = field(init=False, repr=False)
_exists_attribute_alias: str = field(init=False, repr=False)

def __post_init__(self) -> None:
table_alias = f"span_annotation_{self.index}"
alias_id = uuid4().hex
label_attribute_alias = f"{table_alias}_label_{alias_id}"
score_attribute_alias = f"{table_alias}_score_{alias_id}"
exists_attribute_alias = f"{table_alias}_exists_{alias_id}"

table = aliased(models.SpanAnnotation, name=table_alias)
object.__setattr__(self, "_label_attribute_alias", label_attribute_alias)
object.__setattr__(self, "_score_attribute_alias", score_attribute_alias)
object.__setattr__(self, "_exists_attribute_alias", exists_attribute_alias)
object.__setattr__(self, "table", table)

@property
def attributes(self) -> typing.Iterator[tuple[str, Mapped[typing.Any]]]:
def attributes(self) -> typing.Iterator[tuple[str, ColumnElement[typing.Any]]]:
"""
Alias names and attributes (i.e., columns) of the `span_annotation`
relation.
"""
yield self._label_attribute_alias, self.table.label
yield self._score_attribute_alias, self.table.score
yield (
self._exists_attribute_alias,
case((self.table.id.is_not(None), literal(True)), else_=literal(False)),
)

def attribute_alias(self, attribute: AnnotationAttribute) -> str:
"""
Expand Down Expand Up @@ -555,6 +565,7 @@ def _validate_expression(
isinstance(node, (ast.BoolOp, ast.Compare))
or isinstance(node, ast.UnaryOp)
and isinstance(node.op, ast.Not)
or _is_annotation(node)
):
continue
elif (
Expand Down Expand Up @@ -783,7 +794,7 @@ def _apply_eval_aliasing(
eval_aliases: dict[AnnotationName, AliasedAnnotationRelation] = {}
for (
annotation_expression,
annotation_type,
_annotation_type,
annotation_name,
annotation_attribute,
) in _parse_annotation_expressions_and_names(source):
Expand All @@ -792,6 +803,15 @@ def _apply_eval_aliasing(
eval_aliases[annotation_name] = eval_alias
alias_name = eval_alias.attribute_alias(annotation_attribute)
source = source.replace(annotation_expression, alias_name)

for match in EVAL_NAME_PATTERN.finditer(source):
annotation_expression, _, quoted_eval_name = match.groups()
annotation_name = quoted_eval_name[1:-1]
if (eval_alias := eval_aliases.get(annotation_name)) is None:
eval_alias = AliasedAnnotationRelation(index=len(eval_aliases), name=annotation_name)
eval_aliases[annotation_name] = eval_alias
alias_name = eval_alias._exists_attribute_alias
source = source.replace(annotation_expression, alias_name)
return source, tuple(eval_aliases.values())


Expand All @@ -811,11 +831,11 @@ def _parse_annotation_expressions_and_names(
for match in EVAL_EXPRESSION_PATTERN.finditer(source):
(
annotation_expression,
annotation_type,
_annotation_type,
quoted_eval_name,
evaluation_attribute_name,
) = match.groups()
annotation_type = typing.cast(AnnotationType, annotation_type)
annotation_type = typing.cast(AnnotationType, _annotation_type)
yield (
annotation_expression,
annotation_type,
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/trace/dsl/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ async def test_filter_translated(
"span_annotation_0_label_00000000000000000000000000000000 is not None",
id="double-quoted-annotation-name",
),
# Existence checks (bare annotation reference)
pytest.param(
"""evals['Hallucination']""",
"span_annotation_0_exists_00000000000000000000000000000000",
id="bare-evals-exists",
),
pytest.param(
"""annotations['Hallucination']""",
"span_annotation_0_exists_00000000000000000000000000000000",
id="bare-annotations-exists",
),
],
)
def test_apply_eval_aliasing(filter_condition: str, expected: str) -> None:
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/trace/dsl/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,22 @@ async def test_filter_on_trace_id_multiple(
["evals['0'].score is None or evals['1'].label is not None", ["234", "456", "567"]],
["evals['0'].score == 0 or evals['1'].label != '1'", ["345", "567"]],
["evals['0'].score != 0 or evals['1'].label == '1'", ["456"]],
[
"evals['0']",
["345", "456"],
],
[
"annotations['0']",
["345", "456"],
],
[
"evals['1']",
["456", "567"],
],
[
"annotations['1']",
["456", "567"],
],
],
)
async def test_filter_on_span_annotation(
Expand Down
Loading