Skip to content

Commit e420eb5

Browse files
authored
feat: add semantic search for component definitions (#33)
* feat: add model protocol for embeddings * feat: add component helper functions * feat: update imports for search functionality * refactor: move component info extraction to helper function * refactor: move io info formatting to helper function * feat: add search_component_definition function * test: add tests for component helper functions * test: update imports for search functionality tests * test: add fake model implementation * test: add io response parameter to fake service * test: use separate response for io information * test: update component definition test to use updated fake service * test: add success case for search component functionality * test: add error cases for search component functionality * feat: component search mcp * fix: linting and types
1 parent 9f538fd commit e420eb5

8 files changed

Lines changed: 691 additions & 89 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ dependencies = [
1212
"httpx",
1313
"pydantic>=2.0.0",
1414
"pyyaml",
15+
"numpy",
16+
"model2vec",
1517
]
1618

1719
[project.scripts]

src/deepset_mcp/main.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22

33
from mcp.server.fastmcp import FastMCP
4+
from model2vec import StaticModel
45

56
from deepset_mcp.api.client import AsyncDeepsetClient
67
from deepset_mcp.tools.haystack_service import (
78
get_component_definition as get_component_definition_tool,
89
list_component_families as list_component_families_tool,
10+
search_component_definition as search_component_definition_tool,
911
)
1012
from deepset_mcp.tools.pipeline import (
1113
create_pipeline as create_pipeline_tool,
@@ -15,6 +17,8 @@
1517
validate_pipeline as validate_pipeline_tool,
1618
)
1719

20+
INITIALIZED_MODEL = StaticModel.from_pretrained("minishlab/potion-base-2M")
21+
1822
# Initialize MCP Server
1923
mcp = FastMCP("Deepset Cloud MCP")
2024

@@ -134,6 +138,21 @@ async def validate_pipeline(yaml_configuration: str) -> str:
134138
return response
135139

136140

141+
@mcp.tool()
142+
async def search_component_definitions(query: str) -> str:
143+
"""Use this to search for components in deepset.
144+
145+
You can use full natural language queries to find components.
146+
You can also use simple keywords.
147+
Use this if you want to find the definition for a component,
148+
but you are not sure what the exact name of the component is.
149+
"""
150+
async with AsyncDeepsetClient() as client:
151+
response = await search_component_definition_tool(client=client, query=query, model=INITIALIZED_MODEL)
152+
153+
return response
154+
155+
137156
#
138157
#
139158
# @mcp.tool()
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from typing import Any
2+
3+
4+
def extract_component_info(components: dict[str, Any], component_def: dict[str, Any]) -> str:
5+
"""Extracts and formats component information from its definition.
6+
7+
Args:
8+
components: The components dictionary from the schema
9+
component_def: The specific component definition
10+
11+
Returns:
12+
A formatted string containing the component information
13+
"""
14+
component_type_info = component_def["properties"]["type"]
15+
init_params = component_def["properties"].get("init_parameters", {}).get("properties", {})
16+
component_type = component_type_info["const"]
17+
18+
# Format the basic component information
19+
parts = [
20+
f"Component: {component_type}",
21+
f"Name: {component_def.get('title', 'Unknown')}",
22+
f"Family: {component_type_info.get('family', 'Unknown')}",
23+
f"Family Description: {component_type_info.get('family_description', 'No description available.')}",
24+
f"\nDescription:\n{component_def.get('description', 'No description available.')}\n",
25+
"\nInitialization Parameters:",
26+
]
27+
28+
if not init_params:
29+
parts.append(" No initialization parameters")
30+
else:
31+
for param_name, param_info in init_params.items():
32+
param_type = param_info.get("_annotation", param_info.get("type", "Unknown"))
33+
param_desc = param_info.get("description", "No description available.")
34+
default = f" (default: {param_info['default']})" if "default" in param_info else ""
35+
parts.append(f" {param_name}: {param_type}{default}\n {param_desc}")
36+
37+
return "\n".join(parts)
38+
39+
40+
def format_io_info(io_info: dict[str, Any]) -> str:
41+
"""Formats the input/output information for a component.
42+
43+
Args:
44+
io_info: The input/output information dictionary
45+
46+
Returns:
47+
A formatted string containing the IO information
48+
"""
49+
parts = []
50+
51+
# Add Input Schema
52+
parts.append("\nInput Schema:")
53+
if "input" in io_info:
54+
input_props = io_info["input"].get("properties", {})
55+
if not input_props:
56+
parts.append(" No input parameters")
57+
else:
58+
required = io_info["input"].get("required", [])
59+
for param_name, param_info in input_props.items():
60+
req_marker = " (required)" if param_name in required else ""
61+
param_type = param_info.get("_annotation", param_info.get("type", "Unknown"))
62+
param_desc = param_info.get("description", "No description available.")
63+
default = f" (default: {param_info['default']})" if "default" in param_info else ""
64+
parts.append(f" {param_name}: {param_type}{req_marker}{default}\n {param_desc}")
65+
else:
66+
parts.append(" Input schema not available")
67+
68+
# Add Output Schema
69+
parts.append("\nOutput Schema:")
70+
if "output" in io_info and isinstance(io_info["output"], dict):
71+
output_info = io_info["output"]
72+
if "properties" in output_info:
73+
output_props = output_info.get("properties", {})
74+
if not output_props:
75+
parts.append(" No output parameters")
76+
else:
77+
required = output_info.get("required", [])
78+
for param_name, param_info in output_props.items():
79+
req_marker = " (required)" if param_name in required else ""
80+
param_type = param_info.get("_annotation", param_info.get("type", "Unknown"))
81+
param_desc = param_info.get("description", "No description available.")
82+
default = f" (default: {param_info['default']})" if "default" in param_info else ""
83+
parts.append(f" {param_name}: {param_type}{req_marker}{default}\n {param_desc}")
84+
85+
# Include any definitions if they exist
86+
if "definitions" in output_info:
87+
parts.append("\n Definitions:")
88+
for def_name, def_info in output_info["definitions"].items():
89+
parts.append(f"\n {def_name}:")
90+
if "properties" in def_info:
91+
def_required = def_info.get("required", [])
92+
for prop_name, prop_info in def_info["properties"].items():
93+
req_marker = " (required)" if prop_name in def_required else ""
94+
prop_type = prop_info.get("_annotation", prop_info.get("type", "Unknown"))
95+
prop_desc = prop_info.get("description", "No description available.")
96+
default = f" (default: {prop_info['default']})" if "default" in prop_info else ""
97+
parts.append(
98+
f" {prop_name}: {prop_type}{req_marker}{default}\n {prop_desc}"
99+
)
100+
else:
101+
# Simple output schema
102+
desc = output_info.get("description", "No description available.")
103+
output_type = output_info.get("type", "Unknown")
104+
parts.append(f" Type: {output_type}\n {desc}")
105+
else:
106+
parts.append(" Output schema not available")
107+
108+
return "\n".join(parts)
109+
110+
111+
def extract_component_texts(component_def: dict[str, Any]) -> tuple[str, str]:
112+
"""Extracts the component name and description for embedding.
113+
114+
Args:
115+
component_def: The component definition
116+
117+
Returns:
118+
A tuple containing the component name and description
119+
"""
120+
component_type = component_def["properties"]["type"]["const"]
121+
name = component_def.get("title", "")
122+
description = component_def.get("description", "")
123+
return component_type, f"{name} {description}"

src/deepset_mcp/tools/haystack_service.py

Lines changed: 66 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
import numpy as np
2+
13
from deepset_mcp.api.exceptions import UnexpectedAPIError
24
from deepset_mcp.api.protocols import AsyncClientProtocol
5+
from deepset_mcp.tools.component_helper import (
6+
extract_component_info,
7+
extract_component_texts,
8+
format_io_info,
9+
)
10+
from deepset_mcp.tools.model_protocol import ModelProtocol
311

412

513
async def get_component_definition(client: AsyncClientProtocol, component_type: str) -> str:
@@ -32,97 +40,76 @@ async def get_component_definition(client: AsyncClientProtocol, component_type:
3240
if not component_def:
3341
return f"Component not found: {component_type}"
3442

35-
# Extract relevant information
36-
component_type_info = component_def["properties"]["type"]
37-
init_params = component_def["properties"].get("init_parameters", {}).get("properties", {})
38-
39-
# Format the basic component information
40-
parts = [
41-
f"Component: {component_type}",
42-
f"Name: {component_def.get('title', 'Unknown')}",
43-
f"Family: {component_type_info.get('family', 'Unknown')}",
44-
f"Family Description: {component_type_info.get('family_description', 'No description available.')}",
45-
f"\nDescription:\n{component_def.get('description', 'No description available.')}\n",
46-
"\nInitialization Parameters:",
47-
]
48-
49-
if not init_params:
50-
parts.append(" No initialization parameters")
51-
else:
52-
for param_name, param_info in init_params.items():
53-
param_type = param_info.get("_annotation", param_info.get("type", "Unknown"))
54-
param_desc = param_info.get("description", "No description available.")
55-
default = f" (default: {param_info['default']})" if "default" in param_info else ""
56-
parts.append(f" {param_name}: {param_type}{default}\n {param_desc}")
43+
# Get component information
44+
parts = [extract_component_info(components, component_def)]
5745

5846
# Fetch and add input/output information
5947
try:
6048
# Extract component name from the full path
6149
component_name = component_type.split(".")[-1]
6250
io_info = await haystack_service.get_component_input_output(component_name)
63-
64-
# Add Input Schema
65-
parts.append("\nInput Schema:")
66-
if "input" in io_info:
67-
input_props = io_info["input"].get("properties", {})
68-
if not input_props:
69-
parts.append(" No input parameters")
70-
else:
71-
required = io_info["input"].get("required", [])
72-
for param_name, param_info in input_props.items():
73-
req_marker = " (required)" if param_name in required else ""
74-
param_type = param_info.get("_annotation", param_info.get("type", "Unknown"))
75-
param_desc = param_info.get("description", "No description available.")
76-
default = f" (default: {param_info['default']})" if "default" in param_info else ""
77-
parts.append(f" {param_name}: {param_type}{req_marker}{default}\n {param_desc}")
78-
else:
79-
parts.append(" Input schema not available")
80-
81-
# Add Output Schema
82-
parts.append("\nOutput Schema:")
83-
if "output" in io_info and isinstance(io_info["output"], dict):
84-
output_info = io_info["output"]
85-
if "properties" in output_info:
86-
output_props = output_info.get("properties", {})
87-
if not output_props:
88-
parts.append(" No output parameters")
89-
else:
90-
required = output_info.get("required", [])
91-
for param_name, param_info in output_props.items():
92-
req_marker = " (required)" if param_name in required else ""
93-
param_type = param_info.get("_annotation", param_info.get("type", "Unknown"))
94-
param_desc = param_info.get("description", "No description available.")
95-
default = f" (default: {param_info['default']})" if "default" in param_info else ""
96-
parts.append(f" {param_name}: {param_type}{req_marker}{default}\n {param_desc}")
97-
98-
# Include any definitions if they exist
99-
if "definitions" in output_info:
100-
parts.append("\n Definitions:")
101-
for def_name, def_info in output_info["definitions"].items():
102-
parts.append(f"\n {def_name}:")
103-
if "properties" in def_info:
104-
def_required = def_info.get("required", [])
105-
for prop_name, prop_info in def_info["properties"].items():
106-
req_marker = " (required)" if prop_name in def_required else ""
107-
prop_type = prop_info.get("_annotation", prop_info.get("type", "Unknown"))
108-
prop_desc = prop_info.get("description", "No description available.")
109-
default = f" (default: {prop_info['default']})" if "default" in prop_info else ""
110-
parts.append(
111-
f" {prop_name}: {prop_type}{req_marker}{default}\n {prop_desc}"
112-
)
113-
else:
114-
# Simple output schema
115-
desc = output_info.get("description", "No description available.")
116-
output_type = output_info.get("type", "Unknown")
117-
parts.append(f" Type: {output_type}\n {desc}")
118-
else:
119-
parts.append(" Output schema not available")
51+
parts.append(format_io_info(io_info))
12052
except Exception as e:
12153
parts.append(f"\nFailed to fetch input/output schema: {str(e)}")
12254

12355
return "\n".join(parts)
12456

12557

58+
async def search_component_definition(
59+
client: AsyncClientProtocol, query: str, model: ModelProtocol, top_k: int = 5
60+
) -> str:
61+
"""Searches for components based on name or description using semantic similarity.
62+
63+
Args:
64+
client: The API client to use
65+
query: The search query
66+
model: The model to use for computing embeddings
67+
top_k: Maximum number of results to return (default: 5)
68+
69+
Returns:
70+
A formatted string containing the matched component definitions
71+
"""
72+
haystack_service = client.haystack_service()
73+
74+
try:
75+
response = await haystack_service.get_component_schemas()
76+
except UnexpectedAPIError as e:
77+
return f"Failed to retrieve component schemas: {e}"
78+
79+
components = response["component_schema"]["definitions"]["Components"]
80+
81+
# Extract text for embedding from all components
82+
component_texts: list[tuple[str, str]] = [extract_component_texts(comp) for comp in components.values()]
83+
component_types: list[str] = [c[0] for c in component_texts]
84+
85+
if not component_texts:
86+
return "No components found"
87+
88+
# Compute embeddings
89+
query_embedding = model.encode(query)
90+
component_embeddings = model.encode([text for _, text in component_texts])
91+
92+
query_embedding_reshaped = query_embedding.reshape(1, -1)
93+
94+
# Calculate dot product between target and all paths
95+
# This gives us a similarity score for each path
96+
similarities = np.dot(component_embeddings, query_embedding_reshaped.T).flatten()
97+
98+
# Create (path, similarity) pairs
99+
component_similarities = list(zip(component_types, similarities, strict=False))
100+
101+
# Sort by similarity score in descending order
102+
component_similarities.sort(key=lambda x: x[1], reverse=True)
103+
104+
top_components = component_similarities[:top_k]
105+
results = []
106+
for component_type, sim in top_components:
107+
definition = await get_component_definition(client, component_type)
108+
results.append(f"Similarity Score: {sim:.3f}\n{definition}\n{'-' * 80}\n")
109+
110+
return "\n".join(results)
111+
112+
126113
async def list_component_families(client: AsyncClientProtocol) -> str:
127114
"""Lists all Haystack component families that are available on deepset."""
128115
haystack_service = client.haystack_service()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any, Protocol
2+
3+
import numpy as np
4+
5+
6+
class ModelProtocol(Protocol):
7+
"""Protocol for static embedding models."""
8+
9+
def encode(self, sentences: list[str] | str) -> np.ndarray[Any, Any]:
10+
"""
11+
Encodes a single or multiple sentences.
12+
13+
:param sentences: Single sentence or list of sentences to encode
14+
:returns: Numpy array of encoded sentences
15+
"""
16+
...

0 commit comments

Comments
 (0)