Skip to content

Commit 2fb911a

Browse files
committed
feat: 添加启动绑定地址解析功能,支持从配置文件和环境变量迁移
1 parent d87e6ec commit 2fb911a

11 files changed

Lines changed: 437 additions & 41 deletions

File tree

bot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
load_dotenv(str(env_path), override=True)
3131
else:
3232
print("[WIP] no .env file found, and templates is not ready yet.")
33-
raise
33+
print("[WIP] continue startup, use environment and existing config values.")
3434
# try:
3535
# if template_env_path.exists():
3636
# shutil.copyfile(template_env_path, env_path)

dashboard/src/routes/config/modelProvider/ProviderForm.tsx

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useCallback, useMemo, useState } from 'react'
1+
import { useCallback, useEffect, useMemo, useState } from 'react'
22
import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react'
33

44
import { Button } from '@/components/ui/button'
@@ -42,8 +42,16 @@ export function ProviderForm({
4242
const [localProvider, setLocalProvider] = useState<APIProvider | null>(editingProvider)
4343
const { toast } = useToast()
4444

45-
// 同步外部状态到本地
46-
if (editingProvider !== localProvider && open) {
45+
// 当弹窗打开时,根据当前编辑对象同步一次本地编辑状态
46+
useEffect(() => {
47+
if (!open) {
48+
setLocalProvider(null)
49+
setFormErrors({})
50+
setShowApiKey(false)
51+
setSelectedTemplate('custom')
52+
return
53+
}
54+
4755
setLocalProvider(editingProvider)
4856
setFormErrors({})
4957
setShowApiKey(false)
@@ -57,7 +65,7 @@ export function ProviderForm({
5765
} else {
5866
setSelectedTemplate('custom')
5967
}
60-
}
68+
}, [open, editingProvider, editingIndex])
6169

6270
const isUsingTemplate = useMemo(() => selectedTemplate !== 'custom', [selectedTemplate])
6371

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from pathlib import Path
2+
from types import SimpleNamespace
3+
import sys
4+
5+
from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict
6+
from src.config.startup_bindings import (
7+
BindAddress,
8+
get_startup_main_bind_address,
9+
get_startup_webui_bind_address,
10+
resolve_main_bind_address,
11+
resolve_webui_bind_address,
12+
)
13+
14+
15+
def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path):
16+
missing_path = tmp_path / "missing_bot_config.toml"
17+
18+
assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080)
19+
assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001)
20+
21+
22+
def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path):
23+
config_path = tmp_path / "bot_config.toml"
24+
config_path.write_text(
25+
"""
26+
[inner]
27+
version = "8.3.1"
28+
29+
[maim_message]
30+
ws_server_host = "0.0.0.0"
31+
ws_server_port = 22345
32+
33+
[webui]
34+
host = "192.168.1.9"
35+
port = 18001
36+
""".strip(),
37+
encoding="utf-8",
38+
)
39+
40+
assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345)
41+
assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001)
42+
43+
44+
def test_resolve_bindings_prefer_initialized_global_config(monkeypatch):
45+
fake_config_module = SimpleNamespace(
46+
global_config=SimpleNamespace(
47+
maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000),
48+
webui=SimpleNamespace(host="10.0.0.3", port=32001),
49+
)
50+
)
51+
52+
monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module)
53+
54+
assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000)
55+
assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001)
56+
57+
58+
def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch):
59+
monkeypatch.setenv("HOST", "0.0.0.0")
60+
monkeypatch.setenv("PORT", "22345")
61+
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
62+
monkeypatch.setenv("WEBUI_PORT", "19001")
63+
64+
payload = {
65+
"maim_message": {
66+
"ws_server_host": "127.0.0.1",
67+
"ws_server_port": 8080,
68+
},
69+
"webui": {},
70+
}
71+
72+
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
73+
74+
assert result.migrated is True
75+
assert payload["maim_message"]["ws_server_host"] == "0.0.0.0"
76+
assert payload["maim_message"]["ws_server_port"] == 22345
77+
assert payload["webui"]["host"] == "192.168.1.8"
78+
assert payload["webui"]["port"] == 19001
79+
80+
81+
def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch):
82+
monkeypatch.setenv("HOST", "0.0.0.0")
83+
monkeypatch.setenv("PORT", "22345")
84+
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
85+
monkeypatch.setenv("WEBUI_PORT", "19001")
86+
87+
payload = {
88+
"maim_message": {
89+
"ws_server_host": "10.1.1.1",
90+
"ws_server_port": 30000,
91+
},
92+
"webui": {
93+
"host": "10.1.1.2",
94+
"port": 30001,
95+
},
96+
}
97+
98+
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
99+
100+
assert result.migrated is False
101+
assert payload["maim_message"]["ws_server_host"] == "10.1.1.1"
102+
assert payload["maim_message"]["ws_server_port"] == 30000
103+
assert payload["webui"]["host"] == "10.1.1.2"
104+
assert payload["webui"]["port"] == 30001

src/chat/emoji_system/emoji_manager.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,20 @@ def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
480480
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
481481
return False
482482

483-
# 将表情包移动到已注册目录
484483
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
484+
485+
# 先查库,避免重复记录导致文件被误移动后无法回收
486+
original_path = emoji.full_path
487+
try:
488+
with get_db_session() as session:
489+
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
490+
existing_record = session.exec(statement).first()
491+
if existing_record and not existing_record.no_file_flag:
492+
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
493+
return False
494+
except Exception as e:
495+
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
496+
return False
485497
try:
486498
emoji.full_path.replace(target_path)
487499
emoji.full_path = target_path
@@ -490,6 +502,7 @@ def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
490502
return False
491503

492504
# 注册到数据库
505+
restore_file = False
493506
try:
494507
with get_db_session() as session:
495508
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
@@ -509,6 +522,7 @@ def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
509522
)
510523
else:
511524
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
525+
restore_file = True
512526
return False
513527
else:
514528
image_record = emoji.to_db_instance()
@@ -521,7 +535,15 @@ def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
521535
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
522536
except Exception as e:
523537
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
538+
restore_file = True
524539
return False
540+
finally:
541+
if restore_file:
542+
try:
543+
emoji.full_path.replace(original_path)
544+
emoji.full_path = original_path
545+
except Exception as e:
546+
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
525547
return True
526548

527549
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
@@ -1045,7 +1067,13 @@ async def register_emoji_by_filename(self, filename: Path | str) -> bool:
10451067
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
10461068
return False
10471069

1048-
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
1070+
calc_success = await target_emoji.calculate_hash_format()
1071+
if not calc_success:
1072+
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
1073+
return False
1074+
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
1075+
1076+
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
10491077
try:
10501078
with get_db_session_manual() as session:
10511079
statement = (
@@ -1068,13 +1096,7 @@ async def register_emoji_by_filename(self, filename: Path | str) -> bool:
10681096
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
10691097
return False
10701098

1071-
# 1. 计算哈希值和格式
1072-
calc_success = await target_emoji.calculate_hash_format()
1073-
if not calc_success:
1074-
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
1075-
return False
1076-
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
1077-
# 2. 检查是否已经存在过
1099+
# 3. 检查内存缓存是否已经存在
10781100
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
10791101
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
10801102
return False

src/common/message_server/api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from maim_message import MessageServer
2-
1+
from importlib import metadata
32
import traceback
4-
import importlib.metadata
3+
4+
from maim_message import MessageServer
55

66
from src.common.logger import adopt_library_logger, get_logger
77
from src.common.utils.port_checker import assert_port_available
8-
from src.config.config import global_config
98
from .server import get_global_server
109

1110
global_api = None
@@ -14,10 +13,12 @@
1413

1514
def get_global_api() -> MessageServer: # sourcery skip: extract-method
1615
"""获取全局MessageServer实例"""
16+
from src.config.config import global_config
17+
1718
global global_api
1819
if global_api is None:
1920
# 检查maim_message版本
20-
maim_message_version = importlib.metadata.version("maim_message")
21+
maim_message_version = metadata.version("maim_message")
2122
version_int = [int(x) for x in maim_message_version.split(".")]
2223
if version_int < [0, 6, 2]:
2324
raise RuntimeError("maim_message 版本过低,请升级到 0.6.2 或更高版本。")

src/common/message_server/server.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
from typing import Optional
2+
13
import asyncio
24

3-
from fastapi import FastAPI, APIRouter
5+
from fastapi import APIRouter, FastAPI
46
from rich.traceback import install
5-
from typing import Optional
67
from uvicorn import Config, Server as UvicornServer
78

89
from src.common.logger import get_logger
910
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
11+
from src.config.startup_bindings import resolve_main_bind_address
1012

1113
install(extra_lines=3)
1214

@@ -21,7 +23,7 @@ def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_n
2123
self._server: Optional[UvicornServer] = None
2224
self.set_address(host, port)
2325

24-
def register_router(self, router: APIRouter, prefix: str = ""):
26+
def register_router(self, router: APIRouter, prefix: str = ""):
2527
"""注册路由
2628
2729
APIRouter 用于对相关的路由端点进行分组和模块化管理:
@@ -121,11 +123,8 @@ def get_app(self) -> FastAPI:
121123

122124
def get_global_server() -> Server:
123125
"""获取全局服务器实例"""
124-
from src.config.config import global_config
125-
126126
global global_server
127127
if global_server is None:
128-
global_server = Server(
129-
host=global_config.maim_message.ws_server_host, port=global_config.maim_message.ws_server_port
130-
)
128+
bind_address = resolve_main_bind_address()
129+
global_server = Server(host=bind_address.host, port=bind_address.port)
131130
return global_server

src/config/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .config_base import AttributeData, ConfigBase, Field
1313
from .config_utils import compare_versions, output_config_changes, recursive_parse_item_to_table
1414
from .file_watcher import FileChange, FileWatcher
15-
from .legacy_migration import try_migrate_legacy_bot_config_dict
15+
from .legacy_migration import migrate_legacy_bind_env_to_bot_config_dict, try_migrate_legacy_bot_config_dict
1616
from .model_configs import APIProvider, ModelInfo, ModelTaskConfig
1717
from .official_configs import (
1818
BotConfig,
@@ -55,7 +55,7 @@
5555
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
5656
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
5757
MMC_VERSION: str = "1.0.0"
58-
CONFIG_VERSION: str = "8.3.0"
58+
CONFIG_VERSION: str = "8.3.1"
5959
MODEL_CONFIG_VERSION: str = "1.13.1"
6060

6161
logger = get_logger("config")
@@ -472,6 +472,11 @@ def load_config_from_file(
472472
old_ver: str = inner_version
473473
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
474474
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
475+
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
476+
env_migration = migrate_legacy_bind_env_to_bot_config_dict(config_data)
477+
if env_migration.migrated:
478+
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
479+
config_data = env_migration.data
475480
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
476481
original_data: dict[str, Any] = copy.deepcopy(config_data)
477482
try:

0 commit comments

Comments
 (0)