11
11
ToolCallAttributes ,
12
12
)
13
13
from sqlalchemy import and_ , delete , distinct , func , insert , select , update
14
+ from sqlalchemy .orm import contains_eager
14
15
from strawberry import UNSET
16
+ from strawberry .relay .types import GlobalID
15
17
from strawberry .types import Info
16
18
17
19
from phoenix .db import models
@@ -130,44 +132,43 @@ async def add_spans_to_dataset(
130
132
raise ValueError (
131
133
f"Unknown dataset: { dataset_id } "
132
134
) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
133
- dataset_version_rowid = await session .scalar (
134
- insert (models .DatasetVersion )
135
- .values (
136
- dataset_id = dataset_rowid ,
137
- description = dataset_version_description ,
138
- metadata_ = dataset_version_metadata ,
139
- )
140
- .returning (models .DatasetVersion .id )
135
+ dataset_version = models .DatasetVersion (
136
+ dataset_id = dataset_rowid ,
137
+ description = dataset_version_description ,
138
+ metadata_ = dataset_version_metadata or {},
141
139
)
140
+ session .add (dataset_version )
141
+ await session .flush ()
142
142
spans = (
143
- await session .scalars (select (models .Span ).where (models .Span .id .in_ (span_rowids )))
144
- ).all ()
143
+ (
144
+ await session .scalars (
145
+ select (models .Span )
146
+ .join (
147
+ models .SpanAnnotation ,
148
+ models .Span .id == models .SpanAnnotation .span_rowid ,
149
+ )
150
+ .outerjoin (models .User , models .SpanAnnotation .user_id == models .User .id )
151
+ .order_by (
152
+ models .Span .id ,
153
+ models .SpanAnnotation .name ,
154
+ models .User .username ,
155
+ )
156
+ .where (models .Span .id .in_ (span_rowids ))
157
+ .options (
158
+ contains_eager (models .Span .span_annotations ).contains_eager (
159
+ models .SpanAnnotation .user
160
+ )
161
+ )
162
+ )
163
+ )
164
+ .unique ()
165
+ .all ()
166
+ )
145
167
if missing_span_rowids := span_rowids - {span .id for span in spans }:
146
168
raise ValueError (
147
169
f"Could not find spans with rowids: { ', ' .join (map (str , missing_span_rowids ))} "
148
170
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
149
171
150
- span_annotations = (
151
- await session .scalars (
152
- select (models .SpanAnnotation ).where (
153
- models .SpanAnnotation .span_rowid .in_ (span_rowids )
154
- )
155
- )
156
- ).all ()
157
-
158
- span_annotations_by_span : dict [int , dict [Any , Any ]] = {span .id : {} for span in spans }
159
- for annotation in span_annotations :
160
- span_id = annotation .span_rowid
161
- if span_id not in span_annotations_by_span :
162
- span_annotations_by_span [span_id ] = dict ()
163
- span_annotations_by_span [span_id ][annotation .name ] = {
164
- "label" : annotation .label ,
165
- "score" : annotation .score ,
166
- "explanation" : annotation .explanation ,
167
- "metadata" : annotation .metadata_ ,
168
- "annotator_kind" : annotation .annotator_kind ,
169
- }
170
-
171
172
DatasetExample = models .DatasetExample
172
173
dataset_example_rowids = (
173
174
await session .scalars (
@@ -201,7 +202,7 @@ async def add_spans_to_dataset(
201
202
[
202
203
{
203
204
DatasetExampleRevision .dataset_example_id .key : dataset_example_rowid ,
204
- DatasetExampleRevision .dataset_version_id .key : dataset_version_rowid ,
205
+ DatasetExampleRevision .dataset_version_id .key : dataset_version . id ,
205
206
DatasetExampleRevision .input .key : get_dataset_example_input (span ),
206
207
DatasetExampleRevision .output .key : get_dataset_example_output (span ),
207
208
DatasetExampleRevision .metadata_ .key : {
@@ -212,11 +213,7 @@ async def add_spans_to_dataset(
212
213
if k in nonprivate_span_attributes
213
214
},
214
215
"span_kind" : span .span_kind ,
215
- ** (
216
- {"annotations" : annotations }
217
- if (annotations := span_annotations_by_span [span .id ])
218
- else {}
219
- ),
216
+ "annotations" : _gather_span_annotations_by_name (span .span_annotations ),
220
217
},
221
218
DatasetExampleRevision .revision_kind .key : "CREATE" ,
222
219
}
@@ -602,6 +599,32 @@ def _to_orm_revision(
602
599
}
603
600
604
601
602
+ def _gather_span_annotations_by_name (
603
+ span_annotations : list [models .SpanAnnotation ],
604
+ ) -> dict [str , list [dict [str , Any ]]]:
605
+ span_annotations_by_name : dict [str , list [dict [str , Any ]]] = {}
606
+ for span_annotation in span_annotations :
607
+ if span_annotation .name not in span_annotations_by_name :
608
+ span_annotations_by_name [span_annotation .name ] = []
609
+ span_annotations_by_name [span_annotation .name ].append (
610
+ _to_span_annotation_dict (span_annotation )
611
+ )
612
+ return span_annotations_by_name
613
+
614
+
615
+ def _to_span_annotation_dict (span_annotation : models .SpanAnnotation ) -> dict [str , Any ]:
616
+ return {
617
+ "label" : span_annotation .label ,
618
+ "score" : span_annotation .score ,
619
+ "explanation" : span_annotation .explanation ,
620
+ "metadata" : span_annotation .metadata_ ,
621
+ "annotator_kind" : span_annotation .annotator_kind ,
622
+ "user_id" : str (GlobalID (models .User .__name__ , str (span_annotation .user_id ))),
623
+ "username" : user .username if (user := span_annotation .user ) is not None else None ,
624
+ "email" : user .email if user is not None else None ,
625
+ }
626
+
627
+
605
628
INPUT_MIME_TYPE = SpanAttributes .INPUT_MIME_TYPE
606
629
INPUT_VALUE = SpanAttributes .INPUT_VALUE
607
630
OUTPUT_MIME_TYPE = SpanAttributes .OUTPUT_MIME_TYPE
0 commit comments