Skip to content

Commit f3e8f85

Browse files
committed
Attach authn schema in tigrbl_auth tests
1 parent 438fb8c commit f3e8f85

7 files changed

Lines changed: 442 additions & 13 deletions

File tree

pkgs/standards/tigrbl/tigrbl/runtime/errors/converters.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
IntegrityError,
1111
DBAPIError,
1212
OperationalError,
13+
StatementError,
1314
NoResultFound,
1415
_is_asyncpg_constraint_error,
1516
_stringify_exc,
1617
_format_validation,
18+
_format_sqlalchemy_error_data,
19+
_looks_like_validation_error,
1720
)
1821
from .exceptions import TigrblError
1922
from .mappings import (
@@ -122,6 +125,16 @@ def _classify_exception(
122125
if _is_asyncpg_constraint_error(exc):
123126
return status.HTTP_409_CONFLICT, _stringify_exc(exc), None
124127

128+
if (StatementError is not None) and isinstance(exc, StatementError):
129+
msg = _stringify_exc(exc)
130+
if _looks_like_validation_error(msg):
131+
return status.HTTP_422_UNPROCESSABLE_ENTITY, msg, None
132+
return (
133+
status.HTTP_500_INTERNAL_SERVER_ERROR,
134+
msg,
135+
_format_sqlalchemy_error_data(exc),
136+
)
137+
125138
if (IntegrityError is not None) and isinstance(exc, IntegrityError):
126139
msg = _stringify_exc(exc)
127140
lower_msg = msg.lower()
@@ -130,10 +143,21 @@ def _classify_exception(
130143
return status.HTTP_409_CONFLICT, msg, None
131144

132145
if (OperationalError is not None) and isinstance(exc, OperationalError):
133-
return status.HTTP_503_SERVICE_UNAVAILABLE, _stringify_exc(exc), None
146+
return (
147+
status.HTTP_503_SERVICE_UNAVAILABLE,
148+
_stringify_exc(exc),
149+
_format_sqlalchemy_error_data(exc),
150+
)
134151

135152
if (DBAPIError is not None) and isinstance(exc, DBAPIError):
136-
return status.HTTP_500_INTERNAL_SERVER_ERROR, _stringify_exc(exc), None
153+
msg = _stringify_exc(exc)
154+
if _looks_like_validation_error(msg):
155+
return status.HTTP_422_UNPROCESSABLE_ENTITY, msg, None
156+
return (
157+
status.HTTP_500_INTERNAL_SERVER_ERROR,
158+
msg,
159+
_format_sqlalchemy_error_data(exc),
160+
)
137161

138162
# 5) Fallback
139163
return status.HTTP_500_INTERNAL_SERVER_ERROR, _stringify_exc(exc), None

pkgs/standards/tigrbl/tigrbl/runtime/errors/utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,16 @@ class _Status:
5555

5656
try:
5757
# SQLAlchemy v1/v2 exception sets
58-
from sqlalchemy.exc import IntegrityError, DBAPIError, OperationalError
58+
from sqlalchemy.exc import (
59+
IntegrityError,
60+
DBAPIError,
61+
OperationalError,
62+
StatementError,
63+
)
5964
from sqlalchemy.orm.exc import NoResultFound # type: ignore
6065
except Exception: # pragma: no cover
61-
IntegrityError = DBAPIError = OperationalError = NoResultFound = None # type: ignore
66+
IntegrityError = DBAPIError = OperationalError = StatementError = None # type: ignore
67+
NoResultFound = None # type: ignore
6268

6369

6470
# Detect asyncpg constraint errors without importing asyncpg (optional dep).
@@ -99,6 +105,50 @@ def _format_validation(err: Any) -> Any:
99105
return _limit(str(err))
100106

101107

108+
def _format_sqlalchemy_error_data(exc: BaseException) -> Optional[Dict[str, Any]]:
109+
data: Dict[str, Any] = {}
110+
hide_parameters = bool(getattr(exc, "hide_parameters", False))
111+
statement = getattr(exc, "statement", None)
112+
if statement and not hide_parameters:
113+
data["statement"] = _limit(str(statement))
114+
if hasattr(exc, "params"):
115+
params = getattr(exc, "params")
116+
if hide_parameters:
117+
data.update(_safe_params_metadata(params))
118+
else:
119+
try:
120+
params_repr = repr(params)
121+
except Exception: # pragma: no cover
122+
params_repr = "<unrepresentable params>"
123+
data["params"] = _limit(params_repr)
124+
orig = getattr(exc, "orig", None)
125+
if orig:
126+
data["orig"] = _limit(f"{type(orig).__name__}: {orig}")
127+
return data or None
128+
129+
130+
def _safe_params_metadata(params: Any) -> Dict[str, Any]:
131+
metadata: Dict[str, Any] = {"params_redacted": True}
132+
if isinstance(params, Mapping):
133+
metadata["param_keys"] = list(params.keys())
134+
metadata["param_types"] = {
135+
key: type(value).__name__ for key, value in params.items()
136+
}
137+
return metadata
138+
if isinstance(params, (list, tuple)):
139+
metadata["param_keys"] = list(range(len(params)))
140+
metadata["param_types"] = [type(value).__name__ for value in params]
141+
return metadata
142+
metadata["param_keys"] = [0]
143+
metadata["param_types"] = [type(params).__name__]
144+
return metadata
145+
146+
147+
def _looks_like_validation_error(message: str) -> bool:
148+
lowered = message.lower()
149+
return "not null constraint" in lowered or "check constraint" in lowered
150+
151+
102152
def _get_temp(ctx: Any) -> Mapping[str, Any]:
103153
tmp = getattr(ctx, "temp", None)
104154
return tmp if isinstance(tmp, Mapping) else {}
@@ -139,11 +189,14 @@ def _read_in_errors(ctx: Any) -> List[Dict[str, Any]]:
139189
"IntegrityError",
140190
"DBAPIError",
141191
"OperationalError",
192+
"StatementError",
142193
"NoResultFound",
143194
"_is_asyncpg_constraint_error",
144195
"_limit",
145196
"_stringify_exc",
146197
"_format_validation",
198+
"_format_sqlalchemy_error_data",
199+
"_looks_like_validation_error",
147200
"_get_temp",
148201
"_has_in_errors",
149202
"_read_in_errors",

pkgs/standards/tigrbl_auth/tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ async def test_db_engine() -> AsyncGenerator[Engine, None]:
6262
engine_resolver.register_api(surface_api, provider)
6363
engine_resolver.register_api(app, provider)
6464
setattr(surface_api, "_ddl_executed", False)
65-
await surface_api.initialize()
65+
temp_dir = Path(tempfile.mkdtemp())
66+
authn_db = temp_dir / "authn.db"
67+
await surface_api.initialize(sqlite_attachments={"authn": str(authn_db)})
6668
try:
6769
yield engine
6870
finally:
@@ -71,6 +73,7 @@ async def test_db_engine() -> AsyncGenerator[Engine, None]:
7173
engine_resolver.register_api(surface_api, original_surface)
7274
engine_resolver.register_api(app, original_app)
7375
setattr(surface_api, "_ddl_executed", False)
76+
shutil.rmtree(temp_dir, ignore_errors=True)
7477

7578

7679
@pytest_asyncio.fixture

pkgs/standards/tigrbl_auth/tests/unit/test_rfc8523_jwt_client_auth.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
from tigrbl_auth.errors import InvalidTokenError
12+
from tigrbl_auth.rfc import rfc8523
1213
from tigrbl_auth.rfc.rfc8523 import (
1314
RFC8523_SPEC_URL,
1415
validate_enhanced_jwt_bearer,
@@ -189,10 +190,31 @@ def test_validate_enhanced_jwt_bearer_disabled():
189190

190191
@pytest.mark.unit
191192
def test_is_jwt_replay():
192-
"""RFC 8523: JWT replay detection placeholder."""
193-
# This is a placeholder test for the replay detection function
194-
result = is_jwt_replay("test-jti", int(time.time()), 300)
195-
assert result is False # Currently always returns False
193+
"""RFC 8523: JWT replay detection with in-memory cache."""
194+
rfc8523._JTI_CACHE.clear()
195+
iat = int(time.time())
196+
assert is_jwt_replay("test-jti", iat, 300) is False
197+
assert is_jwt_replay("test-jti", iat, 300) is True
198+
199+
200+
@pytest.mark.unit
201+
def test_validate_enhanced_jwt_bearer_replay_detected():
202+
"""RFC 8523: Replay protection rejects reused JTIs."""
203+
rfc8523._JTI_CACHE.clear()
204+
with patch.object(settings, "enable_rfc8523", True):
205+
with patch.object(settings, "enable_rfc7523", True):
206+
token = encode_jwt(
207+
iss="client",
208+
sub="client",
209+
aud="token-endpoint",
210+
exp=int(time.time()) + 300,
211+
iat=int(time.time()),
212+
jti="unique-jwt-id-456",
213+
tid="tenant-1",
214+
)
215+
validate_enhanced_jwt_bearer(token, audience="token-endpoint")
216+
with pytest.raises(InvalidTokenError, match="JWT replay detected"):
217+
validate_enhanced_jwt_bearer(token, audience="token-endpoint")
196218

197219

198220
@pytest.mark.unit

pkgs/standards/tigrbl_auth/tigrbl_auth/rfc/rfc8523.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from __future__ import annotations
1111

12+
import threading
1213
import time
1314
import warnings
1415
from typing import Any, Dict, Iterable, Optional, Set, Union
@@ -19,6 +20,18 @@
1920

2021
RFC8523_SPEC_URL = "https://www.rfc-editor.org/rfc/rfc8523"
2122
REQUIRED_CLAIMS: Set[str] = {"iss", "sub", "aud", "exp", "iat", "jti"}
23+
_JTI_CACHE: Dict[str, int] = {}
24+
_JTI_LOCK = threading.Lock()
25+
26+
27+
def _purge_expired_jtis(current_time: int, max_age_seconds: int) -> None:
28+
expired = [
29+
jti
30+
for jti, seen_at in _JTI_CACHE.items()
31+
if current_time - seen_at > max_age_seconds
32+
]
33+
for jti in expired:
34+
_JTI_CACHE.pop(jti, None)
2235

2336

2437
def validate_enhanced_jwt_bearer(
@@ -82,6 +95,8 @@ def validate_enhanced_jwt_bearer(
8295
jti = claims.get("jti")
8396
if not isinstance(jti, str) or not jti.strip():
8497
raise ValueError("'jti' claim must be a non-empty string")
98+
if is_jwt_replay(jti, iat, max_age_seconds=max_age_seconds):
99+
raise InvalidTokenError("JWT replay detected")
85100

86101
return claims
87102

@@ -152,8 +167,8 @@ def create_client_assertion_jwt(
152167
def is_jwt_replay(jti: str, iat: int, max_age_seconds: int = 300) -> bool:
153168
"""Check if a JWT ID indicates a replay attack.
154169
155-
This is a placeholder implementation. In production, this should
156-
check against a cache/database of recently used JTIs.
170+
This uses an in-memory cache of recently seen JTIs. In production, this should
171+
be backed by a shared cache or database to enforce replay protection.
157172
158173
Args:
159174
jti: JWT ID claim value
@@ -163,8 +178,13 @@ def is_jwt_replay(jti: str, iat: int, max_age_seconds: int = 300) -> bool:
163178
Returns:
164179
True if the JWT appears to be a replay, False otherwise
165180
"""
166-
# TODO: Implement proper JTI tracking with cache/database
167-
# For now, always return False (no replay detection)
181+
current_time = int(time.time())
182+
with _JTI_LOCK:
183+
_purge_expired_jtis(current_time, max_age_seconds)
184+
if jti in _JTI_CACHE:
185+
return True
186+
if current_time - iat <= max_age_seconds:
187+
_JTI_CACHE[jti] = iat
168188
return False
169189

170190

0 commit comments

Comments
 (0)