-
Notifications
You must be signed in to change notification settings - Fork 175
Expand file tree
/
Copy pathhook.py
More file actions
175 lines (148 loc) · 6.79 KB
/
hook.py
File metadata and controls
175 lines (148 loc) · 6.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from __future__ import annotations
import hashlib
import hmac
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any
import sqlalchemy as sa
import trafaret as t
from aiohttp import web
from dateutil.parser import parse as dateutil_parse
from ai.backend.common.logging_utils import BraceStyleAdapter
from ai.backend.common.plugin.hook import HookHandler, HookPlugin, Reject
from ai.backend.common.utils import nmget
from ai.backend.manager.errors.auth import AuthorizationFailed, InvalidAuthParameters
from ai.backend.manager.models.keypair import KeyPairRow, keypairs
from ai.backend.manager.models.user import UserStatus, users
from .utils import deserialize_stoken
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
plugin_config_checker = t.Dict({
t.Key("auth_token_name", default="sToken"): t.Null | t.String,
}).allow_extra("*")
DEFAULT_STOKEN_COOKIE_VALUE = "BackendAI"
class KeypairAuthHookPlugin(HookPlugin):
def __init__(self, plugin_config: Mapping[str, Any], local_config: Mapping[str, Any]) -> None:
super().__init__(plugin_config, local_config)
self.plugin_config = plugin_config_checker.check(self.plugin_config)
def get_handlers(self) -> Sequence[tuple[str, HookHandler]]:
return [
("AUTHORIZE", self.authorize),
]
async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None:
self.plugin_config = plugin_config
async def init(self, context: Any = None) -> None:
pass
async def cleanup(self) -> None:
pass
def parse_token(self, token: str) -> tuple[str, str, str] | None:
pattern = r"BackendAI signMethod=(?P<sign_method>[A-Z0-9-]+), credential=(?P<access_key>\w+):(?P<signature>\w+)"
match = re.search(pattern, token)
if match:
sign_method = match.group("sign_method")
access_key = match.group("access_key")
signature = match.group("signature")
return (sign_method, access_key, signature)
return None
async def sign_token(self, sign_method: str, secret_key: str, params: Mapping[str, Any]) -> str:
try:
mac_type, hash_type = map(lambda s: s.lower(), sign_method.split("-"))
if mac_type != "hmac":
raise InvalidAuthParameters("Unsupported signing method (MAC type)")
if hash_type not in hashlib.algorithms_guaranteed:
raise InvalidAuthParameters("Unsupported signing method (hash type)")
date_obj = dateutil_parse(params["date"])
date = date_obj.isoformat()
endpoint = params["endpoint"]
api_version = params["api_version"]
if date is None:
raise InvalidAuthParameters("Request date is missing")
if endpoint is None:
raise InvalidAuthParameters("Request endpoint is missing")
if api_version is None:
raise InvalidAuthParameters("API version is missing")
body = b""
body_hash = hashlib.new(hash_type, body).hexdigest()
sign_bytes = (
"{0}\n{1}\n{2}\nhost:{3}\ncontent-type:{4}\nx-{name}-version:{5}\n{6}".format(
"POST",
"/authorize/keypair",
date,
endpoint,
"application/json",
api_version,
body_hash,
name="backendai",
)
).encode()
sign_key = hmac.new(
secret_key.encode(), date_obj.strftime("%Y%m%d").encode(), hash_type
).digest()
sign_key = hmac.new(sign_key, endpoint.encode(), hash_type).digest()
return hmac.new(sign_key, sign_bytes, hash_type).hexdigest()
except ValueError:
raise AuthorizationFailed("Invalid signature") from None
async def authorize(
self,
request: web.Request,
params: Mapping[str, Any],
) -> Any:
root_app = request.app["_root_app"]
db = root_app["_db"]
config_provider = root_app["_config_provider"]
shared_config = await config_provider.legacy_etcd_config_loader.load()
plugin_config = nmget(shared_config, "plugins.webapp.keypair_auth")
auth_token_name = self.plugin_config["auth_token_name"]
try:
body = await request.json()
except Exception:
body = {}
stoken = params[auth_token_name]
if stoken:
secret = plugin_config["secret"]
try:
payload = deserialize_stoken(stoken, secret)
query = sa.select(KeyPairRow).where(KeyPairRow.access_key == payload.access_key)
async with db.begin_readonly_session() as db_session:
keypair_row = await db_session.scalar(query)
user_id = keypair_row.user
except Exception:
try:
result = self.parse_token(stoken)
if not result:
raise Reject("invalid authentication token")
sign_method, access_key, signature = result
async with db.begin() as conn:
query = (
sa.select(keypairs.c.user, keypairs.c.secret_key)
.select_from(keypairs)
.where(keypairs.c.access_key == access_key)
)
result = await conn.execute(query)
keypair = result.fetchone()
sign_params = {
"date": body.get("date"),
"endpoint": body.get("endpoint"),
"api_version": body.get("api_version"),
}
generated_token = await self.sign_token(
sign_method, keypair.secret_key, sign_params
)
if generated_token != signature:
raise Reject("Invalid auth token")
user_id = keypair.user
except Exception as e:
log.error("AUTHORIZE_KEYPAIR_HOOK: invalid auth token {}", stoken)
log.error(repr(e))
raise Reject("Invalid auth token") from None
else:
return None # no-op for normal login
async with db.begin() as conn:
query = sa.select(users).select_from(users).where(users.c.uuid == user_id)
result = await conn.execute(query)
user = result.fetchone()
if not user:
raise Reject("No such user with access key")
if user.status != UserStatus.ACTIVE:
raise Reject("user is inactivated with access key")
return user