From f032421975d326d3ca9ba3b430b8f083f0dd79d9 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 4 May 2025 14:47:49 -0700 Subject: [PATCH 1/3] span annotation mutations --- .../api/mutations/span_annotations_mutations.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/phoenix/server/api/mutations/span_annotations_mutations.py b/src/phoenix/server/api/mutations/span_annotations_mutations.py index c8e5ef3354..02636e6828 100644 --- a/src/phoenix/server/api/mutations/span_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/span_annotations_mutations.py @@ -5,6 +5,7 @@ from sqlalchemy import delete, insert, select from starlette.requests import Request from strawberry import UNSET, Info +from strawberry.relay import GlobalID from phoenix.db import models from phoenix.server.api.auth import IsLocked, IsNotReadOnly @@ -62,16 +63,13 @@ async def create_span_annotations( async with info.context.db() as session: for idx, (span_rowid, annotation_input) in enumerate(zip(span_rowids, input)): - resolved_identifier = annotation_input.identifier or "" - if annotation_input.source == AnnotationSource.APP: + resolved_identifier = "" + if annotation_input.identifier: + resolved_identifier = annotation_input.identifier + elif annotation_input.source == AnnotationSource.APP and user_id is not None: # Ensure that the annotation has a per-user identifier if submitted via the UI - if user_id is not None: - username = await session.scalar( - select(models.User.username).where(models.User.id == user_id) - ) - resolved_identifier = f"px-app:{username}" - else: - resolved_identifier = "px-app" + user_gid = str(GlobalID(type_name="User", node_id=str(user_id))) + resolved_identifier = f"px-app:{user_gid}" values = { "span_rowid": span_rowid, "name": annotation_input.name, From ef873c97d9cc36eae0b3c707919d5a1524c96e16 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 4 May 2025 15:34:05 -0700 Subject: [PATCH 2/3] trace mutations --- .../api/mutations/trace_annotations_mutations.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/phoenix/server/api/mutations/trace_annotations_mutations.py b/src/phoenix/server/api/mutations/trace_annotations_mutations.py index d307a33700..59dcd6c9f8 100644 --- a/src/phoenix/server/api/mutations/trace_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/trace_annotations_mutations.py @@ -4,6 +4,7 @@ from sqlalchemy import delete, insert, select from starlette.requests import Request from strawberry import UNSET, Info +from strawberry.relay.types import GlobalID from phoenix.db import models from phoenix.server.api.auth import IsLocked, IsNotReadOnly @@ -55,16 +56,13 @@ async def create_trace_annotations( async with info.context.db() as session: for idx, (trace_rowid, annotation_input) in enumerate(zip(trace_rowids, input)): - resolved_identifier = annotation_input.identifier or "" - if annotation_input.source == AnnotationSource.APP: + resolved_identifier = "" + if annotation_input.identifier: + resolved_identifier = annotation_input.identifier + elif annotation_input.source == AnnotationSource.APP and user_id is not None: # Ensure that the annotation has a per-user identifier if submitted via the UI - if user_id is not None: - username = await session.scalar( - select(models.User.username).where(models.User.id == user_id) - ) - resolved_identifier = f"px-app:{username}" - else: - resolved_identifier = "px-app" + user_gid = str(GlobalID(type_name="User", node_id=str(user_id))) + resolved_identifier = f"px-app:{user_gid}" values = { "trace_rowid": trace_rowid, "name": annotation_input.name, From 4948e16ac98df37ea0fae042379b2b7c643d367c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 4 May 2025 15:48:31 -0700 Subject: [PATCH 3/3] clean --- .../server/api/mutations/span_annotations_mutations.py | 6 +----- .../server/api/mutations/trace_annotations_mutations.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/phoenix/server/api/mutations/span_annotations_mutations.py b/src/phoenix/server/api/mutations/span_annotations_mutations.py index 02636e6828..318154c90c 100644 --- a/src/phoenix/server/api/mutations/span_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/span_annotations_mutations.py @@ -88,12 +88,8 @@ async def create_span_annotations( q = select(models.SpanAnnotation).where( models.SpanAnnotation.span_rowid == span_rowid, models.SpanAnnotation.name == annotation_input.name, + models.SpanAnnotation.identifier == resolved_identifier, ) - if resolved_identifier is None: - q = q.where(models.SpanAnnotation.identifier.is_(None)) - else: - q = q.where(models.SpanAnnotation.identifier == resolved_identifier) - existing_annotation = await session.scalar(q) if existing_annotation: diff --git a/src/phoenix/server/api/mutations/trace_annotations_mutations.py b/src/phoenix/server/api/mutations/trace_annotations_mutations.py index 59dcd6c9f8..633bcab467 100644 --- a/src/phoenix/server/api/mutations/trace_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/trace_annotations_mutations.py @@ -82,12 +82,8 @@ async def create_trace_annotations( q = select(models.TraceAnnotation).where( models.TraceAnnotation.trace_rowid == trace_rowid, models.TraceAnnotation.name == annotation_input.name, + models.TraceAnnotation.identifier == resolved_identifier, ) - if resolved_identifier is None: - q = q.where(models.TraceAnnotation.identifier.is_(None)) - else: - q = q.where(models.TraceAnnotation.identifier == resolved_identifier) - existing_annotation = await session.scalar(q) if existing_annotation: