Skip to content

Commit 92c3da0

Browse files
authored
Don't record GeneratorExit errors in stream RPC spans (#129)
Why === Stream RPCs which are not read to completion would show up with a `GeneratorExit` error in their associated span. There are valid cases where a caller may not need to exhaustively read the stream, so these should not be considered errors. What changed ============ - Stop using the span as a context manager, we'll handle recording errors ourselves - Catch RiverExceptions, Exceptions, and CancelledErrors - Switch from `AsyncIterator` to `AsyncGenerator` for stream return types - This allows us to call close on the generator - We may also want to do this in the codegen, this will be especially useful for cancellation support (just close the generator) Test plan ========= - Added a test which ensured a closed generator creates a span with OK status
1 parent 933fcc0 commit 92c3da0

File tree

3 files changed

+105
-34
lines changed

3 files changed

+105
-34
lines changed

replit_river/client.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2-
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
2+
from collections.abc import AsyncIterable, Awaitable, Callable
33
from contextlib import contextmanager
4+
from dataclasses import dataclass
45
from datetime import timedelta
5-
from typing import Any, Generator, Generic, Literal, Optional, Union
6+
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union
67

78
from opentelemetry import trace
8-
from opentelemetry.trace import Span, SpanKind, StatusCode
9+
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
910

1011
from replit_river.client_transport import ClientTransport
1112
from replit_river.error_schema import RiverError, RiverException
@@ -63,7 +64,7 @@ async def send_rpc(
6364
error_deserializer: Callable[[Any], ErrorType],
6465
timeout: timedelta,
6566
) -> ResponseType:
66-
with _trace_procedure("rpc", service_name, procedure_name) as span:
67+
with _trace_procedure("rpc", service_name, procedure_name) as span_handle:
6768
session = await self._transport.get_or_create_session()
6869
return await session.send_rpc(
6970
service_name,
@@ -72,7 +73,7 @@ async def send_rpc(
7273
request_serializer,
7374
response_deserializer,
7475
error_deserializer,
75-
span,
76+
span_handle.span,
7677
timeout,
7778
)
7879

@@ -87,7 +88,7 @@ async def send_upload(
8788
response_deserializer: Callable[[Any], ResponseType],
8889
error_deserializer: Callable[[Any], ErrorType],
8990
) -> ResponseType:
90-
with _trace_procedure("upload", service_name, procedure_name) as span:
91+
with _trace_procedure("upload", service_name, procedure_name) as span_handle:
9192
session = await self._transport.get_or_create_session()
9293
return await session.send_upload(
9394
service_name,
@@ -98,7 +99,7 @@ async def send_upload(
9899
request_serializer,
99100
response_deserializer,
100101
error_deserializer,
101-
span,
102+
span_handle.span,
102103
)
103104

104105
async def send_subscription(
@@ -109,8 +110,10 @@ async def send_subscription(
109110
request_serializer: Callable[[RequestType], Any],
110111
response_deserializer: Callable[[Any], ResponseType],
111112
error_deserializer: Callable[[Any], ErrorType],
112-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
113-
with _trace_procedure("subscription", service_name, procedure_name) as span:
113+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
114+
with _trace_procedure(
115+
"subscription", service_name, procedure_name
116+
) as span_handle:
114117
session = await self._transport.get_or_create_session()
115118
async for msg in session.send_subscription(
116119
service_name,
@@ -119,10 +122,10 @@ async def send_subscription(
119122
request_serializer,
120123
response_deserializer,
121124
error_deserializer,
122-
span,
125+
span_handle.span,
123126
):
124127
if isinstance(msg, RiverError):
125-
_record_river_error(span, msg)
128+
_record_river_error(span_handle, msg)
126129
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
127130

128131
async def send_stream(
@@ -135,8 +138,8 @@ async def send_stream(
135138
request_serializer: Callable[[RequestType], Any],
136139
response_deserializer: Callable[[Any], ResponseType],
137140
error_deserializer: Callable[[Any], ErrorType],
138-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
139-
with _trace_procedure("stream", service_name, procedure_name) as span:
141+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
142+
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
140143
session = await self._transport.get_or_create_session()
141144
async for msg in session.send_stream(
142145
service_name,
@@ -147,32 +150,63 @@ async def send_stream(
147150
request_serializer,
148151
response_deserializer,
149152
error_deserializer,
150-
span,
153+
span_handle.span,
151154
):
152155
if isinstance(msg, RiverError):
153-
_record_river_error(span, msg)
156+
_record_river_error(span_handle, msg)
154157
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
155158

156159

160+
@dataclass
161+
class _SpanHandle:
162+
"""Wraps a span and keeps track of whether or not a status has been recorded yet."""
163+
164+
span: Span
165+
did_set_status: bool = False
166+
167+
def set_status(
168+
self,
169+
status: Union[Status, StatusCode],
170+
description: Optional[str] = None,
171+
) -> None:
172+
if self.did_set_status:
173+
return
174+
self.did_set_status = True
175+
self.span.set_status(status, description)
176+
177+
157178
@contextmanager
158179
def _trace_procedure(
159180
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
160181
service_name: str,
161182
procedure_name: str,
162-
) -> Generator[Span, None, None]:
163-
with tracer.start_span(
183+
) -> Generator[_SpanHandle, None, None]:
184+
span = tracer.start_span(
164185
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
165186
kind=SpanKind.CLIENT,
166-
) as span:
167-
try:
168-
yield span
169-
except RiverException as e:
170-
_record_river_error(span, RiverError(code=e.code, message=e.message))
171-
raise e
172-
173-
174-
def _record_river_error(span: Span, error: RiverError) -> None:
175-
span.set_status(StatusCode.ERROR, error.message)
176-
span.record_exception(RiverException(error.code, error.message))
177-
span.set_attribute("river.error_code", error.code)
178-
span.set_attribute("river.error_message", error.message)
187+
)
188+
span_handle = _SpanHandle(span)
189+
try:
190+
yield span_handle
191+
except GeneratorExit:
192+
# This error indicates the caller is done with the async generator
193+
# but messages are still left. This is okay, we do not consider it an error.
194+
raise
195+
except RiverException as e:
196+
span.record_exception(e, escaped=True)
197+
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
198+
raise e
199+
except BaseException as e:
200+
span.record_exception(e, escaped=True)
201+
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
202+
raise e
203+
finally:
204+
span_handle.set_status(StatusCode.OK)
205+
span.end()
206+
207+
208+
def _record_river_error(span_handle: _SpanHandle, error: RiverError) -> None:
209+
span_handle.set_status(StatusCode.ERROR, error.message)
210+
span_handle.span.record_exception(RiverException(error.code, error.message))
211+
span_handle.span.set_attribute("river.error_code", error.code)
212+
span_handle.span.set_attribute("river.error_message", error.message)

replit_river/client_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
22
import logging
3-
from collections.abc import AsyncIterable, AsyncIterator
3+
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, AsyncGenerator, Callable, Optional, Union
66

77
import nanoid # type: ignore
88
from aiochannel import Channel
@@ -194,7 +194,7 @@ async def send_subscription(
194194
response_deserializer: Callable[[Any], ResponseType],
195195
error_deserializer: Callable[[Any], ErrorType],
196196
span: Span,
197-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
197+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
198198
"""Sends a subscription request to the server.
199199
200200
Expects the input and output be messages that will be msgpacked.
@@ -246,7 +246,7 @@ async def send_stream(
246246
response_deserializer: Callable[[Any], ResponseType],
247247
error_deserializer: Callable[[Any], ErrorType],
248248
span: Span,
249-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
249+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
250250
"""Sends a subscription request to the server.
251251
252252
Expects the input and output be messages that will be msgpacked.

tests/test_opentelemetry.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
from datetime import timedelta
23
from typing import AsyncGenerator, AsyncIterator, Iterator
34

@@ -182,3 +183,39 @@ async def stream_data() -> AsyncGenerator[str, None]:
182183
assert len(spans) == 1
183184
assert spans[0].name == "river.client.stream.test_service.stream_method_error"
184185
assert spans[0].status.status_code == StatusCode.ERROR
186+
187+
188+
@pytest.mark.asyncio
189+
@pytest.mark.parametrize("handlers", [{**basic_stream}])
190+
async def test_stream_method_span_generator_exit_not_recorded(
191+
client: Client, span_exporter: InMemorySpanExporter
192+
) -> None:
193+
async def stream_data() -> AsyncGenerator[str, None]:
194+
yield "Stream 1"
195+
yield "Stream 2"
196+
yield "Stream 3"
197+
198+
responses = []
199+
stream = client.send_stream(
200+
"test_service",
201+
"stream_method",
202+
"Initial Stream Data",
203+
stream_data(),
204+
serialize_request,
205+
serialize_request,
206+
deserialize_response,
207+
deserialize_error,
208+
)
209+
async with contextlib.aclosing(stream) as generator:
210+
async for response in generator:
211+
responses.append(response)
212+
break
213+
214+
assert responses == [
215+
"Stream response for Initial Stream Data",
216+
]
217+
218+
spans = span_exporter.get_finished_spans()
219+
assert len(spans) == 1
220+
assert spans[0].name == "river.client.stream.test_service.stream_method"
221+
assert spans[0].status.status_code == StatusCode.OK

0 commit comments

Comments
 (0)