Skip to content

Commit

Permalink
Add serverCertificateHashes test server
Browse files Browse the repository at this point in the history
Add a second webtransport server, in order to test connection with a
server, that has a self-signed certificate together with
serverCertificateHashes.

See web-platform-tests/rfcs#216
  • Loading branch information
jgraham committed Jan 24, 2025
1 parent d852a80 commit c75a377
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 12 deletions.
33 changes: 29 additions & 4 deletions tools/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import traceback
import urllib
import uuid
import datetime
from collections import defaultdict, OrderedDict
from io import IOBase
from itertools import chain, product
Expand Down Expand Up @@ -991,7 +992,7 @@ def start_servers(logger, host, ports, paths, routes, bind_address, config,
continue

# Skip WebTransport over HTTP/3 server unless if is enabled explicitly.
if scheme == 'webtransport-h3' and not kwargs.get("webtransport_h3"):
if scheme in ['webtransport-h3', 'webtransport-h3-cert-hash'] and not kwargs.get("webtransport_h3"):
continue

for port in ports:
Expand All @@ -1009,6 +1010,7 @@ def start_servers(logger, host, ports, paths, routes, bind_address, config,
"ws": start_ws_server,
"wss": start_wss_server,
"webtransport-h3": start_webtransport_h3_server,
"webtransport-h3-cert-hash": start_webtransport_h3_server_cert_hash,
}[scheme]

server_proc = ServerProc(mp_context, scheme=scheme)
Expand Down Expand Up @@ -1174,18 +1176,40 @@ def start_webtransport_h3_server(logger, host, port, paths, routes, bind_address
try:
# TODO(bashi): Move the following import to the beginning of this file
# once WebTransportH3Server is enabled by default.
from webtransport.h3.webtransport_h3_server import WebTransportH3Server # type: ignore
from webtransport.h3.webtransport_h3_server import WebTransportH3Server, WebTransportCertificateGeneration # type: ignore
return WebTransportH3Server(host=host,
port=port,
doc_root=paths["doc_root"],
cert_mode=WebTransportCertificateGeneration.USEPREGENERATED,
cert_path=config.ssl_config["cert_path"],
key_path=config.ssl_config["key_path"],
logger=logger)
logger=logger,
cert_hash_info=None
)
except Exception as error:
logger.critical(
f"Failed to start WebTransport over HTTP/3 server: {error}")
sys.exit(0)

def start_webtransport_h3_server_cert_hash(logger, host, port, paths, routes, bind_address, config, **kwargs):
try:
# TODO(bashi): Move the following import to the beginning of this file
# once WebTransportH3Server is enabled by default.
from webtransport.h3.webtransport_h3_server import WebTransportH3Server, WebTransportCertificateGeneration
return WebTransportH3Server(host=host,
port=port,
doc_root=paths["doc_root"],
cert_mode=WebTransportCertificateGeneration.GENERATEDVALIDSERVERCERTIFICATEHASHCERT,
cert_path=None,
key_path=None,
logger=logger,
cert_hash_info=config["cert_hash_info"]
)
except Exception as error:
logger.critical(
f"Failed to start WebTransport over HTTP/3 server with certificate hashes: {error}")
sys.exit(0)


def start(logger, config, routes, mp_context, log_handlers, **kwargs):
host = config["server_host"]
Expand Down Expand Up @@ -1249,6 +1273,7 @@ class ConfigBuilder(config.ConfigBuilder):
"ws": ["auto"],
"wss": ["auto"],
"webtransport-h3": ["auto"],
"webtransport-h3-cert-hash": ["auto"],
},
"check_subdomains": True,
"bind_address": True,
Expand Down Expand Up @@ -1372,7 +1397,7 @@ def get_parser():
parser.add_argument("--no-h2", action="store_false", dest="h2", default=None,
help="Disable the HTTP/2.0 server")
parser.add_argument("--webtransport-h3", action="store_true",
help="Enable WebTransport over HTTP/3 server")
help="Enable WebTransport over HTTP/3 servers")
parser.add_argument("--exit-after-start", action="store_true",
help="Exit after starting servers")
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
Expand Down
48 changes: 42 additions & 6 deletions tools/webtransport/h3/webtransport_h3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import sys
import threading
import traceback
from enum import IntEnum
from enum import IntEnum, Enum
from urllib.parse import urlparse
from typing import Any, Dict, List, Optional, Tuple, cast

from cryptography import x509
from cryptography.hazmat.primitives import serialization

# TODO(bashi): Remove import check suppressions once aioquic dependency is resolved.
from aioquic.buffer import Buffer # type: ignore
from aioquic.asyncio import QuicConnectionProtocol, serve # type: ignore
Expand All @@ -31,6 +34,7 @@
from tools import localpaths # noqa: F401
from wptserve import stash
from .capsule import H3Capsule, H3CapsuleDecoder, CapsuleType
from http.server import BaseHTTPRequestHandler, HTTPServer

"""
A WebTransport over HTTP/3 server for testing.
Expand Down Expand Up @@ -499,6 +503,16 @@ def add(self, ticket: SessionTicket) -> None:
def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)

class WebTransportCertificateGeneration(Enum):
"""
Specify, if the server should generate a certificate or use an existing certificate
USEPREGENERATED: use existing certificate
GENERATEDVALIDSERVERCERTIFICATEHASHCERT: generate a certificate compatible to server cert hashes
"""
USEPREGENERATED = 1,
GENERATEDVALIDSERVERCERTIFICATEHASHCERT = 2
# TODO add cases for invalid certificates


class WebTransportH3Server:
"""
Expand All @@ -507,18 +521,31 @@ class WebTransportH3Server:
:param host: Host from which to serve.
:param port: Port from which to serve.
:param doc_root: Document root for serving handlers.
:paran cert_mode: The used certificate mode can be
USEPREGENERATED or GENERATEDVALIDSERVERCERTIFICATEHASHCERT
:param cert_path: Path to certificate file to use.
:param key_path: Path to key file to use.
:param logger: a Logger object for this server.
"""

def __init__(self, host: str, port: int, doc_root: str, cert_path: str,
key_path: str, logger: Optional[logging.Logger]) -> None:
def __init__(self, host: str, port: int, doc_root: str, cert_mode: WebTransportCertificateGeneration,
cert_path: Optional[str], key_path: Optional[str], logger: Optional[logging.Logger],
cert_hash_info: Optional[Dict]) -> None:
self.host = host
self.port = port
self.doc_root = doc_root
self.cert_path = cert_path
self.key_path = key_path
if cert_path is not None:
self.cert_path = cert_path
if key_path is not None:
self.key_path = key_path
if cert_hash_info is not None:
self.cert_hash_info = cert_hash_info
self.cert_mode = cert_mode
if (cert_path is None or key_path is None) and cert_mode == WebTransportCertificateGeneration.USEPREGENERATED:
raise ValueError("Both cert_path and key_path must be provided, if cert_mode is USEPREGENERATED")
if (cert_hash_info is None and cert_mode == WebTransportCertificateGeneration.GENERATEDVALIDSERVERCERTIFICATEHASHCERT):
raise ValueError("cert_hash_info must be provided, if cert_mode is GENERATEDVALIDSERVERCERTIFICATEHASHCERT")

self.started = False
global _doc_root
_doc_root = self.doc_root
Expand Down Expand Up @@ -551,7 +578,16 @@ def _start_on_server_thread(self) -> None:
_logger.info("Starting WebTransport over HTTP/3 server on %s:%s",
self.host, self.port)

configuration.load_cert_chain(self.cert_path, self.key_path)
if self.cert_mode == WebTransportCertificateGeneration.USEPREGENERATED:
configuration.load_cert_chain(self.cert_path, self.key_path)
else: # GENERATEDVALIDSERVERCERTIFICATEHASHCERT case
configuration.private_key = serialization.load_pem_private_key(self.cert_hash_info["private_key"],
password=None
)
configuration.certificate = x509.load_pem_x509_certificate(self.cert_hash_info["certificate"])
configuration.certificate_chain = []



ticket_store = SessionTicketStore()

Expand Down
1 change: 1 addition & 0 deletions tools/webtransport/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
aioquic==1.2.0
cryptography
48 changes: 47 additions & 1 deletion tools/wptrunner/wptrunner/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
import socket
import sys
import time
import datetime
from typing import Optional

from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import NameOID
from cryptography import x509

import mozprocess
from mozlog import get_default_logger, handlers
from mozlog.structuredlog import StructuredLogger
Expand Down Expand Up @@ -46,6 +52,37 @@ def do_delayed_imports(logger, test_paths):
(", ".join(failed), serve_root))
sys.exit(1)

def generate_hash_certificate(host: str) -> str:
private_key = ec.generate_private_key(ec.SECP256R1())
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "DE"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Berlin"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Berlin"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Wpt tests"),
x509.NameAttribute(NameOID.COMMON_NAME, host),
])
now = datetime.datetime.now(datetime.timezone.utc)
certificate = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=13))
.sign(private_key, hashes.SHA256())
)
fingerprint = certificate.fingerprint(hashes.SHA256())
server_certificate_hash = ":".join(f"{byte:02x}" for byte in fingerprint)
return { "certificate": certificate.public_bytes(
encoding=serialization.Encoding.PEM
),
"private_key": private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()),
"hash": server_certificate_hash
}

def serve_path(test_paths):
return test_paths["/"].tests_path
Expand Down Expand Up @@ -150,7 +187,8 @@ def __enter__(self):
self.get_routes(),
mp_context=mpcontext.get_context(),
log_handlers=[server_log_handler],
webtransport_h3=self.enable_webtransport)
webtransport_h3=self.enable_webtransport,
webtransport_h3_cert_hash=self.enable_webtransport)

if self.options.get("supports_debugger") and self.debug_info and self.debug_info.interactive:
self._stack.enter_context(self.ignore_interrupts())
Expand Down Expand Up @@ -197,6 +235,7 @@ def build_config(self):
"wss": [8889],
"h2": [9000],
"webtransport-h3": [11000],
"webtransport-h3-cert-hash": [11001],
}
config.ports = ports

Expand All @@ -221,6 +260,8 @@ def build_config(self):
config.doc_root = serve_path(self.test_paths)
config.inject_script = self.inject_script

config.cert_hash_info = generate_hash_certificate(config.server_host)

if self.suppress_handler_traceback is not None:
config.logging["suppress_handler_traceback"] = self.suppress_handler_traceback

Expand Down Expand Up @@ -323,10 +364,15 @@ def test_servers(self):
for port, server in self.servers.get("webtransport-h3", []):
if not webtranport_h3_server_is_running(host, port, timeout=5):
pending.append((host, port))
for port, server in self.servers.get("webtransport-h3-cert-hash", []):
if not webtranport_h3_server_is_running(host, port, timeout=5):
pending.append((host, port))

for scheme, servers in self.servers.items():
if scheme == "webtransport-h3":
continue
if scheme == "webtransport-h3-cert-hash":
continue
for port, server in servers:
s = socket.socket()
s.settimeout(0.1)
Expand Down
1 change: 1 addition & 0 deletions tools/wptserve/wptserve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ConfigBuilder:

_default = {
"browser_host": "localhost",
"certificate_hash": {},
"alternate_hosts": {},
"doc_root": os.path.dirname("__file__"),
"server_host": None,
Expand Down
2 changes: 2 additions & 0 deletions tools/wptserve/wptserve/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ def config_replacement(match):
value = variables[field]
elif hasattr(SubFunctions, field):
value = getattr(SubFunctions, field)
elif field == "server_certificate_hash":
value = request.server.config["cert_hash_info"]["hash"]
elif field == "headers":
value = request.headers
elif field == "GET":
Expand Down
9 changes: 8 additions & 1 deletion webtransport/resources/webtransport-test-helpers.sub.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@

const HOST = get_host_info().ORIGINAL_HOST;
const PORT = '{{ports[webtransport-h3][0]}}';
const PORT_CERT_HASH = '{{ports[webtransport-h3-cert-hash][0]}}';
const BASE = `https://${HOST}:${PORT}`;
const BASE_CERT_HASH = `https://${HOST}:${PORT_CERT_HASH}`;

// Wait for the given number of milliseconds (ms).
function wait(ms) { return new Promise(res => step_timeout(res, ms)); }

// Create URL for WebTransport session.
function webtransport_url(handler) {
function webtransport_url(handler, options) {
if (options?.cert_hashes) {
return `${BASE_CERT_HASH}/webtransport/handlers/${handler}`;
}
return `${BASE}/webtransport/handlers/${handler}`;
}

const cert_hash = new Uint8Array('{{server_certificate_hash}}'.split(':').map((el) => parseInt(el, 16)));
const cert_hash_str = '{{server_certificate_hash}}'
// Converts WebTransport stream error code to HTTP/3 error code.
// https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#section-4.3
function webtransport_code_to_http_code(n) {
Expand Down

0 comments on commit c75a377

Please sign in to comment.