11import logging
2- from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
2+ from collections .abc import AsyncIterable , Awaitable , Callable
33from contextlib import contextmanager
4+ from dataclasses import dataclass
45from datetime import timedelta
5- from typing import Any , Generator , Generic , Literal , Optional , Union
6+ from typing import Any , AsyncGenerator , Generator , Generic , Literal , Optional , Union
67
78from opentelemetry import trace
8- from opentelemetry .trace import Span , SpanKind , StatusCode
9+ from opentelemetry .trace import Span , SpanKind , Status , StatusCode
910
1011from replit_river .client_transport import ClientTransport
1112from replit_river .error_schema import RiverError , RiverException
@@ -63,7 +64,7 @@ async def send_rpc(
6364 error_deserializer : Callable [[Any ], ErrorType ],
6465 timeout : timedelta ,
6566 ) -> ResponseType :
66- with _trace_procedure ("rpc" , service_name , procedure_name ) as span :
67+ with _trace_procedure ("rpc" , service_name , procedure_name ) as span_handle :
6768 session = await self ._transport .get_or_create_session ()
6869 return await session .send_rpc (
6970 service_name ,
@@ -72,7 +73,7 @@ async def send_rpc(
7273 request_serializer ,
7374 response_deserializer ,
7475 error_deserializer ,
75- span ,
76+ span_handle . span ,
7677 timeout ,
7778 )
7879
@@ -87,7 +88,7 @@ async def send_upload(
8788 response_deserializer : Callable [[Any ], ResponseType ],
8889 error_deserializer : Callable [[Any ], ErrorType ],
8990 ) -> ResponseType :
90- with _trace_procedure ("upload" , service_name , procedure_name ) as span :
91+ with _trace_procedure ("upload" , service_name , procedure_name ) as span_handle :
9192 session = await self ._transport .get_or_create_session ()
9293 return await session .send_upload (
9394 service_name ,
@@ -98,7 +99,7 @@ async def send_upload(
9899 request_serializer ,
99100 response_deserializer ,
100101 error_deserializer ,
101- span ,
102+ span_handle . span ,
102103 )
103104
104105 async def send_subscription (
@@ -109,8 +110,10 @@ async def send_subscription(
109110 request_serializer : Callable [[RequestType ], Any ],
110111 response_deserializer : Callable [[Any ], ResponseType ],
111112 error_deserializer : Callable [[Any ], ErrorType ],
112- ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
113- with _trace_procedure ("subscription" , service_name , procedure_name ) as span :
113+ ) -> AsyncGenerator [Union [ResponseType , ErrorType ], None ]:
114+ with _trace_procedure (
115+ "subscription" , service_name , procedure_name
116+ ) as span_handle :
114117 session = await self ._transport .get_or_create_session ()
115118 async for msg in session .send_subscription (
116119 service_name ,
@@ -119,10 +122,10 @@ async def send_subscription(
119122 request_serializer ,
120123 response_deserializer ,
121124 error_deserializer ,
122- span ,
125+ span_handle . span ,
123126 ):
124127 if isinstance (msg , RiverError ):
125- _record_river_error (span , msg )
128+ _record_river_error (span_handle , msg )
126129 yield msg # type: ignore # https://github.com/python/mypy/issues/10817
127130
128131 async def send_stream (
@@ -135,8 +138,8 @@ async def send_stream(
135138 request_serializer : Callable [[RequestType ], Any ],
136139 response_deserializer : Callable [[Any ], ResponseType ],
137140 error_deserializer : Callable [[Any ], ErrorType ],
138- ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
139- with _trace_procedure ("stream" , service_name , procedure_name ) as span :
141+ ) -> AsyncGenerator [Union [ResponseType , ErrorType ], None ]:
142+ with _trace_procedure ("stream" , service_name , procedure_name ) as span_handle :
140143 session = await self ._transport .get_or_create_session ()
141144 async for msg in session .send_stream (
142145 service_name ,
@@ -147,32 +150,63 @@ async def send_stream(
147150 request_serializer ,
148151 response_deserializer ,
149152 error_deserializer ,
150- span ,
153+ span_handle . span ,
151154 ):
152155 if isinstance (msg , RiverError ):
153- _record_river_error (span , msg )
156+ _record_river_error (span_handle , msg )
154157 yield msg # type: ignore # https://github.com/python/mypy/issues/10817
155158
156159
160+ @dataclass
161+ class _SpanHandle :
162+ """Wraps a span and keeps track of whether or not a status has been recorded yet."""
163+
164+ span : Span
165+ did_set_status : bool = False
166+
167+ def set_status (
168+ self ,
169+ status : Union [Status , StatusCode ],
170+ description : Optional [str ] = None ,
171+ ) -> None :
172+ if self .did_set_status :
173+ return
174+ self .did_set_status = True
175+ self .span .set_status (status , description )
176+
177+
157178@contextmanager
158179def _trace_procedure (
159180 procedure_type : Literal ["rpc" , "upload" , "subscription" , "stream" ],
160181 service_name : str ,
161182 procedure_name : str ,
162- ) -> Generator [Span , None , None ]:
163- with tracer .start_span (
183+ ) -> Generator [_SpanHandle , None , None ]:
184+ span = tracer .start_span (
164185 f"river.client.{ procedure_type } .{ service_name } .{ procedure_name } " ,
165186 kind = SpanKind .CLIENT ,
166- ) as span :
167- try :
168- yield span
169- except RiverException as e :
170- _record_river_error (span , RiverError (code = e .code , message = e .message ))
171- raise e
172-
173-
174- def _record_river_error (span : Span , error : RiverError ) -> None :
175- span .set_status (StatusCode .ERROR , error .message )
176- span .record_exception (RiverException (error .code , error .message ))
177- span .set_attribute ("river.error_code" , error .code )
178- span .set_attribute ("river.error_message" , error .message )
187+ )
188+ span_handle = _SpanHandle (span )
189+ try :
190+ yield span_handle
191+ except GeneratorExit :
192+ # This error indicates the caller is done with the async generator
193+ # but messages are still left. This is okay, we do not consider it an error.
194+ raise
195+ except RiverException as e :
196+ span .record_exception (e , escaped = True )
197+ _record_river_error (span_handle , RiverError (code = e .code , message = e .message ))
198+ raise e
199+ except BaseException as e :
200+ span .record_exception (e , escaped = True )
201+ span_handle .set_status (StatusCode .ERROR , f"{ type (e ).__name__ } : { e } " )
202+ raise e
203+ finally :
204+ span_handle .set_status (StatusCode .OK )
205+ span .end ()
206+
207+
208+ def _record_river_error (span_handle : _SpanHandle , error : RiverError ) -> None :
209+ span_handle .set_status (StatusCode .ERROR , error .message )
210+ span_handle .span .record_exception (RiverException (error .code , error .message ))
211+ span_handle .span .set_attribute ("river.error_code" , error .code )
212+ span_handle .span .set_attribute ("river.error_message" , error .message )
0 commit comments