Skip to content

Commit aea239c

Browse files
committed
fix(workloadapi): harden client and sources for safe streaming and errors
Signed-off-by: Max Lambrecht <maxlambrecht@gmail.com>
1 parent fbc5f6e commit aea239c

File tree

10 files changed

+331
-64
lines changed

10 files changed

+331
-64
lines changed

spiffe/src/spiffe/workloadapi/handle_error.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,18 @@ def wrapper(*args, **kw):
3636
except ArgumentError as ae:
3737
raise ae
3838
except PySpiffeError as pe:
39-
raise error_cls(str(pe))
39+
# Avoid double-wrapping if it's already the expected error type
40+
if isinstance(pe, error_cls):
41+
raise pe
42+
raise error_cls(str(pe)) from pe
4043
except grpc.RpcError as rpc_error:
4144
if isinstance(rpc_error, grpc.Call):
42-
raise error_cls(str(rpc_error.details()))
43-
raise error_cls(DEFAULT_WL_API_ERROR_MESSAGE)
45+
details = rpc_error.details()
46+
code = rpc_error.code()
47+
raise error_cls(
48+
f'{DEFAULT_WL_API_ERROR_MESSAGE}: {details} ({code})'
49+
) from rpc_error
50+
raise error_cls(DEFAULT_WL_API_ERROR_MESSAGE) from rpc_error
4451
except Exception as e:
4552
raise error_cls(str(e))
4653

spiffe/src/spiffe/workloadapi/jwt_source.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
import logging
1818
import threading
19-
from typing import Optional, Set, Callable, List
19+
from typing import Optional, Set, Callable, List, FrozenSet
2020

2121
from spiffe.spiffe_id.spiffe_id import SpiffeId
2222
from spiffe.bundle.jwt_bundle.jwt_bundle import JwtBundle
2323
from spiffe.bundle.jwt_bundle.jwt_bundle_set import JwtBundleSet
2424
from spiffe.spiffe_id.spiffe_id import TrustDomain
2525
from spiffe.svid.jwt_svid import JwtSvid
26-
from spiffe.workloadapi.workload_api_client import WorkloadApiClient
26+
from spiffe.workloadapi.workload_api_client import WorkloadApiClient, StreamCancelHandler
2727
from spiffe.workloadapi.errors import JwtSourceError
2828
from spiffe.errors import ArgumentError
2929

@@ -79,9 +79,12 @@ def __init__(
7979
self._subscribers: List[Callable] = []
8080
self._subscribers_lock = threading.Lock()
8181

82+
# Track ownership: if we create the client, we own it
83+
self._owns_client = workload_api_client is None
8284
self._workload_api_client = (
8385
workload_api_client if workload_api_client else WorkloadApiClient(socket_path)
8486
)
87+
self._client_cancel_handler: Optional[StreamCancelHandler] = None
8588

8689
# Start the watcher in a separate thread
8790
threading.Thread(target=self._start_watcher, daemon=True).start()
@@ -100,12 +103,16 @@ def __init__(
100103
raise JwtSourceError(f"Failed to create JwtSource: {self._error}") from self._error
101104

102105
@property
103-
def bundles(self) -> Set[JwtBundle]:
106+
def bundles(self) -> FrozenSet[JwtBundle]:
104107
"""Returns the set of all JwtBundles."""
105108
with self._lock:
109+
if self._error is not None:
110+
raise JwtSourceError(
111+
f'Cannot get Jwt Bundles: source has error: {self._error}'
112+
)
106113
if self._closed:
107114
raise JwtSourceError('Cannot get Jwt Bundles: source is closed')
108-
return self._jwt_bundle_set.bundles
115+
return frozenset(self._jwt_bundle_set.bundles)
109116

110117
def fetch_svid(self, audience: Set[str], subject: Optional[SpiffeId] = None) -> JwtSvid:
111118
"""Fetches an JWT-SVID from the source.
@@ -124,7 +131,9 @@ def fetch_svid(self, audience: Set[str], subject: Optional[SpiffeId] = None) ->
124131
jwt_svid = self._workload_api_client.fetch_jwt_svid(audience, subject)
125132
return jwt_svid
126133

127-
def fetch_svids(self, audiences: Set[str], subject: Optional[SpiffeId] = None) -> JwtSvid:
134+
def fetch_svids(
135+
self, audiences: Set[str], subject: Optional[SpiffeId] = None
136+
) -> List[JwtSvid]:
128137
"""Fetches all JWT-SVIDs from the source.
129138
130139
Args:
@@ -148,6 +157,8 @@ def get_bundle_for_trust_domain(self, trust_domain: TrustDomain) -> Optional[Jwt
148157
JwtSourceError: In case this JWT Source is closed.
149158
"""
150159
with self._lock:
160+
if self._error is not None:
161+
raise JwtSourceError(f'Cannot get JWT Bundle: source has error: {self._error}')
151162
if self._closed:
152163
raise JwtSourceError('Cannot get JWT Bundle: source is closed')
153164
return self._jwt_bundle_set.get_bundle_for_trust_domain(trust_domain)
@@ -161,9 +172,11 @@ def close(self) -> None:
161172
"""
162173
_logger.info("Closing JWT Source")
163174
with self._lock:
164-
# the cancel method throws a grpc exception, that can be discarded
175+
if self._closed:
176+
return
165177
try:
166-
self._client_cancel_handler.cancel()
178+
if self._client_cancel_handler:
179+
self._client_cancel_handler.cancel()
167180
except Exception as err:
168181
_logger.exception(
169182
'JWT Source: Exception canceling the Workload API client connection: {}'.format(
@@ -172,6 +185,14 @@ def close(self) -> None:
172185
)
173186
self._closed = True
174187

188+
if self._owns_client:
189+
try:
190+
self._workload_api_client.close()
191+
except Exception as err:
192+
_logger.exception(
193+
'Exception closing owned Workload API client: {}'.format(str(err))
194+
)
195+
175196
def is_closed(self) -> bool:
176197
"""Checks if the source has been closed, disallowing further operations."""
177198
with self._lock:
@@ -195,7 +216,10 @@ def unsubscribe_for_updates(self, callback: Callable[[], None]) -> None:
195216
callback (Callable[[], None]): The callback function to unregister.
196217
"""
197218
with self._subscribers_lock:
198-
self._subscribers.remove(callback)
219+
try:
220+
self._subscribers.remove(callback)
221+
except ValueError:
222+
pass
199223

200224
def _start_watcher(self) -> None:
201225
self._client_cancel_handler = self._workload_api_client.stream_jwt_bundles(
@@ -213,16 +237,23 @@ def _set_jwt_bundle_set(self, jwt_bundle_set: JwtBundleSet) -> None:
213237

214238
def _notify_subscribers(self) -> None:
215239
with self._subscribers_lock:
216-
for callback in self._subscribers:
217-
try:
218-
callback()
219-
except Exception as err:
220-
_logger.exception(f"An error occurred while notifying a subscriber: {err}")
240+
subscribers = list(self._subscribers)
241+
for callback in subscribers:
242+
try:
243+
callback()
244+
except Exception as err:
245+
_logger.exception(f"An error occurred while notifying a subscriber: {err}")
221246

222247
def _on_error(self, error: Exception) -> None:
223248
self._log_error(error)
224249
with self._lock:
225250
self._error = error
251+
self._closed = True
252+
try:
253+
if self._client_cancel_handler:
254+
self._client_cancel_handler.cancel()
255+
except Exception as err:
256+
_logger.exception(f"Exception canceling stream on error: {err}")
226257
self._initialization_event.set()
227258

228259
@staticmethod

spiffe/src/spiffe/workloadapi/workload_api_client.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import logging
1818
import os
1919
import threading
20-
import time
2120
from typing import Optional, List, Mapping, Callable, Dict, Set
2221

2322
import grpc
@@ -117,13 +116,31 @@ def reset(self):
117116
class StreamCancelHandler:
118117
def __init__(self):
119118
self.response_iterator = None
119+
self._cancel_event = threading.Event()
120+
self._lock = threading.Lock()
120121

121122
def set_iterator(self, iterator):
122-
self.response_iterator = iterator
123+
with self._lock:
124+
self.response_iterator = iterator
125+
# If already cancelled, cancel the iterator immediately to avoid race
126+
if self._cancel_event.is_set() and hasattr(iterator, "cancel"):
127+
try:
128+
iterator.cancel()
129+
except Exception:
130+
pass
123131

124132
def cancel(self):
125-
if self.response_iterator:
126-
self.response_iterator.cancel()
133+
self._cancel_event.set()
134+
with self._lock:
135+
if self.response_iterator:
136+
self.response_iterator.cancel()
137+
138+
def is_cancelled(self) -> bool:
139+
return self._cancel_event.is_set()
140+
141+
def wait_cancelled(self, timeout: float) -> bool:
142+
"""Waits until the handler is cancelled or timeout is reached."""
143+
return self._cancel_event.wait(timeout)
127144

128145

129146
class WorkloadApiClient:
@@ -298,15 +315,8 @@ def fetch_jwt_bundles(self) -> JwtBundleSet:
298315
FetchJwtBundleError: In case there is an error in fetching the JWT-Bundle from the Workload API or
299316
in case the set of jwt_authorities cannot be parsed from the Workload API Response.
300317
"""
301-
302-
responses = self._spiffe_workload_api_stub.FetchJWTBundles(
303-
workload_pb2.JWTBundlesRequest(), timeout=10
304-
)
305-
res = next(responses)
306-
jwt_bundles: Dict[TrustDomain, JwtBundle] = self._create_td_jwt_bundle_dict(res)
307-
if not jwt_bundles:
308-
raise FetchJwtBundleError('JWT Bundles response is empty')
309-
318+
response = self._call_fetch_jwt_bundles()
319+
jwt_bundles: Dict[TrustDomain, JwtBundle] = self._create_td_jwt_bundle_dict(response)
310320
return JwtBundleSet(jwt_bundles)
311321

312322
@handle_error(error_cls=ValidateJwtSvidError)
@@ -442,13 +452,17 @@ def _watch_x509_context_updates(
442452
on_error: Callable[[Exception], None],
443453
):
444454
while True:
455+
if cancel_handler.is_cancelled():
456+
break
445457
try:
446458
response_iterator = self._spiffe_workload_api_stub.FetchX509SVID(
447459
workload_pb2.X509SVIDRequest()
448460
)
449461
cancel_handler.set_iterator(response_iterator)
450462

451463
for item in response_iterator:
464+
if cancel_handler.is_cancelled():
465+
break
452466
x509_context = self._process_x509_context(item)
453467
on_success(x509_context)
454468

@@ -461,7 +475,9 @@ def _watch_x509_context_updates(
461475
on_error(WorkloadApiError(f"gRPC error: {str(grpc_err.code())}"))
462476
break
463477

464-
time.sleep(retry_handler.get_backoff())
478+
backoff = retry_handler.get_backoff()
479+
if cancel_handler.wait_cancelled(backoff):
480+
break
465481

466482
except Exception as err:
467483
on_error(WorkloadApiError(str(err)))
@@ -475,13 +491,17 @@ def _watch_jwt_bundles_updates(
475491
on_error: Callable[[Exception], None],
476492
):
477493
while True:
494+
if cancel_handler.is_cancelled():
495+
break
478496
try:
479497
response_iterator = self._spiffe_workload_api_stub.FetchJWTBundles(
480498
workload_pb2.JWTBundlesRequest()
481499
)
482500
cancel_handler.set_iterator(response_iterator)
483501

484502
for item in response_iterator:
503+
if cancel_handler.is_cancelled():
504+
break
485505
jwt_bundles = self._process_jwt_bundles(item)
486506
on_success(jwt_bundles)
487507

@@ -494,7 +514,9 @@ def _watch_jwt_bundles_updates(
494514
on_error(WorkloadApiError(f"gRPC error: {str(grpc_err.code())}"))
495515
break
496516

497-
time.sleep(retry_handler.get_backoff())
517+
backoff = retry_handler.get_backoff()
518+
if cancel_handler.wait_cancelled(backoff):
519+
break
498520

499521
except Exception as err:
500522
on_error(WorkloadApiError(str(err)))
@@ -520,7 +542,8 @@ def _process_jwt_bundles(
520542
return self._create_jwt_bundle_set(jwt_bundles_response.bundles)
521543

522544
def _get_spiffe_grpc_channel(self) -> grpc.Channel:
523-
grpc_insecure_channel = grpc.insecure_channel(self._config.spiffe_endpoint_socket)
545+
target = self._grpc_target(self._config.spiffe_endpoint_socket)
546+
grpc_insecure_channel = grpc.insecure_channel(target)
524547
spiffe_client_interceptor = (
525548
header_manipulator_client_interceptor.header_adder_interceptor(
526549
WORKLOAD_API_HEADER_KEY, WORKLOAD_API_HEADER_VALUE
@@ -551,6 +574,18 @@ def _call_fetch_x509_bundles(self) -> workload_pb2.X509BundlesResponse:
551574
raise FetchX509BundleError('X.509 Bundles response is empty')
552575
return item
553576

577+
def _call_fetch_jwt_bundles(self) -> workload_pb2.JWTBundlesResponse:
578+
response = self._spiffe_workload_api_stub.FetchJWTBundles(
579+
workload_pb2.JWTBundlesRequest()
580+
)
581+
try:
582+
item = next(response)
583+
except StopIteration:
584+
raise FetchJwtBundleError('JWT Bundles response is invalid')
585+
if len(item.bundles) == 0:
586+
raise FetchJwtBundleError('JWT Bundles response is empty')
587+
return item
588+
554589
@staticmethod
555590
def _create_x509_bundle_set(resp_bundles: Mapping[str, bytes]) -> X509BundleSet:
556591
x509_bundles = [
@@ -588,7 +623,30 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
588623

589624
@staticmethod
590625
def _check_spiffe_socket_exists(spiffe_socket: str) -> None:
591-
if spiffe_socket.startswith('unix:'):
592-
spiffe_socket = spiffe_socket[5:]
593-
if not os.path.exists(spiffe_socket):
594-
raise ArgumentError(f'SPIFFE socket file "{spiffe_socket}" does not exist.')
626+
path_to_check = WorkloadApiClient._strip_unix_scheme(spiffe_socket)
627+
if not path_to_check:
628+
raise ArgumentError('SPIFFE endpoint socket is empty')
629+
if not os.path.exists(path_to_check):
630+
raise ArgumentError(f'SPIFFE socket file "{path_to_check}" does not exist.')
631+
632+
@staticmethod
633+
def _grpc_target(value: str) -> str:
634+
"""Returns the gRPC target for UDS, normalizing unix:/// to unix:/."""
635+
if value.startswith('unix:'):
636+
path = value[5:]
637+
if path.startswith('/'):
638+
path = '/' + path.lstrip('/')
639+
return f'unix:{path}'
640+
if value.startswith('/'):
641+
return f'unix:{value}'
642+
raise ArgumentError(
643+
f'Invalid SPIFFE endpoint socket "{value}": only unix domain sockets are supported'
644+
)
645+
646+
@staticmethod
647+
def _strip_unix_scheme(value: str) -> str:
648+
"""Strips unix: scheme and normalizes leading slashes for filesystem checks."""
649+
path = value[5:] if value.startswith('unix:') else value
650+
if path.startswith('/'):
651+
path = '/' + path.lstrip('/')
652+
return path

0 commit comments

Comments
 (0)