Skip to content

Commit e60d5ac

Browse files
committed
feat(annotations): ensure span annotations are included in dataset examples
1 parent e9ead03 commit e60d5ac

File tree

2 files changed

+64
-37
lines changed

2 files changed

+64
-37
lines changed

src/phoenix/db/models.py

+4
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def _llm_token_count_total_expression(cls) -> ColumnElement[int]:
684684
)
685685

686686
trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
687+
span_annotations: Mapped[list["SpanAnnotation"]] = relationship(back_populates="span")
687688
document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span")
688689
dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span")
689690

@@ -830,6 +831,9 @@ class SpanAnnotation(Base):
830831
)
831832
user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
832833

834+
span: Mapped["Span"] = relationship(back_populates="span_annotations")
835+
user: Mapped[Optional["User"]] = relationship("User")
836+
833837
__table_args__ = (
834838
UniqueConstraint(
835839
"name",

src/phoenix/server/api/mutations/dataset_mutations.py

+60-37
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
ToolCallAttributes,
1212
)
1313
from sqlalchemy import and_, delete, distinct, func, insert, select, update
14+
from sqlalchemy.orm import contains_eager
1415
from strawberry import UNSET
16+
from strawberry.relay.types import GlobalID
1517
from strawberry.types import Info
1618

1719
from phoenix.db import models
@@ -130,44 +132,43 @@ async def add_spans_to_dataset(
130132
raise ValueError(
131133
f"Unknown dataset: {dataset_id}"
132134
) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
133-
dataset_version_rowid = await session.scalar(
134-
insert(models.DatasetVersion)
135-
.values(
136-
dataset_id=dataset_rowid,
137-
description=dataset_version_description,
138-
metadata_=dataset_version_metadata,
139-
)
140-
.returning(models.DatasetVersion.id)
135+
dataset_version = models.DatasetVersion(
136+
dataset_id=dataset_rowid,
137+
description=dataset_version_description,
138+
metadata_=dataset_version_metadata or {},
141139
)
140+
session.add(dataset_version)
141+
await session.flush()
142142
spans = (
143-
await session.scalars(select(models.Span).where(models.Span.id.in_(span_rowids)))
144-
).all()
143+
(
144+
await session.scalars(
145+
select(models.Span)
146+
.join(
147+
models.SpanAnnotation,
148+
models.Span.id == models.SpanAnnotation.span_rowid,
149+
)
150+
.outerjoin(models.User, models.SpanAnnotation.user_id == models.User.id)
151+
.order_by(
152+
models.Span.id,
153+
models.SpanAnnotation.name,
154+
models.User.username,
155+
)
156+
.where(models.Span.id.in_(span_rowids))
157+
.options(
158+
contains_eager(models.Span.span_annotations).contains_eager(
159+
models.SpanAnnotation.user
160+
)
161+
)
162+
)
163+
)
164+
.unique()
165+
.all()
166+
)
145167
if missing_span_rowids := span_rowids - {span.id for span in spans}:
146168
raise ValueError(
147169
f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
148170
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
149171

150-
span_annotations = (
151-
await session.scalars(
152-
select(models.SpanAnnotation).where(
153-
models.SpanAnnotation.span_rowid.in_(span_rowids)
154-
)
155-
)
156-
).all()
157-
158-
span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
159-
for annotation in span_annotations:
160-
span_id = annotation.span_rowid
161-
if span_id not in span_annotations_by_span:
162-
span_annotations_by_span[span_id] = dict()
163-
span_annotations_by_span[span_id][annotation.name] = {
164-
"label": annotation.label,
165-
"score": annotation.score,
166-
"explanation": annotation.explanation,
167-
"metadata": annotation.metadata_,
168-
"annotator_kind": annotation.annotator_kind,
169-
}
170-
171172
DatasetExample = models.DatasetExample
172173
dataset_example_rowids = (
173174
await session.scalars(
@@ -201,7 +202,7 @@ async def add_spans_to_dataset(
201202
[
202203
{
203204
DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
204-
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
205+
DatasetExampleRevision.dataset_version_id.key: dataset_version.id,
205206
DatasetExampleRevision.input.key: get_dataset_example_input(span),
206207
DatasetExampleRevision.output.key: get_dataset_example_output(span),
207208
DatasetExampleRevision.metadata_.key: {
@@ -212,11 +213,7 @@ async def add_spans_to_dataset(
212213
if k in nonprivate_span_attributes
213214
},
214215
"span_kind": span.span_kind,
215-
**(
216-
{"annotations": annotations}
217-
if (annotations := span_annotations_by_span[span.id])
218-
else {}
219-
),
216+
"annotations": _gather_span_annotations_by_name(span.span_annotations),
220217
},
221218
DatasetExampleRevision.revision_kind.key: "CREATE",
222219
}
@@ -602,6 +599,32 @@ def _to_orm_revision(
602599
}
603600

604601

602+
def _gather_span_annotations_by_name(
603+
span_annotations: list[models.SpanAnnotation],
604+
) -> dict[str, list[dict[str, Any]]]:
605+
span_annotations_by_name: dict[str, list[dict[str, Any]]] = {}
606+
for span_annotation in span_annotations:
607+
if span_annotation.name not in span_annotations_by_name:
608+
span_annotations_by_name[span_annotation.name] = []
609+
span_annotations_by_name[span_annotation.name].append(
610+
_to_span_annotation_dict(span_annotation)
611+
)
612+
return span_annotations_by_name
613+
614+
615+
def _to_span_annotation_dict(span_annotation: models.SpanAnnotation) -> dict[str, Any]:
616+
return {
617+
"label": span_annotation.label,
618+
"score": span_annotation.score,
619+
"explanation": span_annotation.explanation,
620+
"metadata": span_annotation.metadata_,
621+
"annotator_kind": span_annotation.annotator_kind,
622+
"user_id": str(GlobalID(models.User.__name__, str(span_annotation.user_id))),
623+
"username": user.username if (user := span_annotation.user) is not None else None,
624+
"email": user.email if user is not None else None,
625+
}
626+
627+
605628
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
606629
INPUT_VALUE = SpanAttributes.INPUT_VALUE
607630
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE

0 commit comments

Comments
 (0)