|
| 1 | +import json |
| 2 | +import os |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from litellm.types.utils import ModelResponse |
| 6 | +import pytest |
| 7 | +from langevals import expect |
| 8 | +from langevals_langevals.llm_boolean import ( |
| 9 | + CustomLLMBooleanEvaluator, |
| 10 | + CustomLLMBooleanSettings, |
| 11 | +) |
| 12 | +from litellm import ChatCompletionMessageToolCall, Choices, Message, acompletion |
| 13 | +from mcp.types import TextContent, Tool |
| 14 | +from mcp import ClientSession |
| 15 | +from mcp.client.sse import sse_client |
| 16 | +from dotenv import load_dotenv |
| 17 | + |
| 18 | +load_dotenv() |
| 19 | + |
| 20 | +DEFAULT_GRAFANA_URL = "http://localhost:3000" |
| 21 | +DEFAULT_MCP_URL = "http://localhost:8000/sse" |
| 22 | + |
| 23 | +models = ["gpt-4o", "claude-3-5-sonnet-20240620"] |
| 24 | + |
| 25 | +pytestmark = pytest.mark.anyio |
| 26 | + |
| 27 | + |
| 28 | +@pytest.fixture |
| 29 | +def mcp_url(): |
| 30 | + return os.environ.get("MCP_GRAFANA_URL", DEFAULT_MCP_URL) |
| 31 | + |
| 32 | + |
| 33 | +@pytest.fixture |
| 34 | +def grafana_headers(): |
| 35 | + headers = { |
| 36 | + "X-Grafana-URL": os.environ.get("GRAFANA_URL", DEFAULT_GRAFANA_URL), |
| 37 | + } |
| 38 | + if key := os.environ.get("GRAFANA_API_KEY"): |
| 39 | + headers["X-Grafana-API-Key"] = key |
| 40 | + return headers |
| 41 | + |
| 42 | + |
| 43 | +@pytest.fixture |
| 44 | +async def mcp_client(mcp_url, grafana_headers): |
| 45 | + async with sse_client(mcp_url, headers=grafana_headers) as ( |
| 46 | + read, |
| 47 | + write, |
| 48 | + ): |
| 49 | + async with ClientSession(read, write) as session: |
| 50 | + await session.initialize() |
| 51 | + yield session |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.parametrize("model", models) |
| 55 | +@pytest.mark.flaky(max_runs=3) |
| 56 | +async def test_loki_logs_tool(model: str, mcp_client: ClientSession): |
| 57 | + tools = await mcp_client.list_tools() |
| 58 | + prompt = "Can you list the last 10 log lines from all containers using any available Loki datasource? Give me the raw log lines. Please use only the necessary tools to get this information." |
| 59 | + |
| 60 | + messages: list[Message] = [ |
| 61 | + Message(role="system", content="You are a helpful assistant."), |
| 62 | + Message(role="user", content=prompt), |
| 63 | + ] |
| 64 | + tools = [convert_tool(t) for t in tools.tools] |
| 65 | + |
| 66 | + response = await acompletion( |
| 67 | + model=model, |
| 68 | + messages=messages, |
| 69 | + tools=tools, |
| 70 | + ) |
| 71 | + |
| 72 | + # Check that there's a datasources tool call. |
| 73 | + assert isinstance(response, ModelResponse) |
| 74 | + messages.extend( |
| 75 | + await assert_and_handle_tool_call(response, mcp_client, "list_datasources") |
| 76 | + ) |
| 77 | + |
| 78 | + datasources_response = messages[-1].content |
| 79 | + datasources_data = json.loads(datasources_response) |
| 80 | + assert len(datasources_data) > 0, "Should have at least one datasource" |
| 81 | + |
| 82 | + # Verify Loki datasource exists |
| 83 | + loki_datasources = [ds for ds in datasources_data if ds.get("type") == "loki"] |
| 84 | + assert len(loki_datasources) > 0, "No Loki datasource found" |
| 85 | + print( |
| 86 | + f"\nFound Loki datasource: {loki_datasources[0]['name']} (uid: {loki_datasources[0]['uid']})" |
| 87 | + ) |
| 88 | + |
| 89 | + # Call the LLM including the tool call result. |
| 90 | + response = await acompletion( |
| 91 | + model=model, |
| 92 | + messages=messages, |
| 93 | + tools=tools, |
| 94 | + ) |
| 95 | + |
| 96 | + # Check that there's a loki logstool call. |
| 97 | + assert isinstance(response, ModelResponse) |
| 98 | + messages.extend( |
| 99 | + await assert_and_handle_tool_call( |
| 100 | + response, |
| 101 | + mcp_client, |
| 102 | + "query_loki_logs", |
| 103 | + {"datasourceUid": "loki"}, |
| 104 | + ) |
| 105 | + ) |
| 106 | + |
| 107 | + # Call the LLM including the tool call result. |
| 108 | + response = await acompletion( |
| 109 | + model=model, |
| 110 | + messages=messages, |
| 111 | + tools=tools, |
| 112 | + ) |
| 113 | + |
| 114 | + # Check that the response has some log lines. |
| 115 | + content = response.choices[0].message.content |
| 116 | + log_lines_checker = CustomLLMBooleanEvaluator( |
| 117 | + settings=CustomLLMBooleanSettings( |
| 118 | + prompt="Does the response contain specific information that could only come from a Loki datasource? This could be actual log lines with timestamps, container names, or a summary that references specific log data. The response should show evidence of real data rather than generic statements.", |
| 119 | + ) |
| 120 | + ) |
| 121 | + print("content", content) |
| 122 | + expect(input=prompt, output=content).to_pass(log_lines_checker) |
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.parametrize("model", models) |
| 126 | +@pytest.mark.flaky(max_runs=3) |
| 127 | +async def test_loki_container_labels(model: str, mcp_client: ClientSession): |
| 128 | + tools = await mcp_client.list_tools() |
| 129 | + prompt = "Can you list the values for the label container in any available loki datasource? Please use only the necessary tools to get this information." |
| 130 | + |
| 131 | + messages: list[Message] = [ |
| 132 | + Message(role="system", content="You are a helpful assistant."), |
| 133 | + Message(role="user", content=prompt), |
| 134 | + ] |
| 135 | + tools = [convert_tool(t) for t in tools.tools] |
| 136 | + |
| 137 | + response = await acompletion( |
| 138 | + model=model, |
| 139 | + messages=messages, |
| 140 | + tools=tools, |
| 141 | + ) |
| 142 | + |
| 143 | + # Check that there's a datasources tool call. |
| 144 | + assert isinstance(response, ModelResponse) |
| 145 | + messages.extend( |
| 146 | + await assert_and_handle_tool_call(response, mcp_client, "list_datasources") |
| 147 | + ) |
| 148 | + |
| 149 | + datasources_response = messages[-1].content |
| 150 | + datasources_data = json.loads(datasources_response) |
| 151 | + assert len(datasources_data) > 0, "Should have at least one datasource" |
| 152 | + |
| 153 | + # Verify Loki datasource exists |
| 154 | + loki_datasources = [ds for ds in datasources_data if ds.get("type") == "loki"] |
| 155 | + assert len(loki_datasources) > 0, "No Loki datasource found" |
| 156 | + print( |
| 157 | + f"\nFound Loki datasource: {loki_datasources[0]['name']} (uid: {loki_datasources[0]['uid']})" |
| 158 | + ) |
| 159 | + |
| 160 | + # Call the LLM including the tool call result. |
| 161 | + response = await acompletion( |
| 162 | + model=model, |
| 163 | + messages=messages, |
| 164 | + tools=tools, |
| 165 | + ) |
| 166 | + |
| 167 | + # Check that there's a list_loki_label_values tool call. |
| 168 | + assert isinstance(response, ModelResponse) |
| 169 | + messages.extend( |
| 170 | + await assert_and_handle_tool_call( |
| 171 | + response, |
| 172 | + mcp_client, |
| 173 | + "list_loki_label_values", |
| 174 | + {"datasourceUid": "loki", "labelName": "container"}, |
| 175 | + ) |
| 176 | + ) |
| 177 | + |
| 178 | + # Call the LLM including the tool call result. |
| 179 | + response = await acompletion( |
| 180 | + model=model, |
| 181 | + messages=messages, |
| 182 | + tools=tools, |
| 183 | + ) |
| 184 | + |
| 185 | + # Check that the response provides a meaningful summary of container labels |
| 186 | + content = response.choices[0].message.content |
| 187 | + label_checker = CustomLLMBooleanEvaluator( |
| 188 | + settings=CustomLLMBooleanSettings( |
| 189 | + prompt="Does the response provide a clear and organized list of container names found in the logs? It should present the container names in a readable format and may include additional context about their usage.", |
| 190 | + ) |
| 191 | + ) |
| 192 | + expect(input=prompt, output=content).to_pass(label_checker) |
| 193 | + |
| 194 | + |
| 195 | +async def assert_and_handle_tool_call( |
| 196 | + response: ModelResponse, |
| 197 | + mcp_client: ClientSession, |
| 198 | + expected_tool: str, |
| 199 | + expected_args: dict[str, Any] | None = None, |
| 200 | +) -> list[Message]: |
| 201 | + messages: list[Message] = [] |
| 202 | + tool_calls: list[ChatCompletionMessageToolCall] = [] |
| 203 | + for c in response.choices: |
| 204 | + assert isinstance(c, Choices) |
| 205 | + tool_calls.extend(c.message.tool_calls or []) |
| 206 | + # Add the message to the list of messages. |
| 207 | + # We'll need to send these back to the LLM with the tool call result. |
| 208 | + messages.append(c.message) |
| 209 | + |
| 210 | + # Check that the expected tool call is in the response. |
| 211 | + assert len(tool_calls) == 1 |
| 212 | + |
| 213 | + # Call the tool(s) with the requested args. |
| 214 | + for tool_call in tool_calls: |
| 215 | + assert isinstance(tool_call.function.name, str) |
| 216 | + arguments = ( |
| 217 | + {} |
| 218 | + if len(tool_call.function.arguments) == 0 |
| 219 | + else json.loads(tool_call.function.arguments) |
| 220 | + ) |
| 221 | + assert tool_call.function.name == expected_tool |
| 222 | + |
| 223 | + if expected_args: |
| 224 | + for key, value in expected_args.items(): |
| 225 | + assert key in arguments, ( |
| 226 | + f"Missing required argument '{key}' in tool call" |
| 227 | + ) |
| 228 | + assert arguments[key] == value, ( |
| 229 | + f"Argument '{key}' has wrong value. Expected: {value}, Got: {arguments[key]}" |
| 230 | + ) |
| 231 | + |
| 232 | + print(f"calling tool: {tool_call.function.name}({arguments})") |
| 233 | + result = await mcp_client.call_tool(tool_call.function.name, arguments) |
| 234 | + # Assume each tool returns a single text content for now |
| 235 | + assert len(result.content) == 1 |
| 236 | + assert isinstance(result.content[0], TextContent) |
| 237 | + messages.append( |
| 238 | + Message( |
| 239 | + role="tool", tool_call_id=tool_call.id, content=result.content[0].text |
| 240 | + ) |
| 241 | + ) |
| 242 | + return messages |
| 243 | + |
| 244 | + |
| 245 | +def convert_tool(tool: Tool) -> dict: |
| 246 | + return { |
| 247 | + "type": "function", |
| 248 | + "function": { |
| 249 | + "name": tool.name, |
| 250 | + "description": tool.description, |
| 251 | + "parameters": { |
| 252 | + **tool.inputSchema, |
| 253 | + "properties": tool.inputSchema.get("properties", {}), |
| 254 | + }, |
| 255 | + }, |
| 256 | + } |
0 commit comments