Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

async file in multipart file upload #3339

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions httpx/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys

if sys.version_info >= (3, 10):
from contextlib import aclosing
else:
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar

class _SupportsAclose(Protocol):
def aclose(self) -> Awaitable[object]: ...

_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)

@asynccontextmanager
async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]:
try:
yield thing
finally:
await thing.aclose()


__all__ = ["aclosing"]
52 changes: 50 additions & 2 deletions httpx/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import os
import re
import typing
from collections.abc import AsyncIterable
from pathlib import Path

from ._compat import aclosing
from ._types import (
AsyncByteStream,
FileContent,
Expand Down Expand Up @@ -201,6 +203,8 @@ def render_headers(self) -> bytes:
return self._headers

def render_data(self) -> typing.Iterator[bytes]:
if isinstance(self.file, AsyncIterable):
raise TypeError("Invalid type for file. AsyncIterable is not supported.")
if isinstance(self.file, (str, bytes)):
yield to_bytes(self.file)
return
Expand All @@ -216,10 +220,40 @@ def render_data(self) -> typing.Iterator[bytes]:
yield to_bytes(chunk)
chunk = self.file.read(self.CHUNK_SIZE)

async def arender_data(self) -> typing.AsyncGenerator[bytes]:
if not isinstance(self.file, AsyncIterable):
for chunk in self.render_data():
yield chunk
return
Comment on lines +224 to +227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want block loop for non-AsynIterable types?

Suggested change
if not isinstance(self.file, AsyncIterable):
for chunk in self.render_data():
yield chunk
return
if not isinstance(self.file, AsyncIterable):
raise TypeError("Invalid type for file. Only AsyncIterable is supported.")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but won't it break IO[bytes] usage with AsyncClient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Async generator objects should always be typed as AsyncGenerators so they can be used with contextlib.aclosing


file_aiter = self.file.__aiter__()

try:
achunk = await file_aiter.__anext__()
except StopAsyncIteration:
return

if not isinstance(achunk, bytes):
raise TypeError(
"Multipart file uploads must be opened in binary mode,"
" not text mode."
)

yield achunk

async for achunk in file_aiter:
yield achunk

def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield from self.render_data()

async def arender(self) -> typing.AsyncGenerator[bytes]:
yield self.render_headers()
async with aclosing(self.arender_data()) as data:
async for chunk in data:
yield chunk


class MultipartStream(SyncByteStream, AsyncByteStream):
"""
Expand Down Expand Up @@ -262,6 +296,19 @@ def iter_chunks(self) -> typing.Iterator[bytes]:
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary

async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
for field in self.fields:
yield b"--%s\r\n" % self.boundary
if isinstance(field, FileField):
async with aclosing(field.arender()) as data:
async for chunk in data:
yield chunk
else:
for chunk in field.render():
yield chunk
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary

def get_content_length(self) -> int | None:
"""
Return the length of the multipart encoded content, or `None` if
Expand Down Expand Up @@ -296,5 +343,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
yield chunk

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
for chunk in self.iter_chunks():
yield chunk
async with aclosing(self.aiter_chunks()) as data:
async for chunk in data:
yield chunk
2 changes: 1 addition & 1 deletion httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@

RequestData = Mapping[str, Any]

FileContent = Union[IO[bytes], bytes, str]
FileContent = Union[IO[bytes], bytes, str, AsyncIterable[bytes]]
FileTypes = Union[
# file (or bytes)
FileContent,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ markers = [
]

[tool.coverage.run]
omit = ["venv/*"]
omit = ["venv/*", "httpx/_compat.py"]
include = ["httpx/*", "tests/*"]
57 changes: 57 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import tempfile
import typing

import anyio
import pytest
import trio

import httpx

Expand Down Expand Up @@ -41,6 +43,61 @@ def test_multipart(value, output):
)


async def test_async_multipart_streaming(tmp_path, server, anyio_backend):
content = b"\n".join([b"a" * io.DEFAULT_BUFFER_SIZE] * 3)
to_upload = tmp_path / "test.txt"
to_upload.write_bytes(content)
empty_file = tmp_path / "empty.txt"
empty_file.write_bytes(b"")
opener: typing.Any
text_opener: typing.Any
empty_opener: typing.Any
if anyio_backend == "trio":
opener = trio.open_file(to_upload, "b+r")
text_opener = trio.open_file(to_upload, "t+r")
empty_opener = trio.open_file(empty_file, "b+r")
else:
opener = anyio.open_file(to_upload, "b+r")
text_opener = anyio.open_file(to_upload, "t+r")
empty_opener = anyio.open_file(empty_file, "b+r")
url = server.url.copy_with(path="/echo_body")
async with await opener as fp, httpx.AsyncClient() as client:
files = {"file": fp}
response = await client.post(url, files=files)
boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
boundary_bytes = boundary.encode("ascii")

assert response.status_code == 200
assert response.content == b"".join(
[
b"--" + boundary_bytes + b"\r\n",
b'Content-Disposition: form-data; name="file";'
b' filename="test.txt"\r\n',
b"Content-Type: text/plain\r\n",
b"\r\n",
content,
b"\r\n",
b"--" + boundary_bytes + b"--\r\n",
]
)

with httpx.Client() as sync_client:
with pytest.raises(TypeError, match="AsyncIterable is not supported"):
sync_client.post(url, files=files)

async with await text_opener as fp, httpx.AsyncClient() as client:
files = {"file": fp}
with pytest.raises(
TypeError,
match="Multipart file uploads must be opened in binary mode",
):
await client.post(url, files=files)

async with await empty_opener as fp, httpx.AsyncClient() as client:
files = {"file": fp}
await client.post(url, files=files)


@pytest.mark.parametrize(
"header",
[
Expand Down