Skip to content

Commit e466d2f

Browse files
authored
feat(annotations): ensure all span annotations are included in dataset examples (#7412)
1 parent 7e34714 commit e466d2f

File tree

4 files changed

+104
-58
lines changed

4 files changed

+104
-58
lines changed

src/phoenix/db/models.py

+4
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def _llm_token_count_total_expression(cls) -> ColumnElement[int]:
684684
)
685685

686686
trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
687+
span_annotations: Mapped[list["SpanAnnotation"]] = relationship(back_populates="span")
687688
document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span")
688689
dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span")
689690

@@ -830,6 +831,9 @@ class SpanAnnotation(Base):
830831
)
831832
user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
832833

834+
span: Mapped["Span"] = relationship(back_populates="span_annotations")
835+
user: Mapped[Optional["User"]] = relationship("User")
836+
833837
__table_args__ = (
834838
UniqueConstraint(
835839
"name",

src/phoenix/server/api/mutations/dataset_mutations.py

+62-39
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
ToolCallAttributes,
1212
)
1313
from sqlalchemy import and_, delete, distinct, func, insert, select, update
14+
from sqlalchemy.orm import contains_eager
1415
from strawberry import UNSET
16+
from strawberry.relay.types import GlobalID
1517
from strawberry.types import Info
1618

1719
from phoenix.db import models
@@ -130,43 +132,40 @@ async def add_spans_to_dataset(
130132
raise ValueError(
131133
f"Unknown dataset: {dataset_id}"
132134
) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
133-
dataset_version_rowid = await session.scalar(
134-
insert(models.DatasetVersion)
135-
.values(
136-
dataset_id=dataset_rowid,
137-
description=dataset_version_description,
138-
metadata_=dataset_version_metadata,
139-
)
140-
.returning(models.DatasetVersion.id)
135+
dataset_version = models.DatasetVersion(
136+
dataset_id=dataset_rowid,
137+
description=dataset_version_description,
138+
metadata_=dataset_version_metadata or {},
141139
)
140+
session.add(dataset_version)
141+
await session.flush()
142142
spans = (
143-
await session.scalars(select(models.Span).where(models.Span.id.in_(span_rowids)))
144-
).all()
145-
if missing_span_rowids := span_rowids - {span.id for span in spans}:
146-
raise ValueError(
147-
f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
148-
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
149-
150-
span_annotations = (
151-
await session.scalars(
152-
select(models.SpanAnnotation).where(
153-
models.SpanAnnotation.span_rowid.in_(span_rowids)
143+
(
144+
await session.scalars(
145+
select(models.Span)
146+
.outerjoin(
147+
models.SpanAnnotation,
148+
models.Span.id == models.SpanAnnotation.span_rowid,
149+
)
150+
.outerjoin(models.User, models.SpanAnnotation.user_id == models.User.id)
151+
.order_by(
152+
models.Span.id,
153+
models.SpanAnnotation.name,
154+
models.User.username,
155+
)
156+
.where(models.Span.id.in_(span_rowids))
157+
.options(
158+
contains_eager(models.Span.span_annotations).contains_eager(
159+
models.SpanAnnotation.user
160+
)
161+
)
154162
)
155163
)
156-
).all()
157-
158-
span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
159-
for annotation in span_annotations:
160-
span_id = annotation.span_rowid
161-
if span_id not in span_annotations_by_span:
162-
span_annotations_by_span[span_id] = dict()
163-
span_annotations_by_span[span_id][annotation.name] = {
164-
"label": annotation.label,
165-
"score": annotation.score,
166-
"explanation": annotation.explanation,
167-
"metadata": annotation.metadata_,
168-
"annotator_kind": annotation.annotator_kind,
169-
}
164+
.unique()
165+
.all()
166+
)
167+
if span_rowids - {span.id for span in spans}:
168+
raise NotFound("Some spans could not be found")
170169

171170
DatasetExample = models.DatasetExample
172171
dataset_example_rowids = (
@@ -201,7 +200,7 @@ async def add_spans_to_dataset(
201200
[
202201
{
203202
DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
204-
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
203+
DatasetExampleRevision.dataset_version_id.key: dataset_version.id,
205204
DatasetExampleRevision.input.key: get_dataset_example_input(span),
206205
DatasetExampleRevision.output.key: get_dataset_example_output(span),
207206
DatasetExampleRevision.metadata_.key: {
@@ -212,11 +211,7 @@ async def add_spans_to_dataset(
212211
if k in nonprivate_span_attributes
213212
},
214213
"span_kind": span.span_kind,
215-
**(
216-
{"annotations": annotations}
217-
if (annotations := span_annotations_by_span[span.id])
218-
else {}
219-
),
214+
"annotations": _gather_span_annotations_by_name(span.span_annotations),
220215
},
221216
DatasetExampleRevision.revision_kind.key: "CREATE",
222217
}
@@ -602,6 +597,34 @@ def _to_orm_revision(
602597
}
603598

604599

600+
def _gather_span_annotations_by_name(
601+
span_annotations: list[models.SpanAnnotation],
602+
) -> dict[str, list[dict[str, Any]]]:
603+
span_annotations_by_name: dict[str, list[dict[str, Any]]] = {}
604+
for span_annotation in span_annotations:
605+
if span_annotation.name not in span_annotations_by_name:
606+
span_annotations_by_name[span_annotation.name] = []
607+
span_annotations_by_name[span_annotation.name].append(
608+
_to_span_annotation_dict(span_annotation)
609+
)
610+
return span_annotations_by_name
611+
612+
613+
def _to_span_annotation_dict(span_annotation: models.SpanAnnotation) -> dict[str, Any]:
614+
return {
615+
"label": span_annotation.label,
616+
"score": span_annotation.score,
617+
"explanation": span_annotation.explanation,
618+
"metadata": span_annotation.metadata_,
619+
"annotator_kind": span_annotation.annotator_kind,
620+
"user_id": str(GlobalID(models.User.__name__, str(user_id)))
621+
if (user_id := span_annotation.user_id) is not None
622+
else None,
623+
"username": user.username if (user := span_annotation.user) is not None else None,
624+
"email": user.email if user is not None else None,
625+
}
626+
627+
605628
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
606629
INPUT_VALUE = SpanAttributes.INPUT_VALUE
607630
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE

tests/unit/graphql.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class GraphQLError(Exception):
1414
def __init__(self, message: str) -> None:
1515
self.message = message
1616

17+
def __repr__(self) -> str:
18+
return f'GraphQLError(message="{self.message}")'
19+
1720

1821
@dataclass
1922
class GraphQLExecutionResult:

tests/unit/server/api/mutations/test_dataset_mutations.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def test_updating_a_single_field_leaves_remaining_fields_unchannged(
148148
async def test_add_span_to_dataset(
149149
gql_client: AsyncGraphQLClient,
150150
empty_dataset: None,
151-
spans: None,
151+
spans: list[models.Span],
152152
span_annotation: None,
153153
) -> None:
154154
dataset_id = GlobalID(type_name="Dataset", node_id=str(1))
@@ -176,10 +176,7 @@ async def test_add_span_to_dataset(
176176
query=mutation,
177177
variables={
178178
"datasetId": str(dataset_id),
179-
"spanIds": [
180-
str(GlobalID(type_name="Span", node_id=span_id))
181-
for span_id in map(str, range(1, 4))
182-
],
179+
"spanIds": [str(GlobalID(type_name="Span", node_id=str(span.id))) for span in spans],
183180
},
184181
)
185182
assert not response.errors
@@ -199,6 +196,7 @@ async def test_add_span_to_dataset(
199196
},
200197
"metadata": {
201198
"span_kind": "LLM",
199+
"annotations": {},
202200
},
203201
"output": {
204202
"messages": [
@@ -225,6 +223,7 @@ async def test_add_span_to_dataset(
225223
},
226224
"metadata": {
227225
"span_kind": "RETRIEVER",
226+
"annotations": {},
228227
},
229228
}
230229
}
@@ -237,13 +236,18 @@ async def test_add_span_to_dataset(
237236
"metadata": {
238237
"span_kind": "CHAIN",
239238
"annotations": {
240-
"test annotation": {
241-
"label": "ambiguous",
242-
"score": 0.5,
243-
"explanation": "meaningful words",
244-
"metadata": {},
245-
"annotator_kind": "HUMAN",
246-
}
239+
"test annotation": [
240+
{
241+
"label": "ambiguous",
242+
"score": 0.5,
243+
"explanation": "meaningful words",
244+
"metadata": {},
245+
"annotator_kind": "HUMAN",
246+
"user_id": None,
247+
"username": None,
248+
"email": None,
249+
}
250+
]
247251
},
248252
},
249253
}
@@ -524,11 +528,12 @@ async def empty_dataset(db: DbSessionFactory) -> None:
524528

525529

526530
@pytest.fixture
527-
async def spans(db: DbSessionFactory) -> None:
531+
async def spans(db: DbSessionFactory) -> list[models.Span]:
528532
"""
529533
Inserts three spans from a single trace: a chain root span, a retriever
530534
child span, and an llm child span.
531535
"""
536+
spans = []
532537
async with db() as session:
533538
project_row_id = await session.scalar(
534539
insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id)
@@ -543,7 +548,7 @@ async def spans(db: DbSessionFactory) -> None:
543548
)
544549
.returning(models.Trace.id)
545550
)
546-
await session.execute(
551+
span = await session.scalar(
547552
insert(models.Span)
548553
.values(
549554
trace_rowid=trace_row_id,
@@ -564,10 +569,13 @@ async def spans(db: DbSessionFactory) -> None:
564569
cumulative_llm_token_count_prompt=0,
565570
cumulative_llm_token_count_completion=0,
566571
)
567-
.returning(models.Span.id)
572+
.returning(models.Span)
568573
)
569-
await session.execute(
570-
insert(models.Span).values(
574+
assert span is not None
575+
spans.append(span)
576+
span = await session.scalar(
577+
insert(models.Span)
578+
.values(
571579
trace_rowid=trace_row_id,
572580
span_id="2",
573581
parent_id="1",
@@ -593,9 +601,13 @@ async def spans(db: DbSessionFactory) -> None:
593601
cumulative_llm_token_count_prompt=0,
594602
cumulative_llm_token_count_completion=0,
595603
)
604+
.returning(models.Span)
596605
)
597-
await session.execute(
598-
insert(models.Span).values(
606+
assert span is not None
607+
spans.append(span)
608+
span = await session.scalar(
609+
insert(models.Span)
610+
.values(
599611
trace_rowid=trace_row_id,
600612
span_id="3",
601613
parent_id="1",
@@ -626,7 +638,11 @@ async def spans(db: DbSessionFactory) -> None:
626638
cumulative_llm_token_count_prompt=0,
627639
cumulative_llm_token_count_completion=0,
628640
)
641+
.returning(models.Span)
629642
)
643+
assert span is not None
644+
spans.append(span)
645+
return spans
630646

631647

632648
@pytest.fixture

0 commit comments

Comments
 (0)