Skip to content

Commit cf2b693

Browse files
committed
Harden web server security
1 parent 11672e9 commit cf2b693

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

server/MMVCServerSIO.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def setupArgParser():
6565
parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
6666
parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
6767

68+
parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.")
69+
parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.")
70+
6871
return parser
6972

7073

@@ -114,16 +117,19 @@ def printMessage(message, level=0):
114117

115118
printMessage(f"Booting PHASE :{__name__}", level=2)
116119

120+
HOST = args.host
117121
PORT = args.p
118122

119123

120-
def localServer(logLevel: str = "critical"):
124+
def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None):
121125
try:
122126
uvicorn.run(
123127
f"{os.path.basename(__file__)[:-3]}:app_socketio",
124-
host="0.0.0.0",
128+
host=HOST,
125129
port=int(PORT),
126130
reload=False if hasattr(sys, "_MEIPASS") else True,
131+
ssl_keyfile=key_path,
132+
ssl_certfile=cert_path,
127133
log_level=logLevel,
128134
)
129135
except Exception as e:
@@ -134,7 +140,7 @@ def localServer(logLevel: str = "critical"):
134140
mp.freeze_support()
135141

136142
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
137-
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams)
143+
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
138144
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
139145

140146

@@ -220,34 +226,26 @@ def localServer(logLevel: str = "critical"):
220226
printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
221227
if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker)
222228
if args.https == 1:
223-
printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1)
229+
printMessage(f"https://localhost:{EX_PORT}/", level=1)
224230
for ip in EX_IP.strip().split(" "):
225231
printMessage(f"https://{ip}:{EX_PORT}/", level=1)
226232
else:
227-
printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1)
233+
printMessage(f"http://localhost:{EX_PORT}/", level=1)
228234
else: # 直接python起動
229235
if args.https == 1:
230236
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
231237
s.connect((args.test_connect, 80))
232238
hostname = s.getsockname()[0]
233-
printMessage(f"https://127.0.0.1:{PORT}/", level=1)
239+
printMessage(f"https://localhost:{PORT}/", level=1)
234240
printMessage(f"https://{hostname}:{PORT}/", level=1)
235241
else:
236-
printMessage(f"http://127.0.0.1:{PORT}/", level=1)
242+
printMessage(f"http://localhost:{PORT}/", level=1)
237243

238244
# サーバ起動
239245
if args.https:
240246
# HTTPS サーバ起動
241247
try:
242-
uvicorn.run(
243-
f"{os.path.basename(__file__)[:-3]}:app_socketio",
244-
host="0.0.0.0",
245-
port=int(PORT),
246-
reload=False if hasattr(sys, "_MEIPASS") else True,
247-
ssl_keyfile=key_path,
248-
ssl_certfile=cert_path,
249-
log_level=args.logLevel,
250-
)
248+
localServer(args.logLevel, key_path, cert_path)
251249
except Exception as e:
252250
logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
253251

@@ -256,12 +254,12 @@ def localServer(logLevel: str = "critical"):
256254
p.start()
257255
try:
258256
if sys.platform.startswith("win"):
259-
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
257+
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
260258
return_code = process.wait()
261259
logger.info("client closed.")
262260
p.terminate()
263261
elif sys.platform.startswith("darwin"):
264-
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
262+
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
265263
return_code = process.wait()
266264
logger.info("client closed.")
267265
p.terminate()

server/restapi/MMVC_Rest.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import sys
33

4+
from restapi.mods.trustedorigin import TrustedOriginMiddleware
45
from fastapi import FastAPI, Request, Response, HTTPException
56
from fastapi.routing import APIRoute
6-
from fastapi.middleware.cors import CORSMiddleware
77
from fastapi.staticfiles import StaticFiles
88
from fastapi.exceptions import RequestValidationError
99
from typing import Callable
@@ -43,17 +43,17 @@ def get_instance(
4343
cls,
4444
voiceChangerManager: VoiceChangerManager,
4545
voiceChangerParams: VoiceChangerParams,
46+
port: int,
47+
allowedOrigins: list[str],
4648
):
4749
if cls._instance is None:
4850
logger.info("[Voice Changer] MMVC_Rest initializing...")
4951
app_fastapi = FastAPI()
5052
app_fastapi.router.route_class = ValidationErrorLoggingRoute
5153
app_fastapi.add_middleware(
52-
CORSMiddleware,
53-
allow_origins=["*"],
54-
allow_credentials=True,
55-
allow_methods=["*"],
56-
allow_headers=["*"],
54+
TrustedOriginMiddleware,
55+
allowed_origins=allowedOrigins,
56+
port=port
5757
)
5858

5959
app_fastapi.mount(

server/restapi/mods/trustedorigin.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import typing
2+
3+
from urllib.parse import urlparse
4+
from starlette.datastructures import Headers
5+
from starlette.responses import PlainTextResponse
6+
from starlette.types import ASGIApp, Receive, Scope, Send
7+
8+
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
9+
10+
11+
class TrustedOriginMiddleware:
12+
def __init__(
13+
self,
14+
app: ASGIApp,
15+
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
16+
port: typing.Optional[int] = None,
17+
) -> None:
18+
schemas = ['http', 'https']
19+
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
20+
if port is not None:
21+
local_origins = [f'{origin}:{port}' for origin in local_origins]
22+
23+
if not allowed_origins:
24+
allowed_origins = local_origins
25+
else:
26+
for origin in allowed_origins:
27+
assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT
28+
allowed_origins = local_origins + allowed_origins
29+
30+
self.app = app
31+
self.allowed_origins = list(allowed_origins)
32+
33+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
34+
if scope["type"] not in (
35+
"http",
36+
"websocket",
37+
): # pragma: no cover
38+
await self.app(scope, receive, send)
39+
return
40+
41+
headers = Headers(scope=scope)
42+
origin = headers.get("origin", "")
43+
# Origin header is not present for same origin
44+
if not origin or origin in self.allowed_origins:
45+
await self.app(scope, receive, send)
46+
return
47+
48+
response = PlainTextResponse("Invalid origin header", status_code=400)
49+
await response(scope, receive, send)

0 commit comments

Comments
 (0)