Skip to content

Commit 1779e04

Browse files
authored
Add sampling support example and add ask_user (#138)
* Add sampling support to chatbot example * Use LLM in sampling callback with version check (#139) * Use info logging in chatbot (#142) * Add ask_user support (#143) * Upgrade MCP and add ask_user * Fix example lint issues
1 parent f13e027 commit 1779e04

File tree

8 files changed

+150
-14
lines changed

8 files changed

+150
-14
lines changed

docs/api/context.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ result = await ctx.ask_llm(
4040
print(result.content.text)
4141
```
4242

43+
## User Elicitation
44+
45+
Call `ask_user()` when you need additional input from the client. It wraps the
46+
underlying MCP `elicit()` API:
47+
48+
```python
49+
class BookingPreferences(BaseModel):
50+
alternativeDate: str | None
51+
checkAlternative: bool = False
52+
53+
result = await ctx.ask_user(
54+
message="No tables available. Try another date?",
55+
schema=BookingPreferences,
56+
)
57+
```
58+
4359
## Extending Context
4460

4561
For now, if you need context functionality, you can extend the base class:

examples/openai_chat_agent/app.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,75 @@
77
from __future__ import annotations
88

99
import asyncio
10+
import logging
1011
import os
12+
from importlib import metadata
13+
from typing import TYPE_CHECKING
1114

1215
import httpx
1316
from dotenv import load_dotenv
17+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
1418
from langchain_ollama import ChatOllama
1519
from langchain_openai import ChatOpenAI
16-
from mcp_use import MCPAgent, MCPClient
20+
from mcp.types import (
21+
CreateMessageRequestParams,
22+
CreateMessageResult,
23+
ErrorData,
24+
TextContent,
25+
)
26+
from mcp_use import MCPAgent, MCPClient, load_config_file
27+
from packaging.version import Version
28+
29+
if TYPE_CHECKING: # pragma: no cover - only for type hints
30+
from mcp import ClientSession
31+
32+
logger = logging.getLogger(__name__)
33+
logging.basicConfig(level=logging.INFO)
1734

1835
SYSTEM_MESSAGE = "You are a helpful assistant that talks to the user and uses tools via MCP."
1936

2037

38+
def make_sampling_callback(llm: ChatOpenAI | ChatOllama):
39+
async def sampling_callback(
40+
context: ClientSession, params: CreateMessageRequestParams
41+
) -> CreateMessageResult | ErrorData:
42+
lc_messages = []
43+
system_prompt = getattr(params, "systemPrompt", None)
44+
if system_prompt:
45+
lc_messages.append(SystemMessage(content=system_prompt))
46+
for msg in params.messages:
47+
content = msg.content.text
48+
if msg.role == "assistant":
49+
lc_messages.append(AIMessage(content=content))
50+
else:
51+
lc_messages.append(HumanMessage(content=content))
52+
53+
try:
54+
logger.info(f"Sampling with messages: {lc_messages}")
55+
max_tokens = getattr(params, "maxTokens", None)
56+
stop_sequences = getattr(params, "stopSequences", None)
57+
result_msg = await llm.ainvoke(
58+
lc_messages,
59+
temperature=params.temperature,
60+
max_tokens=max_tokens,
61+
stop=stop_sequences,
62+
)
63+
except Exception as exc:
64+
logger.error(f"Failed to invoke llm for sampling: {exc}")
65+
return ErrorData(code=400, message=str(exc))
66+
67+
text = getattr(result_msg, "content", str(result_msg))
68+
model_name = getattr(llm, "model", "llm")
69+
logger.info(f"Sampling result: {text}")
70+
return CreateMessageResult(
71+
content=TextContent(text=text, type="text"),
72+
model=model_name,
73+
role="assistant",
74+
)
75+
76+
return sampling_callback
77+
78+
2179
async def ensure_ollama_running(model: str) -> None:
2280
"""Check that an Ollama server is running."""
2381
try:
@@ -40,18 +98,34 @@ async def run_memory_chat() -> None:
4098
load_dotenv()
4199
config_file = os.path.join(os.path.dirname(__file__), "config.json")
42100

43-
print("Initializing chat...")
44-
client = MCPClient.from_config_file(config_file)
45-
46101
openai_key = os.getenv("OPENAI_API_KEY")
47102
ollama_model = os.getenv("OLLAMA_MODEL", "llama3.2")
48103

104+
print("Initializing chat...")
105+
49106
if openai_key:
50107
llm = ChatOpenAI(model="gpt-4o")
51108
else:
52109
await ensure_ollama_running(ollama_model)
53110
llm = ChatOllama(model=ollama_model)
54111

112+
try:
113+
mcp_use_version = metadata.version("mcp_use")
114+
except metadata.PackageNotFoundError: # pragma: no cover - dev env only
115+
mcp_use_version = "0"
116+
117+
if Version(mcp_use_version) > Version("1.3.6"):
118+
client = MCPClient(
119+
load_config_file(config_file),
120+
sampling_callback=make_sampling_callback(llm),
121+
)
122+
else:
123+
logger.warning(
124+
"mcp-use %s does not support sampling, install >1.3.6. Disabling sampling callback",
125+
mcp_use_version,
126+
)
127+
client = MCPClient(load_config_file(config_file))
128+
55129
agent = MCPAgent(
56130
llm=llm,
57131
client=client,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"mcpServers": {
3-
"shop_api": {
3+
"travel_agent": {
44
"command": "python",
5-
"args": ["../shop_api/app.py"]
5+
"args": ["../server_side_llm_travel_planner/app.py"]
66
}
77
}
88
}

examples/openai_chat_agent/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ langchain_ollama
44
langchain_community
55
mcp_use
66
python-dotenv
7+
httpx

examples/server_side_llm_travel_planner/app.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,21 @@ async def plan_trip(
6363
) -> list[Destination]:
6464
"""Return three destinations that best match the given preferences."""
6565
ctx = app.get_context()
66-
6766
bullet_list = "\n".join(f"- {d.name}: {d.summary}" for d in DESTINATIONS)
6867
prompt = (
6968
"Select the three best destinations from the list below based on the "
70-
"given preferences. Reply with a JSON list of names only.\nPreferences: "
69+
"given preferences. Reply with a JSON list of names only. "
70+
"The text should be directly parsable with json.loads in Python. "
71+
'Do NOT add ```json like markdown. Example response:\n["San Francisco"]'
72+
"\n\n\nPreferences: "
7173
f"{preferences}\n\n{bullet_list}"
7274
)
73-
result = await ctx.sampling(
75+
result = await ctx.ask_llm(
7476
prompt,
7577
model_preferences=prefer_fast_model(),
7678
max_tokens=50,
7779
)
78-
try:
79-
names = json.loads(result.content.text)
80-
except Exception:
81-
return []
80+
names = json.loads(result.content.text)
8281
return [d for d in DESTINATIONS if d.name in names]
8382

8483

src/enrichmcp/context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import Literal
88

9+
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT
910
from mcp.server.fastmcp import Context # pyright: ignore[reportMissingTypeArgument]
1011
from mcp.types import (
1112
CreateMessageResult,
@@ -108,6 +109,15 @@ async def sampling(
108109

109110
return await self.ask_llm(messages, **kwargs)
110111

112+
async def ask_user(
113+
self,
114+
message: str,
115+
schema: type[ElicitSchemaModelT],
116+
) -> ElicitationResult:
117+
"""Interactively ask the client for input using MCP elicitation."""
118+
119+
return await super().elicit(message=message, schema=schema)
120+
111121

112122
def prefer_fast_model() -> ModelPreferences:
113123
"""Model preferences optimized for speed and cost."""

tests/test_elicitation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unittest.mock import AsyncMock, Mock, patch
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
from enrichmcp import EnrichContext
7+
8+
9+
class Prefs(BaseModel):
10+
choice: bool
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_ask_user_delegates_to_context_elicit():
15+
ctx = EnrichContext.model_construct(_request_context=Mock())
16+
17+
with patch("enrichmcp.context.Context.elicit", AsyncMock(return_value="ok")) as mock:
18+
got = await ctx.ask_user("hi", Prefs)
19+
assert got == "ok"
20+
mock.assert_awaited_once_with(message="hi", schema=Prefs)
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_ask_user_requires_request_context():
25+
ctx = EnrichContext()
26+
27+
with pytest.raises(ValueError, match="outside of a request"):
28+
await ctx.ask_user("hi", Prefs)

uv.lock

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)