Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions pkgs/standards/tigrbl/tigrbl/runtime/errors/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
IntegrityError,
DBAPIError,
OperationalError,
StatementError,
NoResultFound,
_is_asyncpg_constraint_error,
_stringify_exc,
_format_validation,
_format_sqlalchemy_error_data,
_looks_like_validation_error,
)
from .exceptions import TigrblError
from .mappings import (
Expand Down Expand Up @@ -122,6 +125,16 @@ def _classify_exception(
if _is_asyncpg_constraint_error(exc):
return status.HTTP_409_CONFLICT, _stringify_exc(exc), None

if (StatementError is not None) and isinstance(exc, StatementError):
msg = _stringify_exc(exc)
if _looks_like_validation_error(msg):
return status.HTTP_422_UNPROCESSABLE_ENTITY, msg, None
return (
status.HTTP_500_INTERNAL_SERVER_ERROR,
msg,
_format_sqlalchemy_error_data(exc),
)

if (IntegrityError is not None) and isinstance(exc, IntegrityError):
msg = _stringify_exc(exc)
lower_msg = msg.lower()
Expand All @@ -130,10 +143,21 @@ def _classify_exception(
return status.HTTP_409_CONFLICT, msg, None

if (OperationalError is not None) and isinstance(exc, OperationalError):
return status.HTTP_503_SERVICE_UNAVAILABLE, _stringify_exc(exc), None
return (
status.HTTP_503_SERVICE_UNAVAILABLE,
_stringify_exc(exc),
_format_sqlalchemy_error_data(exc),
)

if (DBAPIError is not None) and isinstance(exc, DBAPIError):
return status.HTTP_500_INTERNAL_SERVER_ERROR, _stringify_exc(exc), None
msg = _stringify_exc(exc)
if _looks_like_validation_error(msg):
return status.HTTP_422_UNPROCESSABLE_ENTITY, msg, None
return (
status.HTTP_500_INTERNAL_SERVER_ERROR,
msg,
_format_sqlalchemy_error_data(exc),
)

# 5) Fallback
return status.HTTP_500_INTERNAL_SERVER_ERROR, _stringify_exc(exc), None
Expand Down
57 changes: 55 additions & 2 deletions pkgs/standards/tigrbl/tigrbl/runtime/errors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@ class _Status:

try:
# SQLAlchemy v1/v2 exception sets
from sqlalchemy.exc import IntegrityError, DBAPIError, OperationalError
from sqlalchemy.exc import (
IntegrityError,
DBAPIError,
OperationalError,
StatementError,
)
from sqlalchemy.orm.exc import NoResultFound # type: ignore
except Exception: # pragma: no cover
IntegrityError = DBAPIError = OperationalError = NoResultFound = None # type: ignore
IntegrityError = DBAPIError = OperationalError = StatementError = None # type: ignore
NoResultFound = None # type: ignore


# Detect asyncpg constraint errors without importing asyncpg (optional dep).
Expand Down Expand Up @@ -99,6 +105,50 @@ def _format_validation(err: Any) -> Any:
return _limit(str(err))


def _format_sqlalchemy_error_data(exc: BaseException) -> Optional[Dict[str, Any]]:
data: Dict[str, Any] = {}
hide_parameters = bool(getattr(exc, "hide_parameters", False))
statement = getattr(exc, "statement", None)
if statement and not hide_parameters:
data["statement"] = _limit(str(statement))
if hasattr(exc, "params"):
params = getattr(exc, "params")
if hide_parameters:
data.update(_safe_params_metadata(params))
else:
try:
params_repr = repr(params)
except Exception: # pragma: no cover
params_repr = "<unrepresentable params>"
data["params"] = _limit(params_repr)
orig = getattr(exc, "orig", None)
if orig:
data["orig"] = _limit(f"{type(orig).__name__}: {orig}")
return data or None


def _safe_params_metadata(params: Any) -> Dict[str, Any]:
metadata: Dict[str, Any] = {"params_redacted": True}
if isinstance(params, Mapping):
metadata["param_keys"] = list(params.keys())
metadata["param_types"] = {
key: type(value).__name__ for key, value in params.items()
}
return metadata
if isinstance(params, (list, tuple)):
metadata["param_keys"] = list(range(len(params)))
metadata["param_types"] = [type(value).__name__ for value in params]
return metadata
metadata["param_keys"] = [0]
metadata["param_types"] = [type(params).__name__]
return metadata


def _looks_like_validation_error(message: str) -> bool:
lowered = message.lower()
return "not null constraint" in lowered or "check constraint" in lowered


def _get_temp(ctx: Any) -> Mapping[str, Any]:
tmp = getattr(ctx, "temp", None)
return tmp if isinstance(tmp, Mapping) else {}
Expand Down Expand Up @@ -139,11 +189,14 @@ def _read_in_errors(ctx: Any) -> List[Dict[str, Any]]:
"IntegrityError",
"DBAPIError",
"OperationalError",
"StatementError",
"NoResultFound",
"_is_asyncpg_constraint_error",
"_limit",
"_stringify_exc",
"_format_validation",
"_format_sqlalchemy_error_data",
"_looks_like_validation_error",
"_get_temp",
"_has_in_errors",
"_read_in_errors",
Expand Down
5 changes: 4 additions & 1 deletion pkgs/standards/tigrbl_auth/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ async def test_db_engine() -> AsyncGenerator[Engine, None]:
engine_resolver.register_api(surface_api, provider)
engine_resolver.register_api(app, provider)
setattr(surface_api, "_ddl_executed", False)
await surface_api.initialize()
temp_dir = Path(tempfile.mkdtemp())
authn_db = temp_dir / "authn.db"
await surface_api.initialize(sqlite_attachments={"authn": str(authn_db)})
try:
yield engine
finally:
Expand All @@ -71,6 +73,7 @@ async def test_db_engine() -> AsyncGenerator[Engine, None]:
engine_resolver.register_api(surface_api, original_surface)
engine_resolver.register_api(app, original_app)
setattr(surface_api, "_ddl_executed", False)
shutil.rmtree(temp_dir, ignore_errors=True)


@pytest_asyncio.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from tigrbl_auth.errors import InvalidTokenError
from tigrbl_auth.rfc import rfc8523
from tigrbl_auth.rfc.rfc8523 import (
RFC8523_SPEC_URL,
validate_enhanced_jwt_bearer,
Expand Down Expand Up @@ -189,10 +190,31 @@ def test_validate_enhanced_jwt_bearer_disabled():

@pytest.mark.unit
def test_is_jwt_replay():
"""RFC 8523: JWT replay detection placeholder."""
# This is a placeholder test for the replay detection function
result = is_jwt_replay("test-jti", int(time.time()), 300)
assert result is False # Currently always returns False
"""RFC 8523: JWT replay detection with in-memory cache."""
rfc8523._JTI_CACHE.clear()
iat = int(time.time())
assert is_jwt_replay("test-jti", iat, 300) is False
assert is_jwt_replay("test-jti", iat, 300) is True


@pytest.mark.unit
def test_validate_enhanced_jwt_bearer_replay_detected():
"""RFC 8523: Replay protection rejects reused JTIs."""
rfc8523._JTI_CACHE.clear()
with patch.object(settings, "enable_rfc8523", True):
with patch.object(settings, "enable_rfc7523", True):
token = encode_jwt(
iss="client",
sub="client",
aud="token-endpoint",
exp=int(time.time()) + 300,
iat=int(time.time()),
jti="unique-jwt-id-456",
tid="tenant-1",
)
validate_enhanced_jwt_bearer(token, audience="token-endpoint")
with pytest.raises(InvalidTokenError, match="JWT replay detected"):
validate_enhanced_jwt_bearer(token, audience="token-endpoint")


@pytest.mark.unit
Expand Down
28 changes: 24 additions & 4 deletions pkgs/standards/tigrbl_auth/tigrbl_auth/rfc/rfc8523.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from __future__ import annotations

import threading
import time
import warnings
from typing import Any, Dict, Iterable, Optional, Set, Union
Expand All @@ -19,6 +20,18 @@

RFC8523_SPEC_URL = "https://www.rfc-editor.org/rfc/rfc8523"
REQUIRED_CLAIMS: Set[str] = {"iss", "sub", "aud", "exp", "iat", "jti"}
_JTI_CACHE: Dict[str, int] = {}
_JTI_LOCK = threading.Lock()


def _purge_expired_jtis(current_time: int, max_age_seconds: int) -> None:
expired = [
jti
for jti, seen_at in _JTI_CACHE.items()
if current_time - seen_at > max_age_seconds
]
for jti in expired:
_JTI_CACHE.pop(jti, None)


def validate_enhanced_jwt_bearer(
Expand Down Expand Up @@ -82,6 +95,8 @@ def validate_enhanced_jwt_bearer(
jti = claims.get("jti")
if not isinstance(jti, str) or not jti.strip():
raise ValueError("'jti' claim must be a non-empty string")
if is_jwt_replay(jti, iat, max_age_seconds=max_age_seconds):
raise InvalidTokenError("JWT replay detected")

return claims

Expand Down Expand Up @@ -152,8 +167,8 @@ def create_client_assertion_jwt(
def is_jwt_replay(jti: str, iat: int, max_age_seconds: int = 300) -> bool:
"""Check if a JWT ID indicates a replay attack.

This is a placeholder implementation. In production, this should
check against a cache/database of recently used JTIs.
This uses an in-memory cache of recently seen JTIs. In production, this should
be backed by a shared cache or database to enforce replay protection.

Args:
jti: JWT ID claim value
Expand All @@ -163,8 +178,13 @@ def is_jwt_replay(jti: str, iat: int, max_age_seconds: int = 300) -> bool:
Returns:
True if the JWT appears to be a replay, False otherwise
"""
# TODO: Implement proper JTI tracking with cache/database
# For now, always return False (no replay detection)
current_time = int(time.time())
with _JTI_LOCK:
_purge_expired_jtis(current_time, max_age_seconds)
if jti in _JTI_CACHE:
return True
if current_time - iat <= max_age_seconds:
_JTI_CACHE[jti] = iat
Comment on lines +183 to +187

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Cache JTIs for skew-allowed tokens

Because validate_enhanced_jwt_bearer allows tokens up to max_age_seconds + clock_skew_seconds old, a token with age just over max_age_seconds still passes validation, but is_jwt_replay won’t cache it due to the stricter current_time - iat <= max_age_seconds check. That means reusing the same token within the skew window won’t be detected as a replay. Consider caching JTIs for the full validation window (e.g., include skew or pass the effective max age) so second use of an otherwise valid token is rejected.

Useful? React with 👍 / 👎.

return False


Expand Down
Loading
Loading