Skip to content

Commit 8dd8d71

Browse files
committed
Refactor and add origin check to SIO
1 parent ce9b599 commit 8dd8d71

File tree

6 files changed

+64
-28
lines changed

6 files changed

+64
-28
lines changed

server/MMVCServerSIO.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ def localServer(logLevel: str = "critical", key_path: str | None = None, cert_pa
140140
mp.freeze_support()
141141

142142
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
143-
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
144-
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
143+
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
144+
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)
145145

146146

147147
if __name__ == "__mp_main__":

server/mods/origins.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Optional, Sequence
2+
from urllib.parse import urlparse
3+
4+
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
5+
SCHEMAS = ('http', 'https')
6+
LOCAL_ORIGINS = ('127.0.0.1', 'localhost')
7+
8+
def compute_local_origins(port: Optional[int] = None) -> list[str]:
9+
local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
10+
if port is not None:
11+
local_origins = [f'{origin}:{port}' for origin in local_origins]
12+
return local_origins
13+
14+
15+
def normalize_origins(origins: Sequence[str]) -> set[str]:
16+
allowed_origins = set()
17+
for origin in origins:
18+
url = urlparse(origin)
19+
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
20+
valid_origin = f'{url.scheme}://{url.hostname}'
21+
if url.port:
22+
valid_origin += f':{url.port}'
23+
allowed_origins.add(valid_origin)
24+
return allowed_origins

server/restapi/MMVC_Rest.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi.routing import APIRoute
77
from fastapi.staticfiles import StaticFiles
88
from fastapi.exceptions import RequestValidationError
9-
from typing import Callable
9+
from typing import Callable, Optional, Sequence, Literal
1010
from mods.log_control import VoiceChangaerLogger
1111
from voice_changer.VoiceChangerManager import VoiceChangerManager
1212

@@ -43,8 +43,8 @@ def get_instance(
4343
cls,
4444
voiceChangerManager: VoiceChangerManager,
4545
voiceChangerParams: VoiceChangerParams,
46-
port: int,
47-
allowedOrigins: list[str],
46+
allowedOrigins: Optional[Sequence[str]] = None,
47+
port: Optional[int] = None,
4848
):
4949
if cls._instance is None:
5050
logger.info("[Voice Changer] MMVC_Rest initializing...")

server/restapi/mods/trustedorigin.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,27 @@
1-
import typing
1+
from typing import Optional, Sequence, Literal
22

3-
from urllib.parse import urlparse
3+
from mods.origins import compute_local_origins, normalize_origins
44
from starlette.datastructures import Headers
55
from starlette.responses import PlainTextResponse
66
from starlette.types import ASGIApp, Receive, Scope, Send
77

8-
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
9-
108

119
class TrustedOriginMiddleware:
1210
def __init__(
1311
self,
1412
app: ASGIApp,
15-
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
16-
port: typing.Optional[int] = None,
13+
allowed_origins: Optional[Sequence[str]] = None,
14+
port: Optional[int] = None,
1715
) -> None:
18-
schemas = ['http', 'https']
19-
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
20-
if port is not None:
21-
local_origins = [f'{origin}:{port}' for origin in local_origins]
22-
2316
self.allowed_origins: set[str] = set()
24-
if allowed_origins is not None:
25-
for origin in allowed_origins:
26-
url = urlparse(origin)
27-
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
28-
valid_origin = f'{url.scheme}://{url.hostname}'
29-
if url.port:
30-
valid_origin += f':{url.port}'
31-
self.allowed_origins.add(valid_origin)
17+
18+
local_origins = compute_local_origins(port)
3219
self.allowed_origins.update(local_origins)
20+
21+
if allowed_origins is not None:
22+
normalized_origins = normalize_origins(allowed_origins)
23+
self.allowed_origins.update(normalized_origins)
24+
3325
self.app = app
3426

3527
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

server/sio/MMVC_SocketIOApp.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import socketio
22
from mods.log_control import VoiceChangaerLogger
3+
from mods.origins import compute_local_origins, normalize_origins
34

5+
from typing import Sequence, Optional
46
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
57
from voice_changer.VoiceChangerManager import VoiceChangerManager
68
from const import getFrontendPath
@@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
1214
_instance: socketio.ASGIApp | None = None
1315

1416
@classmethod
15-
def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
17+
def get_instance(
18+
cls,
19+
app_fastapi,
20+
voiceChangerManager: VoiceChangerManager,
21+
allowedOrigins: Optional[Sequence[str]] = None,
22+
port: Optional[int] = None,
23+
):
1624
if cls._instance is None:
1725
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
18-
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)
26+
27+
allowed_origins: set[str] = set()
28+
local_origins = compute_local_origins(port)
29+
allowed_origins.update(local_origins)
30+
if allowedOrigins is not None:
31+
normalized_origins = normalize_origins(allowedOrigins)
32+
allowed_origins.update(normalized_origins)
33+
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))
34+
1935
app_socketio = socketio.ASGIApp(
2036
sio,
2137
other_asgi_app=app_fastapi,

server/sio/MMVC_SocketIOServer.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
88
_instance: socketio.AsyncServer | None = None
99

1010
@classmethod
11-
def get_instance(cls, voiceChangerManager: VoiceChangerManager):
11+
def get_instance(
12+
cls,
13+
voiceChangerManager: VoiceChangerManager,
14+
allowedOrigins: list[str],
15+
):
1216
if cls._instance is None:
13-
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
17+
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
1418
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
1519
sio.register_namespace(namespace)
1620
cls._instance = sio

0 commit comments

Comments
 (0)