Skip to content

Commit 39f9550

Browse files
committed
fix(agent): 修复添加订阅时的用户名映射
1 parent 367ecaf commit 39f9550

2 files changed

Lines changed: 71 additions & 3 deletions

File tree

app/agent/tools/impl/add_subscribe.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""添加订阅工具"""
22

3-
from typing import Optional, Type, List
3+
from typing import List, Optional, Type
44

55
from pydantic import BaseModel, Field
66

77
from app.agent.tools.base import MoviePilotTool
88
from app.chain.subscribe import SubscribeChain
9+
from app.db.user_oper import UserOper
910
from app.log import logger
10-
from app.schemas.types import MediaType
11+
from app.schemas.types import MediaType, MessageChannel
1112

1213

1314
class AddSubscribeInput(BaseModel):
@@ -101,6 +102,36 @@ def get_tool_message(self, **kwargs) -> Optional[str]:
101102

102103
return message
103104

105+
async def _resolve_subscribe_username(self) -> Optional[str]:
106+
"""优先映射为系统用户名,未绑定时回退当前渠道用户名。"""
107+
resolved_username = self._username
108+
if not self._channel or not self._user_id:
109+
return resolved_username
110+
111+
try:
112+
channel = MessageChannel(self._channel)
113+
except ValueError:
114+
return resolved_username
115+
116+
binding_keys = {
117+
MessageChannel.Telegram: ("telegram_userid",),
118+
MessageChannel.Discord: ("discord_userid",),
119+
MessageChannel.Wechat: ("wechat_userid",),
120+
MessageChannel.Slack: ("slack_userid",),
121+
MessageChannel.VoceChat: ("vocechat_userid",),
122+
MessageChannel.SynologyChat: ("synologychat_userid",),
123+
MessageChannel.QQ: ("qq_userid", "qq_openid"),
124+
}.get(channel)
125+
if not binding_keys:
126+
return resolved_username
127+
128+
mapped_username = await self.run_blocking(
129+
"db",
130+
UserOper().get_name,
131+
**{key: self._user_id for key in binding_keys},
132+
)
133+
return mapped_username or resolved_username
134+
104135
async def run(
105136
self,
106137
title: str,
@@ -137,6 +168,7 @@ async def run(
137168
if media_type_enum == MediaType.TV
138169
else None
139170
)
171+
subscribe_username = await self._resolve_subscribe_username()
140172

141173
# 构建额外的订阅参数
142174
subscribe_kwargs = {}
@@ -162,7 +194,7 @@ async def run(
162194
tmdbid=tmdb_id,
163195
doubanid=douban_id,
164196
season=season,
165-
username=self._user_id,
197+
username=subscribe_username,
166198
**subscribe_kwargs,
167199
)
168200
if sid:

tests/test_agent_add_subscribe_tool.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@
33
from unittest.mock import AsyncMock, patch
44

55
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
6+
from app.schemas.types import MessageChannel
67

78

89
class TestAgentAddSubscribeTool(unittest.TestCase):
910
def test_tv_subscription_without_season_reports_default_first_season(self):
1011
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
12+
tool.set_message_attr(
13+
channel=MessageChannel.Telegram.value,
14+
source="telegram-main",
15+
username="tg_display_name",
16+
)
1117

1218
with patch(
1319
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
1420
new=AsyncMock(return_value=(1, "")),
21+
) as async_add, patch(
22+
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
23+
return_value="moviepilot-user",
1524
):
1625
result = asyncio.run(
1726
tool.run(
@@ -21,9 +30,36 @@ def test_tv_subscription_without_season_reports_default_first_season(self):
2130
)
2231
)
2332

33+
self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user")
2434
self.assertIn("第1季", result)
2535
self.assertIn("默认按第一季订阅", result)
2636

37+
def test_subscription_falls_back_to_channel_username_when_no_binding_exists(self):
38+
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
39+
tool.set_message_attr(
40+
channel=MessageChannel.Telegram.value,
41+
source="telegram-main",
42+
username="tg_display_name",
43+
)
44+
45+
with patch(
46+
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
47+
new=AsyncMock(return_value=(1, "")),
48+
) as async_add, patch(
49+
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
50+
return_value=None,
51+
):
52+
result = asyncio.run(
53+
tool.run(
54+
title="The Matrix",
55+
year="1999",
56+
media_type="movie",
57+
)
58+
)
59+
60+
self.assertEqual(async_add.await_args.kwargs["username"], "tg_display_name")
61+
self.assertIn("成功添加订阅:The Matrix (1999)", result)
62+
2763

2864
if __name__ == "__main__":
2965
unittest.main()

0 commit comments

Comments
 (0)