Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 108 additions & 64 deletions pkgs/standards/autoapi/autoapi/v3/bindings/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ def _pk_names(model: type) -> set[str]:
return {"id"}


def _coerce_parent_identifiers(
model: type, parent_kw: Dict[str, Any]
) -> Dict[str, Any]:
"""Coerce nested path identifiers to the model's column python types."""
table = getattr(model, "__table__", None)
if table is None:
return parent_kw
out: Dict[str, Any] = dict(parent_kw)
for key, val in list(parent_kw.items()):
col = getattr(getattr(table, "c", None), key, None)
if col is None:
continue
try:
py_t = getattr(col.type, "python_type", None)
if py_t is not None and not isinstance(val, py_t):
out[key] = py_t(val)
except Exception:
pass
return out


def _get_phase_chains(
model: type, alias: str
) -> Dict[str, Sequence[Callable[..., Awaitable[Any]]]]:
Expand Down Expand Up @@ -618,7 +639,9 @@ async def _endpoint(
db: Any = Depends(db_dep),
**kw: Any,
):
parent_kw = {k: kw[k] for k in nested_vars if k in kw}
parent_kw = _coerce_parent_identifiers(
model, {k: kw[k] for k in nested_vars if k in kw}
)
query = dict(q)
query.update(parent_kw)
payload = _validate_query(model, alias, target, query)
Expand Down Expand Up @@ -675,7 +698,9 @@ async def _endpoint(
db: Any = Depends(db_dep),
**kw: Any,
):
parent_kw = {k: kw[k] for k in nested_vars if k in kw}
parent_kw = _coerce_parent_identifiers(
model, {k: kw[k] for k in nested_vars if k in kw}
)
payload: Mapping[str, Any] = dict(parent_kw)
ctx: Dict[str, Any] = {
"request": request,
Expand Down Expand Up @@ -738,7 +763,9 @@ async def _endpoint(
body=Body(...),
**kw: Any,
):
parent_kw = {k: kw[k] for k in nested_vars if k in kw}
parent_kw = _coerce_parent_identifiers(
model, {k: kw[k] for k in nested_vars if k in kw}
)
payload = _validate_body(model, alias, target, body)
if parent_kw:
if isinstance(payload, Mapping):
Expand Down Expand Up @@ -829,7 +856,9 @@ async def _endpoint(
db: Any = Depends(db_dep),
**kw: Any,
):
parent_kw = {k: kw[k] for k in nested_vars if k in kw}
parent_kw = _coerce_parent_identifiers(
model, {k: kw[k] for k in nested_vars if k in kw}
)
payload: Mapping[str, Any] = dict(parent_kw)
path_params = {real_pk: item_id, pk_param: item_id, **parent_kw}
ctx: Dict[str, Any] = {
Expand Down Expand Up @@ -899,7 +928,9 @@ async def _endpoint(
db: Any = Depends(db_dep),
**kw: Any,
):
parent_kw = {k: kw[k] for k in nested_vars if k in kw}
parent_kw = _coerce_parent_identifiers(
model, {k: kw[k] for k in nested_vars if k in kw}
)
payload: Mapping[str, Any] = dict(parent_kw)
path_params = {real_pk: item_id, pk_param: item_id, **parent_kw}
ctx: Dict[str, Any] = {
Expand Down Expand Up @@ -975,7 +1006,9 @@ async def _endpoint(
body=body_default,
**kw: Any,
):
parent_kw = {k: kw[k] for k in nested_vars if k in kw}
parent_kw = _coerce_parent_identifiers(
model, {k: kw[k] for k in nested_vars if k in kw}
)
payload = _validate_body(model, alias, target, body)

# Enforce path-PK canonicality. If body echoes PK: drop if equal, 409 if mismatch.
Expand Down Expand Up @@ -1083,47 +1116,19 @@ def _build_router(model: type, specs: Sequence[OpSpec]) -> Router:
raw_nested = _nested_prefix(model) or ""
nested_pref = re.sub(r"/{2,}", "/", raw_nested).rstrip("/") or ""
nested_vars = re.findall(r"{(\w+)}", raw_nested)
nested_base = (
re.sub(r"/{2,}", "/", f"{nested_pref}/{resource}").rstrip("/")
if nested_pref
else ""
)

for sp in specs:
if not sp.expose_routes:
continue

# Drop parent identifiers from request models when using nested paths
if nested_vars:
schemas_root = getattr(model, "schemas", None)
if schemas_root:
alias_ns = getattr(schemas_root, sp.alias, None)
if alias_ns:
in_model = getattr(alias_ns, "in_", None)
if (
in_model
and inspect.isclass(in_model)
and issubclass(in_model, BaseModel)
):
pruned = _strip_parent_fields(in_model, drop=set(nested_vars))
setattr(alias_ns, "in_", pruned)

# Determine path and membership
if nested_pref:
suffix = sp.path_suffix or _default_path_suffix(sp) or ""
if not suffix.startswith("/") and suffix:
suffix = "/" + suffix
base = nested_pref
if sp.arity == "member" or sp.target in {
"read",
"update",
"replace",
"delete",
}:
path = f"{base}/{{{pk_param}}}{suffix}"
is_member = True
else:
path = f"{base}{suffix}"
is_member = False
else:
path, is_member = _path_for_spec(
model, sp, resource=resource, pk_param=pk_param
)
base_path, base_is_member = _path_for_spec(
model, sp, resource=resource, pk_param=pk_param
)

# HARDEN list.in_ at runtime to avoid bogus defaults blowing up empty GETs
if sp.target == "list":
Expand All @@ -1141,33 +1146,28 @@ def _build_router(model: type, specs: Sequence[OpSpec]) -> Router:
safe = _optionalize_list_in_model(in_model)
setattr(alias_ns, "in_", safe)

# HTTP methods
methods = list(sp.http_methods or _DEFAULT_METHODS.get(sp.target, ("POST",)))
response_model = None # Allow hooks to mutate response freely

# Build endpoint (split by body/no-body)
if is_member:
endpoint = _make_member_endpoint(
if base_is_member:
base_endpoint = _make_member_endpoint(
model,
sp,
resource=resource,
db_dep=db_dep,
pk_param=pk_param,
nested_vars=nested_vars,
nested_vars=[],
)
else:
endpoint = _make_collection_endpoint(
base_endpoint = _make_collection_endpoint(
model,
sp,
resource=resource,
db_dep=db_dep,
nested_vars=nested_vars,
nested_vars=[],
)

# Status codes
status_code = _status_for(sp)

# Capture OUT schema for OpenAPI without enforcing runtime validation
alias_ns = getattr(getattr(model, "schemas", None), sp.alias, None)
out_model = getattr(alias_ns, "out", None) if alias_ns else None

Expand All @@ -1179,33 +1179,77 @@ def _build_router(model: type, specs: Sequence[OpSpec]) -> Router:
responses_meta[status_code] = {"description": "Successful Response"}
response_class = Response

# Attach route
label = f"{model.__name__} - {sp.alias}"
route_kwargs = dict(
path=path,
endpoint=endpoint,
path=base_path,
endpoint=base_endpoint,
methods=methods,
name=f"{model.__name__}.{sp.alias}",
summary=label,
description=label,
response_model=response_model,
status_code=status_code,
# IMPORTANT: only class name here; never table name
tags=list(sp.tags or (model.__name__,)),
responses=responses_meta,
)
if response_class is not None:
route_kwargs["response_class"] = response_class
router.add_api_route(**route_kwargs)

logger.debug(
"rest: registered %s %s -> %s.%s (response_model=%s)",
methods,
path,
model.__name__,
sp.alias,
getattr(response_model, "__name__", None) if response_model else None,
)
if nested_pref:
if nested_vars:
schemas_root = getattr(model, "schemas", None)
if schemas_root:
alias_ns = getattr(schemas_root, sp.alias, None)
if alias_ns:
in_model = getattr(alias_ns, "in_", None)
if (
in_model
and inspect.isclass(in_model)
and issubclass(in_model, BaseModel)
):
pruned = _strip_parent_fields(
in_model, drop=set(nested_vars)
)
setattr(alias_ns, "in_", pruned)

suffix = sp.path_suffix or _default_path_suffix(sp) or ""
if not suffix.startswith("/") and suffix:
suffix = "/" + suffix
base = nested_base
if sp.arity == "member" or sp.target in {
"read",
"update",
"replace",
"delete",
}:
path = f"{base}/{{{pk_param}}}{suffix}"
is_member = True
else:
path = f"{base}{suffix}"
is_member = False

if is_member:
endpoint = _make_member_endpoint(
model,
sp,
resource=resource,
db_dep=db_dep,
pk_param=pk_param,
nested_vars=nested_vars,
)
else:
endpoint = _make_collection_endpoint(
model,
sp,
resource=resource,
db_dep=db_dep,
nested_vars=nested_vars,
)

nested_kwargs = dict(route_kwargs)
nested_kwargs.update(path=path, endpoint=endpoint)
router.add_api_route(**nested_kwargs)

return router

Expand Down
67 changes: 52 additions & 15 deletions pkgs/standards/autoapi/tests/i9n/test_symmetry_parity.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,64 @@
import pytest
from autoapi.v3.types import SimpleNamespace

CRUD_MAP = {
"create": ("post", "/tenant/{tenant_id}"),
"list": ("get", "/tenant/{tenant_id}"),
"clear": ("delete", "/tenant/{tenant_id}"),
"read": ("get", "/tenant/{tenant_id}/{item_id}"),
"update": ("patch", "/tenant/{tenant_id}/{item_id}"),
"delete": ("delete", "/tenant/{tenant_id}/{item_id}"),
}
PARITY_MAP = [
(
"list",
"get",
"/item",
"/tenant/{tenant_id}/item",
"Item.list",
("tenant_id",),
),
(
"read",
"get",
"/item/{item_id}",
"/tenant/{tenant_id}/item/{item_id}",
"Item.read",
("tenant_id", "item_id"),
),
]


@pytest.mark.i9n
@pytest.mark.asyncio
async def test_route_and_method_symmetry(api_client):
@pytest.mark.parametrize(
"verb,http_verb,rest_path,nested_path,rpc_method,param_keys", PARITY_MAP
)
async def test_rest_nested_rpc_parity(
api_client, verb, http_verb, rest_path, nested_path, rpc_method, param_keys
):
client, api, _ = api_client
api.attach_diagnostics(prefix="")
spec = (await client.get("/openapi.json")).json()
paths = spec["paths"]
methods = await client.get("/methodz")
method_list = {SimpleNamespace(**m).method for m in methods.json()["methods"]}
methods_resp = await client.get("/methodz")
method_list = {SimpleNamespace(**m).method for m in methods_resp.json()["methods"]}

for verb, (http_verb, path) in CRUD_MAP.items():
assert path in paths
assert http_verb in paths[path]
assert f"Item.{verb}" in method_list
assert rest_path in paths
assert http_verb in paths[rest_path]
assert nested_path in paths
assert http_verb in paths[nested_path]
assert f"Item.{verb}" in method_list

tenant = (await client.post("/tenant", json={"name": "t"})).json()
tenant_id = tenant["id"]
item = (await client.post(f"/tenant/{tenant_id}/item", json={"name": "i"})).json()
item_id = item["id"]

ids = {"tenant_id": tenant_id, "item_id": item_id}
rest_resp = await getattr(client, http_verb)(rest_path.format(**ids))
nested_resp = await getattr(client, http_verb)(nested_path.format(**ids))
rpc_payload = {
"jsonrpc": "2.0",
"method": rpc_method,
"params": {k: ids[k] for k in param_keys},
"id": 1,
}
rpc_resp = await client.post("/rpc", json=rpc_payload)

assert rest_resp.status_code == 200
assert nested_resp.status_code == 200
assert rpc_resp.status_code == 200
assert rest_resp.json() == nested_resp.json() == rpc_resp.json()["result"]
Loading