|
1 |
| -import typing |
| 1 | +from typing import Optional, Sequence, Literal |
2 | 2 |
|
3 |
| -from urllib.parse import urlparse |
| 3 | +from mods.origins import compute_local_origins, normalize_origins |
4 | 4 | from starlette.datastructures import Headers
|
5 | 5 | from starlette.responses import PlainTextResponse
|
6 | 6 | from starlette.types import ASGIApp, Receive, Scope, Send
|
7 | 7 |
|
8 |
| -ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." |
9 |
| - |
10 | 8 |
|
11 | 9 | class TrustedOriginMiddleware:
|
12 | 10 | def __init__(
|
13 | 11 | self,
|
14 | 12 | app: ASGIApp,
|
15 |
| - allowed_origins: typing.Optional[typing.Sequence[str]] = None, |
16 |
| - port: typing.Optional[int] = None, |
| 13 | + allowed_origins: Optional[Sequence[str]] = None, |
| 14 | + port: Optional[int] = None, |
17 | 15 | ) -> None:
|
18 |
| - schemas = ['http', 'https'] |
19 |
| - local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']] |
20 |
| - if port is not None: |
21 |
| - local_origins = [f'{origin}:{port}' for origin in local_origins] |
22 |
| - |
23 | 16 | self.allowed_origins: set[str] = set()
|
24 |
| - if allowed_origins is not None: |
25 |
| - for origin in allowed_origins: |
26 |
| - url = urlparse(origin) |
27 |
| - assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT |
28 |
| - valid_origin = f'{url.scheme}://{url.hostname}' |
29 |
| - if url.port: |
30 |
| - valid_origin += f':{url.port}' |
31 |
| - self.allowed_origins.add(valid_origin) |
| 17 | + |
| 18 | + local_origins = compute_local_origins(port) |
32 | 19 | self.allowed_origins.update(local_origins)
|
| 20 | + |
| 21 | + if allowed_origins is not None: |
| 22 | + normalized_origins = normalize_origins(allowed_origins) |
| 23 | + self.allowed_origins.update(normalized_origins) |
| 24 | + |
33 | 25 | self.app = app
|
34 | 26 |
|
35 | 27 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
0 commit comments