Skip to content

Commit 8f060d2

Browse files
authored
feat: template tools should support indexes (#153)
* feat: template tools should support indexes * fix: format
1 parent 90b9c8f commit 8f060d2

6 files changed

Lines changed: 494 additions & 59 deletions

File tree

src/deepset_mcp/api/pipeline_template/models.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from enum import StrEnum
2+
from typing import Any
23
from uuid import UUID
34

4-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, model_validator
56

67

78
class PipelineType(StrEnum):
@@ -28,10 +29,36 @@ class PipelineTemplate(BaseModel):
2829
display_name: str = Field(alias="name")
2930
pipeline_template_id: UUID = Field(alias="pipeline_template_id")
3031
potential_applications: list[str] = Field(alias="potential_applications")
31-
yaml_config: str | None = Field(None, alias="query_yaml")
32+
yaml_config: str | None = None
3233
tags: list[PipelineTemplateTag]
3334
pipeline_type: PipelineType
3435

36+
@model_validator(mode="before")
37+
@classmethod
38+
def populate_yaml_config(cls, values: Any) -> Any:
39+
"""Populate yaml_config from query_yaml or indexing_yaml based on pipeline_type."""
40+
if not isinstance(values, dict):
41+
return values
42+
43+
# Skip if yaml_config is already set
44+
if values.get("yaml_config") is not None:
45+
return values
46+
47+
# Get pipeline_type from the model data
48+
pipeline_type = values.get("pipeline_type")
49+
50+
if pipeline_type == PipelineType.INDEXING or pipeline_type == "indexing":
51+
yaml_config = values.get("indexing_yaml")
52+
elif pipeline_type == PipelineType.QUERY or pipeline_type == "query":
53+
yaml_config = values.get("query_yaml")
54+
else:
55+
yaml_config = None
56+
57+
if yaml_config is not None:
58+
values["yaml_config"] = yaml_config
59+
60+
return values
61+
3562

3663
class PipelineTemplateList(BaseModel):
3764
"""Response model for listing pipeline templates."""

src/deepset_mcp/tool_factory.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747
validate_pipeline as validate_pipeline_tool,
4848
)
4949
from deepset_mcp.tools.pipeline_template import (
50-
get_pipeline_template as get_pipeline_template_tool,
51-
list_pipeline_templates as list_pipeline_templates_tool,
52-
search_pipeline_templates as search_pipeline_templates_tool,
50+
get_template as get_pipeline_template_tool,
51+
list_templates as list_pipeline_templates_tool,
52+
search_templates as search_pipeline_templates_tool,
5353
)
5454
from deepset_mcp.tools.secrets import (
5555
get_secret as get_secret_tool,
@@ -230,20 +230,20 @@ def get_workspace_from_env() -> str:
230230
deploy_index_tool,
231231
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.EXPLORABLE),
232232
),
233-
"list_pipeline_templates": (
233+
"list_templates": (
234234
list_pipeline_templates_tool,
235235
ToolConfig(
236236
needs_client=True,
237237
needs_workspace=True,
238238
memory_type=MemoryType.EXPLORABLE,
239-
custom_args={"filter": None, "field": "created_at", "order": "DESC", "limit": 100},
239+
custom_args={"field": "created_at", "order": "DESC", "limit": 100},
240240
),
241241
),
242-
"get_pipeline_template": (
242+
"get_template": (
243243
get_pipeline_template_tool,
244244
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.EXPLORABLE),
245245
),
246-
"search_pipeline_templates": (
246+
"search_templates": (
247247
search_pipeline_templates_tool,
248248
ToolConfig(
249249
needs_client=True,

src/deepset_mcp/tools/pipeline_template.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,51 +6,53 @@
66
PipelineTemplateList,
77
PipelineTemplateSearchResult,
88
PipelineTemplateSearchResults,
9+
PipelineType,
910
)
1011
from deepset_mcp.api.protocols import AsyncClientProtocol
1112
from deepset_mcp.tools.model_protocol import ModelProtocol
1213

1314

14-
async def list_pipeline_templates(
15+
async def list_templates(
1516
*,
1617
client: AsyncClientProtocol,
1718
workspace: str,
1819
limit: int = 100,
1920
field: str = "created_at",
2021
order: str = "DESC",
21-
filter: str | None = None,
22+
pipeline_type: PipelineType | str | None = None,
2223
) -> PipelineTemplateList | str:
23-
"""Retrieves a list of all available pipeline templates.
24+
"""Retrieves a list of all available pipeline and indexing templates.
2425
2526
:param client: The async client for API requests.
2627
:param workspace: The workspace to list templates from.
2728
:param limit: Maximum number of templates to return (default: 100).
2829
:param field: Field to sort by (default: "created_at").
2930
:param order: Sort order, either "ASC" or "DESC" (default: "DESC").
30-
:param filter: OData filter expression to filter templates by criteria.
31+
:param pipeline_type: The type of pipeline to return.
3132
3233
:returns: List of pipeline templates or error message.
3334
"""
3435
try:
3536
return await client.pipeline_templates(workspace=workspace).list_templates(
36-
limit=limit, field=field, order=order, filter=filter
37+
limit=limit,
38+
field=field,
39+
order=order,
40+
filter=f"pipeline_type eq '{pipeline_type}'" if pipeline_type else None,
3741
)
3842
except ResourceNotFoundError:
3943
return f"There is no workspace named '{workspace}'. Did you mean to configure it?"
4044
except UnexpectedAPIError as e:
4145
return f"Failed to list pipeline templates: {e}"
4246

4347

44-
async def get_pipeline_template(
45-
*, client: AsyncClientProtocol, workspace: str, template_name: str
46-
) -> PipelineTemplate | str:
47-
"""Fetches detailed information for a specific pipeline template, identified by its `template_name`.
48+
async def get_template(*, client: AsyncClientProtocol, workspace: str, template_name: str) -> PipelineTemplate | str:
49+
"""Fetches detailed information for a specific pipeline or indexing template, identified by its `template_name`.
4850
4951
:param client: The async client for API requests.
5052
:param workspace: The workspace to fetch template from.
5153
:param template_name: The name of the template to fetch.
5254
53-
:returns: Pipeline template details or error message.
55+
:returns: Pipeline or indexing template details or error message.
5456
"""
5557
try:
5658
return await client.pipeline_templates(workspace=workspace).get_template(template_name=template_name)
@@ -60,22 +62,29 @@ async def get_pipeline_template(
6062
return f"Failed to fetch pipeline template '{template_name}': {e}"
6163

6264

63-
async def search_pipeline_templates(
64-
*, client: AsyncClientProtocol, query: str, model: ModelProtocol, workspace: str, top_k: int = 10
65+
async def search_templates(
66+
*,
67+
client: AsyncClientProtocol,
68+
query: str,
69+
model: ModelProtocol,
70+
workspace: str,
71+
top_k: int = 10,
72+
pipeline_type: PipelineType | str = PipelineType.QUERY,
6573
) -> PipelineTemplateSearchResults | str:
66-
"""Searches for pipeline templates based on name or description using semantic similarity.
74+
"""Searches for pipeline or indexing templates based on name or description using semantic similarity.
6775
6876
:param client: The API client to use.
6977
:param query: The search query.
7078
:param model: The model to use for computing embeddings.
7179
:param workspace: The workspace to search templates from.
7280
:param top_k: Maximum number of results to return (default: 10).
81+
:param pipeline_type: The type of pipeline to return ('indexing' or 'query'; default: 'query').
7382
7483
:returns: Search results with similarity scores or error message.
7584
"""
7685
try:
7786
response = await client.pipeline_templates(workspace=workspace).list_templates(
78-
filter="pipeline_type eq 'QUERY'"
87+
filter=f"pipeline_type eq '{pipeline_type}'"
7988
)
8089
except UnexpectedAPIError as e:
8190
return f"Failed to retrieve pipeline templates: {e}"

test/integration/test_integration_pipeline_template_resource.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,86 @@ async def test_list_templates_with_custom_sorting(
143143
if len(templates_list.data) > 1:
144144
for i in range(len(templates_list.data) - 1):
145145
assert templates_list.data[i].display_name <= templates_list.data[i + 1].display_name
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_get_indexing_template(
150+
template_resource: PipelineTemplateResource,
151+
) -> None:
152+
"""Test getting a single indexing template by name.
153+
154+
First lists all indexing templates, then gets the first one by name.
155+
"""
156+
# Get all indexing templates
157+
indexing_templates_list = await template_resource.list_templates(filter="pipeline_type eq 'INDEXING'")
158+
159+
# Skip if no indexing templates are available
160+
if not indexing_templates_list.data:
161+
pytest.skip("No indexing templates available in the test environment")
162+
163+
# Get the first indexing template's name
164+
template_name = indexing_templates_list.data[0].template_name
165+
166+
# Now get that specific template
167+
template = await template_resource.get_template(template_name=template_name)
168+
169+
# Verify the template was retrieved correctly
170+
assert template.template_name == template_name
171+
assert template.pipeline_type == "indexing"
172+
assert template.pipeline_template_id is not None
173+
assert template.yaml_config is not None # Should have indexing_yaml content
174+
assert isinstance(template.best_for, list)
175+
assert isinstance(template.potential_applications, list)
176+
assert isinstance(template.tags, list)
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_list_indexing_templates_with_filter(
181+
template_resource: PipelineTemplateResource,
182+
) -> None:
183+
"""Test listing templates with an indexing pipeline type filter."""
184+
# Test filtering by INDEXING pipeline type
185+
indexing_templates_list = await template_resource.list_templates(filter="pipeline_type eq 'INDEXING'")
186+
187+
# Verify that all returned templates are INDEXING type
188+
assert isinstance(indexing_templates_list, PipelineTemplateList)
189+
assert isinstance(indexing_templates_list.data, list)
190+
191+
# If templates are available, verify they are all INDEXING type
192+
for template in indexing_templates_list.data:
193+
assert isinstance(template, PipelineTemplate)
194+
assert template.pipeline_type == "indexing"
195+
assert template.yaml_config is not None # Should have indexing_yaml content
196+
197+
198+
@pytest.mark.asyncio
199+
async def test_mixed_pipeline_types_integration(
200+
template_resource: PipelineTemplateResource,
201+
) -> None:
202+
"""Test that query and indexing templates can be retrieved together."""
203+
# Get all templates
204+
all_templates = await template_resource.list_templates()
205+
206+
# Skip if no templates are available
207+
if not all_templates.data:
208+
pytest.skip("No templates available in the test environment")
209+
210+
# Separate templates by type
211+
query_templates = [t for t in all_templates.data if t.pipeline_type == "query"]
212+
indexing_templates = [t for t in all_templates.data if t.pipeline_type == "indexing"]
213+
214+
# Verify that both types can exist and have proper yaml_config
215+
for template in query_templates:
216+
assert template.pipeline_type == "query"
217+
if template.yaml_config is not None:
218+
# Query templates should have query_yaml content
219+
assert isinstance(template.yaml_config, str)
220+
221+
for template in indexing_templates:
222+
assert template.pipeline_type == "indexing"
223+
if template.yaml_config is not None:
224+
# Indexing templates should have indexing_yaml content
225+
assert isinstance(template.yaml_config, str)
226+
227+
# Verify that the total matches the sum of individual types
228+
assert len(query_templates) + len(indexing_templates) == len(all_templates.data)

0 commit comments

Comments
 (0)