11import logging
22import inspect
3-
3+ import time
4+ from datetime import datetime , timezone
5+ from types import SimpleNamespace
6+ from typing import Optional , List
47
58from satosa .backends .base import BackendModule
69from satosa .backends .oauth import get_metadata_desc_for_oauth_backend
710
811from .utils .endpoints_loader import EndpointsLoader
12+ from .storage .db_engine import OidcDbEngine
13+ from .models .trust_chain_cache import TrustChainCache
914
1015from pyeudiw .federation .trust_chain_builder import TrustChainBuilder
1116from pyeudiw .federation .statements import EntityStatement , get_entity_configurations
1217
18+ from .utils .exceptions import TrustChainNotFoundError
19+
1320
1421logger = logging .getLogger (__name__ )
1522
1623
24+ def _trust_chain_from_cache (cached : TrustChainCache ):
25+ """
26+ Build a minimal trust-chain-like object from TrustChainCache.
27+ Has .subject and .subject_configuration.payload as required by authorization endpoint.
28+ """
29+ wrapper = SimpleNamespace ()
30+ wrapper .subject = cached .provider_url
31+ wrapper .subject_configuration = SimpleNamespace (payload = cached .payload )
32+ return wrapper
33+
34+
35+ def _is_cache_expired (cached : TrustChainCache , now = None ) -> bool :
36+ """Return True if the cached payload is expired (exp in the past)."""
37+ exp = cached .exp or cached .payload .get ("exp" )
38+ if exp is None :
39+ return False
40+ t = now if now is not None else time .time ()
41+ return t >= exp
42+
43+
44+ class TrustChainResolver :
45+ """
46+ Resolves trust chains from cache or builds them on-demand via discovery.
47+ When a provider is requested but not in the cache (e.g. startup failed),
48+ discovery is performed and the resulting trust chain is stored for reuse.
49+ """
50+
51+ def __init__ (self , trust_chains : dict , build_callback ):
52+ """
53+ :param trust_chains: Dict of provider_url -> TrustChainBuilder (mutated when new chains are built)
54+ :param build_callback: Callable(provider_url) -> TrustChainBuilder; raises TrustChainNotFoundError on failure
55+ """
56+ self ._chains = trust_chains
57+ self ._build = build_callback
58+
59+ def __contains__ (self , key ):
60+ return key in self ._chains
61+
62+ def __getitem__ (self , key ):
63+ return self ._chains [key ]
64+
65+ def keys (self ):
66+ return self ._chains .keys ()
67+
68+ def get_or_build (self , provider_url : str ) -> TrustChainBuilder :
69+ """Get trust chain from cache, or discover and store it on-demand."""
70+ for key in (
71+ provider_url ,
72+ provider_url .rstrip ("/" ),
73+ provider_url + "/" if not provider_url .endswith ("/" ) else None ,
74+ ):
75+ if key and key in self ._chains :
76+ return self ._chains [key ]
77+ return self ._build (provider_url )
78+
79+
1780class CieOidcBackend (BackendModule ):
1881
1982 def __init__ (self , callback , internal_attributes , module_config , base_url , name ):
@@ -23,7 +86,12 @@ def __init__(self, callback, internal_attributes, module_config, base_url, name)
2386 super ().__init__ (callback , internal_attributes , base_url , name )
2487 self .config = module_config
2588 self .endpoints = {}
89+ self ._validated_trust_anchors : List [EntityStatement ] = []
2690 self .trust_chain = self ._generate_trust_chains ()
91+ self ._trust_chain_resolver = TrustChainResolver (
92+ self .trust_chain ,
93+ self .get_or_build_trust_chain ,
94+ )
2795 metadata = self .config .get ("metadata" , {}).get ("openid_relying_party" , {})
2896 self ._client_id = metadata .get ("client_id" ) or f"{ base_url } /{ name } "
2997
@@ -58,7 +126,7 @@ def register_endpoints(self):
58126 self .name ,
59127 self .auth_callback_func ,
60128 self .converter ,
61- self .trust_chain ,
129+ self ._trust_chain_resolver ,
62130 )
63131
64132 url_map = []
@@ -81,42 +149,116 @@ def get_metadata_desc(self):
81149 meta = get_metadata_desc_for_oauth_backend (self ._client_id , self .config )
82150 return meta
83151
84- def _generate_trust_chains (self ) -> dict :
85- '''
86- private method _generate_trust_chains:
87- This method generate a list of trust-chain. After create a entity statement
88- for Trust Anchor, validate itself, and call the generate_trust_chain
89- for all providers into configuration.
90- Add all providers into dictionary.
91- '''
92- logger .debug (
93- f"Entering method: { inspect .getframeinfo (inspect .currentframe ()).function } . "
94- )
152+ def _get_storage (self ) -> Optional [OidcDbEngine ]:
153+ """Create and return storage engine; connect if needed. Returns None if no storage configured."""
154+ if getattr (self , "_storage_engine" , None ) is not None :
155+ return self ._storage_engine
156+ storage_config = self .config .get ("storage" ) or {}
157+ if not storage_config :
158+ return None
159+ try :
160+ engine = OidcDbEngine (storage_config )
161+ engine .connect ()
162+ self ._storage_engine = engine
163+ return engine
164+ except Exception as e :
165+ logger .warning ("Could not initialize storage for trust chain persistence: %s" , e )
166+ return None
167+
168+ def _store_trust_chain (self , chain , provider_url : str ) -> None :
169+ """Persist trust chain to database if storage is available."""
170+ engine = self ._get_storage ()
171+ if engine is None :
172+ return
173+ try :
174+ payload = chain .subject_configuration .payload
175+ exp = payload .get ("exp" )
176+ variants = {
177+ provider_url .rstrip ("/" ),
178+ provider_url .rstrip ("/" ) + "/"
179+ }
180+
181+ for url in variants :
182+ cached = TrustChainCache (
183+ provider_url = url ,
184+ payload = payload ,
185+ exp = exp ,
186+ created = datetime .now (timezone .utc ),
187+ )
188+ engine .add_or_update_trust_chain (cached )
189+ except Exception as e :
190+ logger .warning ("Could not persist trust chain for %s: %s" , provider_url , e )
95191
192+ def _generate_trust_chains (self ) -> dict :
193+ """try load from DB, or can try discovery with TA's list."""
96194 httpc_params = self .config ["trust_chain" ]["config" ]["httpc_params" ]
97-
98- ta_url = self .config ["trust_chain" ]["config" ]["trust_anchor" ][0 ]
99- jwt = get_entity_configurations (ta_url , httpc_params = httpc_params )[0 ]
100-
101- trust_anchor_ec = EntityStatement (jwt , httpc_params = httpc_params )
102-
103- trust_anchor_ec .validate_by_itself ()
104-
105195 providers = self .config ["providers" ]
106-
107196 trust_chains = dict ()
108197
109198 for provider_url in providers :
199+ # try load from DB
200+ engine = self ._get_storage ()
201+ if engine :
202+ cached = engine .get_trust_chain_by_provider (provider_url )
203+ if cached and not _is_cache_expired (cached ):
204+ chain = _trust_chain_from_cache (cached )
205+ self ._add_to_dict (trust_chains , provider_url , chain )
206+ continue
207+
208+ # Build via discovery, tryng each TA
110209 try :
111- trust_chains [provider_url ] = CieOidcBackend .generate_trust_chain (
112- trust_anchor_ec , provider_url , httpc_params )
113- except Exception as exception :
210+ tas = self ._ensure_trust_anchors ()
211+ chain_built = False
212+ for ta_ec in tas :
213+ try :
214+ chain = self .generate_trust_chain (
215+ ta_ec , provider_url , httpc_params
216+ )
217+ self ._add_to_dict (trust_chains , provider_url , chain )
218+ self ._store_trust_chain (chain , provider_url )
219+ logger .info (
220+ "Provider %s linked to TA %s" , provider_url , ta_ec .sub
221+ )
222+ chain_built = True
223+ break
224+ except Exception as e :
225+ logger .warning (
226+ "Failed to build trust chain for provider %s with TA %s: %s" ,
227+ provider_url ,
228+ getattr (ta_ec , "sub" , "<unknown>" ),
229+ e ,
230+ )
231+ if not chain_built :
232+ logger .error (
233+ "Could not build trust chain for provider %s with any configured trust anchor" ,
234+ provider_url ,
235+ )
236+ except Exception as e :
114237 logger .error (
115- f"Exception { exception } generated from this provider { provider_url } "
238+ "Could not resolve trust chain for %s: %s" , provider_url , e
116239 )
117240
118241 return trust_chains
119242
243+ def _add_to_dict (self , d , url , chain ):
244+ """Helper to add a normalized URL in a dict."""
245+ # Always store the exact URL key.
246+ d [url ] = chain
247+ # Also store the normalized variant (with/without trailing slash),
248+ # but avoid silently overwriting an existing normalized entry.
249+ norm = url .rstrip ("/" ) if url .endswith ("/" ) else url + "/"
250+ if norm != url :
251+ if norm in d :
252+ logger .warning (
253+ "Duplicate provider URL variants configured: %s and %s; "
254+ "keeping existing trust chain for %s" ,
255+ url ,
256+ norm ,
257+ norm ,
258+ )
259+ else :
260+ d [norm ] = chain
261+
120262 @staticmethod
121263 def generate_trust_chain (
122264 trust_anchor_ec : EntityStatement , provider_endpoint : str , httpc_params
@@ -141,3 +283,58 @@ def generate_trust_chain(
141283 trust_chain .start ()
142284 trust_chain .apply_metadata_policy ()
143285 return trust_chain
286+
287+ def _ensure_trust_anchors (self ) -> List [EntityStatement ]:
288+ """Return a list of valid TAs."""
289+ if not self ._validated_trust_anchors :
290+ httpc_params = self .config ["trust_chain" ]["config" ]["httpc_params" ]
291+ ta_urls = self .config ["trust_chain" ]["config" ]["trust_anchor" ]
292+
293+ for ta_url in ta_urls :
294+ try :
295+ jwt = get_entity_configurations (ta_url , httpc_params = httpc_params )[0 ]
296+ ta_ec = EntityStatement (jwt , httpc_params = httpc_params )
297+ ta_ec .validate_by_itself ()
298+ self ._validated_trust_anchors .append (ta_ec )
299+ except Exception as e :
300+ logger .error (f"Failed to validate TA { ta_url } : { e } " )
301+
302+ if not self ._validated_trust_anchors :
303+ raise ValueError ("No valid Trust Anchors could be loaded." )
304+
305+ return self ._validated_trust_anchors
306+
307+ def get_or_build_trust_chain (self , provider_url : str ) -> TrustChainBuilder :
308+ """
309+ Get trust chain from cache, or from DB, or discover and build it on-demand.
310+ Newly built chains are stored in memory and in the database.
311+ """
312+ providers = self .config .get ("providers" , [])
313+ provider_variants = [provider_url , provider_url .rstrip ("/" )]
314+ if not provider_url .endswith ("/" ):
315+ provider_variants .append (provider_url + "/" )
316+ if not any (p in providers for p in provider_variants if p ):
317+ raise TrustChainNotFoundError (f"Provider { provider_url } not in allowed list." )
318+
319+ # Try load from DB (in-memory cache already checked by TrustChainResolver)
320+ engine = self ._get_storage ()
321+ if engine :
322+ cached = engine .get_trust_chain_by_provider (provider_url )
323+ if cached and not _is_cache_expired (cached ):
324+ chain = _trust_chain_from_cache (cached )
325+ self ._add_to_dict (self .trust_chain , provider_url , chain )
326+ return chain
327+
328+ httpc_params = self .config ["trust_chain" ]["config" ]["httpc_params" ]
329+ tas = self ._ensure_trust_anchors ()
330+
331+ for ta_ec in tas :
332+ try :
333+ chain = self .generate_trust_chain (ta_ec , provider_url , httpc_params )
334+ self ._add_to_dict (self .trust_chain , provider_url , chain )
335+ self ._store_trust_chain (chain , provider_url )
336+ return chain
337+ except Exception :
338+ continue
339+
340+ raise TrustChainNotFoundError (f"Failed to build trust chain for { provider_url } with any TA." )
0 commit comments