Skip to content

Commit da877a9

Browse files
committed
feat: add SAML feature and optimize the code
Signed-off-by: Youngjin Jo <[email protected]>
1 parent 6692c31 commit da877a9

File tree

2 files changed

+114
-59
lines changed

2 files changed

+114
-59
lines changed

src/cloudforet/console_api_v2/interface/rest/extension/auth.py

+10-56
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,13 @@
22

33
from fastapi import Depends, Request
44
from fastapi.concurrency import run_in_threadpool
5-
from fastapi.responses import RedirectResponse
65
from fastapi.security import HTTPBasic, HTTPBasicCredentials
76
from fastapi_utils.cbv import cbv
87
from fastapi_utils.inferring_router import InferringRouter
9-
from spaceone.core import config
108
from spaceone.core.error import ERROR_REQUIRED_PARAMETER
119
from spaceone.core.fastapi.api import BaseAPI, exception_handler
1210

13-
from cloudforet.console_api_v2.manager.cloudforet_manager import CloudforetManager
1411
from cloudforet.console_api_v2.service.auth_service import AuthService
15-
from cloudforet.console_api_v2.service.proxy_service import ProxyService
1612

1713
_LOGGER = logging.getLogger(__name__)
1814
_AUTH_SCHEME = HTTPBasic()
@@ -40,57 +36,15 @@ async def basic(
4036
@router.post("/saml/{domain_id}")
4137
@exception_handler
4238
async def saml(self, request: Request, domain_id: str):
39+
saml_service: AuthService = AuthService()
4340
form_data = await request.form()
44-
credentials = self._extract_credentials(request, dict(form_data))
45-
refresh_token = self._issue_token(credentials, domain_id)
46-
domain_name = self._get_domain_name(domain_id)
47-
return self._redirect_response(domain_name, refresh_token)
41+
params = {"request": request, "form_data": form_data, "domain_id": domain_id}
42+
response = await run_in_threadpool(saml_service.saml, params)
43+
return response
4844

49-
@staticmethod
50-
def _extract_credentials(request: Request, form_data: dict) -> dict:
51-
return {
52-
"http_host": request.client.host,
53-
"server_port": request.url.port,
54-
"script_name": request.url.path,
55-
"post_data": form_data,
56-
}
57-
58-
@staticmethod
59-
def _issue_token(credentials: dict, domain_id: str) -> str:
60-
dispatch_params = {
61-
"grpc_method": "identity.Token.issue",
62-
"credentials": credentials,
63-
"auth_type": "EXTERNAL",
64-
"domain_id": domain_id,
65-
}
66-
67-
proxy_service = ProxyService()
68-
response = proxy_service.dispatch_api(dispatch_params)
69-
70-
return response.get("refresh_token")
71-
72-
@staticmethod
73-
def _get_domain_name(domain_id: str) -> str:
74-
cloudforet_mgr = CloudforetManager()
75-
grpc_method = "identity.Domain.get"
76-
dispatch_params = {"domain_id": domain_id}
77-
system_token = config.get_global("TOKEN")
78-
79-
response = cloudforet_mgr.dispatch_api(
80-
grpc_method, dispatch_params, system_token
81-
)
82-
83-
return response.get("name")
84-
85-
@staticmethod
86-
def _redirect_response(domain_name: str, refresh_token: str) -> RedirectResponse:
87-
console_domain: str = config.get_global("CONSOLE_DOMAIN").format(
88-
domain_name=domain_name
89-
)
90-
if refresh_token:
91-
return RedirectResponse(
92-
f"{console_domain}/saml?refresh_token={refresh_token}",
93-
status_code=302,
94-
)
95-
96-
return RedirectResponse(f"{console_domain}", status_code=302)
45+
@router.get("/saml/{domain_id}/metadata")
46+
@exception_handler
47+
async def saml_sp_metadata(self, domain_id: str):
48+
saml_service: AuthService = AuthService()
49+
response = await run_in_threadpool(saml_service.saml_sp_metadata, domain_id)
50+
return response

src/cloudforet/console_api_v2/service/auth_service.py

+104-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import json
22
import logging
33

4-
from spaceone.core import cache
5-
from spaceone.core import config
4+
from fastapi import Request, Response
5+
from fastapi.responses import RedirectResponse
6+
from spaceone.core import cache, config
67
from spaceone.core.auth.jwt import JWTAuthenticator, JWTUtil
78
from spaceone.core.error import ERROR_AUTHENTICATE_FAILURE
8-
from spaceone.core.service import *
9+
from spaceone.core.service import BaseService, event_handler, transaction
910

1011
from cloudforet.console_api_v2.manager.cloudforet_manager import CloudforetManager
12+
from cloudforet.console_api_v2.service.proxy_service import ProxyService
1113

1214
_LOGGER = logging.getLogger(__name__)
1315

@@ -51,6 +53,23 @@ def basic(self, params: dict) -> None:
5153
self._check_app(client_id, domain_id)
5254
self._authenticate(token, domain_id)
5355

56+
def saml(self, params: dict) -> RedirectResponse:
57+
request = params.get("request")
58+
form_data = params.get("form_data")
59+
domain_id = params.get("domain_id")
60+
61+
credentials = self._extract_credentials(request, dict(form_data))
62+
refresh_token = self._issue_token(credentials, domain_id)
63+
domain_name = self._get_domain_name(domain_id)
64+
return self._redirect_response(domain_name, refresh_token)
65+
66+
def saml_sp_metadata(self, domain_id: str) -> Response:
67+
sp_entity_id = domain_id
68+
domain_name = self._get_domain_name(domain_id)
69+
acs_url = self._get_acs_url(domain_name, domain_id)
70+
metadata_xml = self._generate_sp_metadata(sp_entity_id, acs_url)
71+
return Response(content=metadata_xml, media_type="application/xml")
72+
5473
def _authenticate(self, token: str, domain_id: str) -> dict:
5574
public_key = self._get_public_key(domain_id)
5675
return JWTAuthenticator(json.loads(public_key)).validate(token)
@@ -100,3 +119,85 @@ def _check_app(client_id: str, domain_id: str):
100119
{"client_id": client_id, "domain_id": domain_id},
101120
token=system_token,
102121
)
122+
123+
@staticmethod
124+
def _extract_credentials(request: Request, form_data: dict) -> dict:
125+
return {
126+
"http_host": request.client.host,
127+
"server_port": str(request.url.port),
128+
"script_name": request.url.path,
129+
"post_data": form_data,
130+
}
131+
132+
@staticmethod
133+
def _issue_token(credentials: dict, domain_id: str) -> str:
134+
dispatch_params = {
135+
"grpc_method": "identity.Token.issue",
136+
"credentials": credentials,
137+
"auth_type": "EXTERNAL",
138+
"domain_id": domain_id,
139+
}
140+
141+
proxy_service = ProxyService()
142+
response = proxy_service.dispatch_api(dispatch_params)
143+
144+
return response.get("refresh_token")
145+
146+
@staticmethod
147+
def _get_domain_name(domain_id: str) -> str:
148+
cloudforet_mgr = CloudforetManager()
149+
grpc_method = "identity.Domain.get"
150+
dispatch_params = {"domain_id": domain_id}
151+
system_token = config.get_global("TOKEN")
152+
153+
response = cloudforet_mgr.dispatch_api(
154+
grpc_method, dispatch_params, system_token
155+
)
156+
157+
return response.get("name")
158+
159+
@staticmethod
160+
def _redirect_response(domain_name: str, refresh_token: str) -> RedirectResponse:
161+
console_domain: str = config.get_global("CONSOLE_DOMAIN").format(
162+
domain_name=domain_name
163+
)
164+
165+
if refresh_token:
166+
return RedirectResponse(
167+
f"{console_domain}/saml?refresh_token={refresh_token}",
168+
status_code=302,
169+
)
170+
171+
return RedirectResponse(f"{console_domain}", status_code=302)
172+
173+
@staticmethod
174+
def _get_acs_url(domain_name: str, domain_id: str) -> str:
175+
console_api_v2_endpoint = config.get_global("CONSOLE_API_V2_ENDPOINT")
176+
acs_url = (
177+
f"{console_api_v2_endpoint}/console-api/extension/auth/saml/{domain_id}"
178+
)
179+
180+
return acs_url
181+
182+
@staticmethod
183+
def _generate_sp_metadata(sp_entity_id: str, acs_url: str) -> str:
184+
"""Generates SP metadata XML.
185+
186+
Args:
187+
'sp_entity_id': 'str' (Service Provider Entity ID),
188+
'acs_url': 'str' (Assertion Consumer Service URL),
189+
'x509_cert': 'str' (X.509 certificate),
190+
191+
Returns:
192+
'metadata_template': 'str' (SP metadata XML)
193+
"""
194+
metadata_template = f"""
195+
<?xml version="1.0" encoding="UTF-8"?>
196+
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{sp_entity_id}">
197+
<md:SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
198+
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{acs_url}" index="1"/>
199+
</md:SPSSODescriptor>
200+
</md:EntityDescriptor>
201+
"""
202+
203+
return metadata_template.strip()

0 commit comments

Comments
 (0)