1010
1111import httpx
1212import httpx_sse
13+ import jwt
1314from httpx ._types import QueryParamTypes , RequestData , RequestFiles # pyright: ignore[reportPrivateImportUsage]
1415
1516logger = logging .getLogger (__name__ )
1617
1718Priority = Literal ["low" , "standard" , "high" ]
1819StateID = NewType ("StateID" , str )
1920
20- VERSION = "0.2 "
21+ VERSION = "0.3 "
2122
2223API_KEY_PATTERN = re .compile (r"^FGAPI(\-[A-Z0-9]{6}){4}$" )
2324EMAIL_PWD_PATTERN = re .compile (r"^\s*(?P<email>[\S]+?@[\S]+?):(?P<pwd>\S+)\s*$" )
@@ -293,7 +294,61 @@ def description(self) -> str:
293294 return f"API key { self .api_key [:13 ]} ..."
294295
295296
296- type Credentials = LoginCredentials | ApiKeyCredentials
297+ @dc .dataclass (kw_only = True )
298+ class OAuthCredentials :
299+ access_token : str
300+ refresh_token : str
301+ client_id : str
302+ client_secret : str
303+ account_url : str = "https://account.finegrain.ai"
304+ account_verify : bool | str = True
305+
306+ def __post_init__ (self ):
307+ self .validate_tokens ()
308+
309+ def validate_tokens (self ) -> None :
310+ assert self .access_token , "access_token must not be empty"
311+ assert self .refresh_token , "refresh_token must not be empty"
312+
313+ decoded_access = jwt .decode (self .access_token , options = {"verify_signature" : False })
314+ decoded_refresh = jwt .decode (self .refresh_token , options = {"verify_signature" : False })
315+
316+ assert decoded_access ["aud" ] == "access"
317+ assert decoded_refresh ["aud" ] == "refresh"
318+
319+ sub = decoded_access .get ("sub" , "" )
320+ assert sub .startswith ("FGUSR-" )
321+
322+ assert decoded_refresh ["iss" ] == self .client_id
323+ assert decoded_refresh ["sub" ] == sub
324+
325+ @property
326+ def as_login_params (self ) -> dict [str , str ]:
327+ raise ValueError ("cannot login with OAuth credentials" )
328+
329+ @property
330+ def description (self ) -> str :
331+ return f"OAuth client { self .client_id } "
332+
333+ async def renew (self ) -> None :
334+ async with httpx .AsyncClient (verify = self .account_verify ) as client :
335+ response = await client .post (
336+ url = f"{ self .account_url } /oauth/token" ,
337+ data = {
338+ "grant_type" : "refresh_token" ,
339+ "client_id" : self .client_id ,
340+ "client_secret" : self .client_secret ,
341+ "refresh_token" : self .refresh_token ,
342+ },
343+ )
344+ check_status (response )
345+ r = response .json ()
346+ self .access_token = r ["access_token" ]
347+ self .refresh_token = r ["refresh_token" ]
348+ self .validate_tokens ()
349+
350+
351+ Credentials = LoginCredentials | ApiKeyCredentials | OAuthCredentials
297352
298353
299354class EditorAPIContext :
@@ -323,7 +378,7 @@ class EditorAPIContext:
323378 def __init__ (
324379 self ,
325380 * ,
326- credentials : str | None = None ,
381+ credentials : Credentials | str | None = None ,
327382 api_key : str | None = None ,
328383 user : str | None = None ,
329384 password : str | None = None ,
@@ -340,7 +395,9 @@ def __init__(
340395 self .default_timeout = default_timeout
341396 self .subscription_topic = subscription_topic
342397
343- if credentials is not None :
398+ if isinstance (credentials , Credentials ):
399+ self .credentials = credentials
400+ elif credentials is not None :
344401 if (m := API_KEY_PATTERN .match (credentials )) is not None :
345402 self .credentials = ApiKeyCredentials (api_key = m [0 ])
346403 elif (m := EMAIL_PWD_PATTERN .match (credentials )) is not None :
@@ -367,6 +424,9 @@ def __init__(
367424 verify = self .verify ,
368425 )
369426 self .reset ()
427+ if isinstance (self .credentials , OAuthCredentials ):
428+ # Use token provided in credentials initially (avoids a useless refresh).
429+ self .token = self .credentials .access_token
370430
371431 def reset (self ) -> None :
372432 self .token = None
@@ -404,6 +464,16 @@ def auth_headers(self) -> dict[str, str]:
404464 return {"Authorization" : f"Bearer { self .token } " }
405465
406466 async def login (self ) -> None :
467+ if isinstance (self .credentials , OAuthCredentials ):
468+ if self .token is None :
469+ await self .credentials .renew ()
470+ self .token = self .credentials .access_token
471+ # If the token is set but invalid, `me` will fail with 401.
472+ # The token will be unset and `login` will be called again.
473+ r = await self .me ()
474+ self .credits = r ["credits" ]
475+ self .logger .debug (f"logged in as { self .credentials .description } - { r ["username" ]} " )
476+ return
407477 async with self as client :
408478 response = await client .post (
409479 f"{ self .base_url } /auth/login" ,
@@ -447,6 +517,7 @@ async def _q() -> httpx.Response:
447517 r = await _q ()
448518 if r .status_code == 401 :
449519 self .logger .debug ("renewing token" )
520+ self .token = None
450521 await self .login ()
451522 r = await _q ()
452523
0 commit comments