From e60d5acd654c651346ae495abbf4857cb5d0edab Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 3 May 2025 15:00:52 -0700 Subject: [PATCH 1/3] feat(annotations): ensure span annotations are included in dataset examples --- src/phoenix/db/models.py | 4 + .../server/api/mutations/dataset_mutations.py | 97 ++++++++++++------- 2 files changed, 64 insertions(+), 37 deletions(-) diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 1766b8b242..63906d9318 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -684,6 +684,7 @@ def _llm_token_count_total_expression(cls) -> ColumnElement[int]: ) trace: Mapped["Trace"] = relationship("Trace", back_populates="spans") + span_annotations: Mapped[list["SpanAnnotation"]] = relationship(back_populates="span") document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span") dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span") @@ -830,6 +831,9 @@ class SpanAnnotation(Base): ) user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL")) + span: Mapped["Span"] = relationship(back_populates="span_annotations") + user: Mapped[Optional["User"]] = relationship("User") + __table_args__ = ( UniqueConstraint( "name", diff --git a/src/phoenix/server/api/mutations/dataset_mutations.py b/src/phoenix/server/api/mutations/dataset_mutations.py index 0ed334ec48..89bfea9064 100644 --- a/src/phoenix/server/api/mutations/dataset_mutations.py +++ b/src/phoenix/server/api/mutations/dataset_mutations.py @@ -11,7 +11,9 @@ ToolCallAttributes, ) from sqlalchemy import and_, delete, distinct, func, insert, select, update +from sqlalchemy.orm import contains_eager from strawberry import UNSET +from strawberry.relay.types import GlobalID from strawberry.types import Info from phoenix.db import models @@ -130,44 +132,43 @@ async def add_spans_to_dataset( raise ValueError( f"Unknown dataset: {dataset_id}" ) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221 - dataset_version_rowid = await session.scalar( - insert(models.DatasetVersion) - .values( - dataset_id=dataset_rowid, - description=dataset_version_description, - metadata_=dataset_version_metadata, - ) - .returning(models.DatasetVersion.id) + dataset_version = models.DatasetVersion( + dataset_id=dataset_rowid, + description=dataset_version_description, + metadata_=dataset_version_metadata or {}, ) + session.add(dataset_version) + await session.flush() spans = ( - await session.scalars(select(models.Span).where(models.Span.id.in_(span_rowids))) - ).all() + ( + await session.scalars( + select(models.Span) + .join( + models.SpanAnnotation, + models.Span.id == models.SpanAnnotation.span_rowid, + ) + .outerjoin(models.User, models.SpanAnnotation.user_id == models.User.id) + .order_by( + models.Span.id, + models.SpanAnnotation.name, + models.User.username, + ) + .where(models.Span.id.in_(span_rowids)) + .options( + contains_eager(models.Span.span_annotations).contains_eager( + models.SpanAnnotation.user + ) + ) + ) + ) + .unique() + .all() + ) if missing_span_rowids := span_rowids - {span.id for span in spans}: raise ValueError( f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}" ) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221 - span_annotations = ( - await session.scalars( - select(models.SpanAnnotation).where( - models.SpanAnnotation.span_rowid.in_(span_rowids) - ) - ) - ).all() - - span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans} - for annotation in span_annotations: - span_id = annotation.span_rowid - if span_id not in span_annotations_by_span: - span_annotations_by_span[span_id] = dict() - span_annotations_by_span[span_id][annotation.name] = { - "label": annotation.label, - "score": annotation.score, - "explanation": annotation.explanation, - "metadata": annotation.metadata_, - "annotator_kind": annotation.annotator_kind, - } - DatasetExample = models.DatasetExample dataset_example_rowids = ( await session.scalars( @@ -201,7 +202,7 @@ async def add_spans_to_dataset( [ { DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid, - DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid, + DatasetExampleRevision.dataset_version_id.key: dataset_version.id, DatasetExampleRevision.input.key: get_dataset_example_input(span), DatasetExampleRevision.output.key: get_dataset_example_output(span), DatasetExampleRevision.metadata_.key: { @@ -212,11 +213,7 @@ async def add_spans_to_dataset( if k in nonprivate_span_attributes }, "span_kind": span.span_kind, - **( - {"annotations": annotations} - if (annotations := span_annotations_by_span[span.id]) - else {} - ), + "annotations": _gather_span_annotations_by_name(span.span_annotations), }, DatasetExampleRevision.revision_kind.key: "CREATE", } @@ -602,6 +599,32 @@ def _to_orm_revision( } +def _gather_span_annotations_by_name( + span_annotations: list[models.SpanAnnotation], +) -> dict[str, list[dict[str, Any]]]: + span_annotations_by_name: dict[str, list[dict[str, Any]]] = {} + for span_annotation in span_annotations: + if span_annotation.name not in span_annotations_by_name: + span_annotations_by_name[span_annotation.name] = [] + span_annotations_by_name[span_annotation.name].append( + _to_span_annotation_dict(span_annotation) + ) + return span_annotations_by_name + + +def _to_span_annotation_dict(span_annotation: models.SpanAnnotation) -> dict[str, Any]: + return { + "label": span_annotation.label, + "score": span_annotation.score, + "explanation": span_annotation.explanation, + "metadata": span_annotation.metadata_, + "annotator_kind": span_annotation.annotator_kind, + "user_id": str(GlobalID(models.User.__name__, str(span_annotation.user_id))), + "username": user.username if (user := span_annotation.user) is not None else None, + "email": user.email if user is not None else None, + } + + INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE INPUT_VALUE = SpanAttributes.INPUT_VALUE OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE From 1c0ff16ccde6b0c77e454674c8e1b1d413501ec6 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 3 May 2025 19:27:51 -0700 Subject: [PATCH 2/3] fix test --- .../server/api/mutations/dataset_mutations.py | 12 ++--- tests/unit/graphql.py | 3 ++ .../api/mutations/test_dataset_mutations.py | 51 ++++++++++++------- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/phoenix/server/api/mutations/dataset_mutations.py b/src/phoenix/server/api/mutations/dataset_mutations.py index 89bfea9064..02338f486a 100644 --- a/src/phoenix/server/api/mutations/dataset_mutations.py +++ b/src/phoenix/server/api/mutations/dataset_mutations.py @@ -143,7 +143,7 @@ async def add_spans_to_dataset( ( await session.scalars( select(models.Span) - .join( + .outerjoin( models.SpanAnnotation, models.Span.id == models.SpanAnnotation.span_rowid, ) @@ -164,10 +164,8 @@ async def add_spans_to_dataset( .unique() .all() ) - if missing_span_rowids := span_rowids - {span.id for span in spans}: - raise ValueError( - f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}" - ) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221 + if span_rowids - {span.id for span in spans}: + raise NotFound("Some spans could not be found") DatasetExample = models.DatasetExample dataset_example_rowids = ( @@ -619,7 +617,9 @@ def _to_span_annotation_dict(span_annotation: models.SpanAnnotation) -> dict[str "explanation": span_annotation.explanation, "metadata": span_annotation.metadata_, "annotator_kind": span_annotation.annotator_kind, - "user_id": str(GlobalID(models.User.__name__, str(span_annotation.user_id))), + "user_id": str(GlobalID(models.User.__name__, str(user_id))) + if (user_id := span_annotation.user_id) is not None + else None, "username": user.username if (user := span_annotation.user) is not None else None, "email": user.email if user is not None else None, } diff --git a/tests/unit/graphql.py b/tests/unit/graphql.py index a1a2330478..7fad2810ee 100644 --- a/tests/unit/graphql.py +++ b/tests/unit/graphql.py @@ -14,6 +14,9 @@ class GraphQLError(Exception): def __init__(self, message: str) -> None: self.message = message + def __repr__(self) -> str: + return f'GraphQLError(message="{self.message}")' + @dataclass class GraphQLExecutionResult: diff --git a/tests/unit/server/api/mutations/test_dataset_mutations.py b/tests/unit/server/api/mutations/test_dataset_mutations.py index def32f6ac3..7277208980 100644 --- a/tests/unit/server/api/mutations/test_dataset_mutations.py +++ b/tests/unit/server/api/mutations/test_dataset_mutations.py @@ -148,7 +148,7 @@ async def test_updating_a_single_field_leaves_remaining_fields_unchannged( async def test_add_span_to_dataset( gql_client: AsyncGraphQLClient, empty_dataset: None, - spans: None, + spans: list[models.Span], span_annotation: None, ) -> None: dataset_id = GlobalID(type_name="Dataset", node_id=str(1)) @@ -176,10 +176,7 @@ async def test_add_span_to_dataset( query=mutation, variables={ "datasetId": str(dataset_id), - "spanIds": [ - str(GlobalID(type_name="Span", node_id=span_id)) - for span_id in map(str, range(1, 4)) - ], + "spanIds": [str(GlobalID(type_name="Span", node_id=str(span.id))) for span in spans], }, ) assert not response.errors @@ -199,6 +196,7 @@ async def test_add_span_to_dataset( }, "metadata": { "span_kind": "LLM", + "annotations": {}, }, "output": { "messages": [ @@ -225,6 +223,7 @@ async def test_add_span_to_dataset( }, "metadata": { "span_kind": "RETRIEVER", + "annotations": {}, }, } } @@ -237,13 +236,18 @@ async def test_add_span_to_dataset( "metadata": { "span_kind": "CHAIN", "annotations": { - "test annotation": { - "label": "ambiguous", - "score": 0.5, - "explanation": "meaningful words", - "metadata": {}, - "annotator_kind": "HUMAN", - } + "test annotation": [ + { + "label": "ambiguous", + "score": 0.5, + "explanation": "meaningful words", + "metadata": {}, + "annotator_kind": "HUMAN", + "user_id": None, + "username": None, + "email": None, + } + ] }, }, } @@ -524,11 +528,12 @@ async def empty_dataset(db: DbSessionFactory) -> None: @pytest.fixture -async def spans(db: DbSessionFactory) -> None: +async def spans(db: DbSessionFactory) -> list[models.Span]: """ Inserts three spans from a single trace: a chain root span, a retriever child span, and an llm child span. """ + spans = [] async with db() as session: project_row_id = await session.scalar( insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id) @@ -543,7 +548,7 @@ async def spans(db: DbSessionFactory) -> None: ) .returning(models.Trace.id) ) - await session.execute( + span = await session.scalar( insert(models.Span) .values( trace_rowid=trace_row_id, @@ -564,10 +569,12 @@ async def spans(db: DbSessionFactory) -> None: cumulative_llm_token_count_prompt=0, cumulative_llm_token_count_completion=0, ) - .returning(models.Span.id) + .returning(models.Span) ) - await session.execute( - insert(models.Span).values( + spans.append(span) + span = await session.scalar( + insert(models.Span) + .values( trace_rowid=trace_row_id, span_id="2", parent_id="1", @@ -593,9 +600,12 @@ async def spans(db: DbSessionFactory) -> None: cumulative_llm_token_count_prompt=0, cumulative_llm_token_count_completion=0, ) + .returning(models.Span) ) - await session.execute( - insert(models.Span).values( + spans.append(span) + span = await session.scalar( + insert(models.Span) + .values( trace_rowid=trace_row_id, span_id="3", parent_id="1", @@ -626,7 +636,10 @@ async def spans(db: DbSessionFactory) -> None: cumulative_llm_token_count_prompt=0, cumulative_llm_token_count_completion=0, ) + .returning(models.Span) ) + spans.append(span) + return spans @pytest.fixture From dad4bfd4ccf67cf0b8e99c2d15e2ca2c46b6da49 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 4 May 2025 15:53:56 -0700 Subject: [PATCH 3/3] fix type --- tests/unit/server/api/mutations/test_dataset_mutations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/server/api/mutations/test_dataset_mutations.py b/tests/unit/server/api/mutations/test_dataset_mutations.py index 7277208980..7e6afd70f9 100644 --- a/tests/unit/server/api/mutations/test_dataset_mutations.py +++ b/tests/unit/server/api/mutations/test_dataset_mutations.py @@ -571,6 +571,7 @@ async def spans(db: DbSessionFactory) -> list[models.Span]: ) .returning(models.Span) ) + assert span is not None spans.append(span) span = await session.scalar( insert(models.Span) @@ -602,6 +603,7 @@ async def spans(db: DbSessionFactory) -> list[models.Span]: ) .returning(models.Span) ) + assert span is not None spans.append(span) span = await session.scalar( insert(models.Span) @@ -638,6 +640,7 @@ async def spans(db: DbSessionFactory) -> list[models.Span]: ) .returning(models.Span) ) + assert span is not None spans.append(span) return spans