Skip to content

Commit 46ac71a

Browse files
committed
feat: add ChatUI project workspaces
1 parent 758e432 commit 46ac71a

35 files changed

Lines changed: 1812 additions & 455 deletions

astrbot/core/astr_main_agent.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
3030
)
3131
from astrbot.core.conversation_mgr import Conversation
32+
from astrbot.core.db import BaseDatabase
3233
from astrbot.core.message.components import File, Image, Record, Reply, Video
3334
from astrbot.core.persona_error_reply import (
3435
extract_persona_custom_error_message_from_persona,
@@ -73,7 +74,6 @@
7374
RollbackSkillReleaseTool,
7475
RunBrowserSkillTool,
7576
SyncSkillReleaseTool,
76-
normalize_umo_for_workspace,
7777
)
7878
from astrbot.core.tools.cron_tools import FutureTaskTool
7979
from astrbot.core.tools.knowledge_base_tools import (
@@ -115,6 +115,10 @@
115115
extract_quoted_message_text,
116116
)
117117
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
118+
from astrbot.core.workspace import (
119+
normalize_umo_for_workspace,
120+
resolve_workspace_root_for_umo,
121+
)
118122

119123
LLM_ERROR_MESSAGE_EXTRA_KEY = "_llm_error_message"
120124
WEEKDAY_NAMES = (
@@ -357,41 +361,63 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None:
357361
req.prompt = f"{prefix}{req.prompt}"
358362

359363

360-
def _get_workspace_path_for_umo(umo: str) -> Path:
361-
normalized_umo = normalize_umo_for_workspace(umo)
362-
return Path(get_astrbot_workspaces_path()) / normalized_umo
364+
async def _get_workspace_path_for_umo(umo: str, plugin_context: Context) -> Path:
365+
"""Resolve the workspace path for the current request.
366+
367+
Args:
368+
umo: Unified message origin.
369+
plugin_context: Star context containing the database instance.
370+
371+
Returns:
372+
Workspace path used as cwd.
373+
"""
374+
fallback_root = (
375+
Path(get_astrbot_workspaces_path()) / normalize_umo_for_workspace(umo)
376+
).resolve(strict=False)
377+
db = getattr(plugin_context, "_db", None)
378+
if not isinstance(db, BaseDatabase):
379+
return fallback_root
380+
try:
381+
return await resolve_workspace_root_for_umo(umo, db)
382+
except Exception:
383+
return fallback_root
363384

364385

365-
def _apply_workspace_extra_prompt(
386+
async def _apply_workspace_extra_prompt(
366387
event: AstrMessageEvent,
367388
req: ProviderRequest,
389+
plugin_context: Context,
368390
) -> None:
369-
extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / (
370-
"EXTRA_PROMPT.md"
391+
workspace_root = await _get_workspace_path_for_umo(
392+
event.unified_msg_origin,
393+
plugin_context,
371394
)
372-
if not extra_prompt_path.is_file():
373-
return
374-
375-
try:
376-
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
377-
except Exception as exc: # noqa: BLE001
378-
logger.warning(
379-
"Failed to read workspace extra prompt for umo=%s from %s: %s",
380-
event.unified_msg_origin,
381-
extra_prompt_path,
382-
exc,
383-
)
384-
return
395+
extra_prompts: list[str] = []
396+
extra_prompt_path = workspace_root / "EXTRA_PROMPT.md"
397+
if extra_prompt_path.is_file():
398+
try:
399+
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
400+
except Exception as exc: # noqa: BLE001
401+
logger.warning(
402+
"Failed to read workspace extra prompt for umo=%s from %s: %s",
403+
event.unified_msg_origin,
404+
extra_prompt_path,
405+
exc,
406+
)
407+
else:
408+
if extra_prompt:
409+
extra_prompts.append(f"From `{extra_prompt_path}`:\n{extra_prompt}")
385410

386-
if not extra_prompt:
411+
if not extra_prompts:
387412
return
388413

414+
extra_prompt_text = "\n\n".join(extra_prompts)
389415
req.system_prompt = (
390416
f"{req.system_prompt or ''}\n"
391417
"[Workspace Extra Prompt]\n"
392418
"The following instructions are loaded from the current workspace "
393419
"`EXTRA_PROMPT.md` file.\n"
394-
f"{extra_prompt}\n"
420+
f"{extra_prompt_text}\n"
395421
)
396422

397423

@@ -498,13 +524,13 @@ async def _ensure_persona_and_skills(
498524
skill_manager = SkillManager()
499525
skills = skill_manager.list_skills(active_only=True, runtime=runtime)
500526
skills = _filter_skills_for_current_config(skills, cfg)
501-
workspace_skills = (
502-
skill_manager.list_workspace_skills(
503-
_get_workspace_path_for_umo(event.unified_msg_origin)
527+
workspace_skills: list[SkillInfo] = []
528+
if runtime == "local":
529+
workspace_root = await _get_workspace_path_for_umo(
530+
event.unified_msg_origin,
531+
plugin_context,
504532
)
505-
if runtime == "local"
506-
else []
507-
)
533+
workspace_skills.extend(skill_manager.list_workspace_skills(workspace_root))
508534

509535
if skills or workspace_skills:
510536
if persona and persona.get("skills") is not None:
@@ -989,7 +1015,7 @@ async def _decorate_llm_request(
9891015
if tz is None:
9901016
tz = plugin_context.get_config().get("timezone")
9911017
_append_system_reminders(event, req, cfg, tz)
992-
_apply_workspace_extra_prompt(event, req)
1018+
await _apply_workspace_extra_prompt(event, req, plugin_context)
9931019

9941020

9951021
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
@@ -1590,10 +1616,14 @@ async def build_main_agent(
15901616
)
15911617

15921618
if config.computer_use_runtime == "local":
1619+
workspace_root = await _get_workspace_path_for_umo(
1620+
event.unified_msg_origin,
1621+
plugin_context,
1622+
)
1623+
workspace_prompt = f"\nCurrent workspace you can use: `{workspace_root}`\n"
15931624
tool_prompt += (
1594-
f"\nCurrent workspace you can use: "
1595-
f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n"
1596-
"Unless the user explicitly specifies a different directory, "
1625+
workspace_prompt
1626+
+ "Unless the user explicitly specifies a different directory, "
15971627
"perform all file-related operations in this workspace.\n"
15981628
)
15991629

astrbot/core/db/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,8 @@ async def create_chatui_project(
851851
title: str,
852852
emoji: str | None = "📁",
853853
description: str | None = None,
854+
workspace_type: str = "session",
855+
workspace_path: str | None = None,
854856
) -> ChatUIProject:
855857
"""Create a new ChatUI project."""
856858
...
@@ -877,6 +879,8 @@ async def update_chatui_project(
877879
title: str | None = None,
878880
emoji: str | None = None,
879881
description: str | None = None,
882+
workspace_type: str | None = None,
883+
workspace_path: str | None = None,
880884
) -> None:
881885
"""Update a ChatUI project."""
882886
...

astrbot/core/db/po.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ class ChatUIProject(TimestampMixin, SQLModel, table=True):
447447
"""Title of the project"""
448448
description: str | None = Field(default=None, max_length=1000)
449449
"""Description of the project"""
450+
workspace_type: str = Field(default="session", nullable=False, max_length=32)
451+
"""Workspace mode: session, project, or custom"""
452+
workspace_path: str | None = Field(default=None, max_length=1024)
453+
"""Custom workspace path"""
450454

451455
__table_args__ = (
452456
UniqueConstraint(

astrbot/core/db/sqlite.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async def initialize(self) -> None:
6464
await self._ensure_persona_skills_column(conn)
6565
await self._ensure_persona_custom_error_message_column(conn)
6666
await self._ensure_platform_message_history_checkpoint_column(conn)
67+
await self._ensure_chatui_project_workspace_columns(conn)
6768
await conn.commit()
6869

6970
async def _ensure_persona_folder_columns(self, conn) -> None:
@@ -128,6 +129,23 @@ async def _ensure_platform_message_history_checkpoint_column(self, conn) -> None
128129
)
129130
)
130131

132+
async def _ensure_chatui_project_workspace_columns(self, conn) -> None:
133+
"""Ensure chatui_projects has workspace configuration columns."""
134+
result = await conn.execute(text("PRAGMA table_info(chatui_projects)"))
135+
columns = {row[1] for row in result.fetchall()}
136+
137+
if "workspace_type" not in columns:
138+
await conn.execute(
139+
text(
140+
"ALTER TABLE chatui_projects "
141+
"ADD COLUMN workspace_type VARCHAR(32) NOT NULL DEFAULT 'session'"
142+
)
143+
)
144+
if "workspace_path" not in columns:
145+
await conn.execute(
146+
text("ALTER TABLE chatui_projects ADD COLUMN workspace_path VARCHAR")
147+
)
148+
131149
# ====
132150
# Platform Statistics
133151
# ====
@@ -1877,6 +1895,8 @@ async def create_chatui_project(
18771895
title: str,
18781896
emoji: str | None = "📁",
18791897
description: str | None = None,
1898+
workspace_type: str = "session",
1899+
workspace_path: str | None = None,
18801900
) -> ChatUIProject:
18811901
"""Create a new ChatUI project."""
18821902
async with self.get_db() as session:
@@ -1887,6 +1907,8 @@ async def create_chatui_project(
18871907
title=title,
18881908
emoji=emoji,
18891909
description=description,
1910+
workspace_type=workspace_type,
1911+
workspace_path=workspace_path,
18901912
)
18911913
session.add(project)
18921914
await session.flush()
@@ -1929,6 +1951,8 @@ async def update_chatui_project(
19291951
title: str | None = None,
19301952
emoji: str | None = None,
19311953
description: str | None = None,
1954+
workspace_type: str | None = None,
1955+
workspace_path: str | None = None,
19321956
) -> None:
19331957
"""Update a ChatUI project."""
19341958
async with self.get_db() as session:
@@ -1941,6 +1965,9 @@ async def update_chatui_project(
19411965
values["emoji"] = emoji
19421966
if description is not None:
19431967
values["description"] = description
1968+
if workspace_type is not None:
1969+
values["workspace_type"] = workspace_type
1970+
values["workspace_path"] = workspace_path
19441971

19451972
await session.execute(
19461973
update(ChatUIProject)

0 commit comments

Comments
 (0)