1717import logging
1818import os
1919import threading
20- import time
2120from typing import Optional , List , Mapping , Callable , Dict , Set
2221
2322import grpc
@@ -117,13 +116,31 @@ def reset(self):
117116class StreamCancelHandler :
118117 def __init__ (self ):
119118 self .response_iterator = None
119+ self ._cancel_event = threading .Event ()
120+ self ._lock = threading .Lock ()
120121
121122 def set_iterator (self , iterator ):
122- self .response_iterator = iterator
123+ with self ._lock :
124+ self .response_iterator = iterator
125+ # If already cancelled, cancel the iterator immediately to avoid race
126+ if self ._cancel_event .is_set () and hasattr (iterator , "cancel" ):
127+ try :
128+ iterator .cancel ()
129+ except Exception :
130+ pass
123131
124132 def cancel (self ):
125- if self .response_iterator :
126- self .response_iterator .cancel ()
133+ self ._cancel_event .set ()
134+ with self ._lock :
135+ if self .response_iterator :
136+ self .response_iterator .cancel ()
137+
138+ def is_cancelled (self ) -> bool :
139+ return self ._cancel_event .is_set ()
140+
141+ def wait_cancelled (self , timeout : float ) -> bool :
142+ """Waits until the handler is cancelled or timeout is reached."""
143+ return self ._cancel_event .wait (timeout )
127144
128145
129146class WorkloadApiClient :
@@ -298,15 +315,8 @@ def fetch_jwt_bundles(self) -> JwtBundleSet:
298315 FetchJwtBundleError: In case there is an error in fetching the JWT-Bundle from the Workload API or
299316 in case the set of jwt_authorities cannot be parsed from the Workload API Response.
300317 """
301-
302- responses = self ._spiffe_workload_api_stub .FetchJWTBundles (
303- workload_pb2 .JWTBundlesRequest (), timeout = 10
304- )
305- res = next (responses )
306- jwt_bundles : Dict [TrustDomain , JwtBundle ] = self ._create_td_jwt_bundle_dict (res )
307- if not jwt_bundles :
308- raise FetchJwtBundleError ('JWT Bundles response is empty' )
309-
318+ response = self ._call_fetch_jwt_bundles ()
319+ jwt_bundles : Dict [TrustDomain , JwtBundle ] = self ._create_td_jwt_bundle_dict (response )
310320 return JwtBundleSet (jwt_bundles )
311321
312322 @handle_error (error_cls = ValidateJwtSvidError )
@@ -442,13 +452,17 @@ def _watch_x509_context_updates(
442452 on_error : Callable [[Exception ], None ],
443453 ):
444454 while True :
455+ if cancel_handler .is_cancelled ():
456+ break
445457 try :
446458 response_iterator = self ._spiffe_workload_api_stub .FetchX509SVID (
447459 workload_pb2 .X509SVIDRequest ()
448460 )
449461 cancel_handler .set_iterator (response_iterator )
450462
451463 for item in response_iterator :
464+ if cancel_handler .is_cancelled ():
465+ break
452466 x509_context = self ._process_x509_context (item )
453467 on_success (x509_context )
454468
@@ -461,7 +475,9 @@ def _watch_x509_context_updates(
461475 on_error (WorkloadApiError (f"gRPC error: { str (grpc_err .code ())} " ))
462476 break
463477
464- time .sleep (retry_handler .get_backoff ())
478+ backoff = retry_handler .get_backoff ()
479+ if cancel_handler .wait_cancelled (backoff ):
480+ break
465481
466482 except Exception as err :
467483 on_error (WorkloadApiError (str (err )))
@@ -475,13 +491,17 @@ def _watch_jwt_bundles_updates(
475491 on_error : Callable [[Exception ], None ],
476492 ):
477493 while True :
494+ if cancel_handler .is_cancelled ():
495+ break
478496 try :
479497 response_iterator = self ._spiffe_workload_api_stub .FetchJWTBundles (
480498 workload_pb2 .JWTBundlesRequest ()
481499 )
482500 cancel_handler .set_iterator (response_iterator )
483501
484502 for item in response_iterator :
503+ if cancel_handler .is_cancelled ():
504+ break
485505 jwt_bundles = self ._process_jwt_bundles (item )
486506 on_success (jwt_bundles )
487507
@@ -494,7 +514,9 @@ def _watch_jwt_bundles_updates(
494514 on_error (WorkloadApiError (f"gRPC error: { str (grpc_err .code ())} " ))
495515 break
496516
497- time .sleep (retry_handler .get_backoff ())
517+ backoff = retry_handler .get_backoff ()
518+ if cancel_handler .wait_cancelled (backoff ):
519+ break
498520
499521 except Exception as err :
500522 on_error (WorkloadApiError (str (err )))
@@ -520,7 +542,8 @@ def _process_jwt_bundles(
520542 return self ._create_jwt_bundle_set (jwt_bundles_response .bundles )
521543
522544 def _get_spiffe_grpc_channel (self ) -> grpc .Channel :
523- grpc_insecure_channel = grpc .insecure_channel (self ._config .spiffe_endpoint_socket )
545+ target = self ._grpc_target (self ._config .spiffe_endpoint_socket )
546+ grpc_insecure_channel = grpc .insecure_channel (target )
524547 spiffe_client_interceptor = (
525548 header_manipulator_client_interceptor .header_adder_interceptor (
526549 WORKLOAD_API_HEADER_KEY , WORKLOAD_API_HEADER_VALUE
@@ -551,6 +574,18 @@ def _call_fetch_x509_bundles(self) -> workload_pb2.X509BundlesResponse:
551574 raise FetchX509BundleError ('X.509 Bundles response is empty' )
552575 return item
553576
577+ def _call_fetch_jwt_bundles (self ) -> workload_pb2 .JWTBundlesResponse :
578+ response = self ._spiffe_workload_api_stub .FetchJWTBundles (
579+ workload_pb2 .JWTBundlesRequest ()
580+ )
581+ try :
582+ item = next (response )
583+ except StopIteration :
584+ raise FetchJwtBundleError ('JWT Bundles response is invalid' )
585+ if len (item .bundles ) == 0 :
586+ raise FetchJwtBundleError ('JWT Bundles response is empty' )
587+ return item
588+
554589 @staticmethod
555590 def _create_x509_bundle_set (resp_bundles : Mapping [str , bytes ]) -> X509BundleSet :
556591 x509_bundles = [
@@ -588,7 +623,30 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
588623
589624 @staticmethod
590625 def _check_spiffe_socket_exists (spiffe_socket : str ) -> None :
591- if spiffe_socket .startswith ('unix:' ):
592- spiffe_socket = spiffe_socket [5 :]
593- if not os .path .exists (spiffe_socket ):
594- raise ArgumentError (f'SPIFFE socket file "{ spiffe_socket } " does not exist.' )
626+ path_to_check = WorkloadApiClient ._strip_unix_scheme (spiffe_socket )
627+ if not path_to_check :
628+ raise ArgumentError ('SPIFFE endpoint socket is empty' )
629+ if not os .path .exists (path_to_check ):
630+ raise ArgumentError (f'SPIFFE socket file "{ path_to_check } " does not exist.' )
631+
632+ @staticmethod
633+ def _grpc_target (value : str ) -> str :
634+ """Returns the gRPC target for UDS, normalizing unix:/// to unix:/."""
635+ if value .startswith ('unix:' ):
636+ path = value [5 :]
637+ if path .startswith ('/' ):
638+ path = '/' + path .lstrip ('/' )
639+ return f'unix:{ path } '
640+ if value .startswith ('/' ):
641+ return f'unix:{ value } '
642+ raise ArgumentError (
643+ f'Invalid SPIFFE endpoint socket "{ value } ": only unix domain sockets are supported'
644+ )
645+
646+ @staticmethod
647+ def _strip_unix_scheme (value : str ) -> str :
648+ """Strips unix: scheme and normalizes leading slashes for filesystem checks."""
649+ path = value [5 :] if value .startswith ('unix:' ) else value
650+ if path .startswith ('/' ):
651+ path = '/' + path .lstrip ('/' )
652+ return path
0 commit comments