|
1 | 1 | import json
|
2 | 2 | import logging
|
3 | 3 |
|
4 |
| -from spaceone.core import cache |
5 |
| -from spaceone.core import config |
| 4 | +from fastapi import Request, Response |
| 5 | +from fastapi.responses import RedirectResponse |
| 6 | +from spaceone.core import cache, config |
6 | 7 | from spaceone.core.auth.jwt import JWTAuthenticator, JWTUtil
|
7 | 8 | from spaceone.core.error import ERROR_AUTHENTICATE_FAILURE
|
8 |
| -from spaceone.core.service import * |
| 9 | +from spaceone.core.service import BaseService, event_handler, transaction |
9 | 10 |
|
10 | 11 | from cloudforet.console_api_v2.manager.cloudforet_manager import CloudforetManager
|
| 12 | +from cloudforet.console_api_v2.service.proxy_service import ProxyService |
11 | 13 |
|
12 | 14 | _LOGGER = logging.getLogger(__name__)
|
13 | 15 |
|
@@ -51,6 +53,23 @@ def basic(self, params: dict) -> None:
|
51 | 53 | self._check_app(client_id, domain_id)
|
52 | 54 | self._authenticate(token, domain_id)
|
53 | 55 |
|
| 56 | + def saml(self, params: dict) -> RedirectResponse: |
| 57 | + request = params.get("request") |
| 58 | + form_data = params.get("form_data") |
| 59 | + domain_id = params.get("domain_id") |
| 60 | + |
| 61 | + credentials = self._extract_credentials(request, dict(form_data)) |
| 62 | + refresh_token = self._issue_token(credentials, domain_id) |
| 63 | + domain_name = self._get_domain_name(domain_id) |
| 64 | + return self._redirect_response(domain_name, refresh_token) |
| 65 | + |
| 66 | + def saml_sp_metadata(self, domain_id: str) -> Response: |
| 67 | + sp_entity_id = domain_id |
| 68 | + domain_name = self._get_domain_name(domain_id) |
| 69 | + acs_url = self._get_acs_url(domain_name, domain_id) |
| 70 | + metadata_xml = self._generate_sp_metadata(sp_entity_id, acs_url) |
| 71 | + return Response(content=metadata_xml, media_type="application/xml") |
| 72 | + |
54 | 73 | def _authenticate(self, token: str, domain_id: str) -> dict:
|
55 | 74 | public_key = self._get_public_key(domain_id)
|
56 | 75 | return JWTAuthenticator(json.loads(public_key)).validate(token)
|
@@ -100,3 +119,85 @@ def _check_app(client_id: str, domain_id: str):
|
100 | 119 | {"client_id": client_id, "domain_id": domain_id},
|
101 | 120 | token=system_token,
|
102 | 121 | )
|
| 122 | + |
| 123 | + @staticmethod |
| 124 | + def _extract_credentials(request: Request, form_data: dict) -> dict: |
| 125 | + return { |
| 126 | + "http_host": request.client.host, |
| 127 | + "server_port": str(request.url.port), |
| 128 | + "script_name": request.url.path, |
| 129 | + "post_data": form_data, |
| 130 | + } |
| 131 | + |
| 132 | + @staticmethod |
| 133 | + def _issue_token(credentials: dict, domain_id: str) -> str: |
| 134 | + dispatch_params = { |
| 135 | + "grpc_method": "identity.Token.issue", |
| 136 | + "credentials": credentials, |
| 137 | + "auth_type": "EXTERNAL", |
| 138 | + "domain_id": domain_id, |
| 139 | + } |
| 140 | + |
| 141 | + proxy_service = ProxyService() |
| 142 | + response = proxy_service.dispatch_api(dispatch_params) |
| 143 | + |
| 144 | + return response.get("refresh_token") |
| 145 | + |
| 146 | + @staticmethod |
| 147 | + def _get_domain_name(domain_id: str) -> str: |
| 148 | + cloudforet_mgr = CloudforetManager() |
| 149 | + grpc_method = "identity.Domain.get" |
| 150 | + dispatch_params = {"domain_id": domain_id} |
| 151 | + system_token = config.get_global("TOKEN") |
| 152 | + |
| 153 | + response = cloudforet_mgr.dispatch_api( |
| 154 | + grpc_method, dispatch_params, system_token |
| 155 | + ) |
| 156 | + |
| 157 | + return response.get("name") |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def _redirect_response(domain_name: str, refresh_token: str) -> RedirectResponse: |
| 161 | + console_domain: str = config.get_global("CONSOLE_DOMAIN").format( |
| 162 | + domain_name=domain_name |
| 163 | + ) |
| 164 | + |
| 165 | + if refresh_token: |
| 166 | + return RedirectResponse( |
| 167 | + f"{console_domain}/saml?refresh_token={refresh_token}", |
| 168 | + status_code=302, |
| 169 | + ) |
| 170 | + |
| 171 | + return RedirectResponse(f"{console_domain}", status_code=302) |
| 172 | + |
| 173 | + @staticmethod |
| 174 | + def _get_acs_url(domain_name: str, domain_id: str) -> str: |
| 175 | + console_api_v2_endpoint = config.get_global("CONSOLE_API_V2_ENDPOINT") |
| 176 | + acs_url = ( |
| 177 | + f"{console_api_v2_endpoint}/console-api/extension/auth/saml/{domain_id}" |
| 178 | + ) |
| 179 | + |
| 180 | + return acs_url |
| 181 | + |
| 182 | + @staticmethod |
| 183 | + def _generate_sp_metadata(sp_entity_id: str, acs_url: str) -> str: |
| 184 | + """Generates SP metadata XML. |
| 185 | +
|
| 186 | + Args: |
| 187 | + 'sp_entity_id': 'str' (Service Provider Entity ID), |
| 188 | + 'acs_url': 'str' (Assertion Consumer Service URL), |
| 189 | + 'x509_cert': 'str' (X.509 certificate), |
| 190 | +
|
| 191 | + Returns: |
| 192 | + 'metadata_template': 'str' (SP metadata XML) |
| 193 | + """ |
| 194 | + metadata_template = f""" |
| 195 | + <?xml version="1.0" encoding="UTF-8"?> |
| 196 | + <md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{sp_entity_id}"> |
| 197 | + <md:SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"> |
| 198 | + <md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{acs_url}" index="1"/> |
| 199 | + </md:SPSSODescriptor> |
| 200 | + </md:EntityDescriptor> |
| 201 | + """ |
| 202 | + |
| 203 | + return metadata_template.strip() |
0 commit comments