diff --git a/docs/docs.json b/docs/docs.json index 843d67c7ed..addac0fd0e 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -115,6 +115,7 @@ "servers/composition", "servers/dependency-injection", "servers/elicitation", + "servers/events", "servers/icons", "servers/lifespan", "servers/logging", diff --git a/docs/servers/context.mdx b/docs/servers/context.mdx index a10161baf7..fe8abf3c27 100644 --- a/docs/servers/context.mdx +++ b/docs/servers/context.mdx @@ -23,6 +23,7 @@ The `Context` object provides a clean interface to access MCP features within yo - **Prompt Access**: List and retrieve prompts registered with the server - **LLM Sampling**: Request the client's LLM to generate text based on provided messages - **User Elicitation**: Request structured input from users during tool execution +- **Event Publishing**: [Broadcast events](/servers/events) to subscribed clients - **Session State**: Store data that persists across requests within an MCP session - **Session Visibility**: [Control which components are visible](/servers/visibility#per-session-visibility) to the current session - **Request Information**: Access metadata about the current request @@ -211,6 +212,21 @@ messages = result.messages - **`ctx.list_prompts() -> list[MCPPrompt]`**: Returns list of all available prompts - **`ctx.get_prompt(name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult`**: Get a specific prompt with optional arguments +### Event Publishing + +Broadcast [events](/servers/events) to all clients subscribed to a topic. Events are delivered as server-initiated notifications; subscribers do not need to poll. + +```python +@mcp.tool +async def notify_users(message: str, ctx: Context) -> str: + """Send a notification event to all subscribers.""" + await ctx.emit_event("myapp/notifications", {"text": message}) + return "sent" +``` + +**Method signature:** +- **`await ctx.emit_event(topic, payload=None, *, priority="normal", source=None, expires_at=None, event_id=None, retained=None)`**: Publish an event to matching subscribers. See [Events](/servers/events) for full parameter documentation. + ### Session State diff --git a/docs/servers/events.mdx b/docs/servers/events.mdx new file mode 100644 index 0000000000..ca03a21ac3 --- /dev/null +++ b/docs/servers/events.mdx @@ -0,0 +1,324 @@ +--- +title: Events +sidebarTitle: Events +description: Publish real-time notifications to subscribed clients through topic-based event streams. +icon: tower-broadcast +tag: "NEW" +--- + +import { VersionBadge } from '/snippets/version-badge.mdx' + + + +MCP events let servers push notifications to clients without waiting for a request. Clients subscribe to topics they care about, and the server broadcasts events to all matching subscribers. This is useful for status updates, progress streams, chat messages, sensor readings, or any data that changes over time. + +Events are topic-based. A topic is a hierarchical string like `myapp/status` or `chat/rooms/general/messages`. Servers declare the topics they publish, clients subscribe to patterns, and the server delivers matching events to each session. + +## Declaring Event Topics + +Before emitting events, declare the topics your server publishes. Declared topics are advertised to clients during connection setup, so clients know what they can subscribe to. + +### The `@event` Decorator + +The `@event` decorator declares a topic and uses the decorated function's return type to generate a JSON Schema for the event payload. The function itself is not called automatically; it serves as a schema definition. + +```python +from fastmcp import FastMCP + +mcp = FastMCP("EventServer") + +@mcp.event("myapp/status", kind="content") +def status_event() -> dict: + """Server status updates.""" + ... + +@mcp.event("myapp/metrics", kind="signal") +def metrics_event() -> dict: + """Periodic metric snapshots.""" + ... +``` + +Every declaration MUST specify a `kind`: `"content"` means payloads are suitable for LLM context injection, `"signal"` means machine-only events (programmatic consumption, not LLM injection). + +The decorator reads the docstring as the topic description and extracts a JSON Schema from the return type annotation. If you provide an explicit `description`, it takes precedence over the docstring. + +```python +@mcp.event("myapp/alerts", kind="content", description="Critical system alerts") +def alert_event() -> dict: + ... +``` + +### `declare_event()` + +For cases where a decorator does not fit (dynamic topic registration, topics without a schema function), use `declare_event()` directly: + +```python +mcp = FastMCP("EventServer") + +mcp.declare_event( + "myapp/status", + kind="content", + description="Server status updates", + suggested_handle="notify", + retained=True, + schema={"type": "object", "properties": {"state": {"type": "string"}}}, +) +``` + + + + Topic pattern string. Supports `{param}` placeholders for parameterized topics (see [Topic Patterns](#topic-patterns)). Maximum depth: 8 segments. + + + + REQUIRED. `"content"` for events whose payloads are safe to inject into LLM context. `"signal"` for machine-only events that clients process programmatically without LLM visibility. + + + + Human-readable description of the topic. + + + + Advisory hint for how clients SHOULD handle events on this topic. Clients remain free to override based on their own configuration (zero-trust model). + + + + When `True`, the server stores the most recent event for this topic and delivers it to new subscribers immediately on subscribe. See [Retained Events](#retained-events). + + + + JSON Schema describing the event payload structure. + + + +## Emitting Events + +Once topics are declared, emit events using `mcp.emit_event()` or `ctx.emit_event()`. Both broadcast to all sessions whose subscriptions match the topic. + +### From a Tool (via Context) + +Inside a tool, resource, or prompt function, use `ctx.emit_event()`: + +```python +from fastmcp import FastMCP, Context + +mcp = FastMCP("EventServer") + +mcp.declare_event("myapp/notifications", kind="content", description="User notifications") + +@mcp.tool +async def send_notification(message: str, ctx: Context) -> str: + """Send a notification to all subscribers.""" + await ctx.emit_event("myapp/notifications", {"text": message}) + return "sent" +``` + +`ctx.emit_event()` delegates to the server's `emit_event()`, so the behavior is identical. Use whichever you have access to. + +### From Background Code + +Outside of a tool or request handler, call `emit_event()` on the FastMCP instance directly. This works from lifespan hooks, background tasks, or any code that holds a reference to the server: + +```python +import asyncio +from fastmcp import FastMCP + +mcp = FastMCP("SensorServer") + +mcp.declare_event("sensors/temperature", kind="signal", description="Temperature readings", retained=True) + +@mcp.lifespan +async def publish_temperature(server: FastMCP): + """Emit temperature readings every 5 seconds.""" + async def _loop(): + while True: + reading = await read_sensor() + await server.emit_event("sensors/temperature", {"celsius": reading}) + await asyncio.sleep(5) + + task = asyncio.create_task(_loop()) + try: + yield {} + finally: + task.cancel() +``` + +### `emit_event()` Parameters + + + + Concrete topic string (no wildcards). Must match a declared topic pattern. + + + + Event payload. Any JSON-serializable value. Optional; may be `None` for pure signal events whose topic alone carries the information. + + + + Delivery priority hint per MCP Events Spec v2. Only `"urgent"` may cancel in-progress LLM generation; the others influence when the client processes the event. + + + + Source identifier for tracing where the event originated. Auto-set to `tool/` when called from a tool context. + + + + ISO 8601 timestamp after which the event should be discarded. Servers SHOULD NOT emit already-expired events; clients MUST drop events whose `expires_at` has passed. + + + + Unique event identifier. If omitted, a ULID is generated automatically. + + + + Override the topic's retained setting for this specific event. If `None`, uses the topic descriptor's `retained` value. + + + +## Retained Events + +Retained events solve the "late subscriber" problem. When a topic is marked as retained, the server stores the most recent event for that topic. When a new client subscribes, it receives the stored value immediately without waiting for the next publish. + +This is useful for state that has a "current value" semantic: connection status, latest configuration, sensor readings. + +```python +mcp = FastMCP("StatusServer") + +# The retained flag causes the last emitted event to be stored +mcp.declare_event("service/status", kind="content", description="Service health status", retained=True) + +# Or with the decorator: +@mcp.event("service/config", kind="content", retained=True) +def config_event() -> dict: + """Current service configuration.""" + ... +``` + +When retained is enabled: + +1. Each `emit_event()` call replaces the stored value for that topic. +2. New subscribers receive the stored value as part of the subscribe response. +3. If `expires_at` is set and the timestamp has passed, the stored value is discarded instead of delivered. + +You can also override the retained behavior per-emit: + +```python +# Force retention even if the topic descriptor says retained=False +await mcp.emit_event("myapp/status", {"state": "running"}, retained=True) +``` + +## Topic Patterns + +Topic strings use `/` as a segment separator. Declared topics can include `{param}` placeholders to describe families of related topics: + +```python +mcp = FastMCP("ChatServer") + +# A parameterized topic pattern +mcp.declare_event( + "chat/rooms/{room_id}/messages", + kind="content", + description="Messages in a chat room", +) + +# Emit to a concrete topic that matches the pattern +await mcp.emit_event("chat/rooms/general/messages", {"text": "hello"}) +await mcp.emit_event("chat/rooms/random/messages", {"text": "world"}) +``` + +The `{room_id}` placeholder matches any single segment. When the server receives a subscription or emits an event, it matches concrete topics against declared patterns segment-by-segment. + +### Subscription Wildcards + +Clients subscribe using MQTT-style wildcards: + +| Wildcard | Matches | Example | +|----------|---------|---------| +| `+` | Exactly one segment | `chat/rooms/+/messages` matches `chat/rooms/general/messages` | +| `#` | Zero or more trailing segments (must be last) | `chat/#` matches `chat/rooms/general/messages` and `chat` | + +A session that subscribes to `chat/rooms/+/messages` receives events from all rooms. A session that subscribes to `chat/#` receives all chat-related events. + +Each session receives at most one copy of an event, even if multiple subscription patterns overlap. + +## Authorization + +### The `{agent_id}` Placeholder + +The placeholder name `{agent_id}` is special. Per MCP Events Spec v2, `{agent_id}` is the application-level identity placeholder: clients resolve it to their own agent identity before subscribing, and servers see fully resolved topic strings. For fastmcp's default enforcement, the "agent id" is the MCP transport session UUID that fastmcp assigns to each connection. + +When a topic is declared with `{agent_id}`, FastMCP automatically enforces that each subscriber can only subscribe to their own slot: + +- A subscriber **must** substitute their own session UUID in the `{agent_id}` position. +- Wildcards (`+`, `#`) in the `{agent_id}` position are rejected. +- Another session's UUID in that position is rejected. +- Rejected subscriptions receive reason `"permission_denied"`. + +Other `{param}` names have no special enforcement and allow wildcards freely. + +```python +# Agent-scoped: only the owning agent can subscribe +mcp.declare_event("app/agents/{agent_id}/messages", kind="content") + +# Public: any session can subscribe +mcp.declare_event("app/server/status", kind="signal") + +# Public with descriptive placeholder: {project} is not magic +mcp.declare_event("app/builds/{project}/status", kind="signal") +``` + +### The `authorize` Callback + +For custom authorization beyond `{agent_id}`, use the `authorize` parameter on `declare_event`: + +```python +mcp.declare_event( + "app/rooms/{room}/chat", + kind="content", + authorize=lambda session_id, topic_params: session_id in get_room_members(topic_params["room"]), +) +``` + +The callback signature is `(session_id: str, topic_params: dict[str, str])`: + +- `session_id` is the subscribing session's UUID. +- `topic_params` maps placeholder names to the concrete values from the subscribe pattern. Wildcards are passed as literal `"+"` or `"#"`. +- Return `True` to allow, `False` to reject. +- If `authorize` is set, it **overrides** the default `{agent_id}` check entirely. The callback is fully responsible for authorization. +- If the callback raises, the subscription is denied (fail-closed) and a warning is logged. + +### Client `session_id` Discovery + +Clients can discover their server-assigned `session_id` from the initialize response. The server includes `session_id` in `InitializeResult._meta`. Python SDK clients access it via `session.session_id`. + +This is the UUID that the `{agent_id}` convention enforces against. + +### Targeted Emission + +Use `target_session_ids` on `emit_event` to restrict delivery to specific sessions: + +```python +await mcp.emit_event( + "topic/name", + payload={"data": "value"}, + target_session_ids=["session-uuid-1", "session-uuid-2"], +) +``` + +- If `None` (default): broadcast to all matching subscribers. +- If set: only deliver to sessions in this list that also have a matching subscription. +- Useful for defense-in-depth routing alongside `{agent_id}` scoping. + +## Session Registry and Broadcast + +FastMCP maintains a registry of active sessions and their subscriptions. When `emit_event()` is called: + +1. The subscription registry finds all sessions with patterns matching the topic. +2. The event notification is sent to each matching session. +3. Delivery failures to individual sessions are logged but do not block delivery to other sessions. + +Sessions are automatically registered when they connect and removed when they disconnect. You do not need to manage the session lifecycle. + +## Capabilities + +When at least one event topic is declared, the server advertises the `events` capability during initialization. Clients that do not support events can still connect and use other server features; event-related methods return an error only if the client tries to subscribe to a server with no declared topics. diff --git a/pyproject.toml b/pyproject.toml index be8794e950..8818f3a293 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ dependencies = [ "python-dotenv>=1.1.0", "exceptiongroup>=1.2.2", "httpx>=0.28.1,<1.0", - "mcp>=1.24.0,<2.0", + "mcp @ git+https://github.com/axiomantic/python-sdk.git@mcp-events", "openapi-pydantic>=0.5.1", "opentelemetry-api>=1.20.0", "packaging>=24.0", @@ -25,6 +25,7 @@ dependencies = [ "jsonref>=1.1.0", "uncalled-for>=0.2.0", "watchfiles>=1.0.0", + "python-ulid>=3.0.0", ] requires-python = ">=3.10" diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index 7fe3bb62e2..1fa4060248 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -2,7 +2,7 @@ import logging import weakref -from collections.abc import Callable, Generator, Mapping, Sequence +from collections.abc import Callable, Collection, Generator, Mapping, Sequence from contextlib import contextmanager from contextvars import ContextVar, Token from dataclasses import dataclass @@ -193,6 +193,7 @@ def __init__( *, task_id: str | None = None, origin_request_id: str | None = None, + _tool_name: str | None = None, ): self._fastmcp: weakref.ref[FastMCP] = weakref.ref(fastmcp) self._session: ServerSession | None = session # For state ops during init @@ -200,6 +201,8 @@ def __init__( # Background task support (SEP-1686) self._task_id: str | None = task_id self._origin_request_id: str | None = origin_request_id + # Tool name for auto-source in emit_event + self._tool_name: str | None = _tool_name # Request-scoped state for non-serializable values (serializable=False) self._request_state: dict[str, Any] = {} @@ -242,6 +245,14 @@ def origin_request_id(self) -> str | None: return str(self.request_context.request_id) return self._origin_request_id + @property + def tool_name(self) -> str | None: + """Get the tool name if this context was created for a tool call. + + Returns None if the context was not created from a call_tool invocation. + """ + return self._tool_name + @property def fastmcp(self) -> FastMCP: """Get the FastMCP instance.""" @@ -774,6 +785,63 @@ async def error( extra=extra, ) + async def emit_event( + self, + topic: str, + payload: Any = None, + *, + priority: Literal["urgent", "high", "normal", "low"] = "normal", + source: str | None = None, + expires_at: str | None = None, + event_id: str | None = None, + retained: bool | None = None, + target_session_ids: Collection[str] | None = None, + ) -> None: + """Publish an event to all sessions subscribed to the given topic. + + This delegates to the FastMCP instance's ``emit_event()`` method, + which broadcasts to all matching subscribers across all active sessions. + + Example:: + + @server.tool + async def notify(ctx: Context, message: str) -> str: + await ctx.emit_event("myapp/notifications", {"text": message}) + return "sent" + + Args: + topic: Concrete topic string (no wildcards). + payload: Event payload (any JSON-serializable value). Optional; + may be ``None`` for pure signal events. + priority: Delivery priority hint (``"urgent"``, ``"high"``, + ``"normal"`` (default), or ``"low"``). Per MCP Events + Spec v2. + source: Optional source identifier. Auto-set to ``tool/`` + when called from a tool context if not provided. + expires_at: Optional ISO 8601 expiry timestamp. + event_id: Optional event ID (auto-generated if not provided). + retained: If True, store as retained value for the topic. + target_session_ids: Optional defense-in-depth filter. When + provided, delivery is restricted to sessions whose + fastmcp session_id is in this collection. Used as a + routing safety net alongside subscription-time + authorization; see ``FastMCP.emit_event`` for details. + """ + # Auto-set source from tool name when not explicitly provided + if source is None and self._tool_name is not None: + source = f"tool/{self._tool_name}" + + await self.fastmcp.emit_event( + topic=topic, + payload=payload, + priority=priority, + source=source, + expires_at=expires_at, + event_id=event_id, + retained=retained, + target_session_ids=target_session_ids, + ) + async def list_roots(self) -> list[Root]: """List the roots available to the server, as indicated by the client.""" result = await self.session.list_roots() diff --git a/src/fastmcp/server/events.py b/src/fastmcp/server/events.py new file mode 100644 index 0000000000..8a6d3bb31c --- /dev/null +++ b/src/fastmcp/server/events.py @@ -0,0 +1,228 @@ +"""Subscription infrastructure for MCP events. + +This module provides: +- SubscriptionRegistry for managing session-to-topic subscriptions with MQTT wildcards +- RetainedValueStore for storing the most recent event per topic + +Event types (EventTopicDescriptor, EventParams, etc.) are imported from the +mcp SDK (mcp.types). This module re-exports them for convenience. + +NOTE: This is completely separate from ``event_store.py`` which handles +SSE transport-level resumability for Streamable HTTP. + +Wildcard rules: +- ``+`` matches exactly one segment (between ``/`` separators) +- ``#`` matches zero or more trailing segments (must be last segment) +- Literal segments match exactly +""" + +from __future__ import annotations + +import asyncio +import functools +import re +from datetime import datetime, timezone + +# Re-export event types from the SDK for convenience +from mcp.types import ( # noqa: F401 + EventEmitNotification, + EventListRequest, + EventListResult, + EventParams, + EventsCapability, + EventSubscribeParams, + EventSubscribeRequest, + EventSubscribeResult, + EventTopicDescriptor, + EventUnsubscribeParams, + EventUnsubscribeRequest, + EventUnsubscribeResult, + RejectedTopic, + RetainedEvent, + SubscribedTopic, +) + +# --------------------------------------------------------------------------- +# Wildcard pattern matching +# --------------------------------------------------------------------------- + + +@functools.lru_cache(maxsize=256) +def _pattern_to_regex(pattern: str) -> re.Pattern[str]: + """Convert an MQTT-style topic pattern to a compiled regex. + + ``+`` becomes a single-segment match, ``#`` becomes a greedy + multi-segment match (only valid as the final segment). + """ + parts = pattern.split("/") + regex_parts: list[str] = [] + for i, part in enumerate(parts): + if part == "#": + if i != len(parts) - 1: + raise ValueError("'#' wildcard is only valid as the last segment") + # # matches zero or more trailing segments + # If preceding segments exist, the / before # is optional + # so "myapp/#" matches both "myapp" and "myapp/anything" + if regex_parts: + return re.compile("^" + "/".join(regex_parts) + "(/.*)?$") + else: + return re.compile("^.*$") + elif part == "+": + regex_parts.append("[^/]+") + else: + regex_parts.append(re.escape(part)) + return re.compile("^" + "/".join(regex_parts) + "$") + + +# --------------------------------------------------------------------------- +# Subscription registry +# --------------------------------------------------------------------------- + + +class SubscriptionRegistry: + """Thread-safe registry mapping session IDs to topic subscription patterns. + + Supports MQTT-style wildcards (``+`` for single segment, ``#`` for + trailing multi-segment). ``match()`` guarantees at-most-once delivery + per session regardless of how many patterns overlap. + """ + + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._subscriptions: dict[str, set[str]] = {} + self._compiled: dict[str, re.Pattern[str]] = {} + + def _compile(self, pattern: str) -> re.Pattern[str]: + if pattern not in self._compiled: + self._compiled[pattern] = _pattern_to_regex(pattern) + return self._compiled[pattern] + + async def add(self, session_id: str, pattern: str) -> None: + """Register a subscription for *session_id* on *pattern*.""" + async with self._lock: + self._subscriptions.setdefault(session_id, set()).add(pattern) + self._compile(pattern) + + async def remove(self, session_id: str, pattern: str) -> None: + """Remove a single subscription.""" + async with self._lock: + if session_id in self._subscriptions: + self._subscriptions[session_id].discard(pattern) + if not self._subscriptions[session_id]: + del self._subscriptions[session_id] + + async def remove_all(self, session_id: str) -> None: + """Remove all subscriptions for *session_id* (disconnect cleanup).""" + async with self._lock: + self._subscriptions.pop(session_id, None) + + async def match(self, topic: str) -> set[str]: + """Return session IDs whose subscriptions match *topic*. + + Each session appears at most once (at-most-once delivery guarantee). + """ + async with self._lock: + result: set[str] = set() + for session_id, patterns in self._subscriptions.items(): + for pattern in patterns: + regex = self._compile(pattern) + if regex.match(topic): + result.add(session_id) + break # at-most-once per session + return result + + async def get_subscriptions(self, session_id: str) -> set[str]: + """Return the set of patterns a session is subscribed to.""" + async with self._lock: + return set(self._subscriptions.get(session_id, set())) + + +# --------------------------------------------------------------------------- +# Retained value store +# --------------------------------------------------------------------------- + + +class RetainedValueStore: + """Stores the most recent event per topic for replay on subscribe. + + This is an *application-level* retained value store, distinct from + ``event_store.py`` which is an SSE transport-level event store for + Streamable HTTP resumability. + + All mutating and reading methods are async and guarded by an + ``asyncio.Lock`` to prevent races between concurrent emit and + subscribe operations (mirrors ``SubscriptionRegistry``'s pattern). + """ + + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._store: dict[str, RetainedEvent] = {} + self._expires: dict[str, str] = {} + self._regex_cache: dict[str, re.Pattern[str]] = {} + + async def set( + self, topic: str, event: RetainedEvent, expires_at: str | None = None + ) -> None: + """Store or replace the retained value for *topic*.""" + async with self._lock: + self._store[topic] = event + if expires_at is not None: + self._expires[topic] = expires_at + else: + self._expires.pop(topic, None) + + async def get(self, topic: str) -> RetainedEvent | None: + """Retrieve the retained value, or ``None`` if expired/absent.""" + async with self._lock: + event = self._store.get(topic) + if event is None: + return None + if self._is_expired(topic): + del self._store[topic] + self._expires.pop(topic, None) + return None + return event + + async def get_matching(self, pattern: str) -> list[RetainedEvent]: + """Return all non-expired retained events whose topic matches *pattern*.""" + async with self._lock: + if pattern not in self._regex_cache: + self._regex_cache[pattern] = _pattern_to_regex(pattern) + regex = self._regex_cache[pattern] + result: list[RetainedEvent] = [] + expired_topics: list[str] = [] + for topic, event in self._store.items(): + if self._is_expired(topic): + expired_topics.append(topic) + continue + if regex.match(topic): + # Each topic has exactly one retained event in the store, + # so no per-topic deduplication is needed here. The caller + # is responsible for deduplicating across multiple pattern + # matches (e.g. when processing a subscribe request with + # overlapping patterns). + result.append(event) + for topic in expired_topics: + del self._store[topic] + self._expires.pop(topic, None) + return result + + async def delete(self, topic: str) -> None: + """Remove the retained value for *topic*.""" + async with self._lock: + self._store.pop(topic, None) + self._expires.pop(topic, None) + + def _is_expired(self, topic: str) -> bool: + expires_at = self._expires.get(topic) + if expires_at is None: + return False + try: + # Python 3.10 fromisoformat() doesn't support "Z" suffix; + # replace with "+00:00" for compatibility. + expiry = datetime.fromisoformat(expires_at.replace("Z", "+00:00")) + if expiry.tzinfo is None: + expiry = expiry.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) >= expiry + except (ValueError, TypeError): + return False diff --git a/src/fastmcp/server/low_level.py b/src/fastmcp/server/low_level.py index 36255f4c74..e9bff644c3 100644 --- a/src/fastmcp/server/low_level.py +++ b/src/fastmcp/server/low_level.py @@ -36,6 +36,8 @@ class MiddlewareServerSession(ServerSession): """ServerSession that routes initialization requests through FastMCP middleware.""" + _fastmcp_event_session_id: str | None = None + def __init__(self, fastmcp: FastMCP, *args, **kwargs): super().__init__(*args, **kwargs) self._fastmcp_ref: weakref.ref[FastMCP] = weakref.ref(fastmcp) @@ -64,7 +66,7 @@ def client_supports_extension(self, extension_id: str) -> bool: caps = client_params.capabilities if caps is None: return False - # ClientCapabilities uses extra="allow" — extensions is an extra field + # ClientCapabilities uses extra="allow" -- extensions is an extra field extras = caps.model_extra or {} extensions: dict[str, Any] | None = extras.get("extensions") if not extensions: @@ -99,9 +101,22 @@ async def _received_request( original_respond = responder.respond async def capturing_respond( - response: mcp.types.ServerResult, + response: mcp.types.ServerResult | mcp.types.ErrorData, ) -> None: nonlocal captured_response + # Inject the fastmcp session_id into InitializeResult._meta so + # clients can learn their own session identifier synchronously + # during the initialize handshake. This value is the same UUID + # set on the session in LowLevelServer.run() and used by the + # event subscription system for cross-session authorization. + if not isinstance(response, mcp.types.ErrorData) and isinstance( + response.root, mcp.types.InitializeResult + ): + session_id = getattr(self, "_fastmcp_event_session_id", None) + if session_id is not None: + existing_meta = response.root.meta or {} + merged_meta = {**existing_meta, "session_id": session_id} + response.root.meta = merged_meta captured_response = response return await original_respond(response) @@ -211,8 +226,21 @@ def get_capabilities( # Set tasks as a first-class field (not experimental) per SEP-1686 capabilities.tasks = get_task_capabilities() + # Event handlers are always registered so the SDK's get_capabilities + # will set events_capability. Override: only advertise when topics + # are actually declared, and include the declared topic descriptors. + if self.fastmcp._event_topics: + from fastmcp.server.events import EventsCapability + + capabilities.events = EventsCapability( + topics=list(self.fastmcp._event_topics.values()), + ) + else: + # No topics declared: suppress events capability + capabilities.events = None + # Advertise MCP Apps extension support (io.modelcontextprotocol/ui) - # Uses the same extra-field pattern as tasks above — ServerCapabilities + # Uses the same extra-field pattern as tasks above -- ServerCapabilities # has extra="allow" so this survives serialization. # Merge with any existing extensions to avoid clobbering other features. existing_extensions: dict[str, Any] = ( @@ -245,18 +273,30 @@ async def run( ) ) - async with anyio.create_task_group() as tg: - # Store task group on session for subscription tasks (SEP-1686) - session._subscription_task_group = tg - - async for message in session.incoming_messages: - tg.start_soon( - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) + # Register session for event broadcasting (dict for O(1) lookup) + from uuid import uuid4 + + session_id = str(uuid4()) + session._fastmcp_event_session_id = session_id + self.fastmcp._active_sessions[session_id] = session + + try: + async with anyio.create_task_group() as tg: + # Store task group on session for subscription tasks (SEP-1686) + session._subscription_task_group = tg + + async for message in session.incoming_messages: + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + finally: + # Cleanup: remove session and its subscriptions + self.fastmcp._active_sessions.pop(session_id, None) + await self.fastmcp._subscription_registry.remove_all(session_id) def read_resource( self, diff --git a/src/fastmcp/server/mixins/mcp_operations.py b/src/fastmcp/server/mixins/mcp_operations.py index 70bd656072..c5417c7926 100644 --- a/src/fastmcp/server/mixins/mcp_operations.py +++ b/src/fastmcp/server/mixins/mcp_operations.py @@ -24,6 +24,56 @@ PaginateT = TypeVar("PaginateT") +def _is_placeholder(segment: str) -> bool: + """Return True if ``segment`` is a ``{param}`` placeholder.""" + return len(segment) >= 2 and segment.startswith("{") and segment.endswith("}") + + +def _placeholder_name(segment: str) -> str: + """Extract the name from a ``{param}`` placeholder segment.""" + return segment[1:-1] + + +def _extract_topic_params( + declared_segments: list[str], + subscribe_segments: list[str], +) -> dict[str, str]: + """Build the ``topic_params`` dict passed to ``authorize`` callbacks. + + For each placeholder segment in the declared pattern, determine the + substituted value from the corresponding subscribe-pattern segment. If + the subscribe pattern uses a single-segment wildcard (``+``), the value + is the literal string ``"+"``. If the subscribe pattern uses the + multi-segment wildcard (``#``), ALL placeholder slots at that position + or later receive the literal string ``"#"``. + + Segments without placeholders do not contribute to the dict. + """ + params: dict[str, str] = {} + hash_active = False + for index, declared_seg in enumerate(declared_segments): + if not _is_placeholder(declared_seg): + continue + name = _placeholder_name(declared_seg) + if hash_active: + params[name] = "#" + continue + if index >= len(subscribe_segments): + # Subscribe pattern is shorter than declared. This should only + # happen when `#` consumed earlier segments, which is handled + # above. Record a sentinel for safety. + params[name] = "#" + continue + sub_seg = subscribe_segments[index] + if sub_seg == "#": + params[name] = "#" + hash_active = True + else: + # Covers literal values and the "+" single-segment wildcard. + params[name] = sub_seg + return params + + def _apply_pagination( items: Sequence[PaginateT], cursor: str | None, @@ -81,6 +131,9 @@ def _setup_handlers(self: FastMCP) -> None: self._mcp_server.get_prompt()(self._get_prompt_mcp) self._mcp_server.set_logging_level()(self._set_logging_level_mcp) + # Register event protocol handlers + self._setup_event_protocol_handlers() + # Register SEP-1686 task protocol handlers self._setup_task_protocol_handlers() @@ -371,3 +424,291 @@ async def _set_logging_level_mcp(self, level: mcp.types.LoggingLevel) -> None: session._minimum_logging_level = level except LookupError: pass + + # ------------------------------------------------------------------------- + # Event protocol handlers + # ------------------------------------------------------------------------- + + def _setup_event_protocol_handlers(self: FastMCP) -> None: + """Register event protocol handlers through the SDK's request_handlers. + + Event request types (EventSubscribeRequest, EventUnsubscribeRequest, + EventListRequest) are part of the SDK's ClientRequest union, so the + SDK's built-in dispatch routes them to registered handlers automatically. + + Capabilities are advertised by the SDK based on the presence of + EventSubscribeRequest in request_handlers, and overridden by + LowLevelServer.get_capabilities() to include declared topic descriptors. + """ + from fastmcp.server.events import ( + EventListRequest, + EventSubscribeRequest, + EventUnsubscribeRequest, + ) + + server = self + + def _check_events_capability() -> None: + """Raise -32601 if no event topics are declared.""" + if not server._event_topics: + raise McpError( + mcp.types.ErrorData( + code=-32601, + message="Method not found: server has no events capability", + ) + ) + + async def handle_subscribe( + req: EventSubscribeRequest, + ) -> mcp.types.ServerResult: + _check_events_capability() + result = await server._handle_subscribe_events(req) + return mcp.types.ServerResult(result) + + async def handle_unsubscribe( + req: EventUnsubscribeRequest, + ) -> mcp.types.ServerResult: + _check_events_capability() + result = await server._handle_unsubscribe_events(req) + return mcp.types.ServerResult(result) + + async def handle_list(req: EventListRequest) -> mcp.types.ServerResult: + _check_events_capability() + result = await server._handle_list_events(req) + return mcp.types.ServerResult(result) + + server._mcp_server.request_handlers[EventSubscribeRequest] = handle_subscribe + server._mcp_server.request_handlers[EventUnsubscribeRequest] = ( + handle_unsubscribe + ) + server._mcp_server.request_handlers[EventListRequest] = handle_list + + async def _handle_subscribe_events( + self, req: mcp.types.EventSubscribeRequest + ) -> mcp.types.EventSubscribeResult: + """Handle events/subscribe requests.""" + from mcp.server.lowlevel.server import request_ctx + + from fastmcp.server.events import ( + EventSubscribeResult, + RejectedTopic, + SubscribedTopic, + ) + + server = cast("FastMCP", self) + logger.debug(f"[{server.name}] Handler called: events/subscribe") + + # Get the session from the SDK request context + ctx = request_ctx.get() + session = ctx.session + session_id = getattr(session, "_fastmcp_event_session_id", None) + + if session_id is None: + raise McpError( + mcp.types.ErrorData( + code=-32603, + message="No session context available for subscription", + ) + ) + + topics = req.params.topics + + subscribed: list[SubscribedTopic] = [] + rejected: list[RejectedTopic] = [] + retained_events = [] + seen_event_ids: set[str] = set() + + for pattern in topics: + # Validate topic depth (max 8 segments) + segments = pattern.split("/") + if len(segments) > server._MAX_TOPIC_DEPTH: + raise McpError( + mcp.types.ErrorData( + code=-32602, + message=( + f"Subscription pattern has {len(segments)} segments, " + f"maximum depth is {server._MAX_TOPIC_DEPTH}: {pattern!r}" + ), + ) + ) + + # Check if the pattern matches any declared topic + matched_declared = server._find_matching_declared_topics(pattern) + if not matched_declared: + rejected.append(RejectedTopic(pattern=pattern, reason="unknown_topic")) + continue + + # Authorize the subscription against each declared pattern that + # the subscribe pattern matches. Require every match to + # authorize: a single denial rejects the whole subscription so + # a client cannot smuggle in a forbidden pattern by combining + # it with a permissive one via wildcards. + authorized = True + for declared_pattern in matched_declared: + if not server._authorize_subscription( + declared_pattern, pattern, session_id + ): + authorized = False + break + if not authorized: + rejected.append( + RejectedTopic(pattern=pattern, reason="permission_denied") + ) + continue + + try: + await server._subscription_registry.add(session_id, pattern) + except ValueError as e: + rejected.append( + RejectedTopic(pattern=pattern, reason=f"invalid_pattern: {e}") + ) + continue + subscribed.append(SubscribedTopic(pattern=pattern)) + + # Deliver retained values for this pattern (deduplicated) + matching = await server._retained_store.get_matching(pattern) + for evt in matching: + if evt.eventId not in seen_event_ids: + seen_event_ids.add(evt.eventId) + retained_events.append(evt) + + return EventSubscribeResult( + subscribed=subscribed, + rejected=rejected, + retained=retained_events, + ) + + async def _handle_unsubscribe_events( + self, req: mcp.types.EventUnsubscribeRequest + ) -> mcp.types.EventUnsubscribeResult: + """Handle events/unsubscribe requests.""" + from mcp.server.lowlevel.server import request_ctx + + from fastmcp.server.events import EventUnsubscribeResult + + server = cast("FastMCP", self) + logger.debug(f"[{server.name}] Handler called: events/unsubscribe") + + ctx = request_ctx.get() + session = ctx.session + session_id = getattr(session, "_fastmcp_event_session_id", None) + + topics = req.params.topics + + unsubscribed: list[str] = [] + if session_id is not None: + for pattern in topics: + await server._subscription_registry.remove(session_id, pattern) + unsubscribed.append(pattern) + + return EventUnsubscribeResult(unsubscribed=unsubscribed) + + async def _handle_list_events( + self, req: mcp.types.EventListRequest + ) -> mcp.types.EventListResult: + """Handle events/list requests.""" + from fastmcp.server.events import EventListResult + + server = cast("FastMCP", self) + logger.debug(f"[{server.name}] Handler called: events/list") + + topics = list(server._event_topics.values()) + return EventListResult(topics=topics) + + def _authorize_subscription( + self: FastMCP, + declared_pattern: str, + subscribe_pattern: str, + session_id: str, + ) -> bool: + """Run the declared topic's authorize callback if one is registered. + + Default policy is permissive: any subscriber that matches the topic + pattern is allowed. Per-agent or per-tenant isolation requires an + explicit authorize callback registered on the declared topic via + ``declare_event(authorize=...)``. + + The callback receives ``(session_id, topic_params)`` where + ``topic_params`` maps each ``{param}`` placeholder name in the + declared pattern to the value supplied by the subscribe pattern + (a literal, ``"+"`` for a single-segment wildcard, or ``"#"`` for + the multi-segment wildcard). The callback returns True to allow + the subscription or False to reject it. If the callback raises, + the subscription is denied and a warning is logged. + + Returns True to allow the subscription, False to reject it. + """ + authorize_cb = self._event_topic_authorize.get(declared_pattern) + if authorize_cb is None: + return True + declared_segments = declared_pattern.split("/") + subscribe_segments = subscribe_pattern.split("/") + topic_params = _extract_topic_params(declared_segments, subscribe_segments) + try: + return bool(authorize_cb(session_id, topic_params)) + except Exception: + logger.warning( + "authorize callback raised for declared topic %r; denying subscription", + declared_pattern, + exc_info=True, + ) + return False + + def _match_declared_topic(self: FastMCP, pattern: str) -> bool: + """Check whether a subscription pattern matches any declared event topic. + + See ``_find_matching_declared_topics`` for the underlying logic. + """ + return bool(self._find_matching_declared_topics(pattern)) + + def _find_matching_declared_topics(self: FastMCP, pattern: str) -> list[str]: + """Return the declared topic patterns that a subscription pattern matches. + + Handles both exact matches and wildcard patterns that could match + declared topic patterns. For example, subscription pattern "myapp/+" + matches declared topic "myapp/{param}". + + Uses regex-based matching in both directions: the subscription pattern + is checked against declared patterns (with {param} as single-segment + wildcards), and declared patterns are checked against the subscription + pattern (with + and # as MQTT wildcards). + """ + import re as _re + + from fastmcp.server.events import _pattern_to_regex + + matches: list[str] = [] + + for declared_pattern in self._event_topics: + # Forward: build regex from declared pattern's {param} placeholders + # and test whether the subscription pattern (with wildcards replaced + # by a synthetic single-segment value) matches. + declared_regex = self._declared_topic_regex_cache.get(declared_pattern) + if declared_regex is None: + declared_regex_parts = [] + for segment in declared_pattern.split("/"): + if segment.startswith("{") and segment.endswith("}"): + declared_regex_parts.append("[^/]+") + else: + declared_regex_parts.append(_re.escape(segment)) + declared_regex = _re.compile("^" + "/".join(declared_regex_parts) + "$") + self._declared_topic_regex_cache[declared_pattern] = declared_regex + + # Replace MQTT wildcards with a synthetic literal segment for + # testing against the declared pattern regex. + test_pattern = _re.sub(r"[+#]", "x", pattern) + if declared_regex.match(test_pattern): + matches.append(declared_pattern) + continue + + # Reverse: does the declared pattern (with {param} replaced by a + # synthetic literal) match the subscription pattern's MQTT regex? + concrete_declared = _re.sub(r"\{[^}]+\}", "x", declared_pattern) + try: + sub_regex = _pattern_to_regex(pattern) + if sub_regex.match(concrete_declared): + matches.append(declared_pattern) + except ValueError: + continue + + return matches diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 789bc42ba2..28077cdb92 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -11,6 +11,7 @@ AsyncIterator, Awaitable, Callable, + Collection, Sequence, ) from contextlib import ( @@ -33,12 +34,15 @@ Annotations, AnyFunction, CallToolRequestParams, + EventTopicDescriptor, + RetainedEvent, ToolAnnotations, ) from pydantic import AnyUrl from pydantic import ValidationError as PydanticValidationError from starlette.routing import BaseRoute from typing_extensions import Self +from ulid import ULID import fastmcp import fastmcp.server @@ -64,6 +68,10 @@ from fastmcp.resources.base import Resource, ResourceResult from fastmcp.resources.template import ResourceTemplate from fastmcp.server.auth import AuthCheck, AuthContext, AuthProvider, run_auth_checks +from fastmcp.server.events import ( + RetainedValueStore, + SubscriptionRegistry, +) from fastmcp.server.lifespan import Lifespan from fastmcp.server.low_level import LowLevelServer from fastmcp.server.middleware import Middleware, MiddlewareContext @@ -94,6 +102,7 @@ from fastmcp.client.client import FastMCP1Server from fastmcp.client.sampling import SamplingHandler from fastmcp.client.transports import ClientTransport, ClientTransportT + from fastmcp.server.low_level import MiddlewareServerSession from fastmcp.server.providers.openapi import ComponentFn as OpenAPIComponentFn from fastmcp.server.providers.openapi import RouteMap from fastmcp.server.providers.openapi import RouteMapFn as OpenAPIRouteMapFn @@ -298,6 +307,23 @@ def __init__( self._docket = None self._worker = None + # Event topics and subscription infrastructure + self._event_topics: dict[str, EventTopicDescriptor] = {} + # Per-declared-pattern authorization callbacks. Stored separately from + # EventTopicDescriptor (a python-sdk protocol type we do not mutate). + # Key: declared pattern string; value: callback(session_id, topic_params) + # returning True to allow the subscription, False to reject it. + self._event_topic_authorize: dict[ + str, Callable[[str, dict[str, str]], bool] + ] = {} + # Cached regex compiled from each declared topic pattern's {param} + # placeholders, used to accelerate subscription pattern matching. + # Keyed by declared pattern string; populated lazily on first use. + self._declared_topic_regex_cache: dict[str, re.Pattern[str]] = {} + self._subscription_registry: SubscriptionRegistry = SubscriptionRegistry() + self._retained_store: RetainedValueStore = RetainedValueStore() + self._active_sessions: dict[str, MiddlewareServerSession] = {} + self._additional_http_routes: list[BaseRoute] = [] # Session-scoped state store (shared across all requests) @@ -1109,7 +1135,7 @@ async def call_tool( # For mounted servers, the parent's provider sets fn_key to the # namespaced key before delegating, ensuring correct Docket routing. - async with fastmcp.server.context.Context(fastmcp=self) as ctx: + async with fastmcp.server.context.Context(fastmcp=self, _tool_name=name) as ctx: if run_middleware: mw_context = MiddlewareContext[CallToolRequestParams]( message=mcp.types.CallToolRequestParams( @@ -1619,6 +1645,334 @@ def my_tool(x: int) -> str: return result + # ------------------------------------------------------------------------- + # Event declaration and publishing + # ------------------------------------------------------------------------- + + _MAX_TOPIC_DEPTH = 8 + + def declare_event( + self, + pattern: str, + *, + kind: Literal["content", "signal"], + description: str | None = None, + suggested_handle: Literal[ + "drop", "silent", "notify", "ask", "inject", "interrupt" + ] + | None = None, + retained: bool = False, + schema: dict[str, Any] | None = None, + authorize: Callable[[str, dict[str, str]], bool] | None = None, + ) -> EventTopicDescriptor: + """Declare an event topic that this server can publish to. + + Args: + pattern: Topic pattern (e.g., "agents/{agent_id}/messages"). + ``{param}`` placeholders describe parameterized segments. + Maximum depth: 8 segments. + kind: REQUIRED. Either ``"content"`` (payloads are suitable for + LLM context injection) or ``"signal"`` (machine-only events + intended for programmatic consumption, not LLM injection). + description: Human-readable description of the topic. + suggested_handle: Optional advisory hint for how clients SHOULD + handle events on this topic. One of + ``drop``, ``silent``, ``notify``, ``ask``, + ``inject``, ``interrupt``. Clients remain free + to override based on their own configuration + (zero-trust model). + retained: Whether the most recent event per topic is stored and + delivered to new subscribers on subscribe. + schema: Optional JSON Schema for the event payload. + authorize: Optional callback invoked when a client subscribes to a + pattern that matches this declaration. Receives + ``(session_id, topic_params)`` and returns True to + permit the subscription or False to reject it with + ``reason="permission_denied"``. ``topic_params`` is a + dict mapping each placeholder name in the declared + pattern to the value supplied by the subscribing + pattern: either a literal string, the wildcard + character ``"+"`` (single segment wildcard), or + ``"#"`` (multi-segment wildcard). + + Authorization model: + By default any subscriber whose subscribe pattern matches a + declared topic is allowed. Per-agent or per-tenant isolation + is opt-in via an explicit ``authorize`` callback. The callback + is fully responsible for deciding whether the subscription is + permitted; fastmcp applies no additional policy. + + Per MCP Events Spec v2, ``{agent_id}`` is an application-level + identity placeholder declared by the server and resolved by + the client before subscribing. It has no special meaning in + fastmcp's authorization path: the transport session UUID is + not the agent identity, and multiple agents may share one + transport. Servers that want to gate subscriptions by agent + identity should register an ``authorize`` callback and + consult ``topic_params["agent_id"]`` together with any + out-of-band binding between sessions and agent identities. + + Returns: + The registered EventTopicDescriptor. + + Raises: + ValueError: If pattern has more than 8 segments. + """ + segments = pattern.split("/") + if len(segments) > self._MAX_TOPIC_DEPTH: + raise ValueError( + f"Topic pattern has {len(segments)} segments, " + f"maximum depth is {self._MAX_TOPIC_DEPTH}: {pattern!r}" + ) + descriptor = EventTopicDescriptor( + pattern=pattern, + kind=kind, + description=description, + suggestedHandle=suggested_handle, + retained=retained, + schema=schema, + ) + self._event_topics[pattern] = descriptor + if authorize is not None: + self._event_topic_authorize[pattern] = authorize + else: + # Clear any stale callback from a prior declaration of the same + # pattern so redeclaration behaves predictably. + self._event_topic_authorize.pop(pattern, None) + return descriptor + + @staticmethod + def _topic_matches_pattern(concrete_topic: str, declared_pattern: str) -> bool: + """Check if a concrete topic matches a declared pattern with {param} placeholders. + + Compares segment-by-segment: literal segments must match exactly, + ``{param}`` segments match any single non-empty segment. + """ + if not concrete_topic or not declared_pattern: + return False + concrete_parts = concrete_topic.split("/") + pattern_parts = declared_pattern.split("/") + if any(not s for s in concrete_parts) or any(not s for s in pattern_parts): + return False + if len(concrete_parts) != len(pattern_parts): + return False + for concrete_seg, pattern_seg in zip( + concrete_parts, pattern_parts, strict=True + ): + if pattern_seg.startswith("{") and pattern_seg.endswith("}"): + continue # {param} matches any single segment + if concrete_seg != pattern_seg: + return False + return True + + def _find_topic_descriptor(self, topic: str) -> EventTopicDescriptor | None: + """Find the EventTopicDescriptor for a concrete topic. + + Tries a direct lookup first, then falls back to segment-by-segment + matching against declared parameterized patterns. + """ + # Direct match (fast path) + descriptor = self._event_topics.get(topic) + if descriptor is not None: + return descriptor + # Fall back to parameterized pattern matching + for pattern, desc in self._event_topics.items(): + if self._topic_matches_pattern(topic, pattern): + return desc + return None + + def event( + self, + pattern: str, + *, + kind: Literal["content", "signal"], + description: str | None = None, + suggested_handle: Literal[ + "drop", "silent", "notify", "ask", "inject", "interrupt" + ] + | None = None, + retained: bool = False, + authorize: Callable[[str, dict[str, str]], bool] | None = None, + ) -> Callable[[F], F]: + """Decorator to declare an event topic. + + The decorated function's return type annotation is used to generate + the payload JSON Schema. The function itself is not called automatically; + it serves as a schema source and can be called manually to construct + payloads. + + Example:: + + @mcp.event("myapp/status", kind="content") + def status_event() -> dict: + '''Server status updates.''' + ... + + Args: + pattern: Topic pattern for the event. + kind: REQUIRED. ``"content"`` or ``"signal"``. See + ``declare_event`` for full semantics. + description: Optional description (falls back to docstring). + suggested_handle: Optional advisory hint for client handle + behavior. See ``declare_event`` for details. + retained: Whether to store the most recent value per topic. + authorize: Optional subscription authorization callback. See + ``declare_event`` for full semantics. + """ + + def decorator(fn: F) -> F: + desc = description or (fn.__doc__.strip() if fn.__doc__ else None) + + # Try to extract JSON Schema from the return type annotation + payload_schema: dict[str, Any] | None = None + import typing + + hints = typing.get_type_hints(fn) + return_type = hints.get("return") + if return_type is not None and return_type is not type(None): + try: + from pydantic import TypeAdapter + + adapter = TypeAdapter(return_type) + payload_schema = adapter.json_schema() + except (TypeError, ValueError, PydanticValidationError): + logger.debug( + "Could not generate JSON schema for event %r return type %r", + pattern, + return_type, + ) + + self.declare_event( + pattern, + kind=kind, + description=desc, + suggested_handle=suggested_handle, + retained=retained, + schema=payload_schema, + authorize=authorize, + ) + return fn + + return decorator + + async def emit_event( + self, + topic: str, + payload: Any = None, + *, + priority: Literal["urgent", "high", "normal", "low"] = "normal", + source: str | None = None, + expires_at: str | None = None, + event_id: str | None = None, + retained: bool | None = None, + target_session_ids: Collection[str] | None = None, + ) -> None: + """Broadcast an event to all sessions subscribed to the given topic. + + This is the instance-level method for publishing events from anywhere + (tools via ``ctx.emit_event()``, background tasks, lifespan code, etc.). + + Delivery failures to individual sessions are logged but do not prevent + delivery to other sessions. + + Args: + topic: Concrete topic string (no wildcards). + payload: Event payload (any JSON-serializable value). Optional; + may be ``None`` for pure signal events whose topic alone + carries the information. + priority: Delivery priority hint per MCP Events Spec v2. One of + ``"urgent"``, ``"high"``, ``"normal"`` (default), + ``"low"``. Only ``"urgent"`` may cancel in-progress + LLM generation; the others influence when the client + processes the event. + source: Optional source identifier (e.g., ``"tool/build"``, + ``"spellbook/messaging"``). + expires_at: Optional ISO 8601 expiry timestamp. Servers SHOULD + NOT emit events that are already expired; clients + MUST drop events whose ``expires_at`` has passed. + event_id: Optional event ID (auto-generated ULID if not + provided). + retained: If True, store as retained value for the topic. + Defaults to the topic descriptor's ``retained`` + setting if declared. + target_session_ids: Optional defense-in-depth filter. When None + (default), the event is delivered to every session whose + subscriptions match the topic. When a collection is + provided, the set of matching subscribers is + intersected (AND logic) with this set so that only + sessions whose ``_fastmcp_event_session_id`` is in + ``target_session_ids`` receive the notification. An + empty intersection is a silent no-op. This is intended + as a routing safety net paired with subscription-time + authorization; subscription auth is still the primary + boundary. + """ + if event_id is None: + event_id = str(ULID()) + + # Determine whether to retain based on topic descriptor if not explicit + if retained is None: + descriptor = self._find_topic_descriptor(topic) + retained = descriptor.retained if descriptor is not None else False + + # Store retained value if applicable + if retained: + retained_event = RetainedEvent( + topic=topic, + eventId=event_id, + payload=payload, + ) + await self._retained_store.set(topic, retained_event, expires_at=expires_at) + + # Find matching sessions via subscription registry + matching_session_ids = await self._subscription_registry.match(topic) + if not matching_session_ids: + return + + # Defense-in-depth routing filter: restrict delivery to the provided + # session-id whitelist. Empty intersection is a silent no-op. + if target_session_ids is not None: + target_set = set(target_session_ids) + matching_session_ids = [ + sid for sid in matching_session_ids if sid in target_set + ] + if not matching_session_ids: + return + + # Build the event notification + from mcp.types import EventEmitNotification, EventParams, ServerNotification + + notification = EventEmitNotification( + params=EventParams( + topic=topic, + eventId=event_id, + payload=payload, + priority=priority, + retained=retained, + source=source, + expiresAt=expires_at, + ), + ) + + # Broadcast to matching active sessions in parallel so slow sessions + # don't block delivery to others. + async def _deliver(sid: str) -> None: + session = self._active_sessions.get(sid) + if session is None: + return + try: + await session.send_notification(cast(ServerNotification, notification)) + except Exception: + logger.warning( + f"Failed to deliver event to session {sid}", + exc_info=True, + ) + + await asyncio.gather( + *[_deliver(sid) for sid in matching_session_ids], + return_exceptions=True, + ) + def add_resource( self, resource: Resource | Callable[..., Any] ) -> Resource | ResourceTemplate: diff --git a/tests/client/test_sse.py b/tests/client/test_sse.py index 4f9e51f040..0cbfc05499 100644 --- a/tests/client/test_sse.py +++ b/tests/client/test_sse.py @@ -132,8 +132,10 @@ async def nested_sse_server(): try: yield f"http://127.0.0.1:{port}/nest-outer/nest-inner/mcp/sse/" finally: - # Graceful shutdown - required for uvicorn 0.39+ due to context isolation + # Cancel the server task directly; should_exit doesn't reliably + # propagate on recent uvicorn versions (0.39+). uvicorn_server.should_exit = True + server_task.cancel() try: await server_task except asyncio.CancelledError: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 3d1335ae54..df47bc9fe8 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -2,6 +2,7 @@ import gc import inspect import os +import time import weakref import psutil @@ -158,10 +159,15 @@ async def test_server(): # This test may fail/hang while debugging because the debugger holds a reference to the underlying transport - with pytest.raises(psutil.NoSuchProcess): - while True: + # Poll for process exit with a bounded timeout + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + try: psutil.Process(pid) - await asyncio.sleep(0.1) + except psutil.NoSuchProcess: + return # Process exited as expected + await asyncio.sleep(0.1) + pytest.fail(f"Process {pid} still alive after 10s") async def test_keep_alive_false_exit_scope_kills_server(self, stdio_script): pid: int | None = None @@ -179,10 +185,15 @@ async def test_server(): await test_server() - with pytest.raises(psutil.NoSuchProcess): - while True: + # Poll for process exit with a bounded timeout + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + try: psutil.Process(pid) - await asyncio.sleep(0.1) + except psutil.NoSuchProcess: + return # Process exited as expected + await asyncio.sleep(0.1) + pytest.fail(f"Process {pid} still alive after 10s") async def test_keep_alive_false_starts_new_session_across_multiple_calls( self, stdio_script diff --git a/tests/server/test_events.py b/tests/server/test_events.py new file mode 100644 index 0000000000..d27982d774 --- /dev/null +++ b/tests/server/test_events.py @@ -0,0 +1,2414 @@ +"""Tests for application-level MCP events support (Phase 2). + +Tests event declaration, emission, subscription, retained values, +session registry, context integration, and capabilities. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import AsyncIterator +from typing import Any + +import pytest +from mcp.server.lowlevel.server import request_ctx +from mcp.shared.context import RequestContext +from mcp.types import EventSubscribeRequest, ServerNotification + +from fastmcp import Client, FastMCP +from fastmcp.server.context import Context +from fastmcp.server.events import ( + EventEmitNotification, + EventParams, + EventSubscribeParams, + EventSubscribeResult, + EventTopicDescriptor, + RetainedEvent, + RetainedValueStore, + SubscriptionRegistry, + _pattern_to_regex, +) + +# --------------------------------------------------------------------------- +# SubscriptionRegistry unit tests +# --------------------------------------------------------------------------- + + +class TestSubscriptionRegistry: + async def test_exact_match(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/status") + result = await reg.match("myapp/status") + assert result == {"s1"} + + async def test_no_match(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/status") + result = await reg.match("myapp/other") + assert result == set() + + async def test_plus_wildcard(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/+/messages") + assert await reg.match("myapp/session1/messages") == {"s1"} + assert await reg.match("myapp/session2/messages") == {"s1"} + # + does not match segment separator + assert await reg.match("myapp/a/b/messages") == set() + + async def test_hash_wildcard(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/#") + assert await reg.match("myapp/status") == {"s1"} + assert await reg.match("myapp/a/b/c") == {"s1"} + assert await reg.match("myapp") == {"s1"} + + async def test_hash_only_valid_last(self): + with pytest.raises(ValueError, match="last segment"): + _pattern_to_regex("myapp/#/invalid") + + async def test_at_most_once(self): + """A session with overlapping subscriptions receives at most once.""" + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/+/messages") + await reg.add("s1", "myapp/#") + result = await reg.match("myapp/session1/messages") + assert result == {"s1"} + assert len(result) == 1 + + async def test_multiple_sessions(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/status") + await reg.add("s2", "myapp/status") + result = await reg.match("myapp/status") + assert result == {"s1", "s2"} + + async def test_remove(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/status") + await reg.remove("s1", "myapp/status") + result = await reg.match("myapp/status") + assert result == set() + + async def test_remove_all(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/status") + await reg.add("s1", "myapp/messages") + await reg.remove_all("s1") + assert await reg.match("myapp/status") == set() + assert await reg.match("myapp/messages") == set() + + async def test_get_subscriptions(self): + reg = SubscriptionRegistry() + await reg.add("s1", "myapp/status") + await reg.add("s1", "myapp/messages") + subs = await reg.get_subscriptions("s1") + assert subs == {"myapp/status", "myapp/messages"} + + async def test_get_subscriptions_empty(self): + reg = SubscriptionRegistry() + subs = await reg.get_subscriptions("nonexistent") + assert subs == set() + + +# --------------------------------------------------------------------------- +# RetainedValueStore unit tests +# --------------------------------------------------------------------------- + + +class TestRetainedValueStore: + async def test_set_and_get(self): + store = RetainedValueStore() + event = RetainedEvent(topic="t", eventId="e1", payload={"x": 1}) + await store.set("t", event) + assert await store.get("t") == event + + async def test_get_nonexistent(self): + store = RetainedValueStore() + assert await store.get("nonexistent") is None + + async def test_get_matching(self): + store = RetainedValueStore() + await store.set( + "myapp/a", RetainedEvent(topic="myapp/a", eventId="e1", payload=1) + ) + await store.set( + "myapp/b", RetainedEvent(topic="myapp/b", eventId="e2", payload=2) + ) + await store.set( + "other/c", RetainedEvent(topic="other/c", eventId="e3", payload=3) + ) + result = await store.get_matching("myapp/+") + assert len(result) == 2 + assert {e.event_id for e in result} == {"e1", "e2"} + # Verify payload values for each retained event + by_id = {e.event_id: e for e in result} + assert by_id["e1"].topic == "myapp/a" + assert by_id["e1"].payload == 1 + assert by_id["e2"].topic == "myapp/b" + assert by_id["e2"].payload == 2 + + async def test_delete(self): + store = RetainedValueStore() + event = RetainedEvent(topic="t", eventId="e1", payload=1) + await store.set("t", event) + await store.delete("t") + assert await store.get("t") is None + + async def test_expiry(self): + store = RetainedValueStore() + event = RetainedEvent(topic="t", eventId="e1", payload=1) + # Set with expired timestamp + await store.set("t", event, expires_at="2000-01-01T00:00:00Z") + assert await store.get("t") is None + + async def test_not_expired(self): + store = RetainedValueStore() + event = RetainedEvent(topic="t", eventId="e1", payload=1) + await store.set("t", event, expires_at="2099-01-01T00:00:00Z") + assert await store.get("t") == event + + async def test_get_matching_skips_expired(self): + store = RetainedValueStore() + await store.set( + "myapp/a", + RetainedEvent(topic="myapp/a", eventId="e1", payload=1), + expires_at="2000-01-01T00:00:00Z", + ) + await store.set( + "myapp/b", + RetainedEvent(topic="myapp/b", eventId="e2", payload=2), + ) + result = await store.get_matching("myapp/+") + assert len(result) == 1 + assert result[0].event_id == "e2" + + +# --------------------------------------------------------------------------- +# EventTopicDescriptor tests +# --------------------------------------------------------------------------- + + +class TestEventTopicDescriptor: + def test_basic_creation(self): + desc = EventTopicDescriptor( + pattern="myapp/status", kind="content", description="Status" + ) + assert desc.pattern == "myapp/status" + assert desc.kind == "content" + assert desc.description == "Status" + assert desc.retained is False + + def test_schema_alias(self): + desc = EventTopicDescriptor( + pattern="myapp/status", + kind="content", + schema={"type": "object"}, + ) + dumped = desc.model_dump(by_alias=True) + assert "schema" in dumped + assert dumped["schema"] == {"type": "object"} + + def test_suggested_handle_camel_case_alias(self): + desc = EventTopicDescriptor( + pattern="myapp/alerts", + kind="content", + suggestedHandle="inject", + ) + dumped = desc.model_dump(by_alias=True, exclude_none=True) + assert dumped["suggestedHandle"] == "inject" + + +# --------------------------------------------------------------------------- +# EventEmitNotification tests +# --------------------------------------------------------------------------- + + +class TestEventEmitNotification: + def test_creation(self): + notification = EventEmitNotification( + params=EventParams( + topic="myapp/status", + eventId="e1", + payload={"status": "running"}, + ) + ) + assert notification.method == "events/emit" + assert notification.params.topic == "myapp/status" + + def test_serialization(self): + notification = EventEmitNotification( + params=EventParams( + topic="myapp/status", + eventId="e1", + payload={"status": "running"}, + retained=True, + priority="high", + ) + ) + data = notification.model_dump(by_alias=True, exclude_none=True) + assert data["method"] == "events/emit" + assert data["params"]["topic"] == "myapp/status" + assert data["params"]["eventId"] == "e1" + assert data["params"]["payload"] == {"status": "running"} + assert data["params"]["retained"] is True + assert data["params"]["priority"] == "high" + # v2 removed fields must not appear + assert "requestedEffects" not in data["params"] + assert "correlationId" not in data["params"] + + +# --------------------------------------------------------------------------- +# FastMCP event declaration tests +# --------------------------------------------------------------------------- + + +class TestFastMCPEventDeclaration: + def test_declare_event(self): + mcp = FastMCP("test") + desc = mcp.declare_event( + "myapp/status", kind="content", description="Status", retained=True + ) + assert desc.pattern == "myapp/status" + assert desc.retained is True + assert "myapp/status" in mcp._event_topics + + def test_event_decorator(self): + mcp = FastMCP("test") + + @mcp.event("myapp/messages", kind="content") + def message_event() -> dict: + """Message notifications.""" + return {} + + assert "myapp/messages" in mcp._event_topics + desc = mcp._event_topics["myapp/messages"] + assert desc.description == "Message notifications." + + def test_event_decorator_schema_from_return_type(self): + mcp = FastMCP("test") + + @mcp.event("myapp/typed", kind="content") + def typed_event() -> int: + return 0 + + desc = mcp._event_topics["myapp/typed"] + assert desc.schema_ is not None + assert desc.schema_.get("type") == "integer" + + def test_multiple_topics(self): + mcp = FastMCP("test") + mcp.declare_event("a/b", kind="content") + mcp.declare_event("c/d", kind="content") + assert len(mcp._event_topics) == 2 + + +# --------------------------------------------------------------------------- +# FastMCP capability tests +# --------------------------------------------------------------------------- + + +class TestEventCapability: + async def test_capability_advertised(self): + """When event topics are declared, the events capability is advertised.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", description="Status updates") + + async with Client(mcp) as client: + # The client's initialize_result should have the events capability + result = client._session_state.initialize_result + assert result is not None + # events is a first-class field on ServerCapabilities + events_cap = result.capabilities.events + assert events_cap is not None, "Expected events capability to be set" + assert len(events_cap.topics) == 1, ( + f"Expected 1 topic, got {len(events_cap.topics)}" + ) + topic = events_cap.topics[0] + assert topic.pattern == "myapp/status" + assert topic.description == "Status updates" + assert topic.retained is False + + async def test_no_capability_without_topics(self): + """Without declared topics, events capability is not advertised.""" + mcp = FastMCP("test") + + async with Client(mcp) as client: + result = client._session_state.initialize_result + assert result is not None + assert result.capabilities.events is None + + +# --------------------------------------------------------------------------- +# FastMCP emit_event tests +# --------------------------------------------------------------------------- + + +class TestFastMCPEmitEvent: + async def test_emit_event_generates_id(self): + """emit_event generates a ULID if no event_id provided.""" + import re + + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", retained=True) + + await mcp.emit_event("myapp/status", {"state": "running"}) + + # Verify a valid ULID was generated and stored via retained event + stored = await mcp._retained_store.get("myapp/status") + assert stored is not None, "Event should be stored as retained" + assert stored.event_id, "event_id should be non-empty" + # ULID is 26 characters, Crockford Base32 + assert re.fullmatch(r"[0-9A-HJKMNP-TV-Z]{26}", stored.event_id), ( + f"event_id should be a valid ULID string, got {stored.event_id!r}" + ) + + # Verify event is actually delivered to a subscribed session + received_notifications: list[Any] = [] + async with Client(mcp) as _client: + session = list(mcp._active_sessions.values())[0] + session_id = getattr(session, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(session_id, "myapp/status") + + _original_send = session.send_notification + + async def capturing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + received_notifications.append(notification) + + setattr(session, "send_notification", capturing_send) + + await mcp.emit_event("myapp/status", {"state": "updated"}) + + assert len(received_notifications) == 1, ( + "Event should be delivered to client" + ) + notif = received_notifications[0] + assert notif.params.event_id, "Delivered event should have an event_id" + assert re.fullmatch(r"[0-9A-HJKMNP-TV-Z]{26}", notif.params.event_id) + + async def test_emit_event_retained(self): + """emit_event stores retained value when topic is declared retained.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", retained=True) + + await mcp.emit_event("myapp/status", {"state": "running"}) + + stored = await mcp._retained_store.get("myapp/status") + assert stored is not None + assert stored.payload == {"state": "running"} + + async def test_emit_event_not_retained(self): + """emit_event does not store when topic is not retained.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", retained=False) + + await mcp.emit_event("myapp/status", {"state": "running"}) + + stored = await mcp._retained_store.get("myapp/status") + assert stored is None + + async def test_emit_event_explicit_retained_override(self): + """retained=True on emit overrides topic descriptor.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", retained=False) + + await mcp.emit_event("myapp/status", {"state": "running"}, retained=True) + + stored = await mcp._retained_store.get("myapp/status") + assert stored is not None + assert stored.topic == "myapp/status" + assert stored.payload == {"state": "running"} + assert stored.event_id, "Stored event should have an event_id" + + async def test_emit_event_with_expires_at(self): + """emit_event stores retained value with expires_at, and expired ones are cleaned.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", retained=True) + + # Emit with a far-future expiry - should be retrievable + await mcp.emit_event( + "myapp/status", + {"state": "running"}, + expires_at="2099-01-01T00:00:00Z", + ) + stored = await mcp._retained_store.get("myapp/status") + assert stored is not None, "Non-expired event should be retrievable" + assert stored.topic == "myapp/status" + assert stored.payload == {"state": "running"} + assert stored.event_id, "Stored event should have an event_id" + + # Emit with an already-past expiry - should NOT be retrievable + await mcp.emit_event( + "myapp/status", + {"state": "stopped"}, + expires_at="2000-01-01T00:00:00Z", + ) + stored = await mcp._retained_store.get("myapp/status") + assert stored is None, "Expired event should not be retrievable" + + # Verify expired events are cleaned from get_matching too + mcp2 = FastMCP("test2") + mcp2.declare_event("myapp/a", kind="content", retained=True) + mcp2.declare_event("myapp/b", kind="content", retained=True) + + await mcp2.emit_event( + "myapp/a", + {"val": "expired"}, + expires_at="2000-01-01T00:00:00Z", + ) + await mcp2.emit_event( + "myapp/b", + {"val": "valid"}, + expires_at="2099-01-01T00:00:00Z", + ) + matching = await mcp2._retained_store.get_matching("myapp/+") + assert len(matching) == 1, f"Expected 1 non-expired match, got {len(matching)}" + assert matching[0].topic == "myapp/b" + assert matching[0].payload == {"val": "valid"} + + +# --------------------------------------------------------------------------- +# Context integration tests +# --------------------------------------------------------------------------- + + +class TestContextEmitEvent: + async def test_emit_event_from_tool(self): + """Tools can emit events via ctx.emit_event().""" + mcp = FastMCP("test") + mcp.declare_event("myapp/notifications", kind="content") + + emitted_calls: list[dict[str, Any]] = [] + original_emit = mcp.emit_event + + async def tracking_emit( + topic: str, + payload: Any = None, + *, + priority: str = "normal", + source: str | None = None, + expires_at: str | None = None, + event_id: str | None = None, + retained: bool | None = None, + target_session_ids: Any = None, + ) -> None: + kwargs: dict[str, Any] = {"priority": priority} + if source is not None: + kwargs["source"] = source + if expires_at is not None: + kwargs["expires_at"] = expires_at + if event_id is not None: + kwargs["event_id"] = event_id + if retained is not None: + kwargs["retained"] = retained + if target_session_ids is not None: + kwargs["target_session_ids"] = target_session_ids + emitted_calls.append({"topic": topic, "payload": payload, **kwargs}) + await original_emit(topic, payload, **kwargs) + + setattr(mcp, "emit_event", tracking_emit) + + @mcp.tool + async def notify(message: str, ctx: Context) -> str: + await ctx.emit_event("myapp/notifications", {"text": message}) + return "sent" + + async with Client(mcp) as client: + result = await client.call_tool("notify", {"message": "hello"}) + assert result.data == "sent" + assert len(emitted_calls) == 1, ( + f"Expected exactly 1 emit call, got {len(emitted_calls)}" + ) + call = emitted_calls[0] + assert call["topic"] == "myapp/notifications" + assert call["payload"] == {"text": "hello"} + + +# --------------------------------------------------------------------------- +# Topic matching tests +# --------------------------------------------------------------------------- + + +class TestTopicMatching: + def test_exact_match(self): + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + assert mcp._match_declared_topic("myapp/status") is True + + def test_wildcard_plus_matches_param(self): + mcp = FastMCP("test") + mcp.declare_event("myapp/{agent_id}/messages", kind="content") + assert mcp._match_declared_topic("myapp/+/messages") is True + + def test_wildcard_hash_matches_param(self): + mcp = FastMCP("test") + mcp.declare_event("myapp/{agent_id}/messages", kind="content") + assert mcp._match_declared_topic("myapp/#") is True + + def test_no_match(self): + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + assert mcp._match_declared_topic("other/status") is False + + +# --------------------------------------------------------------------------- +# Parameterized topic pattern matching tests +# --------------------------------------------------------------------------- + + +class TestTopicMatchesPattern: + def test_exact_match(self): + assert FastMCP._topic_matches_pattern("a/b/c", "a/b/c") is True + + def test_single_param_match(self): + assert ( + FastMCP._topic_matches_pattern( + "myapp/worker-42/messages", + "myapp/{agent_id}/messages", + ) + is True + ) + + def test_multiple_params_match(self): + assert FastMCP._topic_matches_pattern("a/1/b/2/c", "a/{x}/b/{y}/c") is True + + def test_segment_count_mismatch(self): + assert FastMCP._topic_matches_pattern("a/b", "a/{x}/c") is False + + def test_literal_segment_mismatch(self): + assert ( + FastMCP._topic_matches_pattern( + "other/worker-42/messages", + "myapp/{agent_id}/messages", + ) + is False + ) + + def test_no_params_no_match(self): + assert FastMCP._topic_matches_pattern("a/b/c", "a/b/d") is False + + @pytest.mark.parametrize( + "topic, pattern", + [ + ("a//b", "a/{x}/b"), + ("a/{x}/b", "a//b"), + ("a//b", "a//b"), + ("/a/b", "/a/b"), + ("a/b/", "a/b/"), + ], + ids=[ + "empty-segment-in-topic", + "empty-segment-in-pattern", + "empty-segment-in-both", + "leading-slash-empty-first-segment", + "trailing-slash-empty-last-segment", + ], + ) + def test_empty_segments_rejected(self, topic: str, pattern: str): + assert FastMCP._topic_matches_pattern(topic, pattern) is False + + @pytest.mark.parametrize( + "topic, pattern", + [ + ("", "a/b"), + ("a/b", ""), + ("", ""), + ], + ids=["empty-topic", "empty-pattern", "both-empty"], + ) + def test_empty_strings_rejected(self, topic: str, pattern: str): + assert FastMCP._topic_matches_pattern(topic, pattern) is False + + +class TestFindTopicDescriptor: + def test_direct_match(self): + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content", retained=True) + desc = mcp._find_topic_descriptor("myapp/status") + assert desc is not None + assert desc.retained is True + + def test_parameterized_match(self): + mcp = FastMCP("test") + mcp.declare_event( + "spellbook/sessions/{agent_id}/messages", kind="content", retained=True + ) + desc = mcp._find_topic_descriptor("spellbook/sessions/worker-42/messages") + assert desc is not None + assert desc.retained is True + + def test_no_match_returns_none(self): + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + assert mcp._find_topic_descriptor("other/topic") is None + + +class TestEmitEventParameterizedRetained: + async def test_parameterized_retained_auto_stores(self): + """Emitting to a concrete topic that matches a parameterized + retained declaration auto-retains the event.""" + mcp = FastMCP("test") + mcp.declare_event( + "spellbook/sessions/{agent_id}/messages", kind="content", retained=True + ) + + await mcp.emit_event("spellbook/sessions/worker-42/messages", {"text": "hello"}) + + stored = await mcp._retained_store.get("spellbook/sessions/worker-42/messages") + assert stored is not None + assert stored.payload == {"text": "hello"} + + async def test_parameterized_not_retained(self): + """Emitting to a concrete topic that matches a parameterized + non-retained declaration does not retain.""" + mcp = FastMCP("test") + mcp.declare_event( + "spellbook/sessions/{agent_id}/messages", kind="content", retained=False + ) + + await mcp.emit_event("spellbook/sessions/worker-42/messages", {"text": "hello"}) + + stored = await mcp._retained_store.get("spellbook/sessions/worker-42/messages") + assert stored is None + + async def test_undeclared_topic_defaults_not_retained(self): + """Emitting to a topic that doesn't match any declaration + defaults retained to False.""" + mcp = FastMCP("test") + + await mcp.emit_event("unknown/topic", {"data": 1}) + + stored = await mcp._retained_store.get("unknown/topic") + assert stored is None + + +# --------------------------------------------------------------------------- +# Session registry and event delivery integration tests +# --------------------------------------------------------------------------- + + +class TestSessionRegistry: + async def test_session_registered_on_connect(self): + """Sessions are registered in _active_sessions on connect.""" + mcp = FastMCP("test") + assert len(mcp._active_sessions) == 0 + + async with Client(mcp) as _client: + assert len(mcp._active_sessions) == 1 + + # After disconnect + assert len(mcp._active_sessions) == 0 + + async def test_session_cleanup_on_disconnect(self): + """Session subscriptions are cleaned up on disconnect.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + + async with Client(mcp) as _client: + assert len(mcp._active_sessions) == 1 + # Get the session ID + session = list(mcp._active_sessions.values())[0] + session_id = getattr(session, "_fastmcp_event_session_id") + assert session_id is not None + + # Manually add a subscription (normally done via events/subscribe) + await mcp._subscription_registry.add(session_id, "myapp/status") + subs = await mcp._subscription_registry.get_subscriptions(session_id) + assert len(subs) == 1 + + # After disconnect, subscriptions should be cleaned up + subs = await mcp._subscription_registry.get_subscriptions(session_id) + assert len(subs) == 0 + + async def test_emit_to_subscribed_session(self): + """Events are delivered to subscribed sessions.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + + received_notifications: list[Any] = [] + + async with Client(mcp) as _client: + # Get session and subscribe + session = list(mcp._active_sessions.values())[0] + session_id = getattr(session, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(session_id, "myapp/status") + + # Monkey-patch send_notification on the session to capture it + _original_send = session.send_notification + + async def capturing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + received_notifications.append(notification) + # Don't actually send to avoid protocol issues in test + + setattr(session, "send_notification", capturing_send) + + await mcp.emit_event("myapp/status", {"state": "running"}) + + assert len(received_notifications) == 1 + notif = received_notifications[0] + assert notif.params.topic == "myapp/status" + assert notif.params.payload == {"state": "running"} + + async def test_emit_to_multiple_sessions(self): + """Events are broadcast to all matching sessions.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + + received: dict[str, list] = {} + + async with Client(mcp) as _client1: + s1 = list(mcp._active_sessions.values())[0] + s1_id = getattr(s1, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s1_id, "myapp/status") + received[s1_id] = [] + + async with Client(mcp) as _client2: + s2 = [s for s in mcp._active_sessions.values() if s is not s1][0] + s2_id = getattr(s2, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s2_id, "myapp/status") + received[s2_id] = [] + + for s, sid in [(s1, s1_id), (s2, s2_id)]: + + async def make_capture(target_list: list[Any]) -> Any: + async def capture( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + target_list.append(notification) + + return capture + + setattr(s, "send_notification", await make_capture(received[sid])) + + await mcp.emit_event("myapp/status", {"state": "running"}) + + assert len(received[s1_id]) == 1 + assert len(received[s2_id]) == 1 + # Verify notification content for both sessions + for sid in [s1_id, s2_id]: + notif = received[sid][0] + assert notif.params.topic == "myapp/status" + assert notif.params.payload == {"state": "running"} + assert notif.params.event_id, ( + f"event_id should be non-empty for session {sid}" + ) + + async def test_emit_failure_does_not_block_others(self): + """Delivery failure to one session does not prevent delivery to others.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + + delivered_to: list[tuple[str, Any]] = [] + + async with Client(mcp) as _client1: + s1 = list(mcp._active_sessions.values())[0] + s1_id = getattr(s1, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s1_id, "myapp/status") + + async with Client(mcp) as _client2: + s2 = [s for s in mcp._active_sessions.values() if s is not s1][0] + s2_id = getattr(s2, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s2_id, "myapp/status") + + # Make s1 fail on send + async def failing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + raise ConnectionError("broken pipe") + + setattr(s1, "send_notification", failing_send) + + async def tracking_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + delivered_to.append((s2_id, notification)) + + setattr(s2, "send_notification", tracking_send) + + await mcp.emit_event("myapp/status", {"state": "running"}) + + # s2 should still receive despite s1 failure + assert len(delivered_to) == 1, ( + f"Expected exactly 1 delivery, got {len(delivered_to)}" + ) + sid, notif = delivered_to[0] + assert sid == s2_id + assert notif.params.topic == "myapp/status" + assert notif.params.payload == {"state": "running"} + assert notif.params.event_id, "Delivered event should have an event_id" + + +# --------------------------------------------------------------------------- +# Protocol-layer subscribe/unsubscribe tests (full round-trip) +# --------------------------------------------------------------------------- + + +class TestProtocolRoundTrip: + @pytest.fixture + def timeout(self): + return 10 + + async def test_full_subscribe_emit_unsubscribe_cycle(self): + """Test the full protocol path: subscribe, receive event, unsubscribe, no more events. + + This exercises the _receive_loop interception, _handle_event_request, + and the complete subscribe/unsubscribe/emit flow at the JSON-RPC level. + """ + import anyio + from mcp.shared.memory import create_client_server_memory_streams + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse + + mcp_server = FastMCP("test-events") + mcp_server.declare_event( + "myapp/status", kind="content", description="Status updates" + ) + + async with mcp_server._lifespan_manager(): + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: mcp_server._mcp_server.run( + server_read, + server_write, + mcp_server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + ) + + try: + # ---- Step 0: Initialize (required handshake) ---- + init_request = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0", + }, + }, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(init_request)) + ) + init_resp = await client_read.receive() + assert not isinstance(init_resp, Exception) + assert isinstance(init_resp.message.root, JSONRPCResponse) + + # Send initialized notification + from mcp.types import JSONRPCNotification + + initialized_notif = JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(initialized_notif)) + ) + + # Wait for server to register the session + for _ in range(50): + if mcp_server._active_sessions: + break + await asyncio.sleep(0.05) + assert len(mcp_server._active_sessions) == 1 + + server_session = list(mcp_server._active_sessions.values())[0] + session_id = getattr( + server_session, "_fastmcp_event_session_id" + ) + + # ---- Step 1: Subscribe via raw JSON-RPC ---- + subscribe_request = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="events/subscribe", + params={"topics": ["myapp/status"]}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(subscribe_request)) + ) + + sub_resp = await client_read.receive() + assert not isinstance(sub_resp, Exception) + response = sub_resp.message.root + assert isinstance(response, JSONRPCResponse), ( + f"Expected result, got: {response}" + ) + result = response.result + assert len(result["subscribed"]) == 1 + assert result["subscribed"][0]["pattern"] == "myapp/status" + + # Verify subscription was registered + subs = ( + await mcp_server._subscription_registry.get_subscriptions( + session_id + ) + ) + assert "myapp/status" in subs + + # ---- Step 2: Server emits event, client receives it ---- + await mcp_server.emit_event( + "myapp/status", {"state": "running"} + ) + + # The event notification should arrive on client_read + event_msg = await client_read.receive() + assert not isinstance(event_msg, Exception) + event_root = event_msg.message.root + # Should be a notification (no id field with result) + assert isinstance(event_root, JSONRPCNotification) + assert event_root.method == "events/emit" + assert event_root.params is not None + assert event_root.params["topic"] == "myapp/status" + assert event_root.params["payload"] == {"state": "running"} + + # ---- Step 3: Unsubscribe ---- + unsubscribe_request = JSONRPCRequest( + jsonrpc="2.0", + id=3, + method="events/unsubscribe", + params={"topics": ["myapp/status"]}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(unsubscribe_request)) + ) + + unsub_resp = await client_read.receive() + assert not isinstance(unsub_resp, Exception) + unsub_root = unsub_resp.message.root + assert isinstance(unsub_root, JSONRPCResponse) + assert "myapp/status" in unsub_root.result["unsubscribed"] + + # Verify subscription was removed + subs = ( + await mcp_server._subscription_registry.get_subscriptions( + session_id + ) + ) + assert "myapp/status" not in subs + + # ---- Step 4: Emit again, verify no delivery ---- + # Subscription registry has no match, so emit + # won't attempt delivery (proven by step 2) + matching = await mcp_server._subscription_registry.match( + "myapp/status" + ) + assert len(matching) == 0, ( + "No sessions should match after unsubscribe" + ) + + finally: + tg.cancel_scope.cancel() + + async def test_events_list_via_protocol(self): + """events/list returns declared topics via raw JSON-RPC.""" + import anyio + from mcp.shared.memory import create_client_server_memory_streams + from mcp.shared.message import SessionMessage + from mcp.types import ( + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ) + + mcp_server = FastMCP("test-events-list") + mcp_server.declare_event( + "myapp/status", kind="content", description="Status updates", retained=True + ) + mcp_server.declare_event("myapp/logs", kind="content", description="Log stream") + + async with mcp_server._lifespan_manager(): + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: mcp_server._mcp_server.run( + server_read, + server_write, + mcp_server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + ) + + try: + # Initialize handshake + init_request = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(init_request)) + ) + await client_read.receive() + await client_write.send( + SessionMessage( + message=JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + for _ in range(50): + if mcp_server._active_sessions: + break + await asyncio.sleep(0.05) + + # Send events/list request + list_request = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="events/list", + params={}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(list_request)) + ) + + list_resp = await client_read.receive() + assert not isinstance(list_resp, Exception) + response = list_resp.message.root + assert isinstance(response, JSONRPCResponse), ( + f"Expected result, got: {response}" + ) + result = response.result + assert "topics" in result + topics = result["topics"] + assert len(topics) == 2 + patterns = {t["pattern"] for t in topics} + assert patterns == {"myapp/status", "myapp/logs"} + # Verify topic details + by_pattern = {t["pattern"]: t for t in topics} + assert ( + by_pattern["myapp/status"]["description"] + == "Status updates" + ) + assert by_pattern["myapp/status"]["retained"] is True + assert by_pattern["myapp/logs"]["description"] == "Log stream" + assert by_pattern["myapp/logs"]["retained"] is False + finally: + tg.cancel_scope.cancel() + + +# --------------------------------------------------------------------------- +# Error path tests (Finding 15/16) +# --------------------------------------------------------------------------- + + +class TestErrorPaths: + async def test_malformed_request_returns_json_rpc_error(self): + """A malformed/unknown request returns a JSON-RPC error, not a crash.""" + import anyio + from mcp.shared.memory import create_client_server_memory_streams + from mcp.shared.message import SessionMessage + from mcp.types import ( + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ) + + mcp_server = FastMCP("test-malformed") + + async with mcp_server._lifespan_manager(): + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: mcp_server._mcp_server.run( + server_read, + server_write, + mcp_server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + ) + + try: + # Initialize + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(init_req)) + ) + await client_read.receive() + await client_write.send( + SessionMessage( + message=JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + for _ in range(50): + if mcp_server._active_sessions: + break + await asyncio.sleep(0.05) + + # Send a completely bogus method + bogus_request = JSONRPCRequest( + jsonrpc="2.0", + id=99, + method="nonexistent/method", + params={}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(bogus_request)) + ) + + resp_msg = await client_read.receive() + assert not isinstance(resp_msg, Exception) + response = resp_msg.message.root + # Should be an error response, not a crash + assert isinstance(response, JSONRPCError), ( + f"Expected JSON-RPC error for unknown method, got: {response}" + ) + finally: + tg.cancel_scope.cancel() + + async def test_event_handler_error_returns_json_rpc_error(self): + """When events/subscribe is sent with invalid params, it returns a JSON-RPC error.""" + import anyio + from mcp.shared.memory import create_client_server_memory_streams + from mcp.shared.message import SessionMessage + from mcp.types import ( + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ) + + mcp_server = FastMCP("test-event-error") + mcp_server.declare_event("myapp/status", kind="content") + + async with mcp_server._lifespan_manager(): + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: mcp_server._mcp_server.run( + server_read, + server_write, + mcp_server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + ) + + try: + # Initialize + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(init_req)) + ) + await client_read.receive() + await client_write.send( + SessionMessage( + message=JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + for _ in range(50): + if mcp_server._active_sessions: + break + await asyncio.sleep(0.05) + + # Subscribe with invalid params (missing topics field) + # The SDK validates the request params before dispatching + # to the handler, returning INVALID_PARAMS (-32602) + bad_sub_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="events/subscribe", + params={"invalid_field": "value"}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(bad_sub_req)) + ) + + resp_msg = await client_read.receive() + assert not isinstance(resp_msg, Exception) + response = resp_msg.message.root + # Should get a JSON-RPC error, not a crash + assert isinstance(response, JSONRPCError), ( + f"Expected JSON-RPC error for bad event params, got: {response}" + ) + assert response.error.code == -32602 # Invalid params + finally: + tg.cancel_scope.cancel() + + +# --------------------------------------------------------------------------- +# Topic depth enforcement tests +# --------------------------------------------------------------------------- + + +class TestTopicDepthEnforcement: + def test_declare_event_rejects_deep_topic(self): + """declare_event rejects patterns with more than 8 segments.""" + mcp_server = FastMCP("test") + # 9 segments should be rejected + with pytest.raises(ValueError, match="maximum depth is 8"): + mcp_server.declare_event("a/b/c/d/e/f/g/h/i", kind="content") + + def test_declare_event_accepts_max_depth(self): + """declare_event accepts patterns with exactly 8 segments.""" + mcp_server = FastMCP("test") + # Exactly 8 segments should work + desc = mcp_server.declare_event("a/b/c/d/e/f/g/h", kind="content") + assert desc.pattern == "a/b/c/d/e/f/g/h" + + async def test_subscribe_rejects_deep_pattern(self): + """events/subscribe rejects patterns with more than 8 segments via JSON-RPC.""" + import anyio + from mcp.shared.memory import create_client_server_memory_streams + from mcp.shared.message import SessionMessage + from mcp.types import ( + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ) + + mcp_server = FastMCP("test") + mcp_server.declare_event("myapp/status", kind="content") + + async with mcp_server._lifespan_manager(): + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: mcp_server._mcp_server.run( + server_read, + server_write, + mcp_server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + ) + + try: + # Initialize + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(init_req)) + ) + await client_read.receive() + await client_write.send( + SessionMessage( + message=JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + for _ in range(50): + if mcp_server._active_sessions: + break + await asyncio.sleep(0.05) + + # Subscribe with too-deep pattern + sub_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="events/subscribe", + params={"topics": ["a/b/c/d/e/f/g/h/i"]}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(sub_req)) + ) + + resp_msg = await client_read.receive() + assert not isinstance(resp_msg, Exception) + response = resp_msg.message.root + assert isinstance(response, JSONRPCError), ( + f"Expected error for deep pattern, got: {response}" + ) + assert response.error.code == -32602 + assert "maximum depth" in response.error.message + finally: + tg.cancel_scope.cancel() + + +# --------------------------------------------------------------------------- +# No-events-capability error tests +# --------------------------------------------------------------------------- + + +class TestNoEventsCapability: + async def test_events_method_returns_error_without_capability(self): + """events/* methods return -32601 when server has no declared event topics.""" + import anyio + from mcp.shared.memory import create_client_server_memory_streams + from mcp.shared.message import SessionMessage + from mcp.types import ( + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ) + + mcp_server = FastMCP("test-no-events") + # No events declared + + async with mcp_server._lifespan_manager(): + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: mcp_server._mcp_server.run( + server_read, + server_write, + mcp_server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + ) + + try: + # Initialize + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(init_req)) + ) + await client_read.receive() + await client_write.send( + SessionMessage( + message=JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + for _ in range(50): + if mcp_server._active_sessions: + break + await asyncio.sleep(0.05) + + # Send events/subscribe - should fail with -32601 + sub_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="events/subscribe", + params={"topics": ["anything"]}, + ) + await client_write.send( + SessionMessage(message=JSONRPCMessage(sub_req)) + ) + + resp_msg = await client_read.receive() + assert not isinstance(resp_msg, Exception) + response = resp_msg.message.root + assert isinstance(response, JSONRPCError), ( + f"Expected error without events capability, got: {response}" + ) + assert response.error.code == -32601 + assert "Method not found" in response.error.message + finally: + tg.cancel_scope.cancel() + + +# --------------------------------------------------------------------------- +# Authorization helpers +# --------------------------------------------------------------------------- + + +@contextlib.asynccontextmanager +async def _request_ctx_for_session(session: Any) -> AsyncIterator[None]: + """Push a minimal RequestContext containing ``session`` for the duration. + + The event subscribe handler reads the session from ``request_ctx``; tests + that invoke ``_handle_subscribe_events`` directly need this context to be + populated. + """ + rctx = RequestContext( + request_id=0, + meta=None, + session=session, + lifespan_context=None, + experimental=None, + ) + token = request_ctx.set(rctx) # type: ignore[arg-type] + try: + yield + finally: + request_ctx.reset(token) + + +async def _subscribe_via_handler( + mcp: FastMCP, session: Any, topics: list[str] +) -> EventSubscribeResult: + """Drive ``_handle_subscribe_events`` using the active session. + + Exercises the full authorization path (the same code that the JSON-RPC + handler runs) without the verbosity of a raw protocol round-trip. + """ + req = EventSubscribeRequest(params=EventSubscribeParams(topics=topics)) + async with _request_ctx_for_session(session): + return await mcp._handle_subscribe_events(req) + + +def _get_active_session(mcp: FastMCP) -> Any: + """Return the single active session, asserting there is exactly one.""" + sessions = list(mcp._active_sessions.values()) + assert len(sessions) == 1, f"Expected 1 active session, got {len(sessions)}" + return sessions[0] + + +# --------------------------------------------------------------------------- +# InitializeResult._meta.session_id exposure +# --------------------------------------------------------------------------- + + +class TestInitializeResultSessionId: + async def test_initialize_result_meta_contains_session_id(self): + """The initialize handshake exposes the server-side session_id via _meta.""" + mcp = FastMCP("test") + + async with Client(mcp) as client: + init_result = client._session_state.initialize_result + assert init_result is not None + assert init_result.meta is not None + session_id = init_result.meta.get("session_id") # type: ignore[union-attr] + assert isinstance(session_id, str) and session_id, ( + "session_id should be a non-empty string" + ) + + server_session = _get_active_session(mcp) + server_side_id = getattr(server_session, "_fastmcp_event_session_id") + assert session_id == server_side_id + + async def test_initialize_result_meta_session_id_is_stable_per_session(self): + """Multiple operations within the same session see a stable id.""" + mcp = FastMCP("test") + + @mcp.tool + def ping() -> str: + return "pong" + + async with Client(mcp) as client: + init_result = client._session_state.initialize_result + assert init_result is not None + assert init_result.meta is not None + session_id_first = init_result.meta.get("session_id") # type: ignore[union-attr] + + # Issue several operations and confirm the underlying session and + # its id remain stable. + await client.call_tool("ping", {}) + await client.call_tool("ping", {}) + + server_session = _get_active_session(mcp) + server_side_id = getattr(server_session, "_fastmcp_event_session_id") + assert session_id_first == server_side_id + + init_result_after = client._session_state.initialize_result + assert init_result_after is not None + assert init_result_after.meta is not None + assert ( + init_result_after.meta.get("session_id") # type: ignore[union-attr] + == session_id_first + ) + + +# --------------------------------------------------------------------------- +# Default authorization policy (no authorize callback): permissive +# --------------------------------------------------------------------------- + + +class TestDefaultAuthorizationPolicy: + """Without an explicit ``authorize`` callback any matching subscribe is + allowed. ``{agent_id}`` has no special meaning in fastmcp's authorization + path per MCP Events v2: multiple agents may share one transport, so the + transport session UUID is not a reliable agent identity. Per-agent + isolation is opt-in via ``declare_event(authorize=...)``. + """ + + async def test_public_topic_allows_any_subscriber(self): + mcp = FastMCP("test") + mcp.declare_event("spellbook/server/status", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler( + mcp, session, ["spellbook/server/status"] + ) + + assert len(result.subscribed) == 1 + assert result.rejected == [] + + async def test_agent_id_literal_allowed_by_default(self): + """A concrete ``{agent_id}`` slot value is allowed without a callback. + The server has no reliable way to bind a subscriber to a specific + agent identity on a shared transport, so default policy permits any + literal in that slot. + """ + mcp = FastMCP("test") + mcp.declare_event("spellbook/sessions/{agent_id}/messages", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler( + mcp, + session, + ["spellbook/sessions/00000000-0000-0000-0000-000000000000/messages"], + ) + + assert len(result.subscribed) == 1 + assert ( + result.subscribed[0].pattern + == "spellbook/sessions/00000000-0000-0000-0000-000000000000/messages" + ) + assert result.rejected == [] + + async def test_single_wildcard_in_agent_id_slot_allowed_by_default(self): + mcp = FastMCP("test") + mcp.declare_event("spellbook/sessions/{agent_id}/messages", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler( + mcp, session, ["spellbook/sessions/+/messages"] + ) + + assert len(result.subscribed) == 1 + assert result.rejected == [] + + async def test_hash_wildcard_over_agent_id_slot_allowed_by_default(self): + mcp = FastMCP("test") + mcp.declare_event("spellbook/sessions/{agent_id}/messages", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler( + mcp, session, ["spellbook/sessions/#"] + ) + + assert len(result.subscribed) == 1 + assert result.rejected == [] + + async def test_non_agent_placeholder_allows_wildcard(self): + """A non-``{agent_id}`` placeholder accepts wildcards with no callback.""" + mcp = FastMCP("test") + mcp.declare_event("spellbook/builds/{project}/status", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler( + mcp, session, ["spellbook/builds/+/status"] + ) + + assert len(result.subscribed) == 1 + assert result.rejected == [] + + +# --------------------------------------------------------------------------- +# authorize callback escape hatch +# --------------------------------------------------------------------------- + + +class TestAuthorizeCallback: + async def test_authorize_callback_called_with_correct_params(self): + captured: list[tuple[str, dict[str, str]]] = [] + + def authorize(session_id: str, params: dict[str, str]) -> bool: + captured.append((session_id, params)) + return True + + mcp = FastMCP("test") + mcp.declare_event("rooms/{room}/chat", kind="content", authorize=authorize) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + sid = getattr(session, "_fastmcp_event_session_id") + result = await _subscribe_via_handler(mcp, session, ["rooms/lobby/chat"]) + + assert len(result.subscribed) == 1 + assert captured == [(sid, {"room": "lobby"})] + + async def test_authorize_callback_receives_wildcard_literal(self): + captured: list[tuple[str, dict[str, str]]] = [] + + def authorize(session_id: str, params: dict[str, str]) -> bool: + captured.append((session_id, params)) + return True + + mcp = FastMCP("test") + mcp.declare_event("rooms/{room}/chat", kind="content", authorize=authorize) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + sid = getattr(session, "_fastmcp_event_session_id") + result = await _subscribe_via_handler(mcp, session, ["rooms/+/chat"]) + + assert len(result.subscribed) == 1 + assert captured == [(sid, {"room": "+"})] + + async def test_authorize_callback_denies_rejects_subscription(self): + def authorize(session_id: str, params: dict[str, str]) -> bool: + return False + + mcp = FastMCP("test") + mcp.declare_event("rooms/{room}/chat", kind="content", authorize=authorize) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["rooms/lobby/chat"]) + + assert result.subscribed == [] + assert len(result.rejected) == 1 + assert result.rejected[0].reason == "permission_denied" + + async def test_authorize_callback_exception_fails_closed( + self, caplog: pytest.LogCaptureFixture + ): + def authorize(session_id: str, params: dict[str, str]) -> bool: + raise RuntimeError("intentional failure for test") + + mcp = FastMCP("test") + mcp.declare_event("rooms/{room}/chat", kind="content", authorize=authorize) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + with caplog.at_level( + logging.WARNING, logger="fastmcp.server.mixins.mcp_operations" + ): + result = await _subscribe_via_handler( + mcp, session, ["rooms/lobby/chat"] + ) + + assert result.subscribed == [] + assert len(result.rejected) == 1 + assert result.rejected[0].reason == "permission_denied" + assert any( + "authorize callback raised" in record.message for record in caplog.records + ), "Expected a warning log when authorize raises" + + async def test_authorize_callback_can_reject_wildcard_by_inspecting_params( + self, + ): + """A callback can reject wildcard subscribes on an ``{agent_id}`` slot + by inspecting ``topic_params`` and refusing the wildcard literal. + This is how per-agent isolation is opted into under v2.""" + + def authorize(session_id: str, params: dict[str, str]) -> bool: + return params.get("agent_id") not in ("+", "#") + + mcp = FastMCP("test") + mcp.declare_event( + "sessions/{agent_id}/messages", kind="content", authorize=authorize + ) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + # Concrete literal: permitted. + ok = await _subscribe_via_handler( + mcp, + session, + ["sessions/00000000-0000-0000-0000-000000000000/messages"], + ) + # Single-segment wildcard: rejected by the callback. + plus = await _subscribe_via_handler(mcp, session, ["sessions/+/messages"]) + # Multi-segment wildcard over the agent_id slot: also rejected. + hashed = await _subscribe_via_handler(mcp, session, ["sessions/#"]) + + assert len(ok.subscribed) == 1 + assert ok.rejected == [] + + assert plus.subscribed == [] + assert len(plus.rejected) == 1 + assert plus.rejected[0].reason == "permission_denied" + + assert hashed.subscribed == [] + assert len(hashed.rejected) == 1 + assert hashed.rejected[0].reason == "permission_denied" + + async def test_authorize_callback_can_bind_specific_agent_id(self): + """A callback can implement per-session/agent binding by comparing + ``topic_params["agent_id"]`` against a server-maintained binding + between session IDs and permitted agent identities.""" + + bindings: dict[str, set[str]] = {} + + def authorize(session_id: str, params: dict[str, str]) -> bool: + permitted = bindings.get(session_id, set()) + return params.get("agent_id") in permitted + + mcp = FastMCP("test") + mcp.declare_event( + "sessions/{agent_id}/messages", kind="content", authorize=authorize + ) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + sid = getattr(session, "_fastmcp_event_session_id") + # Bind this session to a specific agent identity out of band. + bindings[sid] = {"agent-alpha"} + + ok = await _subscribe_via_handler( + mcp, session, ["sessions/agent-alpha/messages"] + ) + bad = await _subscribe_via_handler( + mcp, session, ["sessions/agent-beta/messages"] + ) + + assert len(ok.subscribed) == 1 + assert ok.rejected == [] + + assert bad.subscribed == [] + assert len(bad.rejected) == 1 + assert bad.rejected[0].reason == "permission_denied" + + +# --------------------------------------------------------------------------- +# target_session_ids on emit_event +# --------------------------------------------------------------------------- + + +class TestTargetSessionIds: + async def _capture_session(self, session: Any, sink: list[Any]) -> None: + async def capturing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + sink.append(notification) + + setattr(session, "send_notification", capturing_send) + + async def test_emit_without_target_session_ids_broadcasts(self): + """Default behavior: every matching subscriber receives the event.""" + mcp = FastMCP("test") + mcp.declare_event("public/topic", kind="content") + + sinks: dict[str, list[Any]] = {} + + async with Client(mcp) as _c1: + s1 = _get_active_session(mcp) + s1_id = getattr(s1, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s1_id, "public/topic") + sinks[s1_id] = [] + await self._capture_session(s1, sinks[s1_id]) + + async with Client(mcp) as _c2: + s2 = next(s for s in mcp._active_sessions.values() if s is not s1) + s2_id = getattr(s2, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s2_id, "public/topic") + sinks[s2_id] = [] + await self._capture_session(s2, sinks[s2_id]) + + async with Client(mcp) as _c3: + s3 = next( + s + for s in mcp._active_sessions.values() + if s is not s1 and s is not s2 + ) + s3_id = getattr(s3, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s3_id, "public/topic") + sinks[s3_id] = [] + await self._capture_session(s3, sinks[s3_id]) + + await mcp.emit_event("public/topic", {"v": 1}) + + assert len(sinks[s1_id]) == 1 + assert len(sinks[s2_id]) == 1 + assert len(sinks[s3_id]) == 1 + + async def test_emit_with_target_session_ids_filters(self): + mcp = FastMCP("test") + mcp.declare_event("public/topic", kind="content") + + sinks: dict[str, list[Any]] = {} + + async with Client(mcp) as _c1: + s1 = _get_active_session(mcp) + s1_id = getattr(s1, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s1_id, "public/topic") + sinks[s1_id] = [] + await self._capture_session(s1, sinks[s1_id]) + + async with Client(mcp) as _c2: + s2 = next(s for s in mcp._active_sessions.values() if s is not s1) + s2_id = getattr(s2, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s2_id, "public/topic") + sinks[s2_id] = [] + await self._capture_session(s2, sinks[s2_id]) + + async with Client(mcp) as _c3: + s3 = next( + s + for s in mcp._active_sessions.values() + if s is not s1 and s is not s2 + ) + s3_id = getattr(s3, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s3_id, "public/topic") + sinks[s3_id] = [] + await self._capture_session(s3, sinks[s3_id]) + + await mcp.emit_event( + "public/topic", + {"v": 1}, + target_session_ids=[s1_id, s2_id], + ) + + assert len(sinks[s1_id]) == 1 + assert len(sinks[s2_id]) == 1 + assert sinks[s3_id] == [] + + async def test_emit_with_target_session_ids_intersection_empty_is_noop(self): + """A target list that overlaps no subscribers delivers nothing, no error.""" + mcp = FastMCP("test") + mcp.declare_event("public/topic", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + sid = getattr(session, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(sid, "public/topic") + sink: list[Any] = [] + await self._capture_session(session, sink) + + await mcp.emit_event( + "public/topic", + {"v": 1}, + target_session_ids=["nope-not-a-real-session-id"], + ) + + assert sink == [] + + async def test_emit_with_target_session_ids_and_subscription_mismatch(self): + """Targeted session that lacks a matching subscription does not receive.""" + mcp = FastMCP("test") + mcp.declare_event("public/topic", kind="content") + mcp.declare_event("other/topic", kind="content") + + async with Client(mcp) as _c1: + s1 = _get_active_session(mcp) + s1_id = getattr(s1, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s1_id, "public/topic") + sink_s1: list[Any] = [] + await self._capture_session(s1, sink_s1) + + async with Client(mcp) as _c2: + s2 = next(s for s in mcp._active_sessions.values() if s is not s1) + s2_id = getattr(s2, "_fastmcp_event_session_id") + # s2 subscribes to a DIFFERENT topic + await mcp._subscription_registry.add(s2_id, "other/topic") + sink_s2: list[Any] = [] + await self._capture_session(s2, sink_s2) + + # Target both sessions but emit on a topic only s1 subscribes to. + await mcp.emit_event( + "public/topic", + {"v": 1}, + target_session_ids=[s1_id, s2_id], + ) + + assert len(sink_s1) == 1 + assert sink_s2 == [] + + async def test_context_emit_event_supports_target_session_ids(self): + """Context.emit_event passes target_session_ids through to FastMCP.emit_event.""" + mcp = FastMCP("test") + mcp.declare_event("public/topic", kind="content") + + @mcp.tool + async def fan_out(target: str, ctx: Context) -> str: + await ctx.emit_event( + "public/topic", {"hello": "world"}, target_session_ids=[target] + ) + return "ok" + + async with Client(mcp) as caller: + # caller is the session that invokes the tool; spin up two + # additional subscriber sessions. + caller_session = _get_active_session(mcp) + + async with Client(mcp) as _c2: + s2 = next( + s for s in mcp._active_sessions.values() if s is not caller_session + ) + s2_id = getattr(s2, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s2_id, "public/topic") + sink_s2: list[Any] = [] + await TestTargetSessionIds()._capture_session(s2, sink_s2) + + async with Client(mcp) as _c3: + s3 = next( + s + for s in mcp._active_sessions.values() + if s is not caller_session and s is not s2 + ) + s3_id = getattr(s3, "_fastmcp_event_session_id") + await mcp._subscription_registry.add(s3_id, "public/topic") + sink_s3: list[Any] = [] + await TestTargetSessionIds()._capture_session(s3, sink_s3) + + result = await caller.call_tool("fan_out", {"target": s2_id}) + assert result.data == "ok" + + assert len(sink_s2) == 1 + assert sink_s3 == [] + + +# --------------------------------------------------------------------------- +# Wildcard-smuggling regression guards +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# C1: Overlapping declarations - all must authorize +# --------------------------------------------------------------------------- + + +class TestOverlappingDeclarations: + async def test_overlapping_declarations_all_must_authorize(self): + """When a subscribe pattern matches multiple declarations, ALL must + authorize. Here ``myapp/events`` matches both declarations; the + second is guarded by a callback that rejects everything, so the + subscription must be rejected (a client cannot smuggle a forbidden + pattern through by also matching a permissive one).""" + + def deny_all(session_id: str, params: dict[str, str]) -> bool: + return False + + mcp = FastMCP("test") + # Permissive: no auth required + mcp.declare_event("myapp/events", kind="content") + # Restrictive: explicit deny-all callback + mcp.declare_event("myapp/{slot}", kind="content", authorize=deny_all) + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["myapp/events"]) + + assert result.subscribed == [] + assert len(result.rejected) == 1 + assert result.rejected[0].reason == "permission_denied" + + async def test_exact_match_still_works_after_fix(self): + """A simple exact-match subscription with no overlapping declarations + continues to work after removing the early-return short-circuit.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/events", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["myapp/events"]) + + assert len(result.subscribed) == 1 + assert result.subscribed[0].pattern == "myapp/events" + assert result.rejected == [] + + async def test_overlapping_permissive_declarations_both_pass(self): + """Two overlapping declarations both lack an authorize callback, so + the subscription succeeds under the default permissive policy.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/events", kind="content") + mcp.declare_event("myapp/{project}", kind="content") # permissive + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["myapp/events"]) + + assert len(result.subscribed) == 1 + assert result.subscribed[0].pattern == "myapp/events" + assert result.rejected == [] + + +# --------------------------------------------------------------------------- +# C2: Malformed pattern handling +# --------------------------------------------------------------------------- + + +class TestMalformedPatternHandling: + async def test_malformed_hash_pattern_rejected_gracefully(self): + """Subscribe to ``myapp/#/messages`` (# not terminal) must produce a + RejectedTopic with reason containing ``invalid_pattern``, not crash.""" + mcp = FastMCP("test") + # Use a parameterized declaration whose forward regex will match the + # malformed subscription pattern (after wildcard replacement to "x"). + mcp.declare_event("myapp/{kind}/messages", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["myapp/#/messages"]) + + assert result.subscribed == [] + assert len(result.rejected) == 1 + assert "invalid_pattern" in result.rejected[0].reason + + async def test_malformed_pattern_doesnt_break_other_topics(self): + """A batch subscribe with one valid and one malformed pattern should + succeed for the valid one and reject the malformed one.""" + mcp = FastMCP("test") + mcp.declare_event("valid/topic", kind="content") + # Parameterized so forward regex matches bad/#/topic after wildcard sub + mcp.declare_event("bad/{kind}/topic", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler( + mcp, session, ["valid/topic", "bad/#/topic"] + ) + + subscribed_patterns = [s.pattern for s in result.subscribed] + rejected_patterns = [r.pattern for r in result.rejected] + assert "valid/topic" in subscribed_patterns + assert "bad/#/topic" in rejected_patterns + assert "invalid_pattern" in result.rejected[0].reason + + async def test_valid_hash_terminal_still_works(self): + """Subscribe to ``myapp/#`` (terminal #) must succeed.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["myapp/#"]) + + assert len(result.subscribed) == 1 + assert result.subscribed[0].pattern == "myapp/#" + assert result.rejected == [] + + async def test_empty_pattern_rejected(self): + """Subscribe to an empty string should be rejected gracefully.""" + mcp = FastMCP("test") + mcp.declare_event("myapp/status", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, [""]) + + assert result.subscribed == [] + assert len(result.rejected) == 1 + + +class TestWildcardSmuggling: + async def test_wildcard_smuggling_rejected(self): + """A subscribe pattern that touches a privately guarded declaration + via wildcard must be rejected even if it ALSO matches an open + declaration. + + Two declarations: + - ``sessions/{agent_id}/messages`` (private, guarded by an + explicit authorize callback that rejects wildcards) + - ``sessions/{room}/public`` (open, no callback) + + The subscribe pattern ``sessions/+/messages`` is a wildcard superset + of the private pattern. It must be rejected because its callback + denies wildcard access, even though the public declaration is also + nominally matched by the wildcard subscribe. + """ + + def reject_wildcards(session_id: str, params: dict[str, str]) -> bool: + return params.get("agent_id") not in ("+", "#") + + mcp = FastMCP("test") + mcp.declare_event( + "sessions/{agent_id}/messages", + kind="content", + authorize=reject_wildcards, + ) + mcp.declare_event("sessions/{room}/public", kind="content") + + async with Client(mcp) as _client: + session = _get_active_session(mcp) + result = await _subscribe_via_handler(mcp, session, ["sessions/+/messages"]) + + assert result.subscribed == [] + assert len(result.rejected) == 1 + assert result.rejected[0].reason == "permission_denied" + + +# --------------------------------------------------------------------------- +# Context._tool_name and auto-source tests +# --------------------------------------------------------------------------- + + +class TestContextToolName: + def test_tool_name_set(self): + """Context created with _tool_name exposes it via tool_name property.""" + mcp = FastMCP("test") + ctx = Context(mcp, _tool_name="my_tool") + assert ctx.tool_name == "my_tool" + + def test_tool_name_default_none(self): + """Context created without _tool_name has tool_name == None.""" + mcp = FastMCP("test") + ctx = Context(mcp) + assert ctx.tool_name is None + + +class TestAutoSourceFromToolName: + async def test_auto_source_set_from_tool_name(self): + """When a tool emits an event without explicit source, source is auto-set + to 'tool/' and is present in the delivered notification.""" + mcp_server = FastMCP("test") + mcp_server.declare_event("myapp/notifications", kind="content") + + captured_sources: list[str | None] = [] + original_emit = mcp_server.emit_event + + async def tracking_emit( + topic: str, + payload: Any = None, + *, + priority: str = "normal", + source: str | None = None, + expires_at: str | None = None, + event_id: str | None = None, + retained: bool | None = None, + target_session_ids: Any = None, + ) -> None: + captured_sources.append(source) + await original_emit( + topic, + payload, + priority=priority, + source=source, + expires_at=expires_at, + event_id=event_id, + retained=retained, + target_session_ids=target_session_ids, + ) + + setattr(mcp_server, "emit_event", tracking_emit) + + @mcp_server.tool + async def notify(message: str, ctx: Context) -> str: + await ctx.emit_event("myapp/notifications", {"text": message}) + return "sent" + + received_notifications: list[Any] = [] + + async with Client(mcp_server) as client: + # Subscribe the client's session so it receives the notification + session = list(mcp_server._active_sessions.values())[0] + session_id = getattr(session, "_fastmcp_event_session_id") + await mcp_server._subscription_registry.add( + session_id, "myapp/notifications" + ) + + async def capturing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + received_notifications.append(notification) + + setattr(session, "send_notification", capturing_send) + + result = await client.call_tool("notify", {"message": "hello"}) + assert result.data == "sent" + + # Verify source kwarg passed to emit_event + assert len(captured_sources) == 1 + assert captured_sources[0] == "tool/notify" + + # Verify source is present in the delivered notification + assert len(received_notifications) == 1, ( + "Expected the subscribed session to receive the notification" + ) + notif = received_notifications[0] + assert notif.params.source == "tool/notify", ( + f"Expected source 'tool/notify' in delivered notification, " + f"got {notif.params.source!r}" + ) + + async def test_explicit_source_overrides_auto_source(self): + """When a tool provides explicit source, it is not overridden and is + present in the delivered notification.""" + mcp_server = FastMCP("test") + mcp_server.declare_event("myapp/notifications", kind="content") + + captured_sources: list[str | None] = [] + original_emit = mcp_server.emit_event + + async def tracking_emit( + topic: str, + payload: Any = None, + *, + priority: str = "normal", + source: str | None = None, + expires_at: str | None = None, + event_id: str | None = None, + retained: bool | None = None, + target_session_ids: Any = None, + ) -> None: + captured_sources.append(source) + await original_emit( + topic, + payload, + priority=priority, + source=source, + expires_at=expires_at, + event_id=event_id, + retained=retained, + target_session_ids=target_session_ids, + ) + + setattr(mcp_server, "emit_event", tracking_emit) + + @mcp_server.tool + async def notify_custom(message: str, ctx: Context) -> str: + await ctx.emit_event( + "myapp/notifications", + {"text": message}, + source="custom/source", + ) + return "sent" + + received_notifications: list[Any] = [] + + async with Client(mcp_server) as client: + # Subscribe the client's session so it receives the notification + session = list(mcp_server._active_sessions.values())[0] + session_id = getattr(session, "_fastmcp_event_session_id") + await mcp_server._subscription_registry.add( + session_id, "myapp/notifications" + ) + + async def capturing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + received_notifications.append(notification) + + setattr(session, "send_notification", capturing_send) + + result = await client.call_tool("notify_custom", {"message": "hello"}) + assert result.data == "sent" + + # Verify source kwarg passed to emit_event + assert len(captured_sources) == 1 + assert captured_sources[0] == "custom/source" + + # Verify source is present in the delivered notification + assert len(received_notifications) == 1, ( + "Expected the subscribed session to receive the notification" + ) + notif = received_notifications[0] + assert notif.params.source == "custom/source", ( + f"Expected source 'custom/source' in delivered notification, " + f"got {notif.params.source!r}" + ) + + async def test_source_none_when_not_in_tool_context(self): + """When emit_event is called directly on the server (not via a tool), + source remains None in the delivered notification.""" + mcp_server = FastMCP("test") + mcp_server.declare_event("myapp/status", kind="content") + + received_notifications: list[Any] = [] + + async with Client(mcp_server) as _client: + session = list(mcp_server._active_sessions.values())[0] + session_id = getattr(session, "_fastmcp_event_session_id") + await mcp_server._subscription_registry.add(session_id, "myapp/status") + + async def capturing_send( + notification: ServerNotification, + related_request_id: str | int | None = None, + ) -> None: + received_notifications.append(notification) + + setattr(session, "send_notification", capturing_send) + + # Emit directly on server, not through a tool + await mcp_server.emit_event("myapp/status", {"state": "running"}) + + assert len(received_notifications) == 1 + notif = received_notifications[0] + # Server-level emit has no tool context, so source should be None + assert notif.params.source is None diff --git a/uv.lock b/uv.lock index 49a18196ac..08c06538a5 100644 --- a/uv.lock +++ b/uv.lock @@ -798,6 +798,7 @@ dependencies = [ { name = "pydantic", extra = ["email"] }, { name = "pyperclip" }, { name = "python-dotenv" }, + { name = "python-ulid" }, { name = "pyyaml" }, { name = "rich" }, { name = "uncalled-for" }, @@ -872,7 +873,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1,<1.0" }, { name = "jsonref", specifier = ">=1.1.0" }, { name = "jsonschema-path", specifier = ">=0.3.4" }, - { name = "mcp", specifier = ">=1.24.0,<2.0" }, + { name = "mcp", git = "https://github.com/axiomantic/python-sdk.git?rev=mcp-events" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.102.0" }, { name = "openapi-pydantic", specifier = ">=0.5.1" }, { name = "opentelemetry-api", specifier = ">=1.20.0" }, @@ -886,6 +887,7 @@ requires-dist = [ { name = "pyjwt", marker = "extra == 'azure'", specifier = ">=2.12.0" }, { name = "pyperclip", specifier = ">=1.9.0" }, { name = "python-dotenv", specifier = ">=1.1.0" }, + { name = "python-ulid", specifier = ">=3.0.0" }, { name = "pyyaml", specifier = ">=6.0,<7.0" }, { name = "rich", specifier = ">=13.9.4" }, { name = "uncalled-for", specifier = ">=0.2.0" }, @@ -1559,8 +1561,8 @@ wheels = [ [[package]] name = "mcp" -version = "1.26.0" -source = { registry = "https://pypi.org/simple" } +version = "1.27.1.dev15+8844b1a" +source = { git = "https://github.com/axiomantic/python-sdk.git?rev=mcp-events#8844b1a2e2dc2bda69183a32ee88cfb9334af291" } dependencies = [ { name = "anyio" }, { name = "httpx" }, @@ -1570,6 +1572,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pyjwt", extra = ["crypto"] }, { name = "python-multipart" }, + { name = "python-ulid" }, { name = "pywin32", marker = "sys_platform == 'win32'" }, { name = "sse-starlette" }, { name = "starlette" }, @@ -1577,10 +1580,6 @@ dependencies = [ { name = "typing-inspection" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/6d/62e76bbb8144d6ed86e202b5edd8a4cb631e7c8130f3f4893c3f90262b10/mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66", size = 608005, upload-time = "2026-01-24T19:40:32.468Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, -] [[package]] name = "mdurl" @@ -2500,6 +2499,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" }, ] +[[package]] +name = "python-ulid" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/40/7e/0d6c82b5ccc71e7c833aed43d9e8468e1f2ff0be1b3f657a6fcafbb8433d/python_ulid-3.1.0.tar.gz", hash = "sha256:ff0410a598bc5f6b01b602851a3296ede6f91389f913a5d5f8c496003836f636", size = 93175, upload-time = "2025-08-18T16:09:26.305Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/a0/4ed6632b70a52de845df056654162acdebaf97c20e3212c559ac43e7216e/python_ulid-3.1.0-py3-none-any.whl", hash = "sha256:e2cdc979c8c877029b4b7a38a6fba3bc4578e4f109a308419ff4d3ccf0a46619", size = 11577, upload-time = "2025-08-18T16:09:25.047Z" }, +] + [[package]] name = "pywin32" version = "311"