33import time
44from datetime import datetime , timezone
55from types import SimpleNamespace
6- from typing import Optional
6+ from typing import Optional , List
77
88from satosa .backends .base import BackendModule
99from satosa .backends .oauth import get_metadata_desc_for_oauth_backend
@@ -86,7 +86,7 @@ def __init__(self, callback, internal_attributes, module_config, base_url, name)
8686 super ().__init__ (callback , internal_attributes , base_url , name )
8787 self .config = module_config
8888 self .endpoints = {}
89- self ._trust_anchor_ec = None
89+ self ._validated_trust_anchors : List [ EntityStatement ] = []
9090 self .trust_chain = self ._generate_trust_chains ()
9191 self ._trust_chain_resolver = TrustChainResolver (
9292 self .trust_chain ,
@@ -173,70 +173,93 @@ def _store_trust_chain(self, chain, provider_url: str) -> None:
173173 try :
174174 payload = chain .subject_configuration .payload
175175 exp = payload .get ("exp" )
176- cached = TrustChainCache (
177- provider_url = provider_url ,
178- payload = payload ,
179- exp = exp ,
180- created = datetime .now (timezone .utc ),
181- )
182- engine .add_or_update_trust_chain (cached )
176+ variants = {
177+ provider_url .rstrip ("/" ),
178+ provider_url .rstrip ("/" ) + "/"
179+ }
180+
181+ for url in variants :
182+ cached = TrustChainCache (
183+ provider_url = url ,
184+ payload = payload ,
185+ exp = exp ,
186+ created = datetime .now (timezone .utc ),
187+ )
188+ engine .add_or_update_trust_chain (cached )
183189 except Exception as e :
184190 logger .warning ("Could not persist trust chain for %s: %s" , provider_url , e )
185191
186192 def _generate_trust_chains (self ) -> dict :
187- '''
188- private method _generate_trust_chains:
189- Tries to load trust chains from DB first; for missing or expired entries,
190- fetches Trust Anchor, builds chains, and persists them to DB.
191- '''
192- logger .debug (
193- f"Entering method: { inspect .getframeinfo (inspect .currentframe ()).function } . "
194- )
195-
193+ """try load from DB, or can try discovery with TA's list."""
196194 httpc_params = self .config ["trust_chain" ]["config" ]["httpc_params" ]
197195 providers = self .config ["providers" ]
198196 trust_chains = dict ()
199197 trust_anchor_ec = None
200198
201199 for provider_url in providers :
202- # Try load from DB
200+ # try load from DB
203201 engine = self ._get_storage ()
204202 if engine :
205203 cached = engine .get_trust_chain_by_provider (provider_url )
206204 if cached and not _is_cache_expired (cached ):
207205 chain = _trust_chain_from_cache (cached )
208- trust_chains [provider_url ] = chain
209- normalized = provider_url .rstrip ("/" ) if provider_url .endswith ("/" ) else provider_url + "/"
210- if normalized != provider_url :
211- trust_chains [normalized ] = chain
206+ self ._add_to_dict (trust_chains , provider_url , chain )
212207 continue
213208
214- # Build via discovery
209+ # Build via discovery, tryng each TA
215210 try :
216- if trust_anchor_ec is None :
217- ta_url = self .config ["trust_chain" ]["config" ]["trust_anchor" ][0 ]
218- jwt = get_entity_configurations (ta_url , httpc_params = httpc_params )[0 ]
219- trust_anchor_ec = EntityStatement (jwt , httpc_params = httpc_params )
220- trust_anchor_ec .validate_by_itself ()
221- self ._trust_anchor_ec = trust_anchor_ec
222-
223- chain = CieOidcBackend .generate_trust_chain (
224- trust_anchor_ec , provider_url , httpc_params
225- )
226- trust_chains [provider_url ] = chain
227- normalized = provider_url .rstrip ("/" ) if provider_url .endswith ("/" ) else provider_url + "/"
228- if normalized != provider_url :
229- trust_chains [normalized ] = chain
230- self ._store_trust_chain (chain , provider_url )
231- except Exception as exception :
211+ tas = self ._ensure_trust_anchors ()
212+ chain_built = False
213+ for ta_ec in tas :
214+ try :
215+ chain = self .generate_trust_chain (
216+ ta_ec , provider_url , httpc_params
217+ )
218+ self ._add_to_dict (trust_chains , provider_url , chain )
219+ self ._store_trust_chain (chain , provider_url )
220+ logger .info (
221+ "Provider %s linked to TA %s" , provider_url , ta_ec .sub
222+ )
223+ chain_built = True
224+ break
225+ except Exception as e :
226+ logger .warning (
227+ "Failed to build trust chain for provider %s with TA %s: %s" ,
228+ provider_url ,
229+ getattr (ta_ec , "sub" , "<unknown>" ),
230+ e ,
231+ )
232+ if not chain_built :
233+ logger .error (
234+ "Could not build trust chain for provider %s with any configured trust anchor" ,
235+ provider_url ,
236+ )
237+ except Exception as e :
232238 logger .error (
233- "Exception %s generated from this provider %s" ,
234- exception ,
235- provider_url ,
239+ "Could not resolve trust chain for %s: %s" , provider_url , e
236240 )
237241
238242 return trust_chains
239243
244+ def _add_to_dict (self , d , url , chain ):
245+ """Helper to add a normalized URL in a dict."""
246+ # Always store the exact URL key.
247+ d [url ] = chain
248+ # Also store the normalized variant (with/without trailing slash),
249+ # but avoid silently overwriting an existing normalized entry.
250+ norm = url .rstrip ("/" ) if url .endswith ("/" ) else url + "/"
251+ if norm != url :
252+ if norm in d :
253+ logger .warning (
254+ "Duplicate provider URL variants configured: %s and %s; "
255+ "keeping existing trust chain for %s" ,
256+ url ,
257+ norm ,
258+ norm ,
259+ )
260+ else :
261+ d [norm ] = chain
262+
240263 @staticmethod
241264 def generate_trust_chain (
242265 trust_anchor_ec : EntityStatement , provider_endpoint : str , httpc_params
@@ -262,56 +285,57 @@ def generate_trust_chain(
262285 trust_chain .apply_metadata_policy ()
263286 return trust_chain
264287
265- def _ensure_trust_anchor (self ) -> EntityStatement :
266- """Return cached trust anchor EC, or fetch and cache it ."""
267- if self ._trust_anchor_ec is None :
288+ def _ensure_trust_anchors (self ) -> List [ EntityStatement ] :
289+ """Return a list of valid TAs ."""
290+ if not self ._validated_trust_anchors :
268291 httpc_params = self .config ["trust_chain" ]["config" ]["httpc_params" ]
269- ta_url = self .config ["trust_chain" ]["config" ]["trust_anchor" ][0 ]
270- jwt = get_entity_configurations (ta_url , httpc_params = httpc_params )[0 ]
271- self ._trust_anchor_ec = EntityStatement (jwt , httpc_params = httpc_params )
272- self ._trust_anchor_ec .validate_by_itself ()
273- return self ._trust_anchor_ec
292+ ta_urls = self .config ["trust_chain" ]["config" ]["trust_anchor" ]
293+
294+ for ta_url in ta_urls :
295+ try :
296+ jwt = get_entity_configurations (ta_url , httpc_params = httpc_params )[0 ]
297+ ta_ec = EntityStatement (jwt , httpc_params = httpc_params )
298+ ta_ec .validate_by_itself ()
299+ self ._validated_trust_anchors .append (ta_ec )
300+ except Exception as e :
301+ logger .error (f"Failed to validate TA { ta_url } : { e } " )
302+
303+ if not self ._validated_trust_anchors :
304+ raise ValueError ("No valid Trust Anchors could be loaded." )
305+
306+ return self ._validated_trust_anchors
274307
275308 def get_or_build_trust_chain (self , provider_url : str ) -> TrustChainBuilder :
276309 """
277310 Get trust chain from cache, or from DB, or discover and build it on-demand.
278311 Newly built chains are stored in memory and in the database.
279312 """
280313 providers = self .config .get ("providers" , [])
281- provider_variants = (provider_url , provider_url .rstrip ("/" ), provider_url + "/" if not provider_url .endswith ("/" ) else None )
314+ provider_variants = [provider_url , provider_url .rstrip ("/" )]
315+ if not provider_url .endswith ("/" ):
316+ provider_variants .append (provider_url + "/" )
282317 if not any (p in providers for p in provider_variants if p ):
283- raise TrustChainNotFoundError (
284- f"The identity provider '{ provider_url } ' is not in the configured providers list. "
285- f"Configured: { ', ' .join (providers )} ."
286- ) from None
318+ raise TrustChainNotFoundError (f"Provider { provider_url } not in allowed list." )
287319
288320 # Try load from DB (in-memory cache already checked by TrustChainResolver)
289321 engine = self ._get_storage ()
290322 if engine :
291323 cached = engine .get_trust_chain_by_provider (provider_url )
292324 if cached and not _is_cache_expired (cached ):
293325 chain = _trust_chain_from_cache (cached )
294- self .trust_chain [provider_url ] = chain
295- normalized = provider_url .rstrip ("/" ) if provider_url .endswith ("/" ) else provider_url + "/"
296- if normalized != provider_url :
297- self .trust_chain [normalized ] = chain
298- logger .info ("Trust chain loaded from DB for provider %s" , provider_url )
326+ self ._add_to_dict (self .trust_chain , provider_url , chain )
299327 return chain
300328
301- logger .info (
302- "Trust chain not in cache; performing on-demand discovery for provider %s" ,
303- provider_url ,
304- )
305329 httpc_params = self .config ["trust_chain" ]["config" ]["httpc_params" ]
306- trust_anchor_ec = self ._ensure_trust_anchor ()
330+ tas = self ._ensure_trust_anchors ()
307331
308- chain = CieOidcBackend . generate_trust_chain (
309- trust_anchor_ec , provider_url , httpc_params
310- )
311- self .trust_chain [ provider_url ] = chain
312- normalized = provider_url . rstrip ( "/" ) if provider_url . endswith ( "/" ) else provider_url + "/"
313- if normalized != provider_url :
314- self . trust_chain [ normalized ] = chain
315- self . _store_trust_chain ( chain , provider_url )
316- logger . info ( "Trust chain built and stored for provider %s" , provider_url )
317- return chain
332+ for ta_ec in tas :
333+ try :
334+ chain = self . generate_trust_chain ( ta_ec , provider_url , httpc_params )
335+ self ._add_to_dict ( self . trust_chain , provider_url , chain )
336+ self . _store_trust_chain ( chain , provider_url )
337+ return chain
338+ except Exception :
339+ continue
340+
341+ raise TrustChainNotFoundError ( f"Failed to build trust chain for { provider_url } with any TA." )
0 commit comments