Skip to content

Commit 6917075

Browse files
authored
Merge pull request #220 from Point72/tkp/apik2
Revamp api key setup in anticipation of other mechanisms, add external api key validation mechanism
2 parents 9456298 + d124e94 commit 6917075

File tree

18 files changed

+596
-164
lines changed

18 files changed

+596
-164
lines changed

csp_gateway/client/client.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ class GatewayClientConfig(BaseModel):
232232
host: str = "localhost"
233233
port: Optional[int] = Field(default=8000, ge=1, le=65535, description="Port number for the gateway server")
234234
api_route: str = "/api/v1"
235-
authenticate: bool = False
236235
api_key: str = ""
236+
bearer_token: Optional[str] = None
237237
return_type: ReturnType = Field(
238238
default=ReturnType.Raw,
239239
description="Determines how REST request responses should be returned. Options: 'raw' (JSON dict), 'pandas' (DataFrame), 'polars' (DataFrame), 'struct' (original type), 'wrapper' (ResponseWrapper object).",
@@ -251,13 +251,13 @@ def _promote_return_type(cls, v):
251251
def validate_config(self):
252252
if self.port is not None and self.port < 1:
253253
raise ValueError("Port must be a positive integer")
254-
if self.api_key and not self.authenticate:
255-
raise ValueError("API key must be provided if authentication is enabled")
256254
if self.host.startswith("http"):
257255
# Switch protocol to host
258256
protocol, host = self.host.split("://")
259257
self.__dict__["protocol"] = protocol
260258
self.__dict__["host"] = host
259+
if self.bearer_token and self.api_key:
260+
raise ValueError("Cannot provide both bearer_token and api_key. Choose one authentication method.")
261261
return self
262262

263263
def __hash__(self):
@@ -403,6 +403,15 @@ class BaseGatewayClient(BaseModel):
403403
default=dict(follow_redirects=True), description="Additional arguments to pass to httpx requests (e.g., headers, auth, etc.)"
404404
)
405405

406+
# Additional initialization for bearer_token
407+
def __init__(self, config: GatewayClientConfig = None, **kwargs) -> None:
408+
# Exists for compatibility with positional argument instantiation
409+
if config is None:
410+
config = GatewayClientConfig()
411+
if kwargs:
412+
config = GatewayClientConfig(**{**config.model_dump(exclude_unset=True), **kwargs})
413+
super().__init__(config=config)
414+
406415
# openapi configureation
407416
_initialized: bool = PrivateAttr(default=False)
408417
_openapi_spec: Dict[Any, Any] = PrivateAttr(default=None)
@@ -424,16 +433,14 @@ class BaseGatewayClient(BaseModel):
424433
_event_loop: Optional[AbstractEventLoop] = PrivateAttr(default=None)
425434
_event_loop_thread: Optional[Thread] = PrivateAttr(default=None)
426435

427-
def __init__(self, config: GatewayClientConfig = None, **kwargs) -> None:
428-
# Exists for compatibility with positional argument instantiation
429-
if config is None:
430-
config = GatewayClientConfig()
431-
if kwargs:
432-
config = GatewayClientConfig(**{**config.model_dump(exclude_unset=True), **kwargs})
433-
super().__init__(config=config)
434-
435436
@model_validator(mode="after")
436437
def validate_client(self):
438+
# Set Authorization header if bearer_token is provided
439+
if self.config.bearer_token:
440+
headers = self.http_args.get("headers", {}).copy()
441+
headers["Authorization"] = f"Bearer {self.config.bearer_token}"
442+
self.http_args["headers"] = headers
443+
437444
if self._event_loop is None:
438445
self._event_loop = _get_or_new_event_loop()
439446

@@ -461,10 +468,12 @@ def _initializeStreaming(self) -> None:
461468
def _initialize(self) -> None:
462469
if not self._initialized:
463470
# grab openapi spec
471+
openapi_url = f"{_host(self.config)}/openapi.json"
472+
openapi_params = {"token": self.config.api_key} if self.config.api_key else None
464473
self._openapi_spec: Dict[Any, Any] = replace_refs(
465474
cast(
466475
Dict[Any, Any],
467-
GET(f"{_host(self.config)}/openapi.json", **self.http_args),
476+
GET(openapi_url, params=openapi_params, **self.http_args),
468477
).json(),
469478
)
470479

@@ -504,9 +513,11 @@ def _buildpath(self, route: str) -> str:
504513

505514
def _buildroute(self, route: str) -> str:
506515
url = f"{_host(self.config)}{self._buildpath(route)}"
507-
if self.config.authenticate:
508-
return url, {"token": self.config.api_key}
509-
return url, {}
516+
# If using api_key (not bearer_token), add as query param
517+
extra_params = {}
518+
if self.config.api_key:
519+
extra_params["token"] = self.config.api_key
520+
return url, extra_params
510521

511522
def _api_path_and_route(self, route: str) -> str:
512523
return self.config.api_route + "/" + route
@@ -517,10 +528,11 @@ def _buildroutews(self, route: str) -> str:
517528
host = host.replace("http://", "ws://")
518529
elif host.startswith("https://"):
519530
host = host.replace("https://", "wss://")
520-
if self.config.authenticate:
521-
auth = f"?token={self.config.api_key}"
522-
else:
523-
auth = ""
531+
# If using api_key (not bearer_token), add as query param
532+
auth = ""
533+
if self.config.api_key:
534+
sep = "&" if "?" in host else "?"
535+
auth = f"{sep}token={self.config.api_key}"
524536
return f"{host}{self.config.api_route}/{route}{auth}"
525537

526538
def _handle_response(self, resp: Response, route: str) -> ResponseType:

csp_gateway/server/config/gateway/omnibus.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
defaults:
33
- _self_
44

5-
authenticate: ???
65
port: ???
76

87
modules:
@@ -32,16 +31,12 @@ modules:
3231
force_mount_all: True
3332
mount_websocket_routes:
3433
_target_: csp_gateway.MountWebSocketRoutes
35-
mount_api_key_middleware:
36-
_target_: csp_gateway.MountAPIKeyMiddleware
3734

3835
gateway:
3936
_target_: csp_gateway.Gateway
4037
settings:
4138
PORT: ${port}
42-
AUTHENTICATE: ${authenticate}
4339
UI: True
44-
API_KEY: "12345"
4540
modules:
4641
- /modules/logfire
4742
- /modules/example_module

csp_gateway/server/demo/config/logfire.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ logfire:
1010
instrument_fastapi: True
1111
capture_logging: True
1212
log_level: DEBUG
13-
13+
1414
publish_logfire:
1515
_target_: csp_gateway.server.modules.logging.PublishLogfire
1616
selection:
@@ -25,9 +25,7 @@ gateway:
2525
_target_: csp_gateway.Gateway
2626
settings:
2727
PORT: ${port}
28-
AUTHENTICATE: ${authenticate}
2928
UI: True
30-
API_KEY: "12345"
3129
modules:
3230
- /modules/example_module
3331
- /modules/mount_outputs
@@ -40,5 +38,4 @@ gateway:
4038

4139
# csp-gateway-start --config-dir=csp_gateway/server/demo +config=logfire
4240

43-
authenticate: False
4441
port: 8000

csp_gateway/server/demo/config/omnibus.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,4 @@ defaults:
55

66
# csp-gateway-start --config-dir=csp_gateway/server/demo +config=omnibus
77

8-
authenticate: False
98
port: 8000

csp_gateway/server/gateway/csp/module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .channels import ChannelsType
1313

1414
if typing.TYPE_CHECKING:
15-
from csp_gateway.server import GatewayWebApp
15+
from csp_gateway.server import GatewaySettings, GatewayWebApp
1616

1717

1818
class Module(BaseModel, Generic[ChannelsType], abc.ABC):
@@ -33,6 +33,8 @@ def connect(self, Channels: ChannelsType) -> None: ...
3333

3434
def rest(self, app: "GatewayWebApp") -> None: ...
3535

36+
def info(self, settings: "GatewaySettings") -> Optional[str]: ...
37+
3638
@abc.abstractmethod
3739
def shutdown(self) -> None: ...
3840

csp_gateway/server/gateway/gateway.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,11 @@ def start(
281281
log.info("Launching web server on:")
282282
url = f"http://{gethostname()}:{self.settings.PORT}"
283283

284-
if ui:
285-
if self.settings.AUTHENTICATE:
286-
log.info(f"\tUI: {url}?token={self.settings.API_KEY}")
287-
else:
288-
log.info(f"\tUI: {url}")
284+
# Allow module sto log information at statup
285+
for module in self.modules:
286+
_info = module.info(self.settings)
287+
if _info:
288+
log.info(_info)
289289

290290
log.info(f"\tDocs: {url}/docs")
291291
log.info(f"\tDocs: {url}/redoc")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,9 @@
11
from .api_key import MountAPIKeyMiddleware
2+
from .api_key_external import MountExternalAPIKeyMiddleware
3+
from .base import AuthenticationMiddleware
4+
5+
__all__ = (
6+
"AuthenticationMiddleware",
7+
"MountAPIKeyMiddleware",
8+
"MountExternalAPIKeyMiddleware",
9+
)

0 commit comments

Comments
 (0)