Skip to content

Commit d38c972

Browse files
committed
PR feedback
1 parent 2e12453 commit d38c972

11 files changed

+1563
-1152
lines changed

Diff for: .gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
.venv
22
__pycache__
3-
_certs

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Some examples require extra dependencies. See each sample's directory for specif
6060
* [custom_decorator](custom_decorator) - Custom decorator to auto-heartbeat a long-running activity.
6161
* [dsl](dsl) - DSL workflow that executes steps defined in a YAML file.
6262
* [encryption](encryption) - Apply end-to-end encryption for all input/output.
63+
* [encryption_jwt](encryption_jwt) - Apply end-to-end encryption for all input/output using a KMS and per-namespace JWT-based auth.
6364
* [gevent_async](gevent_async) - Combine gevent and Temporal.
6465
* [langchain](langchain) - Orchestrate workflows for LangChain.
6566
* [message-passing introduction](message_passing/introduction/) - Introduction to queries, signals, and updates.

Diff for: encryption_jwt/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_certs

Diff for: encryption_jwt/README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ The Codec Server uses the [Operations API](https://docs.temporal.io/ops) to get
1111

1212
## Install
1313

14-
For this sample, the optional `encryption` and `bedrock` dependency groups must be included. To include, run:
14+
For this sample, the optional `encryption_jwt` and `bedrock` dependency groups must be included. To include, run:
1515

1616
```sh
17-
poetry install --with encryption,bedrock
17+
poetry install --with encryption_jwt,bedrock
1818
```
1919

2020
## Setup
@@ -31,17 +31,17 @@ Alternately replace the key management portion with your own implementation.
3131
### Self-signed certificates
3232

3333
The codec server will need to use HTTPS, self-signed certificates will work in the development
34-
environment. Run the following command in a `_certs` directory that's a subdirectory of the
35-
repository root, it will create certificate files that are good for 10 years.
34+
environment. Run the following command in a `_certs` directory that's a subdirectory of this one.
35+
It will create certificate files that are good for 10 years.
3636

3737
```sh
3838
openssl req -x509 -newkey rsa:4096 -sha256 -days 3650 -nodes -keyout localhost.key -out localhost.pem -subj "/CN=localhost"
3939
```
4040

4141
In the projects you can access the files using the following relative paths.
4242

43-
- `../_certs/localhost.pem`
44-
- `../_certs/localhost.key`
43+
- `./_certs/localhost.pem`
44+
- `./_certs/localhost.key`
4545

4646
## Run
4747

Diff for: encryption_jwt/codec.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from typing import Iterable, List
2+
23
from temporalio.api.common.v1 import Payload
34
from temporalio.converter import PayloadCodec
5+
46
from encryption_jwt.encryptor import KMSEncryptor
57

68

79
class EncryptionCodec(PayloadCodec):
8-
910
def __init__(self, namespace: str):
1011
self._encryptor = KMSEncryptor(namespace)
1112

1213
async def encode(self, payloads: Iterable[Payload]) -> List[Payload]:
1314
# We blindly encode all payloads with the key and set the metadata with the key that was
1415
# used (base64 encoded).
1516

16-
def encrypt_payload(p: Payload):
17-
data, key = self._encryptor.encrypt(p.SerializeToString())
17+
async def encrypt_payload(p: Payload):
18+
data, key = await self._encryptor.encrypt(p.SerializeToString())
1819
return Payload(
1920
metadata={
2021
"encoding": b"binary/encrypted",
@@ -23,12 +24,14 @@ def encrypt_payload(p: Payload):
2324
data=data,
2425
)
2526

26-
return list(map(encrypt_payload, payloads))
27+
# return list(map(encrypt_payload, payloads))
28+
return [await encrypt_payload(payload) for payload in payloads]
2729

2830
async def decode(self, payloads: Iterable[Payload]) -> List[Payload]:
29-
def decrypt_payload(p: Payload):
31+
async def decrypt_payload(p: Payload):
3032
data_key_encrypted_base64 = p.metadata.get("data_key_encrypted", b"")
31-
data = self._encryptor.decrypt(data_key_encrypted_base64, p.data)
33+
data = await self._encryptor.decrypt(data_key_encrypted_base64, p.data)
3234
return Payload.FromString(data)
3335

34-
return list(map(decrypt_payload, payloads))
36+
# return list(map(decrypt_payload, payloads))
37+
return [await decrypt_payload(payload) for payload in payloads]

Diff for: encryption_jwt/codec_server.py

+87-32
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
import logging
12
import os
23
import ssl
3-
import logging
4-
import jwt
4+
55
import grpc
6+
import jwt
7+
import requests
68
from aiohttp import hdrs, web
7-
8-
from temporalio.api.common.v1 import Payload, Payloads
9-
from temporalio.api.cloud.cloudservice.v1 import request_response_pb2, service_pb2_grpc
109
from google.protobuf import json_format
10+
from jwt.algorithms import RSAAlgorithm
11+
from temporalio.api.cloud.cloudservice.v1 import request_response_pb2, service_pb2_grpc
12+
from temporalio.api.common.v1 import Payload, Payloads
13+
1114
from encryption_jwt.codec import EncryptionCodec
1215

13-
AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["admin"]
16+
AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner", "admin"]
1417
AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read", "write", "admin"]
1518

1619
temporal_ops_address = "saas-api.tmprl.cloud:443"
@@ -43,51 +46,101 @@ async def cors_options(req: web.Request) -> web.Response:
4346
return resp
4447

4548
def decryption_authorized(email: str, namespace: str) -> bool:
46-
credentials = grpc.composite_channel_credentials(grpc.ssl_channel_credentials(
47-
), grpc.access_token_call_credentials(os.environ.get("TEMPORAL_API_KEY")))
49+
credentials = grpc.composite_channel_credentials(
50+
grpc.ssl_channel_credentials(),
51+
grpc.access_token_call_credentials(os.environ.get("TEMPORAL_API_KEY")),
52+
)
4853

4954
with grpc.secure_channel(temporal_ops_address, credentials) as channel:
5055
client = service_pb2_grpc.CloudServiceStub(channel)
5156
request = request_response_pb2.GetUsersRequest()
5257

53-
response = client.GetUsers(request, metadata=(
54-
("temporal-cloud-api-version", os.environ.get("TEMPORAL_OPS_API_VERSION")),))
58+
response = client.GetUsers(
59+
request,
60+
metadata=(
61+
(
62+
"temporal-cloud-api-version",
63+
os.environ.get("TEMPORAL_OPS_API_VERSION"),
64+
),
65+
),
66+
)
5567

56-
authorized = False
5768
for user in response.users:
5869
if user.spec.email.lower() == email.lower():
59-
if user.spec.access.account_access.role in AUTHORIZED_ACCOUNT_ACCESS_ROLES:
60-
authorized = True
70+
if (
71+
user.spec.access.account_access.role
72+
in AUTHORIZED_ACCOUNT_ACCESS_ROLES
73+
):
74+
return True
6175
else:
6276
if namespace in user.spec.access.namespace_accesses:
63-
if user.spec.access.namespace_accesses[namespace].permission in AUTHORIZED_NAMESPACE_ACCESS_ROLES:
64-
authorized = True
77+
if (
78+
user.spec.access.namespace_accesses[
79+
namespace
80+
].permission
81+
in AUTHORIZED_NAMESPACE_ACCESS_ROLES
82+
):
83+
return True
6584

66-
return authorized
85+
return False
6786

6887
def make_handler(fn: str):
6988
async def handler(req: web.Request):
70-
# Read payloads as JSON
71-
assert req.content_type == "application/json"
72-
payloads = json_format.Parse(await req.read(), Payloads())
73-
74-
# Extract the email from the JWT.
75-
auth_header = req.headers.get("Authorization")
7689
namespace = req.headers.get("x-namespace")
90+
auth_header = req.headers.get("Authorization")
7791
_bearer, encoded = auth_header.split(" ")
78-
decoded = jwt.decode(encoded, options={"verify_signature": False})
7992

80-
# Use the email to determine if the payload should be decrypted.
81-
authorized = decryption_authorized(decoded["https://saas-api.tmprl.cloud/user/email"], namespace)
93+
# Extract the kid from the Auth header
94+
jwt_dict = jwt.get_unverified_header(encoded)
95+
kid = jwt_dict["kid"]
96+
algorithm = jwt_dict["alg"]
97+
98+
# Fetch Temporal Cloud JWKS
99+
jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json"
100+
jwks = requests.get(jwks_url).json()
101+
102+
# Extract Temporal Cloud's public key
103+
public_key = None
104+
for key in jwks["keys"]:
105+
if key["kid"] == kid:
106+
# Convert JWKS key to PEM format
107+
public_key = RSAAlgorithm.from_jwk(key)
108+
break
109+
110+
if public_key is None:
111+
raise ValueError("Public key not found in JWKS")
112+
113+
# Decode the jwt, verifying against Temporal Cloud's public key
114+
decoded = jwt.decode(
115+
encoded,
116+
public_key,
117+
algorithms=[algorithm],
118+
audience=[
119+
"https://saas-api.tmprl.cloud",
120+
"https://prod-tmprl.us.auth0.com/userinfo",
121+
],
122+
)
123+
124+
# Use the email to determine if the user is authorized to decrypt the payload
125+
authorized = decryption_authorized(
126+
decoded["https://saas-api.tmprl.cloud/user/email"], namespace
127+
)
128+
82129
if authorized:
130+
# Read payloads as JSON
131+
assert req.content_type == "application/json"
132+
payloads = json_format.Parse(await req.read(), Payloads())
83133
encryptionCodec = EncryptionCodec(namespace)
84-
payloads = Payloads(payloads=await getattr(encryptionCodec, fn)(payloads.payloads))
134+
payloads = Payloads(
135+
payloads=await getattr(encryptionCodec, fn)(payloads.payloads)
136+
)
85137

86138
# Apply CORS and return JSON
87139
resp = await cors_options(req)
88140
resp.content_type = "application/json"
89141
resp.text = json_format.MessageToJson(payloads)
90142
return resp
143+
91144
return handler
92145

93146
# Build app
@@ -97,8 +150,8 @@ async def handler(req: web.Request):
97150
logger = logging.getLogger(__name__)
98151
app.add_routes(
99152
[
100-
web.post("/encode", make_handler('encode')),
101-
web.post("/decode", make_handler('decode')),
153+
web.post("/encode", make_handler("encode")),
154+
web.post("/decode", make_handler("decode")),
102155
web.options("/decode", cors_options),
103156
]
104157
)
@@ -112,8 +165,10 @@ async def handler(req: web.Request):
112165
if os.environ.get("SSL_PEM") and os.environ.get("SSL_KEY"):
113166
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
114167
ssl_context.check_hostname = False
115-
ssl_context.load_cert_chain(os.environ.get(
116-
"SSL_PEM"), os.environ.get("SSL_KEY"))
168+
ssl_context.load_cert_chain(
169+
os.environ.get("SSL_PEM"), os.environ.get("SSL_KEY")
170+
)
117171

118-
web.run_app(build_codec_server(), host="0.0.0.0",
119-
port=8081, ssl_context=ssl_context)
172+
web.run_app(
173+
build_codec_server(), host="0.0.0.0", port=8081, ssl_context=ssl_context
174+
)

Diff for: encryption_jwt/encryptor.py

+43-35
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,37 @@
1-
import os
21
import base64
32
import logging
4-
from temporalio import workflow
5-
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
3+
import os
4+
65
from botocore.exceptions import ClientError
6+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
7+
from temporalio import workflow
78

89
with workflow.unsafe.imports_passed_through():
9-
import boto3
10+
import aioboto3
1011

1112

1213
class KMSEncryptor:
1314
"""Encrypts and decrypts using keys from AWS KMS."""
1415

1516
def __init__(self, namespace: str):
1617
self._namespace = namespace
17-
self._kms_client = None
18+
self._boto_session = None
1819

1920
@property
20-
def kms_client(self):
21+
def boto_session(self):
2122
"""Get a KMS client from boto3."""
22-
if not self._kms_client:
23-
self._kms_client = boto3.client("kms")
23+
if not self._boto_session:
24+
session = aioboto3.Session()
25+
self._boto_session = session
2426

25-
return self._kms_client
27+
return self._boto_session
2628

27-
def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
29+
async def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
2830
"""Encrypt data using a key from KMS."""
2931
# The keys are rotated automatically by KMS, so fetch a new key to encrypt the data.
30-
data_key_encrypted, data_key_plaintext = self.__create_data_key(self._namespace)
32+
data_key_encrypted, data_key_plaintext = await self.__create_data_key(
33+
self._namespace
34+
)
3135

3236
if data_key_encrypted is None:
3337
raise ValueError("No data key!")
@@ -38,38 +42,42 @@ def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
3842
data_key_encrypted
3943
)
4044

41-
def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
45+
async def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
4246
"""Encrypt data using a key from KMS."""
4347
data_key_encrypted = base64.b64decode(data_key_encrypted_base64)
44-
data_key_plaintext = self.__decrypt_data_key(data_key_encrypted)
48+
data_key_plaintext = await self.__decrypt_data_key(data_key_encrypted)
4549
encryptor = AESGCM(data_key_plaintext)
4650
return encryptor.decrypt(data[:12], data[12:], None)
4751

48-
def __create_data_key(self, namespace: str):
52+
async def __create_data_key(self, namespace: str):
4953
"""Get a set of keys from AWS KMS that can be used to encrypt data."""
5054

5155
# Create data key
52-
alias_name = 'alias/' + namespace.replace('.', '_')
53-
response = self.kms_client.describe_key(KeyId=alias_name)
54-
cmk_id = response['KeyMetadata']['Arn']
55-
key_spec = "AES_256"
56-
try:
57-
response = self.kms_client.generate_data_key(KeyId=cmk_id, KeySpec=key_spec)
58-
except ClientError as e:
59-
logging.error(e)
60-
return None, None
61-
62-
# Return the encrypted and plaintext data key
63-
return response["CiphertextBlob"], response["Plaintext"]
64-
65-
def __decrypt_data_key(self, data_key_encrypted):
56+
alias_name = "alias/" + namespace.replace(".", "_")
57+
async with self.boto_session.client("kms") as kms_client:
58+
response = await kms_client.describe_key(KeyId=alias_name)
59+
cmk_id = response["KeyMetadata"]["Arn"]
60+
key_spec = "AES_256"
61+
try:
62+
response = await kms_client.generate_data_key(
63+
KeyId=cmk_id, KeySpec=key_spec
64+
)
65+
except ClientError as e:
66+
logging.error(e)
67+
return None, None
68+
69+
# Return the encrypted and plaintext data key
70+
return response["CiphertextBlob"], response["Plaintext"]
71+
72+
async def __decrypt_data_key(self, data_key_encrypted):
6673
"""Use AWS KMS to exchange an encrypted key for its plaintext value."""
6774

68-
# Decrypt the data key
69-
try:
70-
response = self.kms_client.decrypt(CiphertextBlob=data_key_encrypted)
71-
except ClientError as e:
72-
logging.error(e)
73-
return None
75+
async with self.boto_session.client("kms") as kms_client:
76+
# Decrypt the data key
77+
try:
78+
response = await kms_client.decrypt(CiphertextBlob=data_key_encrypted)
79+
except ClientError as e:
80+
logging.error(e)
81+
return None
7482

75-
return response["Plaintext"]
83+
return response["Plaintext"]

0 commit comments

Comments
 (0)