diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index c7ae2ada5..113acffe7 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -14,6 +14,8 @@ logger = logging.getLogger(__name__) +VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080} + try: import discord from discord import Message as DiscordMessage, Intents @@ -608,6 +610,36 @@ async def slash_sethome(interaction: discord.Interaction): except Exception as e: logger.debug("Discord followup failed: %s", e) + @tree.command(name="thread", description="Create a new Discord thread here") + @discord.app_commands.describe( + name="Thread name", + message="Optional starter message", + auto_archive_duration="Auto-archive in minutes (60, 1440, 4320, 10080)", + ) + async def slash_thread( + interaction: discord.Interaction, + name: str, + message: str = "", + auto_archive_duration: int = 1440, + ): + await interaction.response.defer(ephemeral=True) + await self._handle_thread_create_slash(interaction, name, message, auto_archive_duration) + + @tree.command(name="channel", description="Create a new Discord channel in this server") + @discord.app_commands.describe( + name="Channel name", + topic="Optional channel topic", + nsfw="Mark the channel as NSFW", + ) + async def slash_channel( + interaction: discord.Interaction, + name: str, + topic: str = "", + nsfw: bool = False, + ): + await interaction.response.defer(ephemeral=True) + await self._handle_channel_create_slash(interaction, name, topic, nsfw) + @tree.command(name="stop", description="Stop the running Hermes agent") async def slash_stop(interaction: discord.Interaction): await interaction.response.defer(ephemeral=True) @@ -711,6 +743,215 @@ async def slash_update(interaction: discord.Interaction): except Exception as e: logger.debug("Discord followup failed: %s", e) + async def _handle_thread_create_slash( + self, + interaction: discord.Interaction, + name: str, + message: str = "", + auto_archive_duration: int = 1440, + ) -> None: + """Create a Discord thread natively from a slash command.""" + result = await self._create_thread_from_interaction( + interaction, + name=name, + message=message, + auto_archive_duration=auto_archive_duration, + ) + + if result.get("success"): + thread_id = result.get("thread_id") + thread_name = result.get("thread_name") or name + if thread_id: + await interaction.followup.send( + f"Created thread **{thread_name}**: <#{thread_id}>", + ephemeral=True, + ) + else: + await interaction.followup.send( + f"Created thread **{thread_name}**.", + ephemeral=True, + ) + return + + error = result.get("error", "unknown error") + await interaction.followup.send(f"Failed to create thread: {error}", ephemeral=True) + + async def _handle_channel_create_slash( + self, + interaction: discord.Interaction, + name: str, + topic: str = "", + nsfw: bool = False, + ) -> None: + """Create a Discord channel natively from a slash command.""" + result = await self._create_channel_from_interaction( + interaction, + name=name, + topic=topic, + nsfw=nsfw, + ) + + if result.get("success"): + channel_id = result.get("channel_id") + channel_name = result.get("channel_name") or name + if channel_id: + await interaction.followup.send( + f"Created channel **#{channel_name}**: <#{channel_id}>", + ephemeral=True, + ) + else: + await interaction.followup.send( + f"Created channel **#{channel_name}**.", + ephemeral=True, + ) + return + + error = result.get("error", "unknown error") + await interaction.followup.send(f"Failed to create channel: {error}", ephemeral=True) + + async def _resolve_interaction_channel(self, interaction: discord.Interaction) -> Optional[Any]: + """Return the interaction channel, fetching it if the payload is partial.""" + channel = getattr(interaction, "channel", None) + if channel is not None: + return channel + if not self._client: + return None + channel_id = getattr(interaction, "channel_id", None) + if channel_id is None: + return None + channel = self._client.get_channel(int(channel_id)) + if channel is not None: + return channel + try: + return await self._client.fetch_channel(int(channel_id)) + except Exception: + return None + + def _thread_parent_channel(self, channel: Any) -> Any: + """Use the parent text channel when invoked from a thread.""" + return getattr(channel, "parent", None) or channel + + def _thread_reason(self, interaction: discord.Interaction) -> str: + display_name = getattr(getattr(interaction, "user", None), "display_name", None) or "unknown user" + return f"Requested by {display_name} via /thread" + + def _channel_reason(self, interaction: discord.Interaction) -> str: + display_name = getattr(getattr(interaction, "user", None), "display_name", None) or "unknown user" + return f"Requested by {display_name} via /channel" + + async def _create_thread_from_interaction( + self, + interaction: discord.Interaction, + *, + name: str, + message: str = "", + auto_archive_duration: int = 1440, + ) -> Dict[str, Any]: + """Create a thread in the current Discord channel without going through an agent tool.""" + name = (name or "").strip() + if not name: + return {"error": "Thread name is required."} + + if auto_archive_duration not in VALID_THREAD_AUTO_ARCHIVE_MINUTES: + allowed = ", ".join(str(v) for v in sorted(VALID_THREAD_AUTO_ARCHIVE_MINUTES)) + return {"error": f"auto_archive_duration must be one of: {allowed}."} + + channel = await self._resolve_interaction_channel(interaction) + if channel is None: + return {"error": "Could not resolve the current Discord channel."} + if isinstance(channel, discord.DMChannel): + return {"error": "Discord threads can only be created inside server text channels, not DMs."} + + parent_channel = self._thread_parent_channel(channel) + if parent_channel is None: + return {"error": "Could not determine a parent text channel for the new thread."} + + reason = self._thread_reason(interaction) + starter_message = (message or "").strip() + + try: + thread = await parent_channel.create_thread( + name=name, + auto_archive_duration=auto_archive_duration, + reason=reason, + ) + if starter_message: + await thread.send(starter_message) + return { + "success": True, + "thread_id": str(thread.id), + "thread_name": getattr(thread, "name", None) or name, + } + except Exception as direct_error: + try: + seed_content = starter_message or f"🧵 Thread created by Hermes: **{name}**" + seed_message = await parent_channel.send(seed_content) + thread = await seed_message.create_thread( + name=name, + auto_archive_duration=auto_archive_duration, + reason=reason, + ) + return { + "success": True, + "thread_id": str(thread.id), + "thread_name": getattr(thread, "name", None) or name, + "starter_message_id": str(getattr(seed_message, "id", "")) or None, + } + except Exception as fallback_error: + return { + "error": ( + "Discord rejected direct thread creation and Hermes could not create a starter message either. " + f"Direct error: {direct_error}. Fallback error: {fallback_error}" + ) + } + + async def _create_channel_from_interaction( + self, + interaction: discord.Interaction, + *, + name: str, + topic: str = "", + nsfw: bool = False, + ) -> Dict[str, Any]: + """Create a text channel in the current guild without going through an agent tool.""" + name = (name or "").strip() + if not name: + return {"error": "Channel name is required."} + + channel = await self._resolve_interaction_channel(interaction) + if channel is None: + return {"error": "Could not resolve the current Discord channel."} + if isinstance(channel, discord.DMChannel): + return {"error": "Discord channels can only be created inside servers, not DMs."} + + base_channel = self._thread_parent_channel(channel) + guild = getattr(base_channel, "guild", None) or getattr(channel, "guild", None) + if guild is None: + return {"error": "Could not determine which Discord server should own the new channel."} + + kwargs = { + "name": name, + "nsfw": nsfw, + "reason": self._channel_reason(interaction), + } + topic = (topic or "").strip() + if topic: + kwargs["topic"] = topic + category = getattr(base_channel, "category", None) + if category is not None: + kwargs["category"] = category + + try: + created = await guild.create_text_channel(**kwargs) + except Exception as e: + return {"error": str(e)} + + return { + "success": True, + "channel_id": str(created.id), + "channel_name": getattr(created, "name", None) or name, + } + def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent: """Build a MessageEvent from a Discord slash command interaction.""" is_dm = isinstance(interaction.channel, discord.DMChannel) diff --git a/tests/gateway/test_discord_slash_commands.py b/tests/gateway/test_discord_slash_commands.py new file mode 100644 index 000000000..49d7d934c --- /dev/null +++ b/tests/gateway/test_discord_slash_commands.py @@ -0,0 +1,231 @@ +"""Tests for native Discord slash command fast-paths.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock +import sys + +import pytest + +from gateway.config import PlatformConfig + + +def _ensure_discord_mock(): + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.Interaction = object + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + ) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +from gateway.platforms.discord import DiscordAdapter # noqa: E402 + + +class FakeTree: + def __init__(self): + self.commands = {} + + def command(self, *, name, description): + def decorator(fn): + self.commands[name] = fn + return fn + + return decorator + + +@pytest.fixture +def adapter(): + config = PlatformConfig(enabled=True, token="***") + adapter = DiscordAdapter(config) + adapter._client = SimpleNamespace(tree=FakeTree(), get_channel=lambda _id: None, fetch_channel=AsyncMock()) + return adapter + + +@pytest.mark.asyncio +async def test_registers_native_thread_slash_command(adapter): + adapter._handle_thread_create_slash = AsyncMock() + adapter._register_slash_commands() + + command = adapter._client.tree.commands["thread"] + interaction = SimpleNamespace( + response=SimpleNamespace(defer=AsyncMock()), + ) + + await command(interaction, name="Planning", message="", auto_archive_duration=1440) + + interaction.response.defer.assert_awaited_once_with(ephemeral=True) + adapter._handle_thread_create_slash.assert_awaited_once_with(interaction, "Planning", "", 1440) + + +@pytest.mark.asyncio +async def test_registers_native_channel_slash_command(adapter): + adapter._handle_channel_create_slash = AsyncMock() + adapter._register_slash_commands() + + command = adapter._client.tree.commands["channel"] + interaction = SimpleNamespace( + response=SimpleNamespace(defer=AsyncMock()), + ) + + await command(interaction, name="planning-room", topic="Roadmap", nsfw=True) + + interaction.response.defer.assert_awaited_once_with(ephemeral=True) + adapter._handle_channel_create_slash.assert_awaited_once_with(interaction, "planning-room", "Roadmap", True) + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_reports_success(adapter): + created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock()) + parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread), send=AsyncMock()) + interaction_channel = SimpleNamespace(parent=parent_channel) + interaction = SimpleNamespace( + channel=interaction_channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440) + + parent_channel.create_thread.assert_awaited_once_with( + name="Planning", + auto_archive_duration=1440, + reason="Requested by Jezza via /thread", + ) + created_thread.send.assert_awaited_once_with("Kickoff") + interaction.followup.send.assert_awaited_once() + args, kwargs = interaction.followup.send.await_args + assert "<#555>" in args[0] + assert kwargs["ephemeral"] is True + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_falls_back_to_seed_message(adapter): + created_thread = SimpleNamespace(id=555, name="Planning") + seed_message = SimpleNamespace(id=777, create_thread=AsyncMock(return_value=created_thread)) + channel = SimpleNamespace( + create_thread=AsyncMock(side_effect=RuntimeError("direct failed")), + send=AsyncMock(return_value=seed_message), + ) + interaction = SimpleNamespace( + channel=channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440) + + channel.send.assert_awaited_once_with("Kickoff") + seed_message.create_thread.assert_awaited_once_with( + name="Planning", + auto_archive_duration=1440, + reason="Requested by Jezza via /thread", + ) + interaction.followup.send.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_reports_failure(adapter): + channel = SimpleNamespace( + create_thread=AsyncMock(side_effect=RuntimeError("direct failed")), + send=AsyncMock(side_effect=RuntimeError("nope")), + ) + interaction = SimpleNamespace( + channel=channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440) + + interaction.followup.send.assert_awaited_once() + args, kwargs = interaction.followup.send.await_args + assert "Failed to create thread:" in args[0] + assert "nope" in args[0] + assert kwargs["ephemeral"] is True + + +@pytest.mark.asyncio +async def test_handle_channel_create_slash_reports_success(adapter): + created_channel = SimpleNamespace(id=777, name="planning-room") + guild = SimpleNamespace(create_text_channel=AsyncMock(return_value=created_channel)) + category = object() + channel = SimpleNamespace(guild=guild, category=category) + interaction = SimpleNamespace( + channel=channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_channel_create_slash(interaction, "planning-room", "Roadmap", False) + + guild.create_text_channel.assert_awaited_once_with( + name="planning-room", + nsfw=False, + reason="Requested by Jezza via /channel", + topic="Roadmap", + category=category, + ) + interaction.followup.send.assert_awaited_once() + args, kwargs = interaction.followup.send.await_args + assert "<#777>" in args[0] + assert kwargs["ephemeral"] is True + + +@pytest.mark.asyncio +async def test_handle_channel_create_slash_from_thread_uses_parent_channel_category(adapter): + created_channel = SimpleNamespace(id=777, name="planning-room") + guild = SimpleNamespace(create_text_channel=AsyncMock(return_value=created_channel)) + category = object() + parent_channel = SimpleNamespace(guild=guild, category=category) + thread_channel = SimpleNamespace(parent=parent_channel, guild=guild) + interaction = SimpleNamespace( + channel=thread_channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_channel_create_slash(interaction, "planning-room", "", True) + + guild.create_text_channel.assert_awaited_once_with( + name="planning-room", + nsfw=True, + reason="Requested by Jezza via /channel", + category=category, + ) + + +@pytest.mark.asyncio +async def test_handle_channel_create_slash_reports_failure(adapter): + guild = SimpleNamespace(create_text_channel=AsyncMock(side_effect=RuntimeError("nope"))) + channel = SimpleNamespace(guild=guild, category=None) + interaction = SimpleNamespace( + channel=channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_channel_create_slash(interaction, "planning-room", "", False) + + interaction.followup.send.assert_awaited_once_with("Failed to create channel: nope", ephemeral=True)