1
+ import logging
1
2
import os
2
3
import ssl
3
- import logging
4
- import jwt
4
+
5
5
import grpc
6
+ import jwt
7
+ import requests
6
8
from aiohttp import hdrs , web
7
-
8
- from temporalio .api .common .v1 import Payload , Payloads
9
- from temporalio .api .cloud .cloudservice .v1 import request_response_pb2 , service_pb2_grpc
10
9
from google .protobuf import json_format
10
+ from jwt .algorithms import RSAAlgorithm
11
+ from temporalio .api .cloud .cloudservice .v1 import request_response_pb2 , service_pb2_grpc
12
+ from temporalio .api .common .v1 import Payload , Payloads
13
+
11
14
from encryption_jwt .codec import EncryptionCodec
12
15
13
- AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["admin" ]
16
+ AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner" , " admin" ]
14
17
AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read" , "write" , "admin" ]
15
18
16
19
temporal_ops_address = "saas-api.tmprl.cloud:443"
@@ -43,51 +46,101 @@ async def cors_options(req: web.Request) -> web.Response:
43
46
return resp
44
47
45
48
def decryption_authorized (email : str , namespace : str ) -> bool :
46
- credentials = grpc .composite_channel_credentials (grpc .ssl_channel_credentials (
47
- ), grpc .access_token_call_credentials (os .environ .get ("TEMPORAL_API_KEY" )))
49
+ credentials = grpc .composite_channel_credentials (
50
+ grpc .ssl_channel_credentials (),
51
+ grpc .access_token_call_credentials (os .environ .get ("TEMPORAL_API_KEY" )),
52
+ )
48
53
49
54
with grpc .secure_channel (temporal_ops_address , credentials ) as channel :
50
55
client = service_pb2_grpc .CloudServiceStub (channel )
51
56
request = request_response_pb2 .GetUsersRequest ()
52
57
53
- response = client .GetUsers (request , metadata = (
54
- ("temporal-cloud-api-version" , os .environ .get ("TEMPORAL_OPS_API_VERSION" )),))
58
+ response = client .GetUsers (
59
+ request ,
60
+ metadata = (
61
+ (
62
+ "temporal-cloud-api-version" ,
63
+ os .environ .get ("TEMPORAL_OPS_API_VERSION" ),
64
+ ),
65
+ ),
66
+ )
55
67
56
- authorized = False
57
68
for user in response .users :
58
69
if user .spec .email .lower () == email .lower ():
59
- if user .spec .access .account_access .role in AUTHORIZED_ACCOUNT_ACCESS_ROLES :
60
- authorized = True
70
+ if (
71
+ user .spec .access .account_access .role
72
+ in AUTHORIZED_ACCOUNT_ACCESS_ROLES
73
+ ):
74
+ return True
61
75
else :
62
76
if namespace in user .spec .access .namespace_accesses :
63
- if user .spec .access .namespace_accesses [namespace ].permission in AUTHORIZED_NAMESPACE_ACCESS_ROLES :
64
- authorized = True
77
+ if (
78
+ user .spec .access .namespace_accesses [
79
+ namespace
80
+ ].permission
81
+ in AUTHORIZED_NAMESPACE_ACCESS_ROLES
82
+ ):
83
+ return True
65
84
66
- return authorized
85
+ return False
67
86
68
87
def make_handler (fn : str ):
69
88
async def handler (req : web .Request ):
70
- # Read payloads as JSON
71
- assert req .content_type == "application/json"
72
- payloads = json_format .Parse (await req .read (), Payloads ())
73
-
74
- # Extract the email from the JWT.
75
- auth_header = req .headers .get ("Authorization" )
76
89
namespace = req .headers .get ("x-namespace" )
90
+ auth_header = req .headers .get ("Authorization" )
77
91
_bearer , encoded = auth_header .split (" " )
78
- decoded = jwt .decode (encoded , options = {"verify_signature" : False })
79
92
80
- # Use the email to determine if the payload should be decrypted.
81
- authorized = decryption_authorized (decoded ["https://saas-api.tmprl.cloud/user/email" ], namespace )
93
+ # Extract the kid from the Auth header
94
+ jwt_dict = jwt .get_unverified_header (encoded )
95
+ kid = jwt_dict ["kid" ]
96
+ algorithm = jwt_dict ["alg" ]
97
+
98
+ # Fetch Temporal Cloud JWKS
99
+ jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json"
100
+ jwks = requests .get (jwks_url ).json ()
101
+
102
+ # Extract Temporal Cloud's public key
103
+ public_key = None
104
+ for key in jwks ["keys" ]:
105
+ if key ["kid" ] == kid :
106
+ # Convert JWKS key to PEM format
107
+ public_key = RSAAlgorithm .from_jwk (key )
108
+ break
109
+
110
+ if public_key is None :
111
+ raise ValueError ("Public key not found in JWKS" )
112
+
113
+ # Decode the jwt, verifying against Temporal Cloud's public key
114
+ decoded = jwt .decode (
115
+ encoded ,
116
+ public_key ,
117
+ algorithms = [algorithm ],
118
+ audience = [
119
+ "https://saas-api.tmprl.cloud" ,
120
+ "https://prod-tmprl.us.auth0.com/userinfo" ,
121
+ ],
122
+ )
123
+
124
+ # Use the email to determine if the user is authorized to decrypt the payload
125
+ authorized = decryption_authorized (
126
+ decoded ["https://saas-api.tmprl.cloud/user/email" ], namespace
127
+ )
128
+
82
129
if authorized :
130
+ # Read payloads as JSON
131
+ assert req .content_type == "application/json"
132
+ payloads = json_format .Parse (await req .read (), Payloads ())
83
133
encryptionCodec = EncryptionCodec (namespace )
84
- payloads = Payloads (payloads = await getattr (encryptionCodec , fn )(payloads .payloads ))
134
+ payloads = Payloads (
135
+ payloads = await getattr (encryptionCodec , fn )(payloads .payloads )
136
+ )
85
137
86
138
# Apply CORS and return JSON
87
139
resp = await cors_options (req )
88
140
resp .content_type = "application/json"
89
141
resp .text = json_format .MessageToJson (payloads )
90
142
return resp
143
+
91
144
return handler
92
145
93
146
# Build app
@@ -97,8 +150,8 @@ async def handler(req: web.Request):
97
150
logger = logging .getLogger (__name__ )
98
151
app .add_routes (
99
152
[
100
- web .post ("/encode" , make_handler (' encode' )),
101
- web .post ("/decode" , make_handler (' decode' )),
153
+ web .post ("/encode" , make_handler (" encode" )),
154
+ web .post ("/decode" , make_handler (" decode" )),
102
155
web .options ("/decode" , cors_options ),
103
156
]
104
157
)
@@ -112,8 +165,10 @@ async def handler(req: web.Request):
112
165
if os .environ .get ("SSL_PEM" ) and os .environ .get ("SSL_KEY" ):
113
166
ssl_context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
114
167
ssl_context .check_hostname = False
115
- ssl_context .load_cert_chain (os .environ .get (
116
- "SSL_PEM" ), os .environ .get ("SSL_KEY" ))
168
+ ssl_context .load_cert_chain (
169
+ os .environ .get ("SSL_PEM" ), os .environ .get ("SSL_KEY" )
170
+ )
117
171
118
- web .run_app (build_codec_server (), host = "0.0.0.0" ,
119
- port = 8081 , ssl_context = ssl_context )
172
+ web .run_app (
173
+ build_codec_server (), host = "0.0.0.0" , port = 8081 , ssl_context = ssl_context
174
+ )
0 commit comments