|
1 | 1 | from urllib.parse import urlparse |
2 | 2 |
|
3 | | -from authlib.integrations.starlette_client import OAuthError |
4 | 3 | from fastapi import Request |
5 | 4 | from fastapi.responses import RedirectResponse |
6 | 5 |
|
@@ -39,14 +38,34 @@ class SSOLoginCallback(APIRequest): |
39 | 38 | methods = ["GET"] |
40 | 39 |
|
41 | 40 | async def handle(self, request: Request, provider: str) -> RedirectResponse: |
42 | | - client = NebulaSSO.client(provider) |
| 41 | + remote = NebulaSSO.client(provider) |
| 42 | + if not remote: |
| 43 | + return RedirectResponse("/?error=Invalid provider") |
43 | 44 |
|
44 | | - try: |
45 | | - token = await client.authorize_access_token(request) |
46 | | - except OAuthError as error: |
47 | | - return RedirectResponse(f"/?error={error}") |
48 | | - user = token.get("userinfo", {}) |
49 | | - email = user.get("email") |
| 45 | + code = request.query_params.get("code") |
| 46 | + id_token = request.query_params.get("id_token") |
| 47 | + oauth_verifier = request.query_params.get("oauth_verifier") |
| 48 | + |
| 49 | + user_info = {} |
| 50 | + |
| 51 | + if code: |
| 52 | + token = await remote.authorize_access_token(request) |
| 53 | + user_info = token.get("userinfo", {}) |
| 54 | + |
| 55 | + if id_token and not user_info: |
| 56 | + token = {"id_token": id_token} |
| 57 | + user_info = await remote.parse_id_token(request, token) |
| 58 | + |
| 59 | + if oauth_verifier and not user_info: |
| 60 | + token = await remote.authorize_access_token(request) |
| 61 | + |
| 62 | + if token and not user_info: |
| 63 | + user_info = await remote.userinfo(token=token) |
| 64 | + |
| 65 | + if not user_info: |
| 66 | + return RedirectResponse("/?error=Invalid response from provider") |
| 67 | + |
| 68 | + email = user_info.get("email") |
50 | 69 |
|
51 | 70 | if not email: |
52 | 71 | return RedirectResponse("/?error=User email not found") |
|
0 commit comments