Skip to content

Commit a6ea8b4

Browse files
committed
multipart: aclosing support
1 parent 15de00c commit a6ea8b4

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

httpx/_compat.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import sys
2+
3+
if sys.version_info >= (3, 10):
4+
from contextlib import aclosing
5+
else:
6+
from contextlib import asynccontextmanager
7+
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar
8+
9+
class _SupportsAclose(Protocol):
10+
def aclose(self) -> Awaitable[object]: ...
11+
12+
_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)
13+
14+
@asynccontextmanager
15+
async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]:
16+
try:
17+
yield thing
18+
finally:
19+
await thing.aclose()
20+
21+
22+
__all__ = ["aclosing"]

httpx/_multipart.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import AsyncIterable
99
from pathlib import Path
1010

11+
from ._compat import aclosing
1112
from ._types import (
1213
AsyncByteStream,
1314
FileContent,
@@ -219,7 +220,7 @@ def render_data(self) -> typing.Iterator[bytes]:
219220
yield to_bytes(chunk)
220221
chunk = self.file.read(self.CHUNK_SIZE)
221222

222-
async def arender_data(self) -> typing.AsyncIterator[bytes]:
223+
async def arender_data(self) -> typing.AsyncGenerator[bytes]:
223224
if not isinstance(self.file, AsyncIterable):
224225
for chunk in self.render_data():
225226
yield chunk
@@ -247,10 +248,11 @@ def render(self) -> typing.Iterator[bytes]:
247248
yield self.render_headers()
248249
yield from self.render_data()
249250

250-
async def arender(self) -> typing.AsyncIterator[bytes]:
251+
async def arender(self) -> typing.AsyncGenerator[bytes]:
251252
yield self.render_headers()
252-
async for chunk in self.arender_data():
253-
yield chunk
253+
async with aclosing(self.arender_data()) as data:
254+
async for chunk in data:
255+
yield chunk
254256

255257

256258
class MultipartStream(SyncByteStream, AsyncByteStream):
@@ -294,12 +296,13 @@ def iter_chunks(self) -> typing.Iterator[bytes]:
294296
yield b"\r\n"
295297
yield b"--%s--\r\n" % self.boundary
296298

297-
async def aiter_chunks(self) -> typing.AsyncIterator[bytes]:
299+
async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
298300
for field in self.fields:
299301
yield b"--%s\r\n" % self.boundary
300302
if isinstance(field, FileField):
301-
async for chunk in field.arender():
302-
yield chunk
303+
async with aclosing(field.arender()) as data:
304+
async for chunk in data:
305+
yield chunk
303306
else:
304307
for chunk in field.render():
305308
yield chunk
@@ -340,5 +343,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
340343
yield chunk
341344

342345
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
343-
async for chunk in self.aiter_chunks():
344-
yield chunk
346+
async with aclosing(self.aiter_chunks()) as data:
347+
async for chunk in data:
348+
yield chunk

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,5 @@ markers = [
128128
]
129129

130130
[tool.coverage.run]
131-
omit = ["venv/*"]
131+
omit = ["venv/*", "httpx/_compat.py"]
132132
include = ["httpx/*", "tests/*"]

0 commit comments

Comments
 (0)