-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathcodec.py
37 lines (28 loc) · 1.4 KB
/
codec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Iterable, List
from temporalio.api.common.v1 import Payload
from temporalio.converter import PayloadCodec
from encryption_jwt.encryptor import KMSEncryptor
class EncryptionCodec(PayloadCodec):
def __init__(self, namespace: str):
self._encryptor = KMSEncryptor(namespace)
async def encode(self, payloads: Iterable[Payload]) -> List[Payload]:
# We blindly encode all payloads with the key and set the metadata with the key that was
# used (base64 encoded).
async def encrypt_payload(p: Payload):
data, key = await self._encryptor.encrypt(p.SerializeToString())
return Payload(
metadata={
"encoding": b"binary/encrypted",
"data_key_encrypted": key,
},
data=data,
)
# return list(map(encrypt_payload, payloads))
return [await encrypt_payload(payload) for payload in payloads]
async def decode(self, payloads: Iterable[Payload]) -> List[Payload]:
async def decrypt_payload(p: Payload):
data_key_encrypted_base64 = p.metadata.get("data_key_encrypted", b"")
data = await self._encryptor.decrypt(data_key_encrypted_base64, p.data)
return Payload.FromString(data)
# return list(map(decrypt_payload, payloads))
return [await decrypt_payload(payload) for payload in payloads]