Skip to content

Commit 42892c4

Browse files
authored
Merge pull request #20 from oracle-samples/auth-support-token-refresh
Auth support token refresh
2 parents 591e1bf + eb4d948 commit 42892c4

File tree

3 files changed

+289
-42
lines changed

3 files changed

+289
-42
lines changed

src/oci_openai/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

4-
__version__ = "1.0.0"
4+
__version__ = "1.1.0"

src/oci_openai/oci_openai.py

Lines changed: 227 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from __future__ import annotations
55

66
import logging
7+
import threading
8+
import time
9+
from abc import ABC, abstractmethod
710
from typing import Any, Generator, Mapping, Optional, Type
811

912
import httpx
@@ -150,45 +153,126 @@ def __init__(
150153
)
151154

152155

153-
class HttpxOciAuth(httpx.Auth):
156+
class HttpxOciAuth(httpx.Auth, ABC):
154157
"""
155-
Custom HTTPX authentication class that implements OCI request signing.
158+
Enhanced custom HTTPX authentication class that implements OCI request signing
159+
with auto-refresh.
156160
157161
This class handles the authentication flow for HTTPX requests by signing them
158-
using the OCI Signer, which adds the necessary authentication headers for
159-
OCI API calls.
160-
162+
using the OCI Signer, which adds the necessary authentication headers for OCI API calls.
163+
It also provides automatic token refresh functionality for token-based authentication methods.
161164
Attributes:
162165
signer (oci.signer.Signer): The OCI signer instance used for request signing
166+
refresh_interval: Seconds between token refreshes (default: 3600 - 1 hour)
167+
_lock: Threading lock for thread-safe token refresh
168+
_last_refresh: Last refresh timestamp
163169
"""
164170

165-
def __init__(self, signer: OciAuthSigner):
171+
def __init__(self, signer: OciAuthSigner, refresh_interval: int = 3600):
172+
"""
173+
Initialize the authentication with a signer and refresh configuration.
174+
Args:
175+
signer: OCI signer instance
176+
refresh_interval: Seconds between token refreshes (default: 3600 - 1 hour)
177+
"""
166178
self.signer = signer
179+
self.refresh_interval = refresh_interval
180+
self._lock = threading.Lock()
181+
self._last_refresh: Optional[float] = time.time()
182+
logger.info(
183+
"Initialized %s with refresh interval: %d seconds",
184+
self.__class__.__name__,
185+
refresh_interval,
186+
)
167187

168-
@override
169-
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
170-
# Read the request content to handle streaming requests properly
171-
try:
172-
content = request.content
173-
except httpx.RequestNotRead:
174-
# For streaming requests, we need to read the content first
175-
content = request.read()
188+
def _should_refresh_token(self) -> bool:
189+
"""
190+
Check if the token should be refreshed based on time interval.
191+
Returns:
192+
bool: True if token should be refreshed, False otherwise
193+
"""
194+
if not self._last_refresh:
195+
return True
196+
current_time = time.time()
197+
return (current_time - self._last_refresh) >= self.refresh_interval
198+
199+
@abstractmethod
200+
def _refresh_signer(self) -> None:
201+
"""
202+
Abstract method to refresh the signer. Must be implemented by subclasses.
203+
This method should create a new signer instance with fresh credentials/tokens.
204+
"""
205+
pass
176206

207+
def _refresh_if_needed(self) -> None:
208+
"""
209+
Refresh the signer if enough time has passed since last refresh.
210+
This method is thread-safe and will only refresh once per interval.
211+
"""
212+
with self._lock:
213+
if self._should_refresh_token():
214+
logger.info("Time interval reached, refreshing %s ...", self.__class__.__name__)
215+
try:
216+
self._refresh_signer()
217+
self._last_refresh = time.time()
218+
logger.info("%s token refresh completed successfully", self.__class__.__name__)
219+
except Exception as e:
220+
logger.exception("Warning: Token refresh failed:", e)
221+
222+
def _sign_request(self, request: httpx.Request, content: bytes) -> None:
223+
"""
224+
Sign the given HTTPX request with the OCI signer using the provided content.
225+
Updates request.headers in place with the signed headers.
226+
"""
177227
req = requests.Request(
178228
method=request.method,
179229
url=str(request.url),
180230
headers=dict(request.headers),
181231
data=content,
182232
)
183233
prepared_request = req.prepare()
184-
185-
# Sign the request using the OCI Signer
186234
self.signer.do_request_sign(prepared_request) # type: ignore
187-
188-
# Update the original HTTPX request with the signed headers
189235
request.headers.update(prepared_request.headers)
190236

191-
yield request
237+
@override
238+
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
239+
"""
240+
Authentication flow for HTTPX requests with automatic retry on 401 errors.
241+
This method:
242+
1. Checks if token needs refresh and refreshes if necessary
243+
2. Signs the request using OCI signer
244+
3. Yields the signed request
245+
4. If 401 error is received, attempts token refresh and retries once
246+
Args:
247+
request: The HTTPX request to be authenticated
248+
Yields:
249+
httpx.Request: The authenticated request
250+
"""
251+
# Check and refresh token if needed
252+
self._refresh_if_needed()
253+
254+
# Read the request content to handle streaming requests properly
255+
try:
256+
content = request.content
257+
except httpx.RequestNotRead:
258+
# For streaming requests, we need to read the content first
259+
content = request.read()
260+
261+
self._sign_request(request, content)
262+
263+
response = yield request
264+
265+
# If we get a 401 (Unauthorized), try refreshing the token once and retry
266+
if response.status_code == 401:
267+
logger.info("Received 401 Unauthorized, attempting token refresh and retry...")
268+
with self._lock:
269+
try:
270+
self._refresh_signer()
271+
self._last_refresh = time.time()
272+
self._sign_request(request, content)
273+
yield request
274+
except Exception as e:
275+
logger.exception("Token refresh on 401 failed:", e)
192276

193277

194278
class OciSessionAuth(HttpxOciAuth):
@@ -206,6 +290,7 @@ def __init__(
206290
self,
207291
config_file: str = DEFAULT_LOCATION,
208292
profile_name: str = DEFAULT_PROFILE,
293+
refresh_interval: int = 3600,
209294
**kwargs: Mapping[str, Any],
210295
):
211296
"""
@@ -218,6 +303,8 @@ def __init__(
218303
profile_name : str, optional
219304
Profile name inside the OCI configuration file to use.
220305
Defaults to "DEFAULT".
306+
refresh_interval: int, optional
307+
Seconds between token refreshes (default: 3600 - 1 hour)
221308
**kwargs : Mapping[str, Any]
222309
Optional keyword arguments:
223310
- `generic_headers`: Optional[Dict[str, str]]
@@ -237,6 +324,8 @@ def __init__(
237324
For any other initialization errors.
238325
"""
239326
# Load OCI configuration and token
327+
self.config_file = config_file
328+
self.profile_name = profile_name
240329
config = oci.config.from_file(config_file, profile_name)
241330
token = self._load_token(config)
242331

@@ -256,66 +345,163 @@ def __init__(
256345
if body_headers:
257346
additional_kwargs["body_headers"] = body_headers
258347

259-
self.signer = oci.auth.signers.SecurityTokenSigner(token, private_key, **additional_kwargs)
348+
self.additional_kwargs = additional_kwargs
349+
signer = oci.auth.signers.SecurityTokenSigner(token, private_key, **self.additional_kwargs)
350+
super().__init__(signer=signer, refresh_interval=refresh_interval)
260351

261352
def _load_token(self, config: Mapping[str, Any]) -> str:
353+
"""
354+
Load session token from file specified in configuration.
355+
Args:
356+
config: OCI configuration dictionary
357+
Returns:
358+
str: Session token content
359+
"""
262360
token_file = config["security_token_file"]
263361
with open(token_file, "r") as f:
264362
return f.read().strip()
265363

266364
def _load_private_key(self, config: Any) -> str:
365+
"""
366+
Load private key from file specified in configuration.
367+
Args:
368+
config: OCI configuration dictionary
369+
Returns:
370+
Private key object
371+
"""
267372
return oci.signer.load_private_key_from_file(config["key_file"])
268373

374+
def _refresh_signer(self) -> None:
375+
"""
376+
Refresh the session signer by reloading token and private key.
377+
This method creates a new SecurityTokenSigner with fresh credentials
378+
loaded from the configuration files.
379+
"""
380+
# Reload configuration in case it has changed
381+
config = oci.config.from_file(self.config_file, self.profile_name)
382+
token = self._load_token(config)
383+
private_key = self._load_private_key(config)
384+
self.signer = oci.auth.signers.SecurityTokenSigner(
385+
token, private_key, **self.additional_kwargs
386+
)
387+
269388

270389
class OciResourcePrincipalAuth(HttpxOciAuth):
271390
"""
272-
OCI authentication implementation using Resource Principal authentication.
391+
OCI authentication implementation using Resource Principal authentication with auto-refresh.
273392
274393
This class implements OCI authentication using Resource Principal credentials,
275-
which is suitable for services running within OCI that need to access other
276-
OCI services.
394+
which is suitable for services running within OCI (like Functions, Container Instances)
395+
that need to access other OCI services. The resource principal token is automatically
396+
refreshed at specified intervals.
277397
"""
278398

279-
def __init__(self, **kwargs: Any) -> None:
280-
super().__init__(signer=oci.auth.signers.get_resource_principals_signer(**kwargs))
399+
def __init__(self, refresh_interval: int = 3600, **kwargs: Any) -> None:
400+
"""
401+
Initialize resource principal authentication.
402+
Args:
403+
refresh_interval: Seconds between token refreshes (default: 3600 - 1 hour)
404+
**kwargs: Additional arguments passed to the resource principal signer
405+
"""
406+
self.kwargs = kwargs
407+
signer = oci.auth.signers.get_resource_principals_signer(**kwargs)
408+
super().__init__(signer=signer, refresh_interval=refresh_interval)
409+
410+
def _refresh_signer(self) -> None:
411+
"""
412+
Refresh the resource principal signer.
413+
This method creates a new resource principal signer which will
414+
automatically fetch fresh credentials from the OCI metadata service.
415+
"""
416+
self.signer = oci.auth.signers.get_resource_principals_signer(**self.kwargs)
281417

282418

283419
class OciInstancePrincipalAuth(HttpxOciAuth):
284420
"""
285-
OCI authentication implementation using Instance Principal authentication.
421+
OCI authentication implementation using Instance Principal authentication with auto-refresh.
286422
287423
This class implements OCI authentication using Instance Principal credentials,
288424
which is suitable for compute instances that need to access OCI services.
425+
The instance principal token is automatically refreshed at specified intervals.
289426
"""
290427

291-
def __init__(self, **kwargs: Any) -> None:
292-
super().__init__(signer=oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs))
428+
def __init__(self, refresh_interval: int = 3600, **kwargs) -> None: # noqa: ANN003
429+
"""
430+
Initialize instance principal authentication.
431+
Args:
432+
refresh_interval: Seconds between token refreshes (default: 3600 - 1 hour)
433+
**kwargs: Additional arguments passed to InstancePrincipalsSecurityTokenSigner
434+
"""
435+
self.kwargs = kwargs
436+
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
437+
super().__init__(signer=signer, refresh_interval=refresh_interval)
438+
439+
def _refresh_signer(self) -> None:
440+
"""
441+
Refresh the instance principal signer.
442+
This method creates a new InstancePrincipalsSecurityTokenSigner which will
443+
automatically fetch fresh credentials from the OCI metadata service.
444+
"""
445+
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**self.kwargs)
293446

294447

295448
class OciUserPrincipalAuth(HttpxOciAuth):
296449
"""
297-
OCI authentication implementation using user principal authentication.
450+
OCI authentication implementation using user principal authentication with auto-refresh.
298451
299-
This class implements OCI authentication using API Key credentials loaded from
452+
This class implements OCI authentication using API Key credentials loaded from
300453
the OCI configuration file. It's suitable for programmatic access to OCI services.
301-
454+
Since API key authentication doesn't use tokens that expire, this class doesn't
455+
need frequent refresh but supports configuration reload at specified intervals.
302456
Attributes:
303-
signer (oci.signer.Signer): OCI signer configured with API key credentials
457+
config_file (str): Path to OCI configuration file
458+
profile_name (str): Profile name in the configuration file
459+
config (dict): OCI configuration dictionary
304460
"""
305461

306462
def __init__(
307-
self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE
463+
self,
464+
config_file: str = DEFAULT_LOCATION,
465+
profile_name: str = DEFAULT_PROFILE,
466+
refresh_interval: int = 3600,
308467
) -> None:
309-
config = oci.config.from_file(config_file, profile_name)
310-
oci.config.validate_config(config) # type: ignore
468+
"""
469+
Initialize user principal authentication.
470+
Args:
471+
config_file: Path to OCI config file (default: ~/.oci/config)
472+
profile_name: Profile name to use (default: DEFAULT)
473+
refresh_interval: Seconds between config reloads (default: 3600 - 1 hour)
474+
"""
475+
self.config_file = config_file
476+
self.profile_name = profile_name
477+
self.config = oci.config.from_file(config_file, profile_name)
478+
oci.config.validate_config(self.config)
479+
signer = oci.signer.Signer(
480+
tenancy=self.config["tenancy"],
481+
user=self.config["user"],
482+
fingerprint=self.config["fingerprint"],
483+
private_key_file_location=self.config.get("key_file"),
484+
pass_phrase=oci.config.get_config_value_or_default(self.config, "pass_phrase"),
485+
private_key_content=self.config.get("key_content"),
486+
)
487+
super().__init__(signer=signer, refresh_interval=refresh_interval)
311488

489+
def _refresh_signer(self) -> None:
490+
"""
491+
Refresh the user principal signer.
492+
For API key authentication, this recreates the signer with the same credentials.
493+
This is mainly useful if the configuration file has been updated.
494+
"""
495+
# Reload configuration in case it has changed
496+
self.config = oci.config.from_file(self.config_file, self.profile_name)
497+
oci.config.validate_config(self.config)
312498
self.signer = oci.signer.Signer(
313-
tenancy=config["tenancy"],
314-
user=config["user"],
315-
fingerprint=config["fingerprint"],
316-
private_key_file_location=config.get("key_file"),
317-
pass_phrase=oci.config.get_config_value_or_default(config, "pass_phrase"), # type: ignore
318-
private_key_content=config.get("key_content"),
499+
tenancy=self.config["tenancy"],
500+
user=self.config["user"],
501+
fingerprint=self.config["fingerprint"],
502+
private_key_file_location=self.config.get("key_file"),
503+
pass_phrase=oci.config.get_config_value_or_default(self.config, "pass_phrase"),
504+
private_key_content=self.config.get("key_content"),
319505
)
320506

321507

0 commit comments

Comments
 (0)