Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0.

## [Unreleased]

## [6.1.0]

### Added

- Added `form` property to `RequestInfo` for reading and modifying multipart/form-encoded fields (strings and file uploads) inside `convert_request_to_next_version_for` handlers (#359)

## [6.0.4]

### Added
Expand Down
13 changes: 10 additions & 3 deletions cadwyn/structure/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import ClassVar, Union, cast

from fastapi import Request, Response
from starlette.datastructures import MutableHeaders
from starlette.datastructures import FormData, MutableHeaders, UploadFile
from typing_extensions import Any, ParamSpec, overload

from cadwyn._utils import same_definition_as_in
Expand All @@ -14,9 +14,8 @@
_P = ParamSpec("_P")


# TODO (https://github.com/zmievsa/cadwyn/issues/49): Add form handling
class RequestInfo:
__slots__ = ("_cookies", "_query_params", "_request", "body", "headers")
__slots__ = ("_cookies", "_form", "_query_params", "_request", "body", "headers")

def __init__(self, request: Request, body: Any):
super().__init__()
Expand All @@ -25,6 +24,10 @@ def __init__(self, request: Request, body: Any):
self._cookies = request.cookies
self._query_params = request.query_params._dict
self._request = request
if isinstance(body, FormData):
self._form: Union[list[tuple[str, Union[UploadFile, str]]], None] = list(body.multi_items())
else:
self._form = None

@property
def cookies(self) -> dict[str, str]:
Expand All @@ -34,6 +37,10 @@ def cookies(self) -> dict[str, str]:
def query_params(self) -> dict[str, str]:
return self._query_params

@property
def form(self) -> Union[list[tuple[str, Union[UploadFile, str]]], None]:
return self._form


# TODO (https://github.com/zmievsa/cadwyn/issues/111): handle _response.media_type and _response.background
class ResponseInfo:
Expand Down
9 changes: 8 additions & 1 deletion cadwyn/structure/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from starlette._utils import is_async_callable
from starlette.datastructures import FormData
from typing_extensions import Any, ParamSpec, TypeAlias, TypeVar, assert_never, deprecated, get_args

from cadwyn._internal.context_vars import CURRENT_DEPENDENCY_SOLVER_VAR
Expand Down Expand Up @@ -431,12 +432,18 @@ async def _migrate_request(
del request._headers
# This gives us the ability to tell the user whether cadwyn is running its dependencies or FastAPI
CURRENT_DEPENDENCY_SOLVER_VAR.set("cadwyn")

if request_info._form is not None:
body_for_solving = FormData(request_info._form)
else:
body_for_solving = request_info.body

# Remember this: if len(body_params) == 1, then route.body_schema == route.dependant.body_params[0]
result = await solve_dependencies(
request=request,
response=response,
dependant=head_dependant,
body=request_info.body,
body=body_for_solving,
dependency_overrides_provider=head_route.dependency_overrides_provider,
async_exit_stack=exit_stack,
embed_body_fields=embed_body_fields,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cadwyn"
version = "6.0.4"
version = "6.1.0"
description = "Production-ready community-driven modern Stripe-like API versioning in FastAPI"
authors = [{ name = "Stanislav Zmiev", email = "zmievsa@gmail.com" }]
license = "MIT"
Expand Down
184 changes: 183 additions & 1 deletion tests/test_data_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import fastapi
import pytest
from dirty_equals import IsPartialDict, IsStr
from fastapi import APIRouter, Body, Cookie, File, Header, HTTPException, Query, Request, Response, UploadFile
from fastapi import APIRouter, Body, Cookie, File, Form, Header, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -890,6 +890,188 @@ async def endpoint(file: UploadFile = File(...)):
}


def test__form_migration__can_modify_string_form_field(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: str = Form(...), age: int = Form(...)):
return {"name": name, "age": age}

@convert_request_to_next_version_for(test_path, ["POST"])
def migrator(request: RequestInfo):
request.form[:] = [(k, "migrated_name") if k == "name" else (k, v) for k, v in request.form] # type: ignore[index]

clients = create_versioned_clients(version_change(migrator=migrator))
resp_2000 = clients["2000-01-01"].post(test_path, data={"name": "original_name", "age": "25"})
resp_2001 = clients["2001-01-01"].post(test_path, data={"name": "original_name", "age": "25"})

assert resp_2000.json() == {"name": "migrated_name", "age": 25}
assert resp_2001.json() == {"name": "original_name", "age": 25}


def test__form_migration__form_is_none_for_non_form_requests(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path, response_model=AnyResponseSchema)
async def endpoint(payload: AnyRequestSchema):
return payload

captured = []

@convert_request_to_next_version_for(test_path, ["POST"])
def migrator(request: RequestInfo):
captured.append(request.form)

clients = create_versioned_clients(version_change(migrator=migrator))
clients["2000-01-01"].post(test_path, json={"foo": "bar"})

assert captured == [None]


def test__form_migration__can_add_new_form_field(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: str = Form(...), extra: str = Form("default")):
return {"name": name, "extra": extra}

@convert_request_to_next_version_for(test_path, ["POST"])
def migrator(request: RequestInfo):
request.form.append(("extra", "added_by_migration")) # type: ignore[union-attr]

clients = create_versioned_clients(version_change(migrator=migrator))
resp_2000 = clients["2000-01-01"].post(test_path, data={"name": "my_name"})
resp_2001 = clients["2001-01-01"].post(test_path, data={"name": "my_name"})

assert resp_2000.json() == {"name": "my_name", "extra": "added_by_migration"}
assert resp_2001.json() == {"name": "my_name", "extra": "default"}


def test__form_migration__multi_value_fields_are_preserved(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(tags: list[str] = Form(...)):
return {"tags": tags}

clients = create_versioned_clients(version_change())
resp_2000 = clients["2000-01-01"].post(
test_path,
content=b"tags=python&tags=web",
headers={"content-type": "application/x-www-form-urlencoded"},
)
resp_2001 = clients["2001-01-01"].post(
test_path,
content=b"tags=python&tags=web",
headers={"content-type": "application/x-www-form-urlencoded"},
)

assert resp_2000.json() == {"tags": ["python", "web"]}
assert resp_2001.json() == {"tags": ["python", "web"]}


def test__request_body_parsing__empty_body_with_optional_params(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: Optional[str] = Body(default=None), age: Optional[int] = Body(default=None)):
return {"name": name, "age": age}

clients = create_versioned_clients(version_change())
resp = clients["2000-01-01"].post(
test_path,
content=b"",
headers={"content-type": "application/json"},
)

assert resp.status_code == 200
assert resp.json() == {"name": None, "age": None}


def test__request_body_parsing__json_body_without_content_type(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: Optional[str] = Body(default=None), age: Optional[int] = Body(default=None)):
return {"name": name, "age": age}

clients = create_versioned_clients(version_change())
resp = clients["2000-01-01"].post(test_path, content=b'{"name": "test", "age": 25}')

assert resp.status_code == 200
assert resp.json() == {"name": "test", "age": 25}


def test__request_body_parsing__non_application_content_type(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: Optional[str] = Body(default=None), age: Optional[int] = Body(default=None)):
return {"name": name, "age": age}

clients = create_versioned_clients(version_change())
resp = clients["2000-01-01"].post(
test_path,
content=b'{"name": "test", "age": 25}',
headers={"content-type": "text/plain"},
)

assert resp.status_code in [200, 422]


def test__request_body_parsing__application_xml_body(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: Optional[str] = Body(default=None), age: Optional[int] = Body(default=None)):
return {"name": name, "age": age}

clients = create_versioned_clients(version_change())
resp = clients["2000-01-01"].post(
test_path,
content=b"<root/>",
headers={"content-type": "application/xml"},
)

assert resp.status_code in [200, 422]


def test__request_body_parsing__malformed_json_returns_422(
create_versioned_clients: CreateVersionedClients,
test_path: Literal["/test"],
router: VersionedAPIRouter,
):
@router.post(test_path)
async def endpoint(name: Optional[str] = Body(default=None), age: Optional[int] = Body(default=None)):
return {"name": name, "age": age}

clients = create_versioned_clients(version_change())
resp = clients["2000-01-01"].post(
test_path,
content=b"not valid json",
headers={"content-type": "application/json"},
)

assert resp.status_code == 422
assert resp.json()["detail"][0]["type"] == "json_invalid"


def test__request_and_response_migrations__for_paths_with_variables__can_match(
create_versioned_clients: CreateVersionedClients,
router: VersionedAPIRouter,
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading