Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ jobs:
run: |
uv run python -m pytest tests/ -vv -s


- name: Run fastapi compatibility tests
if: steps.check_test_files.outputs.files_exists == 'true'
env:
PYTHONWARNINGS: ignore
run: |
uv run nox -s test-compat-fastapi

#----------------------------------------------
# make sure docs build
#----------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ https://github.com/jymchng/fastapi-shield
<hr style="border: none; border-top: 1px solid #ccc; margin: 1em 0;">

### Compatibility and Version
<img src="https://github.com/jymchng/fastapi-shield/actions/workflows/tests.yaml/badge.svg">
<img src="https://img.shields.io/badge/dynamic/toml?url=https%3A%2F%2Fraw.githubusercontent.com%2Fjymchng%2Ffastapi-shield%2Frefs%2Fheads%2Fmain%2Fpyproject.toml&query=%24.project.dependencies%5B0%5D&label=compat&labelColor=green">
<img src="https://img.shields.io/pypi/pyversions/fastapi-shield?color=green" alt="Python compat">
<a href="https://pypi.python.org/pypi/fastapi-shield"><img src="https://img.shields.io/pypi/v/fastapi-shield.svg" alt="PyPi"></a>

Expand Down
195 changes: 195 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import shutil
from functools import wraps
import pathlib
import urllib.request
import json
import re

import nox
import nox.command as nox_command
Expand Down Expand Up @@ -154,6 +157,126 @@ def session(
)


# --- FastAPI compatibility matrix helpers ---
PYPI_JSON_URL_TEMPLATE = "https://pypi.org/pypi/{package}/json"


def _parse_strict_version_tuple(ver_str: str):
"""Parse a strict semantic version 'X.Y.Z' into a tuple of ints.

Returns None if version doesn't match strict pattern (filters out pre-releases).
"""
m = re.match(r"^(\d+)\.(\d+)\.(\d+)$", ver_str)
if not m:
return None
return int(m.group(1)), int(m.group(2)), int(m.group(3))


def _version_tuple_to_str(t):
return f"{t[0]}.{t[1]}.{t[2]}"


def _cmp_major_minor(a, b):
"""Compare (major, minor) tuples only."""
if a[0] != b[0]:
return a[0] - b[0]
return a[1] - b[1]


def _get_min_supported_version_from_pyproject(
package_name: str, manifest: dict = PROJECT_MANIFEST
):
"""Extract minimum supported version from pyproject for given package.

Supports entries like 'fastapi>=0.100.1' and 'fastapi[standard]>=0.100.1'.
Returns a version tuple (major, minor, patch) or None if not found.
"""
deps = manifest.get("project", {}).get("dependencies", [])
patterns = [
rf"^{re.escape(package_name)}>=([0-9]+\.[0-9]+\.[0-9]+)$",
rf"^{re.escape(package_name)}\[[^\]]+\]>=([0-9]+\.[0-9]+\.[0-9]+)$",
]
for dep in deps:
for pat in patterns:
m = re.match(pat, dep)
if m:
vt = _parse_strict_version_tuple(m.group(1))
if vt:
return vt
return None


def _fetch_pypi_latest_and_releases(package_name: str):
"""Fetch latest version and releases list from PyPI JSON.

Returns (latest_version_tuple, releases_dict) where releases_dict maps
(major, minor) -> max patch available for that minor.
"""
url = PYPI_JSON_URL_TEMPLATE.format(package=package_name)
try:
with urllib.request.urlopen(url) as resp:
data = json.loads(resp.read().decode("utf-8"))
except Exception:
return None, {}

latest_str = data.get("info", {}).get("version")
latest_tuple = _parse_strict_version_tuple(latest_str) if latest_str else None

releases = data.get("releases", {})
minor_to_max_patch = {}
for ver_str in releases.keys():
vt = _parse_strict_version_tuple(ver_str)
if not vt:
# skip pre-release or non-strict versions
continue
major, minor, patch = vt
key = (major, minor)
prev = minor_to_max_patch.get(key)
if prev is None or patch > prev:
minor_to_max_patch[key] = patch

return latest_tuple, minor_to_max_patch


def _build_minor_matrix(min_vt, latest_vt, minor_to_max_patch):
"""Build a list of version strings representing the highest patch in each minor
from min_vt to latest_vt inclusive. Only includes minors that exist in releases.
"""
if not min_vt or not latest_vt:
return []
result = []
# Collect and sort available minor keys
available_minors = sorted(minor_to_max_patch.keys(), key=lambda k: (k[0], k[1]))
for major, minor in available_minors:
# range filter: min <= (major, minor) <= latest
if _cmp_major_minor((major, minor), (min_vt[0], min_vt[1])) < 0:
continue
if _cmp_major_minor((major, minor), (latest_vt[0], latest_vt[1])) > 0:
continue
patch = minor_to_max_patch[(major, minor)]
result.append(_version_tuple_to_str((major, minor, patch)))
return result


def _compute_fastapi_minor_matrix():
package = "fastapi"
min_vt = _get_min_supported_version_from_pyproject(package)
latest_vt, minor_to_max_patch = _fetch_pypi_latest_and_releases(package)
matrix = _build_minor_matrix(min_vt, latest_vt, minor_to_max_patch)
# Fallbacks if network fails or parsing issues
if not matrix:
vals = []
if min_vt:
vals.append(_version_tuple_to_str(min_vt))
if latest_vt and latest_vt != min_vt:
vals.append(_version_tuple_to_str(latest_vt))
matrix = vals or ["0.100.1"]
return matrix


FASTAPI_MINOR_MATRIX = _compute_fastapi_minor_matrix()


def uv_install_group_dependencies(session: Session, dependency_group: str):
pyproject = nox.project.load_toml(MANIFEST_FILENAME)
dependencies = nox.project.dependency_groups(pyproject, dependency_group)
Expand Down Expand Up @@ -256,6 +379,42 @@ def test(session: AlteredSession):
session.run(*command)


@session(
dependency_group=None,
default_posargs=[TEST_DIR, "-s", "-vv", "-n", "auto", "--dist", "worksteal"],
reuse_venv=False,
)
@nox.parametrize("fastapi_version", FASTAPI_MINOR_MATRIX)
def test_compat_fastapi(session: AlteredSession, fastapi_version: str):
"""Run tests against a matrix of FastAPI minor versions.

The matrix is computed from pyproject's minimum supported version and
PyPI's latest release, selecting the highest patch per minor.
"""
session.log(f"Testing compatibility with FastAPI versions: {FASTAPI_MINOR_MATRIX}")
# Pin FastAPI (and extras) to the target minor's highest patch before running tests.
# Install dev dependencies excluding FastAPI to avoid overriding the pinned version.
pyproject = load_toml(MANIFEST_FILENAME)
dev_deps = nox.project.dependency_groups(pyproject, "dev")
filtered_dev_deps = [d for d in dev_deps if not d.startswith("fastapi")]
if filtered_dev_deps:
session.install(*filtered_dev_deps)
# Pin FastAPI (and extras) to the target minor's highest patch before running tests.
session.install(f"fastapi[standard]=={fastapi_version}")
with alter_session(session, dependency_group=None) as session:
session.install(f".")
session.run(
*(
"python",
"-c",
f'from fastapi import __version__; assert __version__ == "{fastapi_version}", __version__',
)
)

# Run pytest using the Nox-managed virtualenv (avoid external interpreter).
session.run("pytest")


@contextlib.contextmanager
def alter_session(
session: AlteredSession,
Expand Down Expand Up @@ -606,6 +765,42 @@ def ci(session: Session):
test(session)


@session(reuse_venv=False)
def install_latest_tarball(session: Session):
import glob
import re

from packaging import version

# Get all tarball files
tarball_files = glob.glob(f"{DIST_DIR}/{PROJECT_NAME_NORMALIZED}-*.tar.gz")

if not tarball_files:
session.error("No tarball files found in dist/ directory")

# Extract version numbers using regex
version_pattern = re.compile(
rf"{PROJECT_NAME_NORMALIZED}-([0-9]+\.[0-9]+\.[0-9]+(?:\.[0-9]+)?(?:(?:a|b|rc)[0-9]+)?(?:\.post[0-9]+)?(?:\.dev[0-9]+)?).tar.gz"
)

# Create a list of (file_path, version) tuples
versioned_files = []
for file_path in tarball_files:
match = version_pattern.search(file_path)
if match:
ver_str = match.group(1)
versioned_files.append((file_path, version.parse(ver_str)))

if not versioned_files:
session.error("Could not extract version information from tarball files")

# Sort by version (highest first) and get the path
latest_tarball = sorted(versioned_files, key=lambda x: x[1], reverse=True)[0][0]
session.log(f"Installing latest version: {latest_tarball}")
session.run("uv", "run", "pip", "uninstall", f"{PROJECT_NAME}", "-y")
session.install(latest_tarball)


@session(reuse_venv=False)
def test_client_install_run(session: Session):
with alter_session(session, dependency_group="dev"):
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ classifiers = [
]
requires-python = ">=3.9"
dependencies = [
"fastapi>=0.100.1",
"fastapi>=0.115.2",
"typing-extensions>=4.0.0; python_version<'3.10'",
]
dynamic = []
Expand Down Expand Up @@ -116,7 +116,8 @@ reportMissingImports = "none"
[dependency-groups]
dev = [
"bcrypt==4.3.0",
"fastapi[standard]>=0.100.1",
"email-validator>=2.3.0",
"fastapi>=0.115.2",
"httpx>=0.24.0",
"isort>=6.0.1",
"mypy>=1.18.2",
Expand Down
2 changes: 1 addition & 1 deletion src/fastapi_shield/shield.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@

from fastapi import HTTPException, Request, Response, status
from fastapi._compat import _normalize_errors
from fastapi.dependencies.utils import is_coroutine_callable
from fastapi.exceptions import RequestValidationError
from fastapi.params import Security
from typing_extensions import Doc

# Import directly to make patching work correctly in tests
import fastapi_shield.utils
from fastapi_shield.utils import is_coroutine_callable
from fastapi_shield.consts import (
IS_SHIELDED_ENDPOINT_KEY,
SHIELDED_ENDPOINT_PATH_FORMAT_KEY,
Expand Down
37 changes: 36 additions & 1 deletion src/fastapi_shield/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,58 @@
from collections.abc import Iterator
from contextlib import AsyncExitStack
from inspect import Parameter, signature
import inspect
from typing import Any, Callable, Optional, List, Union

from fastapi import HTTPException, Request, params
from fastapi._compat import ModelField, Undefined
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_flat_dependant,
solve_dependencies,
)
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from fastapi.exceptions import RequestValidationError

from starlette.routing import get_name


# copied from `fastapi.dependencies.utils`
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(dunder_call)


# copied from `fastapi.dependencies.utils`
def _should_embed_body_fields(fields: List["ModelField"]) -> bool:
if not fields:
return False
# More than one dependency could have the same field, it would show up as multiple
# fields but it's the same one, so count them by name
body_param_names_set = {field.name for field in fields}
# A top level field has to be a single field, not multiple
if len(body_param_names_set) > 1:
return True
first_field = fields[0]
# If it explicitly specifies it is embedded, it has to be embedded
if getattr(first_field.field_info, "embed", None):
return True
# If it's a Form (or File) field, it has to be a BaseModel to be top level
# otherwise it has to be embedded, so that the key value pair can be extracted
if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
first_field.type_, BaseModel
):
return True
return False


def generate_unique_id_for_fastapi_shield(dependant: Dependant, path_format: str):
"""Generate a unique identifier for FastAPI Shield dependants.

Expand Down
12 changes: 2 additions & 10 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,8 @@ def test_unprotected_endpoint():
client = TestClient(app)
response = client.get("/unprotected")
assert response.status_code == 200
assert response.json() == {
"message": "This is an unprotected endpoint",
"user": {
"dependency": {},
"use_cache": True,
"scopes": [],
"shielded_dependency": {},
"unblocked": False,
},
}, response.json()
result_json = response.json()
assert result_json["message"] == "This is an unprotected endpoint", response.json()


def test_protected_endpoint_without_token():
Expand Down
12 changes: 2 additions & 10 deletions tests/test_basics_three.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,8 @@ def test_unprotected_endpoint():
client = TestClient(app)
response = client.get("/unprotected")
assert response.status_code == 200
assert response.json() == {
"message": "This is an unprotected endpoint",
"user": {
"dependency": {},
"use_cache": True,
"scopes": [],
"shielded_dependency": {},
"unblocked": False,
},
}, response.json()
result_json = response.json()
assert result_json["message"] == "This is an unprotected endpoint", response.json()


def test_protected_endpoint_without_token():
Expand Down
12 changes: 2 additions & 10 deletions tests/test_basics_two.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,8 @@ def test_unprotected_endpoint():
client = TestClient(app)
response = client.get("/unprotected")
assert response.status_code == 200
assert response.json() == {
"message": "This is an unprotected endpoint",
"user": {
"dependency": {},
"use_cache": True,
"scopes": [],
"shielded_dependency": {},
"unblocked": False,
},
}, response.json()
result_json = response.json()
assert result_json["message"] == "This is an unprotected endpoint", response.json()


def test_protected_endpoint_without_token():
Expand Down
Loading