-
-
Notifications
You must be signed in to change notification settings - Fork 883
/
Copy pathconftest.py
295 lines (241 loc) · 8.63 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import asyncio
import json
import os
import threading
import time
import typing
import pytest
import trustme
from blockbuster import blockbuster_ctx
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import (
BestAvailableEncryption,
Encoding,
PrivateFormat,
load_pem_private_key,
)
from uvicorn.config import Config
from uvicorn.server import Server
import httpx
from tests.concurrency import sleep
ENVIRONMENT_VARIABLES = {
"SSL_CERT_FILE",
"SSL_CERT_DIR",
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
"NO_PROXY",
"SSLKEYLOGFILE",
}
@pytest.fixture(autouse=True)
def blockbuster():
with blockbuster_ctx() as bb:
bb.functions["os.stat"].can_block_in("/mimetypes.py", "init")
yield bb
@pytest.fixture(scope="function", autouse=True)
def clean_environ():
"""Keeps os.environ clean for every test without having to mock os.environ"""
original_environ = os.environ.copy()
os.environ.clear()
os.environ.update(
{
k: v
for k, v in original_environ.items()
if k not in ENVIRONMENT_VARIABLES and k.lower() not in ENVIRONMENT_VARIABLES
}
)
yield
os.environ.clear()
os.environ.update(original_environ)
Message = typing.Dict[str, typing.Any]
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
]
Scope = typing.Dict[str, typing.Any]
async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"
if scope["path"].startswith("/slow_response"):
await slow_response(scope, receive, send)
elif scope["path"].startswith("/status"):
await status_code(scope, receive, send)
elif scope["path"].startswith("/echo_body"):
await echo_body(scope, receive, send)
elif scope["path"].startswith("/echo_binary"):
await echo_binary(scope, receive, send)
elif scope["path"].startswith("/echo_headers"):
await echo_headers(scope, receive, send)
elif scope["path"].startswith("/redirect_301"):
await redirect_301(scope, receive, send)
elif scope["path"].startswith("/json"):
await hello_world_json(scope, receive, send)
else:
await hello_world(scope, receive, send)
async def hello_world(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
async def hello_world_json(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"application/json"]],
}
)
await send({"type": "http.response.body", "body": b'{"Hello": "world!"}'})
async def slow_response(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await sleep(1.0) # Allow triggering a read timeout.
await send({"type": "http.response.body", "body": b"Hello, world!"})
async def status_code(scope: Scope, receive: Receive, send: Send) -> None:
status_code = int(scope["path"].replace("/status/", ""))
await send(
{
"type": "http.response.start",
"status": status_code,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
async def echo_body(scope: Scope, receive: Receive, send: Send) -> None:
body = b""
more_body = True
while more_body:
message = await receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": body})
async def echo_binary(scope: Scope, receive: Receive, send: Send) -> None:
body = b""
more_body = True
while more_body:
message = await receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"application/octet-stream"]],
}
)
await send({"type": "http.response.body", "body": body})
async def echo_headers(scope: Scope, receive: Receive, send: Send) -> None:
body = {
name.capitalize().decode(): value.decode()
for name, value in scope.get("headers", [])
}
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"application/json"]],
}
)
await send({"type": "http.response.body", "body": json.dumps(body).encode()})
async def redirect_301(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{"type": "http.response.start", "status": 301, "headers": [[b"location", b"/"]]}
)
await send({"type": "http.response.body"})
@pytest.fixture(scope="session")
def cert_authority():
return trustme.CA()
@pytest.fixture(scope="session")
def localhost_cert(cert_authority):
return cert_authority.issue_cert("localhost")
@pytest.fixture(scope="session")
def cert_pem_file(localhost_cert):
with localhost_cert.cert_chain_pems[0].tempfile() as tmp:
yield tmp
@pytest.fixture(scope="session")
def cert_private_key_file(localhost_cert):
with localhost_cert.private_key_pem.tempfile() as tmp:
yield tmp
@pytest.fixture(scope="session")
def cert_encrypted_private_key_file(localhost_cert):
# Deserialize the private key and then reserialize with a password
private_key = load_pem_private_key(
localhost_cert.private_key_pem.bytes(), password=None, backend=default_backend()
)
encrypted_private_key_pem = trustme.Blob(
private_key.private_bytes(
Encoding.PEM,
PrivateFormat.TraditionalOpenSSL,
BestAvailableEncryption(password=b"password"),
)
)
with encrypted_private_key_pem.tempfile() as tmp:
yield tmp
class TestServer(Server):
@property
def url(self) -> httpx.URL:
protocol = "https" if self.config.is_ssl else "http"
return httpx.URL(f"{protocol}://{self.config.host}:{self.config.port}/")
def install_signal_handlers(self) -> None:
# Disable the default installation of handlers for signals such as SIGTERM,
# because it can only be done in the main thread.
pass # pragma: nocover
async def serve(self, sockets=None):
self.restart_requested = asyncio.Event()
loop = asyncio.get_event_loop()
tasks = {
loop.create_task(super().serve(sockets=sockets)),
loop.create_task(self.watch_restarts()),
}
await asyncio.wait(tasks)
async def restart(self) -> None: # pragma: no cover
# This coroutine may be called from a different thread than the one the
# server is running on, and from an async environment that's not asyncio.
# For this reason, we use an event to coordinate with the server
# instead of calling shutdown()/startup() directly, and should not make
# any asyncio-specific operations.
self.started = False
self.restart_requested.set()
while not self.started:
await sleep(0.2)
async def watch_restarts(self) -> None: # pragma: no cover
while True:
if self.should_exit:
return
try:
await asyncio.wait_for(self.restart_requested.wait(), timeout=0.1)
except asyncio.TimeoutError:
continue
self.restart_requested.clear()
await self.shutdown()
await self.startup()
def serve_in_thread(server: TestServer) -> typing.Iterator[TestServer]:
thread = threading.Thread(target=server.run)
thread.start()
try:
while not server.started:
time.sleep(1e-3)
yield server
finally:
server.should_exit = True
thread.join()
@pytest.fixture(scope="session")
def server() -> typing.Iterator[TestServer]:
config = Config(app=app, lifespan="off", loop="asyncio")
server = TestServer(config=config)
yield from serve_in_thread(server)