Skip to content

Commit d0ac2a9

Browse files
authored
Add support for notifying end of remote stream (#727)
1 parent 9777213 commit d0ac2a9

11 files changed

Lines changed: 231 additions & 13 deletions

File tree

awscrt/aio/http.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def __init__(self,
337337
request: HttpRequest,
338338
request_body_generator: AsyncIterator[bytes] = None,
339339
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
340+
340341
# Initialize the parent class
341342
http2_manual_write = request_body_generator is not None and connection.version is HttpVersion.Http2
342343
super()._init_common(connection, request, http2_manual_write=http2_manual_write)
@@ -357,6 +358,7 @@ def __init__(self,
357358

358359
# Create futures for async operations
359360
self._completion_future = Future()
361+
self._remote_completion_future = Future()
360362
self._response_status_future = Future()
361363
self._response_headers_future = Future()
362364
self._status_code = None
@@ -386,13 +388,12 @@ def _on_body(self, chunk: bytes) -> None:
386388
# Set result outside lock (Future is thread-safe)
387389
future.set_result(chunk)
388390

389-
def _on_complete(self, error_code: int) -> None:
390-
"""Set the completion status of the stream."""
391-
if error_code == 0:
392-
self._completion_future.set_result(self._status_code)
393-
else:
394-
self._completion_future.set_exception(awscrt.exceptions.from_code(error_code))
391+
def _resolve_pending_chunk_futures(self) -> None:
392+
"""Helper to resolve all pending chunk futures with empty bytes.
395393
394+
This indicates end of stream to any waiting get_next_response_chunk() calls.
395+
Must be called when either the stream completes or remote peer sends END_STREAM.
396+
"""
396397
# Resolve all pending chunk futures with lock protection
397398
with self._deque_lock:
398399
pending_futures = list(self._chunk_futures)
@@ -402,6 +403,20 @@ def _on_complete(self, error_code: int) -> None:
402403
for future in pending_futures:
403404
future.set_result(b"")
404405

406+
def _on_complete(self, error_code: int) -> None:
407+
"""Set the completion status of the stream."""
408+
if error_code == 0:
409+
self._completion_future.set_result(self._status_code)
410+
else:
411+
self._completion_future.set_exception(awscrt.exceptions.from_code(error_code))
412+
413+
self._resolve_pending_chunk_futures()
414+
415+
def _on_h2_remote_end_stream(self) -> None:
416+
"""Called when the remote peer has finished sending (HTTP/2 only)."""
417+
self._remote_completion_future.set_result(None)
418+
self._resolve_pending_chunk_futures()
419+
405420
async def _set_request_body_generator(self, body_iterator: AsyncIterator[bytes]):
406421
...
407422

@@ -431,7 +446,7 @@ async def get_next_response_chunk(self) -> bytes:
431446
with self._deque_lock:
432447
if self._received_chunks:
433448
return self._received_chunks.popleft()
434-
elif self._completion_future.done():
449+
elif self._completion_future.done() or self._remote_completion_future.done():
435450
return b""
436451
else:
437452
future = Future()

awscrt/http.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,14 @@ def _on_complete(self, error_code: int) -> None:
569569
else:
570570
self._completion_future.set_exception(awscrt.exceptions.from_code(error_code))
571571

572+
def _on_h2_remote_end_stream(self) -> None:
573+
"""Called when remote peer sends END_STREAM (HTTP/2 only).
574+
575+
This callback is only invoked for HTTP/2 connections. HTTP/1.x streams
576+
will never receive this callback. Base implementation does nothing.
577+
"""
578+
pass
579+
572580
def update_window(self, increment_size: int) -> None:
573581
"""
574582
Update the stream's flow control window.
@@ -613,14 +621,57 @@ def activate(self) -> None:
613621

614622

615623
class Http2ClientStream(HttpClientStreamBase):
624+
__slots__ = ('_remote_end_stream_future',)
625+
616626
def __init__(self,
617627
connection: HttpClientConnection,
618628
request: 'HttpRequest',
619629
on_response: Optional[Callable[..., None]] = None,
620630
on_body: Optional[Callable[..., None]] = None,
621631
manual_write: bool = False) -> None:
632+
self._remote_end_stream_future = Future()
622633
self._init_common(connection, request, on_response, on_body, manual_write)
623634

635+
@property
636+
def remote_end_stream_future(self) -> "concurrent.futures.Future":
637+
"""
638+
concurrent.futures.Future: Future that completes when the remote peer has finished
639+
sending (HTTP/2 only). This occurs when the server sends an END_STREAM flag.
640+
641+
The future will contain a result of None on success, or an exception if the stream
642+
encounters an error before END_STREAM is received (e.g., RST_STREAM).
643+
644+
This is different from `completion_future` which completes when both the
645+
client and server have finished (bidirectional stream closure).
646+
647+
Note: This future only applies to HTTP/2 connections. It will complete when the
648+
server sends END_STREAM, which may occur before the client finishes sending.
649+
In case of stream completed without END_STREAM received, this future will complete
650+
with exception.
651+
"""
652+
return self._remote_end_stream_future
653+
654+
def _on_h2_remote_end_stream(self) -> None:
655+
"""Internal callback when remote peer sends END_STREAM (HTTP/2 only)."""
656+
if not self._remote_end_stream_future.done():
657+
self._remote_end_stream_future.set_result(None)
658+
659+
def _on_complete(self, error_code: int) -> None:
660+
# done with HttpRequest, drop reference
661+
self._request = None # type: ignore
662+
663+
# Ensure remote_completion_future is always resolved
664+
if not self._remote_end_stream_future.done():
665+
# Stream completed successfully but END_STREAM was never received,
666+
# complete `remote_completion_future` with exception.
667+
self._remote_end_stream_future.set_exception(
668+
RuntimeError("Stream completed without receiving remote END_STREAM"))
669+
670+
if error_code == 0:
671+
self._completion_future.set_result(self._response_status_code)
672+
else:
673+
self._completion_future.set_exception(awscrt.exceptions.from_code(error_code))
674+
624675
def activate(self) -> None:
625676
"""Begin sending the request.
626677

crt/aws-lc

source/http_stream.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ static void s_on_stream_complete(struct aws_http_stream *native_stream, int erro
213213
/*************** GIL RELEASE ***************/
214214
}
215215

216+
static void s_on_h2_remote_end_stream(struct aws_http_stream *native_stream, void *user_data) {
217+
(void)native_stream;
218+
struct http_stream_binding *stream = user_data;
219+
220+
/*************** GIL ACQUIRE ***************/
221+
PyGILState_STATE state;
222+
if (aws_py_gilstate_ensure(&state)) {
223+
return; /* Python has shut down. Nothing matters anymore, but don't crash */
224+
}
225+
226+
PyObject *result = PyObject_CallMethod(stream->self_proxy, "_on_h2_remote_end_stream", "()");
227+
if (result) {
228+
Py_DECREF(result);
229+
}
230+
PyGILState_Release(state);
231+
/*************** GIL RELEASE ***************/
232+
}
233+
216234
static void s_stream_capsule_destructor(PyObject *http_stream_capsule) {
217235
struct http_stream_binding *stream = PyCapsule_GetPointer(http_stream_capsule, s_capsule_name_http_stream);
218236

@@ -283,6 +301,7 @@ PyObject *aws_py_http_client_stream_new(PyObject *self, PyObject *args) {
283301
.on_response_header_block_done = s_on_incoming_header_block_done,
284302
.on_response_body = s_on_incoming_body,
285303
.on_complete = s_on_stream_complete,
304+
.on_h2_remote_end_stream = s_on_h2_remote_end_stream,
286305
.user_data = stream,
287306
.http2_use_manual_data_writes = http2_manual_write,
288307
};

test/test_aiohttp_client.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,5 +916,80 @@ def test_h2_stream_flow_control_blocks_and_resumes(self):
916916
asyncio.run(self._test_h2_stream_flow_control_blocks_and_resumes())
917917

918918

919+
class AIOHttp2RemoteEndStreamTest(NativeResourceTest):
920+
"""Test suite for HTTP/2 on_h2_remote_end_stream callback in asyncio"""
921+
timeout = 10.0
922+
923+
async def _new_httpbin_h2_connection(self):
924+
"""Create HTTP/2 connection to httpbin.org"""
925+
event_loop_group = EventLoopGroup()
926+
host_resolver = DefaultHostResolver(event_loop_group)
927+
bootstrap = ClientBootstrap(event_loop_group, host_resolver)
928+
929+
tls_ctx_options = TlsContextOptions()
930+
tls_ctx = ClientTlsContext(tls_ctx_options)
931+
tls_conn_opt = tls_ctx.new_connection_options()
932+
tls_conn_opt.set_server_name("httpbin.org")
933+
tls_conn_opt.set_alpn_list(["h2"])
934+
935+
connection = await AIOHttp2ClientConnection.new(
936+
host_name="httpbin.org",
937+
port=443,
938+
bootstrap=bootstrap,
939+
tls_connection_options=tls_conn_opt)
940+
941+
return connection
942+
943+
async def _test_h2_remote_end_stream_ordering(self):
944+
"""Test that on_h2_remote_end_stream fires before on_complete when server finishes first"""
945+
connection = await self._new_httpbin_h2_connection()
946+
947+
# Use httpbin.org 404 path - server responds immediately
948+
request = HttpRequest('POST', '/this-path-does-not-exist-deliberately-404')
949+
request.headers.add('host', 'httpbin.org')
950+
951+
complete_success = asyncio.Event()
952+
remote_finished = asyncio.Event()
953+
complete_fired = asyncio.Event()
954+
955+
async def slow_body_generator():
956+
# Send first chunk WITHOUT end_stream
957+
yield b'chunk1'
958+
# Wait for server to finish
959+
await remote_finished.wait()
960+
if not complete_fired.is_set():
961+
# Verify complete hasn't fired yet
962+
complete_success.set()
963+
# Now finish sending
964+
yield b'chunk2'
965+
966+
stream = connection.request(request, request_body_generator=slow_body_generator())
967+
968+
# Read response
969+
status_code = await stream.get_response_status_code()
970+
self.assertEqual(404, status_code)
971+
972+
# Read all response body
973+
while True:
974+
chunk = await stream.get_next_response_chunk()
975+
if not chunk:
976+
break
977+
978+
# set remove stream
979+
remote_finished.set()
980+
981+
# Wait for stream to complete successfully
982+
await stream.wait_for_completion()
983+
complete_fired.set()
984+
985+
self.assertTrue(complete_success.is_set())
986+
987+
await connection.close()
988+
989+
def test_h2_remote_end_stream_ordering(self):
990+
"""Test callback ordering with early server response"""
991+
asyncio.run(self._test_h2_remote_end_stream_ordering())
992+
993+
919994
if __name__ == '__main__':
920995
unittest.main()

0 commit comments

Comments
 (0)