Skip to content

Commit 1443687

Browse files
committed
add support for oauth credentials (refresh token)
1 parent 150ad68 commit 1443687

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

finegrain/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ authors = [
88
dependencies = [
99
"httpx>=0.27.0",
1010
"httpx-sse>=0.4.0",
11+
"pyjwt[crypto]>=2.10.1",
1112
]
1213
readme = "README.md"
1314
requires-python = ">= 3.12, <3.13"

finegrain/requirements.lock

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
-e file:.
1313
anyio==4.9.0
1414
# via httpx
15-
certifi==2025.1.31
15+
certifi==2025.4.26
1616
# via httpcore
1717
# via httpx
18-
h11==0.14.0
18+
cffi==1.17.1
19+
# via cryptography
20+
cryptography==45.0.3
21+
# via pyjwt
22+
h11==0.16.0
1923
# via httpcore
20-
httpcore==1.0.8
24+
httpcore==1.0.9
2125
# via httpx
2226
httpx==0.28.1
2327
# via finegrain
@@ -26,6 +30,10 @@ httpx-sse==0.4.0
2630
idna==3.10
2731
# via anyio
2832
# via httpx
33+
pycparser==2.22
34+
# via cffi
35+
pyjwt==2.10.1
36+
# via finegrain
2937
sniffio==1.3.1
3038
# via anyio
3139
typing-extensions==4.13.2

finegrain/src/finegrain/__init__.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010

1111
import httpx
1212
import httpx_sse
13+
import jwt
1314
from httpx._types import QueryParamTypes, RequestData, RequestFiles # pyright: ignore[reportPrivateImportUsage]
1415

1516
logger = logging.getLogger(__name__)
1617

1718
Priority = Literal["low", "standard", "high"]
1819
StateID = NewType("StateID", str)
1920

20-
VERSION = "0.2"
21+
VERSION = "0.3"
2122

2223
API_KEY_PATTERN = re.compile(r"^FGAPI(\-[A-Z0-9]{6}){4}$")
2324
EMAIL_PWD_PATTERN = re.compile(r"^\s*(?P<email>[\S]+?@[\S]+?):(?P<pwd>\S+)\s*$")
@@ -293,7 +294,61 @@ def description(self) -> str:
293294
return f"API key {self.api_key[:13]}..."
294295

295296

296-
type Credentials = LoginCredentials | ApiKeyCredentials
297+
@dc.dataclass(kw_only=True)
298+
class OAuthCredentials:
299+
access_token: str
300+
refresh_token: str
301+
client_id: str
302+
client_secret: str
303+
account_url: str = "https://account.finegrain.ai"
304+
account_verify: bool | str = True
305+
306+
def __post_init__(self):
307+
self.validate_tokens()
308+
309+
def validate_tokens(self) -> None:
310+
assert self.access_token, "access_token must not be empty"
311+
assert self.refresh_token, "refresh_token must not be empty"
312+
313+
decoded_access = jwt.decode(self.access_token, options={"verify_signature": False})
314+
decoded_refresh = jwt.decode(self.refresh_token, options={"verify_signature": False})
315+
316+
assert decoded_access["aud"] == "access"
317+
assert decoded_refresh["aud"] == "refresh"
318+
319+
sub = decoded_access.get("sub", "")
320+
assert sub.startswith("FGUSR-")
321+
322+
assert decoded_refresh["iss"] == self.client_id
323+
assert decoded_refresh["sub"] == sub
324+
325+
@property
326+
def as_login_params(self) -> dict[str, str]:
327+
raise ValueError("cannot login with OAuth credentials")
328+
329+
@property
330+
def description(self) -> str:
331+
return f"OAuth client {self.client_id}"
332+
333+
async def renew(self) -> None:
334+
async with httpx.AsyncClient(verify=self.account_verify) as client:
335+
response = await client.post(
336+
url=f"{self.account_url}/oauth/token",
337+
data={
338+
"grant_type": "refresh_token",
339+
"client_id": self.client_id,
340+
"client_secret": self.client_secret,
341+
"refresh_token": self.refresh_token,
342+
},
343+
)
344+
check_status(response)
345+
r = response.json()
346+
self.access_token = r["access_token"]
347+
self.refresh_token = r["refresh_token"]
348+
self.validate_tokens()
349+
350+
351+
Credentials = LoginCredentials | ApiKeyCredentials | OAuthCredentials
297352

298353

299354
class EditorAPIContext:
@@ -323,7 +378,7 @@ class EditorAPIContext:
323378
def __init__(
324379
self,
325380
*,
326-
credentials: str | None = None,
381+
credentials: Credentials | str | None = None,
327382
api_key: str | None = None,
328383
user: str | None = None,
329384
password: str | None = None,
@@ -340,7 +395,9 @@ def __init__(
340395
self.default_timeout = default_timeout
341396
self.subscription_topic = subscription_topic
342397

343-
if credentials is not None:
398+
if isinstance(credentials, Credentials):
399+
self.credentials = credentials
400+
elif credentials is not None:
344401
if (m := API_KEY_PATTERN.match(credentials)) is not None:
345402
self.credentials = ApiKeyCredentials(api_key=m[0])
346403
elif (m := EMAIL_PWD_PATTERN.match(credentials)) is not None:
@@ -367,6 +424,9 @@ def __init__(
367424
verify=self.verify,
368425
)
369426
self.reset()
427+
if isinstance(self.credentials, OAuthCredentials):
428+
# Use token provided in credentials initially (avoids a useless refresh).
429+
self.token = self.credentials.access_token
370430

371431
def reset(self) -> None:
372432
self.token = None
@@ -404,6 +464,16 @@ def auth_headers(self) -> dict[str, str]:
404464
return {"Authorization": f"Bearer {self.token}"}
405465

406466
async def login(self) -> None:
467+
if isinstance(self.credentials, OAuthCredentials):
468+
if self.token is None:
469+
await self.credentials.renew()
470+
self.token = self.credentials.access_token
471+
# If the token is set but invalid, `me` will fail with 401.
472+
# The token will be unset and `login` will be called again.
473+
r = await self.me()
474+
self.credits = r["credits"]
475+
self.logger.debug(f"logged in as {self.credentials.description} - {r["username"]}")
476+
return
407477
async with self as client:
408478
response = await client.post(
409479
f"{self.base_url}/auth/login",
@@ -447,6 +517,7 @@ async def _q() -> httpx.Response:
447517
r = await _q()
448518
if r.status_code == 401:
449519
self.logger.debug("renewing token")
520+
self.token = None
450521
await self.login()
451522
r = await _q()
452523

0 commit comments

Comments
 (0)