Skip to content

Commit 16bccc2

Browse files
feat(ra-tls): add functions for attestation-based TLS certs
1 parent c665e92 commit 16bccc2

File tree

2 files changed

+237
-1
lines changed

2 files changed

+237
-1
lines changed

src/flare_ai_kit/tee/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from .attestation import VtpmAttestation
2+
from .ra_tls import create_ssl_context, generate_self_signed_cert
23
from .validation import VtpmValidation
34

4-
__all__ = ["VtpmAttestation", "VtpmValidation"]
5+
__all__ = [
6+
"VtpmAttestation",
7+
"VtpmValidation",
8+
"create_ssl_context",
9+
"generate_self_signed_cert",
10+
]

src/flare_ai_kit/tee/ra_tls.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# flare-ai-kit/tee/ra-tls.py
2+
import datetime
3+
import os
4+
import ssl
5+
import tempfile
6+
7+
from cryptography import x509
8+
from cryptography.hazmat.primitives import hashes, serialization
9+
from cryptography.hazmat.primitives.asymmetric import rsa
10+
from cryptography.x509.oid import NameOID
11+
12+
# Custom OID for the attestation token extension
13+
# ATTESTATION_TOKEN_OID = ObjectIdentifier("1.3.6.1.4.1.9999.1.1") # Registered or private OID
14+
15+
16+
def create_attestation_extension(attestation_token: str) -> x509.SubjectAlternativeName:
17+
"""
18+
Create a SubjectAlternativeName extension containing the attestation token as a URI.
19+
20+
Args:
21+
attestation_token: The attestation token to embed.
22+
23+
Returns:
24+
x509.SubjectAlternativeName: The extension with the encoded token as a URI.
25+
26+
"""
27+
token_uri = f"attestation:{attestation_token}"
28+
return x509.SubjectAlternativeName([x509.UniformResourceIdentifier(token_uri)])
29+
30+
31+
def generate_key_and_csr(
32+
attestation_token: str, common_name: str = "localhost"
33+
) -> tuple[bytes, bytes]:
34+
"""
35+
Generate a new private key and CSR with an attestation token extension, all in memory.
36+
37+
Args:
38+
attestation_token: The attestation token to embed in the CSR.
39+
common_name: The Common Name (CN) for the certificate subject.
40+
41+
Returns:
42+
tuple[bytes, bytes]: PEM-encoded private key and CSR.
43+
44+
"""
45+
# Generate private key
46+
private_key = rsa.generate_private_key(
47+
public_exponent=65537,
48+
key_size=2048,
49+
)
50+
51+
# Create subject name
52+
subject = x509.Name(
53+
[
54+
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
55+
]
56+
)
57+
58+
# Build CSR
59+
csr = (
60+
x509.CertificateSigningRequestBuilder()
61+
.subject_name(subject)
62+
.add_extension(create_attestation_extension(attestation_token), critical=False)
63+
.sign(private_key, hashes.SHA256())
64+
)
65+
66+
# Serialize to PEM
67+
key_pem = private_key.private_bytes(
68+
encoding=serialization.Encoding.PEM,
69+
format=serialization.PrivateFormat.TraditionalOpenSSL,
70+
encryption_algorithm=serialization.NoEncryption(),
71+
)
72+
csr_pem = csr.public_bytes(serialization.Encoding.PEM)
73+
74+
return key_pem, csr_pem
75+
76+
77+
def generate_self_signed_cert(token, common_name, days_valid):
78+
# logger.info("Generating self-signed certificate with SAN")
79+
try:
80+
# Generate private key
81+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
82+
public_key = private_key.public_key()
83+
84+
# Create certificate builder
85+
builder = x509.CertificateBuilder()
86+
builder = builder.subject_name(
87+
x509.Name(
88+
[
89+
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
90+
]
91+
)
92+
)
93+
builder = builder.issuer_name(
94+
x509.Name(
95+
[
96+
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
97+
]
98+
)
99+
)
100+
builder = builder.not_valid_before(datetime.datetime.utcnow())
101+
builder = builder.not_valid_after(
102+
datetime.datetime.utcnow() + datetime.timedelta(days=days_valid)
103+
)
104+
builder = builder.serial_number(x509.random_serial_number())
105+
builder = builder.public_key(public_key)
106+
107+
# Add SAN for localhost
108+
builder = builder.add_extension(
109+
x509.SubjectAlternativeName([x509.DNSName("localhost")]),
110+
critical=False,
111+
)
112+
113+
# Sign certificate
114+
certificate = builder.sign(private_key=private_key, algorithm=hashes.SHA256())
115+
116+
# Serialize to PEM
117+
cert_pem = certificate.public_bytes(serialization.Encoding.PEM)
118+
key_pem = private_key.private_bytes(
119+
encoding=serialization.Encoding.PEM,
120+
format=serialization.PrivateFormat.TraditionalOpenSSL,
121+
encryption_algorithm=serialization.NoEncryption(),
122+
)
123+
# logger.info("Certificate generated with SAN: DNS=localhost")
124+
return key_pem, cert_pem
125+
except Exception:
126+
# logger.error("Failed to generate certificate", exc_info=e)
127+
raise
128+
129+
130+
def generate_self_signed_cert_OLD(
131+
attestation_token: str, common_name: str = "localhost", days_valid: int = 365
132+
) -> tuple[bytes, bytes]:
133+
"""
134+
Generate a new private key and self-signed certificate with an attestation token extension, all in memory.
135+
136+
Args:
137+
attestation_token: The attestation token to embed in the certificate.
138+
common_name: The Common Name (CN) for the certificate subject.
139+
days_valid: Validity period of the certificate in days.
140+
141+
Returns:
142+
tuple[bytes, bytes]: PEM-encoded private key and certificate.
143+
144+
"""
145+
# Generate private key
146+
private_key = rsa.generate_private_key(
147+
public_exponent=65537,
148+
key_size=2048,
149+
)
150+
151+
# Create subject and issuer (same for self-signed)
152+
subject = issuer = x509.Name(
153+
[
154+
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
155+
]
156+
)
157+
158+
# Build certificate
159+
cert = (
160+
x509.CertificateBuilder()
161+
.subject_name(subject)
162+
.issuer_name(issuer)
163+
.public_key(private_key.public_key())
164+
.serial_number(x509.random_serial_number())
165+
.not_valid_before(datetime.datetime.utcnow())
166+
.not_valid_after(
167+
datetime.datetime.utcnow() + datetime.timedelta(days=days_valid)
168+
)
169+
.add_extension(
170+
x509.SubjectAlternativeName([x509.DNSName("localhost")]), critical=False
171+
)
172+
.add_extension(create_attestation_extension(attestation_token), critical=False)
173+
.sign(private_key, hashes.SHA256())
174+
)
175+
176+
# Serialize to PEM
177+
key_pem = private_key.private_bytes(
178+
encoding=serialization.Encoding.PEM,
179+
format=serialization.PrivateFormat.TraditionalOpenSSL,
180+
encryption_algorithm=serialization.NoEncryption(),
181+
)
182+
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
183+
184+
return key_pem, cert_pem
185+
186+
187+
def create_ssl_context(
188+
attestation_token: str, common_name: str = "localhost", days_valid: int = 365
189+
) -> ssl.SSLContext:
190+
"""
191+
Create an SSLContext with an in-memory self-signed certificate containing the attestation token.
192+
193+
Args:
194+
attestation_token: The attestation token to embed.
195+
common_name: The Common Name (CN) for the certificate.
196+
days_valid: Validity period of the certificate in days.
197+
198+
Returns:
199+
ssl.SSLContext: Configured SSL context for server-side TLS.
200+
201+
"""
202+
key_pem, cert_pem = generate_self_signed_cert(
203+
attestation_token, common_name, days_valid
204+
)
205+
206+
# Create SSL context
207+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
208+
context.minimum_version = ssl.TLSVersion.TLSv1_2
209+
context.maximum_version = ssl.TLSVersion.TLSv1_3
210+
211+
# Use temporary files to load certificate and key
212+
cert_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem")
213+
key_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem")
214+
215+
try:
216+
cert_file.write(cert_pem)
217+
cert_file.flush()
218+
key_file.write(key_pem)
219+
key_file.flush()
220+
221+
# Load certificate and key into the SSL context
222+
context.load_cert_chain(certfile=cert_file.name, keyfile=key_file.name)
223+
finally:
224+
# Clean up temporary files
225+
cert_file.close()
226+
key_file.close()
227+
os.unlink(cert_file.name)
228+
os.unlink(key_file.name)
229+
230+
return context

0 commit comments

Comments
 (0)