Skip to content

Commit 05c2e6c

Browse files
committed
Merge branch 'feat_ta_support' of https://github.com/GT-BAITA/iam-proxy-italia into GT-BAITA-feat_ta_support
2 parents 86a13a3 + f6b4505 commit 05c2e6c

1 file changed

Lines changed: 223 additions & 26 deletions

File tree

  • iam-proxy-italia-project/backends/cieoidc

iam-proxy-italia-project/backends/cieoidc/cieoidc.py

Lines changed: 223 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,82 @@
11
import logging
22
import inspect
3-
3+
import time
4+
from datetime import datetime, timezone
5+
from types import SimpleNamespace
6+
from typing import Optional, List
47

58
from satosa.backends.base import BackendModule
69
from satosa.backends.oauth import get_metadata_desc_for_oauth_backend
710

811
from .utils.endpoints_loader import EndpointsLoader
12+
from .storage.db_engine import OidcDbEngine
13+
from .models.trust_chain_cache import TrustChainCache
914

1015
from pyeudiw.federation.trust_chain_builder import TrustChainBuilder
1116
from pyeudiw.federation.statements import EntityStatement, get_entity_configurations
1217

18+
from .utils.exceptions import TrustChainNotFoundError
19+
1320

1421
logger = logging.getLogger(__name__)
1522

1623

24+
def _trust_chain_from_cache(cached: TrustChainCache):
25+
"""
26+
Build a minimal trust-chain-like object from TrustChainCache.
27+
Has .subject and .subject_configuration.payload as required by authorization endpoint.
28+
"""
29+
wrapper = SimpleNamespace()
30+
wrapper.subject = cached.provider_url
31+
wrapper.subject_configuration = SimpleNamespace(payload=cached.payload)
32+
return wrapper
33+
34+
35+
def _is_cache_expired(cached: TrustChainCache, now=None) -> bool:
36+
"""Return True if the cached payload is expired (exp in the past)."""
37+
exp = cached.exp or cached.payload.get("exp")
38+
if exp is None:
39+
return False
40+
t = now if now is not None else time.time()
41+
return t >= exp
42+
43+
44+
class TrustChainResolver:
45+
"""
46+
Resolves trust chains from cache or builds them on-demand via discovery.
47+
When a provider is requested but not in the cache (e.g. startup failed),
48+
discovery is performed and the resulting trust chain is stored for reuse.
49+
"""
50+
51+
def __init__(self, trust_chains: dict, build_callback):
52+
"""
53+
:param trust_chains: Dict of provider_url -> TrustChainBuilder (mutated when new chains are built)
54+
:param build_callback: Callable(provider_url) -> TrustChainBuilder; raises TrustChainNotFoundError on failure
55+
"""
56+
self._chains = trust_chains
57+
self._build = build_callback
58+
59+
def __contains__(self, key):
60+
return key in self._chains
61+
62+
def __getitem__(self, key):
63+
return self._chains[key]
64+
65+
def keys(self):
66+
return self._chains.keys()
67+
68+
def get_or_build(self, provider_url: str) -> TrustChainBuilder:
69+
"""Get trust chain from cache, or discover and store it on-demand."""
70+
for key in (
71+
provider_url,
72+
provider_url.rstrip("/"),
73+
provider_url + "/" if not provider_url.endswith("/") else None,
74+
):
75+
if key and key in self._chains:
76+
return self._chains[key]
77+
return self._build(provider_url)
78+
79+
1780
class CieOidcBackend(BackendModule):
1881

1982
def __init__(self, callback, internal_attributes, module_config, base_url, name):
@@ -23,7 +86,12 @@ def __init__(self, callback, internal_attributes, module_config, base_url, name)
2386
super().__init__(callback, internal_attributes, base_url, name)
2487
self.config = module_config
2588
self.endpoints = {}
89+
self._validated_trust_anchors: List[EntityStatement] = []
2690
self.trust_chain = self._generate_trust_chains()
91+
self._trust_chain_resolver = TrustChainResolver(
92+
self.trust_chain,
93+
self.get_or_build_trust_chain,
94+
)
2795
metadata = self.config.get("metadata", {}).get("openid_relying_party", {})
2896
self._client_id = metadata.get("client_id") or f"{base_url}/{name}"
2997

@@ -58,7 +126,7 @@ def register_endpoints(self):
58126
self.name,
59127
self.auth_callback_func,
60128
self.converter,
61-
self.trust_chain,
129+
self._trust_chain_resolver,
62130
)
63131

64132
url_map = []
@@ -81,42 +149,116 @@ def get_metadata_desc(self):
81149
meta = get_metadata_desc_for_oauth_backend(self._client_id, self.config)
82150
return meta
83151

84-
def _generate_trust_chains(self) -> dict:
85-
'''
86-
private method _generate_trust_chains:
87-
This method generate a list of trust-chain. After create a entity statement
88-
for Trust Anchor, validate itself, and call the generate_trust_chain
89-
for all providers into configuration.
90-
Add all providers into dictionary.
91-
'''
92-
logger.debug(
93-
f"Entering method: {inspect.getframeinfo(inspect.currentframe()).function}. "
94-
)
152+
def _get_storage(self) -> Optional[OidcDbEngine]:
153+
"""Create and return storage engine; connect if needed. Returns None if no storage configured."""
154+
if getattr(self, "_storage_engine", None) is not None:
155+
return self._storage_engine
156+
storage_config = self.config.get("storage") or {}
157+
if not storage_config:
158+
return None
159+
try:
160+
engine = OidcDbEngine(storage_config)
161+
engine.connect()
162+
self._storage_engine = engine
163+
return engine
164+
except Exception as e:
165+
logger.warning("Could not initialize storage for trust chain persistence: %s", e)
166+
return None
167+
168+
def _store_trust_chain(self, chain, provider_url: str) -> None:
169+
"""Persist trust chain to database if storage is available."""
170+
engine = self._get_storage()
171+
if engine is None:
172+
return
173+
try:
174+
payload = chain.subject_configuration.payload
175+
exp = payload.get("exp")
176+
variants = {
177+
provider_url.rstrip("/"),
178+
provider_url.rstrip("/") + "/"
179+
}
180+
181+
for url in variants:
182+
cached = TrustChainCache(
183+
provider_url=url,
184+
payload=payload,
185+
exp=exp,
186+
created=datetime.now(timezone.utc),
187+
)
188+
engine.add_or_update_trust_chain(cached)
189+
except Exception as e:
190+
logger.warning("Could not persist trust chain for %s: %s", provider_url, e)
95191

192+
def _generate_trust_chains(self) -> dict:
193+
"""try load from DB, or can try discovery with TA's list."""
96194
httpc_params = self.config["trust_chain"]["config"]["httpc_params"]
97-
98-
ta_url = self.config["trust_chain"]["config"]["trust_anchor"][0]
99-
jwt = get_entity_configurations(ta_url, httpc_params=httpc_params)[0]
100-
101-
trust_anchor_ec = EntityStatement(jwt, httpc_params=httpc_params)
102-
103-
trust_anchor_ec.validate_by_itself()
104-
105195
providers = self.config["providers"]
106-
107196
trust_chains = dict()
108197

109198
for provider_url in providers:
199+
# try load from DB
200+
engine = self._get_storage()
201+
if engine:
202+
cached = engine.get_trust_chain_by_provider(provider_url)
203+
if cached and not _is_cache_expired(cached):
204+
chain = _trust_chain_from_cache(cached)
205+
self._add_to_dict(trust_chains, provider_url, chain)
206+
continue
207+
208+
# Build via discovery, tryng each TA
110209
try:
111-
trust_chains[provider_url] = CieOidcBackend.generate_trust_chain(
112-
trust_anchor_ec, provider_url, httpc_params)
113-
except Exception as exception:
210+
tas = self._ensure_trust_anchors()
211+
chain_built = False
212+
for ta_ec in tas:
213+
try:
214+
chain = self.generate_trust_chain(
215+
ta_ec, provider_url, httpc_params
216+
)
217+
self._add_to_dict(trust_chains, provider_url, chain)
218+
self._store_trust_chain(chain, provider_url)
219+
logger.info(
220+
"Provider %s linked to TA %s", provider_url, ta_ec.sub
221+
)
222+
chain_built = True
223+
break
224+
except Exception as e:
225+
logger.warning(
226+
"Failed to build trust chain for provider %s with TA %s: %s",
227+
provider_url,
228+
getattr(ta_ec, "sub", "<unknown>"),
229+
e,
230+
)
231+
if not chain_built:
232+
logger.error(
233+
"Could not build trust chain for provider %s with any configured trust anchor",
234+
provider_url,
235+
)
236+
except Exception as e:
114237
logger.error(
115-
f"Exception {exception} generated from this provider {provider_url}"
238+
"Could not resolve trust chain for %s: %s", provider_url, e
116239
)
117240

118241
return trust_chains
119242

243+
def _add_to_dict(self, d, url, chain):
244+
"""Helper to add a normalized URL in a dict."""
245+
# Always store the exact URL key.
246+
d[url] = chain
247+
# Also store the normalized variant (with/without trailing slash),
248+
# but avoid silently overwriting an existing normalized entry.
249+
norm = url.rstrip("/") if url.endswith("/") else url + "/"
250+
if norm != url:
251+
if norm in d:
252+
logger.warning(
253+
"Duplicate provider URL variants configured: %s and %s; "
254+
"keeping existing trust chain for %s",
255+
url,
256+
norm,
257+
norm,
258+
)
259+
else:
260+
d[norm] = chain
261+
120262
@staticmethod
121263
def generate_trust_chain(
122264
trust_anchor_ec: EntityStatement, provider_endpoint: str, httpc_params
@@ -141,3 +283,58 @@ def generate_trust_chain(
141283
trust_chain.start()
142284
trust_chain.apply_metadata_policy()
143285
return trust_chain
286+
287+
def _ensure_trust_anchors(self) -> List[EntityStatement]:
288+
"""Return a list of valid TAs."""
289+
if not self._validated_trust_anchors:
290+
httpc_params = self.config["trust_chain"]["config"]["httpc_params"]
291+
ta_urls = self.config["trust_chain"]["config"]["trust_anchor"]
292+
293+
for ta_url in ta_urls:
294+
try:
295+
jwt = get_entity_configurations(ta_url, httpc_params=httpc_params)[0]
296+
ta_ec = EntityStatement(jwt, httpc_params=httpc_params)
297+
ta_ec.validate_by_itself()
298+
self._validated_trust_anchors.append(ta_ec)
299+
except Exception as e:
300+
logger.error(f"Failed to validate TA {ta_url}: {e}")
301+
302+
if not self._validated_trust_anchors:
303+
raise ValueError("No valid Trust Anchors could be loaded.")
304+
305+
return self._validated_trust_anchors
306+
307+
def get_or_build_trust_chain(self, provider_url: str) -> TrustChainBuilder:
308+
"""
309+
Get trust chain from cache, or from DB, or discover and build it on-demand.
310+
Newly built chains are stored in memory and in the database.
311+
"""
312+
providers = self.config.get("providers", [])
313+
provider_variants = [provider_url, provider_url.rstrip("/")]
314+
if not provider_url.endswith("/"):
315+
provider_variants.append(provider_url + "/")
316+
if not any(p in providers for p in provider_variants if p):
317+
raise TrustChainNotFoundError(f"Provider {provider_url} not in allowed list.")
318+
319+
# Try load from DB (in-memory cache already checked by TrustChainResolver)
320+
engine = self._get_storage()
321+
if engine:
322+
cached = engine.get_trust_chain_by_provider(provider_url)
323+
if cached and not _is_cache_expired(cached):
324+
chain = _trust_chain_from_cache(cached)
325+
self._add_to_dict(self.trust_chain, provider_url, chain)
326+
return chain
327+
328+
httpc_params = self.config["trust_chain"]["config"]["httpc_params"]
329+
tas = self._ensure_trust_anchors()
330+
331+
for ta_ec in tas:
332+
try:
333+
chain = self.generate_trust_chain(ta_ec, provider_url, httpc_params)
334+
self._add_to_dict(self.trust_chain, provider_url, chain)
335+
self._store_trust_chain(chain, provider_url)
336+
return chain
337+
except Exception:
338+
continue
339+
340+
raise TrustChainNotFoundError(f"Failed to build trust chain for {provider_url} with any TA.")

0 commit comments

Comments
 (0)