Skip to content

Commit 7daa163

Browse files
committed
feat(annotations): ensure span annotations are included in dataset examples
1 parent 649a7d8 commit 7daa163

File tree

2 files changed

+54
-37
lines changed

2 files changed

+54
-37
lines changed

src/phoenix/db/models.py

+4
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def _llm_token_count_total_expression(cls) -> ColumnElement[int]:
684684
)
685685

686686
trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
687+
span_annotations: Mapped[list["SpanAnnotation"]] = relationship(back_populates="span")
687688
document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span")
688689
dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span")
689690

@@ -830,6 +831,9 @@ class SpanAnnotation(Base):
830831
)
831832
user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
832833

834+
span: Mapped["Span"] = relationship(back_populates="span_annotations")
835+
user: Mapped[Optional["User"]] = relationship("User")
836+
833837
__table_args__ = (
834838
UniqueConstraint(
835839
"name",

src/phoenix/server/api/mutations/dataset_mutations.py

+50-37
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
ToolCallAttributes,
1212
)
1313
from sqlalchemy import and_, delete, distinct, func, insert, select, update
14+
from sqlalchemy.orm import contains_eager
1415
from strawberry import UNSET
16+
from strawberry.relay.types import GlobalID
1517
from strawberry.types import Info
1618

1719
from phoenix.db import models
@@ -130,44 +132,43 @@ async def add_spans_to_dataset(
130132
raise ValueError(
131133
f"Unknown dataset: {dataset_id}"
132134
) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
133-
dataset_version_rowid = await session.scalar(
134-
insert(models.DatasetVersion)
135-
.values(
136-
dataset_id=dataset_rowid,
137-
description=dataset_version_description,
138-
metadata_=dataset_version_metadata,
139-
)
140-
.returning(models.DatasetVersion.id)
135+
dataset_version = models.DatasetVersion(
136+
dataset_id=dataset_rowid,
137+
description=dataset_version_description,
138+
metadata_=dataset_version_metadata or {},
141139
)
140+
session.add(dataset_version)
141+
await session.flush()
142142
spans = (
143-
await session.scalars(select(models.Span).where(models.Span.id.in_(span_rowids)))
144-
).all()
143+
(
144+
await session.scalars(
145+
select(models.Span)
146+
.join(
147+
models.SpanAnnotation,
148+
models.Span.id == models.SpanAnnotation.span_rowid,
149+
)
150+
.join(models.User, models.SpanAnnotation.user_id == models.User.id)
151+
.order_by(
152+
models.Span.id,
153+
models.SpanAnnotation.name,
154+
models.User.username,
155+
)
156+
.where(models.Span.id.in_(span_rowids))
157+
.options(
158+
contains_eager(models.Span.span_annotations).contains_eager(
159+
models.SpanAnnotation.user
160+
)
161+
)
162+
)
163+
)
164+
.unique()
165+
.all()
166+
)
145167
if missing_span_rowids := span_rowids - {span.id for span in spans}:
146168
raise ValueError(
147169
f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
148170
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
149171

150-
span_annotations = (
151-
await session.scalars(
152-
select(models.SpanAnnotation).where(
153-
models.SpanAnnotation.span_rowid.in_(span_rowids)
154-
)
155-
)
156-
).all()
157-
158-
span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
159-
for annotation in span_annotations:
160-
span_id = annotation.span_rowid
161-
if span_id not in span_annotations_by_span:
162-
span_annotations_by_span[span_id] = dict()
163-
span_annotations_by_span[span_id][annotation.name] = {
164-
"label": annotation.label,
165-
"score": annotation.score,
166-
"explanation": annotation.explanation,
167-
"metadata": annotation.metadata_,
168-
"annotator_kind": annotation.annotator_kind,
169-
}
170-
171172
DatasetExample = models.DatasetExample
172173
dataset_example_rowids = (
173174
await session.scalars(
@@ -201,7 +202,7 @@ async def add_spans_to_dataset(
201202
[
202203
{
203204
DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
204-
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
205+
DatasetExampleRevision.dataset_version_id.key: dataset_version.id,
205206
DatasetExampleRevision.input.key: get_dataset_example_input(span),
206207
DatasetExampleRevision.output.key: get_dataset_example_output(span),
207208
DatasetExampleRevision.metadata_.key: {
@@ -212,11 +213,10 @@ async def add_spans_to_dataset(
212213
if k in nonprivate_span_attributes
213214
},
214215
"span_kind": span.span_kind,
215-
**(
216-
{"annotations": annotations}
217-
if (annotations := span_annotations_by_span[span.id])
218-
else {}
219-
),
216+
"annotations": [
217+
_serialize_span_annotation(span_annotation)
218+
for span_annotation in span.span_annotations
219+
],
220220
},
221221
DatasetExampleRevision.revision_kind.key: "CREATE",
222222
}
@@ -602,6 +602,19 @@ def _to_orm_revision(
602602
}
603603

604604

605+
def _serialize_span_annotation(span_annotation: models.SpanAnnotation) -> dict[str, Any]:
606+
return {
607+
"label": span_annotation.label,
608+
"score": span_annotation.score,
609+
"explanation": span_annotation.explanation,
610+
"metadata": span_annotation.metadata_,
611+
"annotator_kind": span_annotation.annotator_kind,
612+
"user_id": str(GlobalID(models.User.__name__, str(span_annotation.user_id))),
613+
"username": user.username if (user := span_annotation.user) is not None else None,
614+
"email": user.email if user is not None else None,
615+
}
616+
617+
605618
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
606619
INPUT_VALUE = SpanAttributes.INPUT_VALUE
607620
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE

0 commit comments

Comments
 (0)