|
4 | 4 | Provides a thin wrapper over FastMCP's Context for request handling. |
5 | 5 | """ |
6 | 6 |
|
| 7 | +from typing import Literal |
| 8 | + |
7 | 9 | from mcp.server.fastmcp import Context # pyright: ignore[reportMissingTypeArgument] |
| 10 | +from mcp.types import ( |
| 11 | + CreateMessageResult, |
| 12 | + ModelHint, |
| 13 | + ModelPreferences, |
| 14 | + SamplingMessage, |
| 15 | + TextContent, |
| 16 | +) |
8 | 17 |
|
9 | 18 | from .cache import ContextCache |
10 | 19 |
|
@@ -36,3 +45,86 @@ def cache(self) -> ContextCache: |
36 | 45 | if self._cache is None: |
37 | 46 | raise ValueError("Cache is not configured") |
38 | 47 | return self._cache |
| 48 | + |
| 49 | + # ------------------------------------------------------------------ |
| 50 | + # LLM Integration |
| 51 | + # ------------------------------------------------------------------ |
| 52 | + |
| 53 | + def _convert_messages( |
| 54 | + self, messages: str | list[str | SamplingMessage] |
| 55 | + ) -> list[SamplingMessage]: |
| 56 | + """Convert plain strings to ``SamplingMessage`` objects.""" |
| 57 | + |
| 58 | + if isinstance(messages, str): |
| 59 | + messages = [messages] |
| 60 | + |
| 61 | + converted: list[SamplingMessage] = [] |
| 62 | + for msg in messages: |
| 63 | + if isinstance(msg, SamplingMessage): |
| 64 | + converted.append(msg) |
| 65 | + elif isinstance(msg, str): |
| 66 | + converted.append( |
| 67 | + SamplingMessage( |
| 68 | + role="user", |
| 69 | + content=TextContent(type="text", text=msg), |
| 70 | + ) |
| 71 | + ) |
| 72 | + else: |
| 73 | + raise TypeError("messages must be str or SamplingMessage") |
| 74 | + return converted |
| 75 | + |
| 76 | + async def ask_llm( |
| 77 | + self, |
| 78 | + messages: str | list[str | SamplingMessage], |
| 79 | + *, |
| 80 | + system_prompt: str | None = None, |
| 81 | + max_tokens: int = 1000, |
| 82 | + temperature: float | None = None, |
| 83 | + model_preferences: ModelPreferences | None = None, |
| 84 | + allow_tools: Literal["none", "thisServer", "allServers"] | None = "none", |
| 85 | + stop_sequences: list[str] | None = None, |
| 86 | + ) -> CreateMessageResult: |
| 87 | + """Request LLM sampling via the connected client.""" |
| 88 | + |
| 89 | + sampling_messages = self._convert_messages(messages) |
| 90 | + session = self._request_context.session # type: ignore[attr-defined] |
| 91 | + return await session.create_message( |
| 92 | + messages=sampling_messages, |
| 93 | + system_prompt=system_prompt, |
| 94 | + max_tokens=max_tokens, |
| 95 | + temperature=temperature, |
| 96 | + model_preferences=model_preferences, |
| 97 | + include_context=allow_tools, |
| 98 | + stop_sequences=stop_sequences, |
| 99 | + ) |
| 100 | + |
| 101 | + async def sampling( |
| 102 | + self, |
| 103 | + messages: str | list[str | SamplingMessage], |
| 104 | + **kwargs, |
| 105 | + ) -> CreateMessageResult: |
| 106 | + """Alias for :meth:`ask_llm`.""" |
| 107 | + |
| 108 | + return await self.ask_llm(messages, **kwargs) |
| 109 | + |
| 110 | + |
| 111 | +def prefer_fast_model() -> ModelPreferences: |
| 112 | + """Model preferences optimized for speed and cost.""" |
| 113 | + |
| 114 | + return ModelPreferences( |
| 115 | + hints=[ModelHint(name="gpt-4o-mini"), ModelHint(name="claude-3-haiku")], |
| 116 | + costPriority=0.8, |
| 117 | + speedPriority=0.9, |
| 118 | + intelligencePriority=0.3, |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +def prefer_smart_model() -> ModelPreferences: |
| 123 | + """Model preferences optimized for intelligence and capability.""" |
| 124 | + |
| 125 | + return ModelPreferences( |
| 126 | + hints=[ModelHint(name="gpt-4o"), ModelHint(name="claude-3-opus")], |
| 127 | + costPriority=0.2, |
| 128 | + speedPriority=0.3, |
| 129 | + intelligencePriority=0.9, |
| 130 | + ) |
0 commit comments