diff --git a/httpx/_compat.py b/httpx/_compat.py new file mode 100644 index 0000000000..d310afbd7d --- /dev/null +++ b/httpx/_compat.py @@ -0,0 +1,22 @@ +import sys + +if sys.version_info >= (3, 10): + from contextlib import aclosing +else: + from contextlib import asynccontextmanager + from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar + + class _SupportsAclose(Protocol): + def aclose(self) -> Awaitable[object]: ... + + _SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose) + + @asynccontextmanager + async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]: + try: + yield thing + finally: + await thing.aclose() + + +__all__ = ["aclosing"] diff --git a/httpx/_multipart.py b/httpx/_multipart.py index b4761af9b2..3da9e9c409 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -5,8 +5,10 @@ import os import re import typing +from collections.abc import AsyncIterable from pathlib import Path +from ._compat import aclosing from ._types import ( AsyncByteStream, FileContent, @@ -201,6 +203,8 @@ def render_headers(self) -> bytes: return self._headers def render_data(self) -> typing.Iterator[bytes]: + if isinstance(self.file, AsyncIterable): + raise TypeError("Invalid type for file. AsyncIterable is not supported.") if isinstance(self.file, (str, bytes)): yield to_bytes(self.file) return @@ -216,10 +220,40 @@ def render_data(self) -> typing.Iterator[bytes]: yield to_bytes(chunk) chunk = self.file.read(self.CHUNK_SIZE) + async def arender_data(self) -> typing.AsyncGenerator[bytes]: + if not isinstance(self.file, AsyncIterable): + for chunk in self.render_data(): + yield chunk + return + + file_aiter = self.file.__aiter__() + + try: + achunk = await file_aiter.__anext__() + except StopAsyncIteration: + return + + if not isinstance(achunk, bytes): + raise TypeError( + "Multipart file uploads must be opened in binary mode," + " not text mode." + ) + + yield achunk + + async for achunk in file_aiter: + yield achunk + def render(self) -> typing.Iterator[bytes]: yield self.render_headers() yield from self.render_data() + async def arender(self) -> typing.AsyncGenerator[bytes]: + yield self.render_headers() + async with aclosing(self.arender_data()) as data: + async for chunk in data: + yield chunk + class MultipartStream(SyncByteStream, AsyncByteStream): """ @@ -262,6 +296,19 @@ def iter_chunks(self) -> typing.Iterator[bytes]: yield b"\r\n" yield b"--%s--\r\n" % self.boundary + async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]: + for field in self.fields: + yield b"--%s\r\n" % self.boundary + if isinstance(field, FileField): + async with aclosing(field.arender()) as data: + async for chunk in data: + yield chunk + else: + for chunk in field.render(): + yield chunk + yield b"\r\n" + yield b"--%s--\r\n" % self.boundary + def get_content_length(self) -> int | None: """ Return the length of the multipart encoded content, or `None` if @@ -296,5 +343,6 @@ def __iter__(self) -> typing.Iterator[bytes]: yield chunk async def __aiter__(self) -> typing.AsyncIterator[bytes]: - for chunk in self.iter_chunks(): - yield chunk + async with aclosing(self.aiter_chunks()) as data: + async for chunk in data: + yield chunk diff --git a/httpx/_types.py b/httpx/_types.py index 704dfdffc8..2f78d30839 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -71,7 +71,7 @@ RequestData = Mapping[str, Any] -FileContent = Union[IO[bytes], bytes, str] +FileContent = Union[IO[bytes], bytes, str, AsyncIterable[bytes]] FileTypes = Union[ # file (or bytes) FileContent, diff --git a/pyproject.toml b/pyproject.toml index 9e67191135..2b83de5f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,5 +128,5 @@ markers = [ ] [tool.coverage.run] -omit = ["venv/*"] +omit = ["venv/*", "httpx/_compat.py"] include = ["httpx/*", "tests/*"] diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 764f85a253..d6283b9c72 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -4,7 +4,9 @@ import tempfile import typing +import anyio import pytest +import trio import httpx @@ -41,6 +43,61 @@ def test_multipart(value, output): ) +async def test_async_multipart_streaming(tmp_path, server, anyio_backend): + content = b"\n".join([b"a" * io.DEFAULT_BUFFER_SIZE] * 3) + to_upload = tmp_path / "test.txt" + to_upload.write_bytes(content) + empty_file = tmp_path / "empty.txt" + empty_file.write_bytes(b"") + opener: typing.Any + text_opener: typing.Any + empty_opener: typing.Any + if anyio_backend == "trio": + opener = trio.open_file(to_upload, "b+r") + text_opener = trio.open_file(to_upload, "t+r") + empty_opener = trio.open_file(empty_file, "b+r") + else: + opener = anyio.open_file(to_upload, "b+r") + text_opener = anyio.open_file(to_upload, "t+r") + empty_opener = anyio.open_file(empty_file, "b+r") + url = server.url.copy_with(path="/echo_body") + async with await opener as fp, httpx.AsyncClient() as client: + files = {"file": fp} + response = await client.post(url, files=files) + boundary = response.request.headers["Content-Type"].split("boundary=")[-1] + boundary_bytes = boundary.encode("ascii") + + assert response.status_code == 200 + assert response.content == b"".join( + [ + b"--" + boundary_bytes + b"\r\n", + b'Content-Disposition: form-data; name="file";' + b' filename="test.txt"\r\n', + b"Content-Type: text/plain\r\n", + b"\r\n", + content, + b"\r\n", + b"--" + boundary_bytes + b"--\r\n", + ] + ) + + with httpx.Client() as sync_client: + with pytest.raises(TypeError, match="AsyncIterable is not supported"): + sync_client.post(url, files=files) + + async with await text_opener as fp, httpx.AsyncClient() as client: + files = {"file": fp} + with pytest.raises( + TypeError, + match="Multipart file uploads must be opened in binary mode", + ): + await client.post(url, files=files) + + async with await empty_opener as fp, httpx.AsyncClient() as client: + files = {"file": fp} + await client.post(url, files=files) + + @pytest.mark.parametrize( "header", [