Skip to content

Commit 7595f86

Browse files
committed
multipart: aclosing support
1 parent cbb1062 commit 7595f86

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

httpx/_compat.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,23 @@ def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None:
6060
context.options |= ssl.OP_NO_TLSv1_1
6161

6262

63-
__all__ = ["brotli", "set_minimum_tls_version_1_2"]
63+
if sys.version_info >= (3, 10):
64+
from contextlib import aclosing
65+
else:
66+
from contextlib import asynccontextmanager
67+
from typing import Awaitable, Protocol, TypeVar
68+
69+
class _SupportsAclose(Protocol):
70+
def aclose(self) -> Awaitable[object]: ...
71+
72+
_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)
73+
74+
@asynccontextmanager
75+
async def aclosing(thing: _SupportsAcloseT) -> None:
76+
try:
77+
yield thing
78+
finally:
79+
await thing.aclose()
80+
81+
82+
__all__ = ["brotli", "set_minimum_tls_version_1_2", "aclosing"]

httpx/_multipart.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import AsyncIterable
77
from pathlib import Path
88

9+
from ._compat import aclosing
910
from ._types import (
1011
AsyncByteStream,
1112
FileContent,
@@ -188,7 +189,7 @@ def render_data(self) -> typing.Iterator[bytes]:
188189
yield to_bytes(chunk)
189190
chunk = self.file.read(self.CHUNK_SIZE)
190191

191-
async def arender_data(self) -> typing.AsyncIterator[bytes]:
192+
async def arender_data(self) -> typing.AsyncGenerator[bytes]:
192193
if not isinstance(self.file, AsyncIterable):
193194
for chunk in self.render_data():
194195
yield chunk
@@ -216,10 +217,11 @@ def render(self) -> typing.Iterator[bytes]:
216217
yield self.render_headers()
217218
yield from self.render_data()
218219

219-
async def arender(self) -> typing.AsyncIterator[bytes]:
220+
async def arender(self) -> typing.AsyncGenerator[bytes]:
220221
yield self.render_headers()
221-
async for chunk in self.arender_data():
222-
yield chunk
222+
async with aclosing(self.arender_data()) as data:
223+
async for chunk in data:
224+
yield chunk
223225

224226

225227
class MultipartStream(SyncByteStream, AsyncByteStream):
@@ -263,12 +265,13 @@ def iter_chunks(self) -> typing.Iterator[bytes]:
263265
yield b"\r\n"
264266
yield b"--%s--\r\n" % self.boundary
265267

266-
async def aiter_chunks(self) -> typing.AsyncIterator[bytes]:
268+
async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
267269
for field in self.fields:
268270
yield b"--%s\r\n" % self.boundary
269271
if isinstance(field, FileField):
270-
async for chunk in field.arender():
271-
yield chunk
272+
async with aclosing(field.arender()) as data:
273+
async for chunk in data:
274+
yield chunk
272275
else:
273276
for chunk in field.render():
274277
yield chunk
@@ -309,5 +312,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
309312
yield chunk
310313

311314
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
312-
async for chunk in self.aiter_chunks():
313-
yield chunk
315+
async with aclosing(self.aiter_chunks()) as data:
316+
async for chunk in data:
317+
yield chunk

0 commit comments

Comments
 (0)