@@ -232,8 +232,8 @@ class GatewayClientConfig(BaseModel):
232232 host : str = "localhost"
233233 port : Optional [int ] = Field (default = 8000 , ge = 1 , le = 65535 , description = "Port number for the gateway server" )
234234 api_route : str = "/api/v1"
235- authenticate : bool = False
236235 api_key : str = ""
236+ bearer_token : Optional [str ] = None
237237 return_type : ReturnType = Field (
238238 default = ReturnType .Raw ,
239239 description = "Determines how REST request responses should be returned. Options: 'raw' (JSON dict), 'pandas' (DataFrame), 'polars' (DataFrame), 'struct' (original type), 'wrapper' (ResponseWrapper object)." ,
@@ -251,13 +251,13 @@ def _promote_return_type(cls, v):
251251 def validate_config (self ):
252252 if self .port is not None and self .port < 1 :
253253 raise ValueError ("Port must be a positive integer" )
254- if self .api_key and not self .authenticate :
255- raise ValueError ("API key must be provided if authentication is enabled" )
256254 if self .host .startswith ("http" ):
257255 # Switch protocol to host
258256 protocol , host = self .host .split ("://" )
259257 self .__dict__ ["protocol" ] = protocol
260258 self .__dict__ ["host" ] = host
259+ if self .bearer_token and self .api_key :
260+ raise ValueError ("Cannot provide both bearer_token and api_key. Choose one authentication method." )
261261 return self
262262
263263 def __hash__ (self ):
@@ -403,6 +403,15 @@ class BaseGatewayClient(BaseModel):
403403 default = dict (follow_redirects = True ), description = "Additional arguments to pass to httpx requests (e.g., headers, auth, etc.)"
404404 )
405405
406+ # Additional initialization for bearer_token
407+ def __init__ (self , config : GatewayClientConfig = None , ** kwargs ) -> None :
408+ # Exists for compatibility with positional argument instantiation
409+ if config is None :
410+ config = GatewayClientConfig ()
411+ if kwargs :
412+ config = GatewayClientConfig (** {** config .model_dump (exclude_unset = True ), ** kwargs })
413+ super ().__init__ (config = config )
414+
406415 # openapi configureation
407416 _initialized : bool = PrivateAttr (default = False )
408417 _openapi_spec : Dict [Any , Any ] = PrivateAttr (default = None )
@@ -424,16 +433,14 @@ class BaseGatewayClient(BaseModel):
424433 _event_loop : Optional [AbstractEventLoop ] = PrivateAttr (default = None )
425434 _event_loop_thread : Optional [Thread ] = PrivateAttr (default = None )
426435
427- def __init__ (self , config : GatewayClientConfig = None , ** kwargs ) -> None :
428- # Exists for compatibility with positional argument instantiation
429- if config is None :
430- config = GatewayClientConfig ()
431- if kwargs :
432- config = GatewayClientConfig (** {** config .model_dump (exclude_unset = True ), ** kwargs })
433- super ().__init__ (config = config )
434-
435436 @model_validator (mode = "after" )
436437 def validate_client (self ):
438+ # Set Authorization header if bearer_token is provided
439+ if self .config .bearer_token :
440+ headers = self .http_args .get ("headers" , {}).copy ()
441+ headers ["Authorization" ] = f"Bearer { self .config .bearer_token } "
442+ self .http_args ["headers" ] = headers
443+
437444 if self ._event_loop is None :
438445 self ._event_loop = _get_or_new_event_loop ()
439446
@@ -461,10 +468,12 @@ def _initializeStreaming(self) -> None:
461468 def _initialize (self ) -> None :
462469 if not self ._initialized :
463470 # grab openapi spec
471+ openapi_url = f"{ _host (self .config )} /openapi.json"
472+ openapi_params = {"token" : self .config .api_key } if self .config .api_key else None
464473 self ._openapi_spec : Dict [Any , Any ] = replace_refs (
465474 cast (
466475 Dict [Any , Any ],
467- GET (f" { _host ( self . config ) } /openapi.json" , ** self .http_args ),
476+ GET (openapi_url , params = openapi_params , ** self .http_args ),
468477 ).json (),
469478 )
470479
@@ -504,9 +513,11 @@ def _buildpath(self, route: str) -> str:
504513
505514 def _buildroute (self , route : str ) -> str :
506515 url = f"{ _host (self .config )} { self ._buildpath (route )} "
507- if self .config .authenticate :
508- return url , {"token" : self .config .api_key }
509- return url , {}
516+ # If using api_key (not bearer_token), add as query param
517+ extra_params = {}
518+ if self .config .api_key :
519+ extra_params ["token" ] = self .config .api_key
520+ return url , extra_params
510521
511522 def _api_path_and_route (self , route : str ) -> str :
512523 return self .config .api_route + "/" + route
@@ -517,10 +528,11 @@ def _buildroutews(self, route: str) -> str:
517528 host = host .replace ("http://" , "ws://" )
518529 elif host .startswith ("https://" ):
519530 host = host .replace ("https://" , "wss://" )
520- if self .config .authenticate :
521- auth = f"?token={ self .config .api_key } "
522- else :
523- auth = ""
531+ # If using api_key (not bearer_token), add as query param
532+ auth = ""
533+ if self .config .api_key :
534+ sep = "&" if "?" in host else "?"
535+ auth = f"{ sep } token={ self .config .api_key } "
524536 return f"{ host } { self .config .api_route } /{ route } { auth } "
525537
526538 def _handle_response (self , resp : Response , route : str ) -> ResponseType :
0 commit comments