Skip to content

Commit bdd82b8

Browse files
authored
feat: Add search_pipeline_templates tool with semantic similarity search (#59)
* feat: add import for pipeline template formatting utility * feat: add search_pipeline_templates function with semantic similarity search * feat: add import for search_pipeline_templates function * feat: add import for pipeline template models for testing * feat: add fake pipeline templates resource for testing * feat: extend fake client to support pipeline templates resource * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * fix: use keyword arguments in fake client initialization * feat: add uuid import for test data generation * feat: add more embedding patterns for pipeline template testing * feat: add unit tests for search_pipeline_templates function * feat: import search_pipeline_templates tool in main module * feat: add search_pipeline_templates tool to main MCP server * feat: add test runner to verify our implementation * feat: create simpler test runner for verification * fix: handle import issues in test runner * feat: create simple import checker * chore: remove test runner as it's not needed for this PR * chore: remove import checker as it's not needed for this PR * refactor: remove search_pipeline_templates from haystack_service.py * feat: add imports for search_pipeline_templates function * feat: add search_pipeline_templates function to pipeline_template.py * fix: remove search_pipeline_templates import from haystack_service * feat: import search_pipeline_templates from pipeline_template module * test: remove search_pipeline_templates import from haystack_service tests * test: remove unused PipelineTemplate imports from haystack_service tests * test: remove FakePipelineTemplatesResource from haystack_service tests * test: simplify FakeClient to remove pipeline_templates functionality * test: remove search_pipeline_templates tests from haystack_service tests * test: add imports for search_pipeline_templates tests * test: add FakeModel class for search_pipeline_templates tests * test: add search_pipeline_templates unit tests * test: fix PipelineTemplate instantiation with correct field names * fix: lint
1 parent d3d8950 commit bdd82b8

4 files changed

Lines changed: 191 additions & 12 deletions

File tree

src/deepset_mcp/main.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from deepset_mcp.tools.pipeline_template import (
2828
get_pipeline_template as get_pipeline_template_tool,
2929
list_pipeline_templates as list_pipeline_templates_tool,
30+
search_pipeline_templates as search_pipeline_templates_tool,
3031
)
3132

3233
INITIALIZED_MODEL = StaticModel.from_pretrained("minishlab/potion-base-2M")
@@ -202,6 +203,24 @@ async def search_component_definitions(query: str) -> str:
202203
return response
203204

204205

206+
@mcp.tool()
207+
async def search_pipeline_templates(query: str) -> str:
208+
"""Use this to search for pipeline templates in deepset.
209+
210+
You can use full natural language queries to find templates.
211+
You can also use simple keywords.
212+
Use this if you want to find pipeline templates for specific use cases,
213+
but you are not sure what the exact name of the template is.
214+
"""
215+
workspace = get_workspace()
216+
async with AsyncDeepsetClient() as client:
217+
response = await search_pipeline_templates_tool(
218+
client=client, query=query, model=INITIALIZED_MODEL, workspace=workspace
219+
)
220+
221+
return response
222+
223+
205224
@mcp.tool()
206225
async def list_indexes() -> str:
207226
"""Retrieves a list of all indexes available in the deepset workspace.

src/deepset_mcp/tools/pipeline_template.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import numpy as np
2+
13
from deepset_mcp.api.exceptions import ResourceNotFoundError, UnexpectedAPIError
24
from deepset_mcp.api.protocols import AsyncClientProtocol
35
from deepset_mcp.tools.formatting_utils import pipeline_template_to_llm_readable_string
6+
from deepset_mcp.tools.model_protocol import ModelProtocol
47

58

69
async def list_pipeline_templates(
@@ -43,3 +46,60 @@ async def get_pipeline_template(client: AsyncClientProtocol, workspace: str, tem
4346
return f"There is no pipeline template named '{template_name}' in workspace '{workspace}'."
4447
except UnexpectedAPIError as e:
4548
return f"Failed to fetch pipeline template '{template_name}': {e}"
49+
50+
51+
async def search_pipeline_templates(
52+
client: AsyncClientProtocol, query: str, model: ModelProtocol, workspace: str, top_k: int = 5
53+
) -> str:
54+
"""Searches for pipeline templates based on name or description using semantic similarity.
55+
56+
Args:
57+
client: The API client to use
58+
query: The search query
59+
model: The model to use for computing embeddings
60+
workspace: The workspace to search templates from
61+
top_k: Maximum number of results to return (default: 5)
62+
63+
Returns:
64+
A formatted string containing the matched pipeline template definitions
65+
"""
66+
try:
67+
response = await client.pipeline_templates(workspace=workspace).list_templates()
68+
except UnexpectedAPIError as e:
69+
return f"Failed to retrieve pipeline templates: {e}"
70+
71+
if not response:
72+
return "No pipeline templates found"
73+
74+
# Extract text for embedding from all templates
75+
template_texts: list[tuple[str, str]] = [
76+
(template.template_name, f"{template.template_name} {template.description}") for template in response
77+
]
78+
template_names: list[str] = [t[0] for t in template_texts]
79+
80+
# Compute embeddings
81+
query_embedding = model.encode(query)
82+
template_embeddings = model.encode([text for _, text in template_texts])
83+
84+
query_embedding_reshaped = query_embedding.reshape(1, -1)
85+
86+
# Calculate dot product between target and all templates
87+
# This gives us a similarity score for each template
88+
similarities = np.dot(template_embeddings, query_embedding_reshaped.T).flatten()
89+
90+
# Create (template_name, similarity) pairs
91+
template_similarities = list(zip(template_names, similarities, strict=False))
92+
93+
# Sort by similarity score in descending order
94+
template_similarities.sort(key=lambda x: x[1], reverse=True)
95+
96+
top_templates = template_similarities[:top_k]
97+
results = []
98+
for template_name, sim in top_templates:
99+
# Find the template object by name
100+
template = next((t for t in response if t.template_name == template_name), None)
101+
if template:
102+
template_str = pipeline_template_to_llm_readable_string(template)
103+
results.append(f"Similarity Score: {sim:.3f}\n{template_str}\n{'-' * 80}\n")
104+
105+
return "\n".join(results)

test/unit/tools/test_haystack_service.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def encode(self, sentences: list[str] | str) -> np.ndarray[Any, Any]:
2626
embeddings[i] = [0, 0, 0.9]
2727
elif "reader" in sentence.lower():
2828
embeddings[i] = [0, 1, 0]
29+
elif "rag" in sentence.lower() or "retrieval" in sentence.lower():
30+
embeddings[i] = [1, 0, 0]
31+
elif "chat" in sentence.lower() or "conversation" in sentence.lower():
32+
embeddings[i] = [0.8, 0.2, 0]
2933
else:
3034
embeddings[i] = [0, 0, 1]
3135
return embeddings
@@ -58,11 +62,16 @@ async def get_component_input_output(self, component_name: str) -> dict[str, Any
5862

5963

6064
class FakeClient(BaseFakeClient):
61-
def __init__(self, resource: FakeHaystackServiceResource):
65+
def __init__(
66+
self,
67+
resource: FakeHaystackServiceResource | None = None,
68+
):
6269
self._resource = resource
6370
super().__init__()
6471

6572
def haystack_service(self) -> FakeHaystackServiceResource:
73+
if self._resource is None:
74+
raise ValueError("Haystack service resource not configured")
6675
return self._resource
6776

6877

@@ -139,7 +148,7 @@ async def test_get_component_definition_success() -> None:
139148
resource = FakeHaystackServiceResource(
140149
get_component_schemas_response=schema_response, get_component_io_response=io_response
141150
)
142-
client = FakeClient(resource)
151+
client = FakeClient(resource=resource)
143152
result = await get_component_definition(client, component_type)
144153

145154
# Check that all required information is present
@@ -169,7 +178,7 @@ async def test_get_component_definition_success() -> None:
169178
async def test_get_component_definition_not_found() -> None:
170179
response: dict[str, Any] = {"component_schema": {"definitions": {"Components": {}}}}
171180
resource = FakeHaystackServiceResource(get_component_schemas_response=response)
172-
client = FakeClient(resource)
181+
client = FakeClient(resource=resource)
173182
result = await get_component_definition(client, "nonexistent.component")
174183
assert "Component not found" in result
175184

@@ -213,7 +222,7 @@ async def test_search_component_definition_success() -> None:
213222
resource = FakeHaystackServiceResource(
214223
get_component_schemas_response=schema_response, get_component_io_response=io_response
215224
)
216-
client = FakeClient(resource)
225+
client = FakeClient(resource=resource)
217226
model = FakeModel()
218227

219228
# Search for converters
@@ -232,7 +241,7 @@ async def test_search_component_definition_success() -> None:
232241
@pytest.mark.asyncio
233242
async def test_get_component_definition_api_error() -> None:
234243
resource = FakeHaystackServiceResource(exception=UnexpectedAPIError(status_code=500, message="API Error"))
235-
client = FakeClient(resource)
244+
client = FakeClient(resource=resource)
236245
result = await get_component_definition(client, "some.component")
237246
assert "Failed to retrieve component definition" in result
238247
assert "API Error" in result
@@ -242,7 +251,7 @@ async def test_get_component_definition_api_error() -> None:
242251
async def test_search_component_definition_no_components() -> None:
243252
schema_response: dict[str, Any] = {"component_schema": {"definitions": {"Components": {}}}}
244253
resource = FakeHaystackServiceResource(get_component_schemas_response=schema_response)
245-
client = FakeClient(resource)
254+
client = FakeClient(resource=resource)
246255
model = FakeModel()
247256

248257
result = await search_component_definition(client, "test query", model)
@@ -252,7 +261,7 @@ async def test_search_component_definition_no_components() -> None:
252261
@pytest.mark.asyncio
253262
async def test_search_component_definition_api_error() -> None:
254263
resource = FakeHaystackServiceResource(exception=UnexpectedAPIError(status_code=500, message="API Error"))
255-
client = FakeClient(resource)
264+
client = FakeClient(resource=resource)
256265
model = FakeModel()
257266

258267
result = await search_component_definition(client, "test query", model)
@@ -263,7 +272,7 @@ async def test_search_component_definition_api_error() -> None:
263272
async def test_list_component_families_no_families() -> None:
264273
response: dict[str, Any] = {"component_schema": {"definitions": {"Components": {}}}}
265274
resource = FakeHaystackServiceResource(get_component_schemas_response=response)
266-
client = FakeClient(resource)
275+
client = FakeClient(resource=resource)
267276
result = await list_component_families(client)
268277
assert "No component families found" in result
269278

@@ -287,7 +296,7 @@ async def test_list_component_families_success() -> None:
287296
}
288297
}
289298
resource = FakeHaystackServiceResource(get_component_schemas_response=response)
290-
client = FakeClient(resource)
299+
client = FakeClient(resource=resource)
291300
result = await list_component_families(client)
292301

293302
assert "Available Haystack component families" in result
@@ -302,7 +311,7 @@ async def test_list_component_families_success() -> None:
302311
@pytest.mark.asyncio
303312
async def test_list_component_families_api_error() -> None:
304313
resource = FakeHaystackServiceResource(exception=UnexpectedAPIError(status_code=500, message="API Error"))
305-
client = FakeClient(resource)
314+
client = FakeClient(resource=resource)
306315
result = await list_component_families(client)
307316
assert "Failed to retrieve component families" in result
308317
assert "API Error" in result

test/unit/tools/test_pipeline_template.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,38 @@
11
from typing import Any
2-
from uuid import UUID
2+
from uuid import UUID, uuid4
33

4+
import numpy as np
45
import pytest
56

67
from deepset_mcp.api.exceptions import ResourceNotFoundError, UnexpectedAPIError
78
from deepset_mcp.api.pipeline_template.models import PipelineTemplate, PipelineTemplateTag, PipelineType
8-
from deepset_mcp.tools.pipeline_template import get_pipeline_template, list_pipeline_templates
9+
from deepset_mcp.tools.model_protocol import ModelProtocol
10+
from deepset_mcp.tools.pipeline_template import (
11+
get_pipeline_template,
12+
list_pipeline_templates,
13+
search_pipeline_templates,
14+
)
915
from test.unit.conftest import BaseFakeClient
1016

1117

18+
class FakeModel(ModelProtocol):
19+
def encode(self, sentences: list[str] | str) -> np.ndarray[Any, Any]:
20+
# Convert input to list if it's a single string
21+
if isinstance(sentences, str):
22+
sentences = [sentences]
23+
24+
# Create fake embeddings with consistent similarities
25+
embeddings = np.zeros((len(sentences), 3))
26+
for i, sentence in enumerate(sentences):
27+
if "rag" in sentence.lower() or "retrieval" in sentence.lower():
28+
embeddings[i] = [1, 0, 0]
29+
elif "chat" in sentence.lower() or "conversation" in sentence.lower():
30+
embeddings[i] = [0.8, 0.2, 0]
31+
else:
32+
embeddings[i] = [0, 0, 1]
33+
return embeddings
34+
35+
1236
class FakePipelineTemplateResource:
1337
def __init__(
1438
self,
@@ -231,3 +255,70 @@ async def test_list_pipeline_templates_with_filter_and_sorting() -> None:
231255
assert resource.last_list_call_params["field"] == "name"
232256
assert resource.last_list_call_params["order"] == "ASC"
233257
assert resource.last_list_call_params["filter"] == filter_value
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_search_pipeline_templates_success() -> None:
262+
# Create sample pipeline templates
263+
templates = [
264+
PipelineTemplate(
265+
author="Deepset",
266+
best_for=["Document Q&A"],
267+
description="A retrieval-augmented generation template for answering questions",
268+
pipeline_name="rag-pipeline",
269+
name="RAG Pipeline",
270+
pipeline_template_id=uuid4(),
271+
potential_applications=["FAQ systems", "Document search"],
272+
query_yaml="components:\n retriever: ...\n generator: ...",
273+
tags=[],
274+
pipeline_type=PipelineType.QUERY,
275+
),
276+
PipelineTemplate(
277+
author="Deepset",
278+
best_for=["Conversational AI"],
279+
description="A chat-based conversational pipeline for interactive responses",
280+
pipeline_name="chat-pipeline",
281+
name="Chat Pipeline",
282+
pipeline_template_id=uuid4(),
283+
potential_applications=["Chatbots", "Virtual assistants"],
284+
query_yaml="components:\n chat_generator: ...\n memory: ...",
285+
tags=[],
286+
pipeline_type=PipelineType.QUERY,
287+
),
288+
]
289+
290+
resource = FakePipelineTemplateResource(list_response=templates)
291+
client = FakeClient(resource)
292+
model = FakeModel()
293+
294+
# Search for RAG templates
295+
result = await search_pipeline_templates(client, "retrieval augmented generation", model, "test_workspace")
296+
assert "rag-pipeline" in result
297+
assert "Similarity Score:" in result
298+
assert "retrieval-augmented generation" in result
299+
300+
# Search for chat templates
301+
result = await search_pipeline_templates(client, "conversational chat interface", model, "test_workspace")
302+
assert "chat-pipeline" in result
303+
assert "Similarity Score:" in result
304+
assert "chat-based conversational" in result
305+
306+
307+
@pytest.mark.asyncio
308+
async def test_search_pipeline_templates_no_templates() -> None:
309+
resource = FakePipelineTemplateResource(list_response=[])
310+
client = FakeClient(resource)
311+
model = FakeModel()
312+
313+
result = await search_pipeline_templates(client, "test query", model, "test_workspace")
314+
assert "No pipeline templates found" in result
315+
316+
317+
@pytest.mark.asyncio
318+
async def test_search_pipeline_templates_api_error() -> None:
319+
resource = FakePipelineTemplateResource(list_exception=UnexpectedAPIError(status_code=500, message="API Error"))
320+
client = FakeClient(resource)
321+
model = FakeModel()
322+
323+
result = await search_pipeline_templates(client, "test query", model, "test_workspace")
324+
assert "Failed to retrieve pipeline templates" in result

0 commit comments

Comments
 (0)