Skip to content

Commit 7706aca

Browse files
authored
feat: migrate index tools to decorator pattern (#111)
1 parent b295d52 commit 7706aca

3 files changed

Lines changed: 122 additions & 53 deletions

File tree

src/deepset_mcp/tool_factory.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Awaitable, Callable
66
from dataclasses import dataclass
77
from enum import StrEnum
8-
from typing import Any, Literal
8+
from typing import Any
99

1010
from mcp.server.fastmcp import FastMCP
1111

@@ -98,13 +98,22 @@ class WorkspaceMode(StrEnum):
9898
EXPLICIT = "explicit" # workspace as required parameter in tool signature
9999

100100

101+
class MemoryType(StrEnum):
102+
"""Configuration for how memory is provided to tools."""
103+
104+
EXPLORABLE = "explorable"
105+
REFERENCEABLE = "referenceable"
106+
BOTH = "both"
107+
NO_MEMORY = "no_memory"
108+
109+
101110
@dataclass
102111
class ToolConfig:
103112
"""Configuration for tool registration."""
104113

105114
needs_client: bool = False
106115
needs_workspace: bool = False
107-
decorator_type: Literal["none", "explorable", "referenceable", "both"] = "none"
116+
memory_type: MemoryType = MemoryType.NO_MEMORY
108117
custom_args: dict[str, Any] | None = None # For special cases like search_component_definition
109118

110119

@@ -127,11 +136,26 @@ def get_workspace_from_env() -> str:
127136
"validate_pipeline": (validate_pipeline_tool, ToolConfig(needs_client=True, needs_workspace=True)),
128137
"get_pipeline_logs": (get_pipeline_logs_tool, ToolConfig(needs_client=True, needs_workspace=True)),
129138
"search_pipeline": (search_pipeline_tool, ToolConfig(needs_client=True, needs_workspace=True)),
130-
"list_indexes": (list_indexes_tool, ToolConfig(needs_client=True, needs_workspace=True)),
131-
"get_index": (get_index_tool, ToolConfig(needs_client=True, needs_workspace=True)),
132-
"create_index": (create_index_tool, ToolConfig(needs_client=True, needs_workspace=True)),
133-
"update_index": (update_index_tool, ToolConfig(needs_client=True, needs_workspace=True)),
134-
"deploy_index": (deploy_index_tool, ToolConfig(needs_client=True, needs_workspace=True)),
139+
"list_indexes": (
140+
list_indexes_tool,
141+
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.EXPLORABLE),
142+
),
143+
"get_index": (
144+
get_index_tool,
145+
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.EXPLORABLE),
146+
),
147+
"create_index": (
148+
create_index_tool,
149+
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.BOTH),
150+
),
151+
"update_index": (
152+
update_index_tool,
153+
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.BOTH),
154+
),
155+
"deploy_index": (
156+
deploy_index_tool,
157+
ToolConfig(needs_client=True, needs_workspace=True, memory_type=MemoryType.EXPLORABLE),
158+
),
135159
"list_pipeline_templates": (list_pipeline_templates_tool, ToolConfig(needs_client=True, needs_workspace=True)),
136160
"get_pipeline_template": (get_pipeline_template_tool, ToolConfig(needs_client=True, needs_workspace=True)),
137161
"search_pipeline_templates": (
@@ -172,15 +196,15 @@ def create_enhanced_tool(
172196
"""
173197
# Apply decorators first (if needed)
174198
decorated_func = base_func
175-
if config.decorator_type != "none":
199+
if config.memory_type != "none":
176200
store = STORE
177201
explorer = RichExplorer(store)
178202

179-
if config.decorator_type == "explorable":
203+
if config.memory_type == "explorable":
180204
decorated_func = explorable(object_store=store, explorer=explorer)(decorated_func)
181-
elif config.decorator_type == "referenceable":
205+
elif config.memory_type == "referenceable":
182206
decorated_func = referenceable(object_store=store, explorer=explorer)(decorated_func)
183-
elif config.decorator_type == "both":
207+
elif config.memory_type == "both":
184208
decorated_func = explorable_and_referenceable(object_store=store, explorer=explorer)(decorated_func)
185209

186210
# Handle client and workspace injection

src/deepset_mcp/tools/indexes.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,36 @@
11
from deepset_mcp.api.exceptions import BadRequestError, ResourceNotFoundError, UnexpectedAPIError
2+
from deepset_mcp.api.indexes.models import Index, IndexList
3+
from deepset_mcp.api.pipeline import PipelineValidationResult
24
from deepset_mcp.api.protocols import AsyncClientProtocol
3-
from deepset_mcp.tools.formatting_utils import validation_result_to_llm_readable_string
4-
from deepset_mcp.tools.formatting_utils_index import index_list_to_llm_readable_string, index_to_llm_readable_string
55

66

7-
async def list_indexes(client: AsyncClientProtocol, workspace: str) -> str:
8-
"""Retrieves a list of all indexes available within the currently configured deepset workspace."""
9-
response = await client.indexes(workspace=workspace).list()
10-
return index_list_to_llm_readable_string(response)
7+
async def list_indexes(client: AsyncClientProtocol, workspace: str) -> IndexList | str:
8+
"""Use this to list available indexes on the deepset platform in your workspace.
9+
10+
:param client: Deepset API client to use for requesting indexes.
11+
:param workspace: Workspace of which to list indexes.
12+
"""
13+
try:
14+
result = await client.indexes(workspace=workspace).list()
15+
except ResourceNotFoundError as e:
16+
return f"Error listing indexes. Error: {e.message} ({e.status_code})"
17+
18+
return result
1119

1220

13-
async def get_index(client: AsyncClientProtocol, workspace: str, index_name: str) -> str:
14-
"""Fetches detailed configuration information for a specific index, identified by its unique `index_name`."""
21+
async def get_index(client: AsyncClientProtocol, workspace: str, index_name: str) -> Index | str:
22+
"""Fetches detailed configuration information for a specific index, identified by its unique `index_name`.
23+
24+
:param client: Deepset API client to use for requesting the index.
25+
:param workspace: Workspace of which to get the index from.
26+
:param index_name: Unique name of the index to fetch.
27+
"""
1528
try:
1629
response = await client.indexes(workspace=workspace).get(index_name)
1730
except ResourceNotFoundError:
1831
return f"There is no index named '{index_name}'. Did you mean to create it?"
1932

20-
return index_to_llm_readable_string(response)
33+
return response
2134

2235

2336
async def create_index(
@@ -26,10 +39,17 @@ async def create_index(
2639
index_name: str,
2740
yaml_configuration: str,
2841
description: str | None = None,
29-
) -> str:
30-
"""Creates a new index within the currently configured deepset workspace."""
42+
) -> dict[str, str | Index] | str:
43+
"""Creates a new index within your deepset platform workspace.
44+
45+
:param client: Deepset API client to use.
46+
:param workspace: Workspace in which to create the index.
47+
:param index_name: Unique name of the index to create.
48+
:param yaml_configuration: YAML configuration to use for the index.
49+
:param description: Description of the index to create.
50+
"""
3151
try:
32-
await client.indexes(workspace=workspace).create(
52+
result = await client.indexes(workspace=workspace).create(
3353
name=index_name, yaml_config=yaml_configuration, description=description
3454
)
3555
except ResourceNotFoundError:
@@ -39,7 +59,7 @@ async def create_index(
3959
except UnexpectedAPIError as e:
4060
return f"Failed to create index '{index_name}': {e}"
4161

42-
return f"Index '{index_name}' created successfully."
62+
return {"message": f"Index '{index_name}' created successfully.", "index": result}
4363

4464

4565
async def update_index(
@@ -48,17 +68,23 @@ async def update_index(
4868
index_name: str,
4969
updated_index_name: str | None = None,
5070
yaml_configuration: str | None = None,
51-
) -> str:
52-
"""Updates an existing index in the specified workspace.
71+
) -> dict[str, str | Index] | str:
72+
"""Updates an existing index in your deepset platform workspace.
5373
5474
This function can update either the name or the configuration of an existing index, or both.
5575
At least one of updated_index_name or yaml_configuration must be provided.
76+
77+
:param client: Deepset API client to use.
78+
:param workspace: Workspace in which to update the index.
79+
:param index_name: Unique name of the index to update.
80+
:param updated_index_name: Updated name of the index.
81+
:param yaml_configuration: YAML configuration to update the index with.
5682
"""
5783
if not updated_index_name and not yaml_configuration:
5884
return "You must provide either a new name or a new configuration to update the index."
5985

6086
try:
61-
await client.indexes(workspace=workspace).update(
87+
result = await client.indexes(workspace=workspace).update(
6288
index_name=index_name, updated_index_name=updated_index_name, yaml_config=yaml_configuration
6389
)
6490
except ResourceNotFoundError:
@@ -68,21 +94,21 @@ async def update_index(
6894
except UnexpectedAPIError as e:
6995
return f"Failed to update index '{index_name}': {e}"
7096

71-
return f"Index '{index_name}' updated successfully."
97+
return {"message": f"Index '{index_name}' updated successfully.", "index": result}
7298

7399

74-
async def deploy_index(client: AsyncClientProtocol, workspace: str, index_name: str) -> str:
100+
async def deploy_index(client: AsyncClientProtocol, workspace: str, index_name: str) -> str | PipelineValidationResult:
75101
"""Deploys an index to production.
76102
77103
This function attempts to deploy the specified index in the given workspace.
78-
If the deployment fails due to validation errors, it returns a readable string
104+
If the deployment fails due to validation errors, it returns an object
79105
describing the validation errors.
80106
81107
:param client: The async client for API communication.
82108
:param workspace: The workspace name.
83109
:param index_name: Name of the index to deploy.
84110
85-
:returns: A string indicating the deployment result.
111+
:returns: A string indicating the deployment result or the validation results including errors.
86112
"""
87113
try:
88114
deployment_result = await client.indexes(workspace=workspace).deploy(index_name=index_name)
@@ -94,6 +120,6 @@ async def deploy_index(client: AsyncClientProtocol, workspace: str, index_name:
94120
return f"Failed to deploy index '{index_name}': {e}"
95121

96122
if not deployment_result.valid:
97-
return validation_result_to_llm_readable_string(deployment_result)
123+
return deployment_result
98124

99125
return f"Index '{index_name}' deployed successfully."

test/unit/tools/test_indexes.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
create_response: Index | None = None,
2020
update_response: Index | None = None,
2121
deploy_response: PipelineValidationResult | None = None,
22+
list_exception: Exception | None = None,
2223
get_exception: Exception | None = None,
2324
create_exception: Exception | None = None,
2425
update_exception: Exception | None = None,
@@ -29,12 +30,15 @@ def __init__(
2930
self._create_response = create_response
3031
self._update_response = update_response
3132
self._deploy_response = deploy_response
33+
self._list_exception = list_exception
3234
self._get_exception = get_exception
3335
self._create_exception = create_exception
3436
self._update_exception = update_exception
3537
self._deploy_exception = deploy_exception
3638

3739
async def list(self, limit: int = 10, page_number: int = 1) -> IndexList:
40+
if self._list_exception:
41+
raise self._list_exception
3842
if self._list_response is not None:
3943
return self._list_response
4044
return IndexList(data=[], has_more=False, total=0)
@@ -118,17 +122,18 @@ def create_test_index(
118122

119123

120124
@pytest.mark.asyncio
121-
async def test_list_indexes_returns_formatted_string_when_no_indexes() -> None:
125+
async def test_list_indexes_without_indexes() -> None:
122126
resource = FakeIndexResource(list_response=IndexList(data=[], has_more=False, total=0))
123127
client = FakeClient(resource)
124128

125129
result = await list_indexes(client=client, workspace="test")
126130

127-
assert result == "No indexes found."
131+
assert isinstance(result, IndexList)
132+
assert len(result.data) == 0
128133

129134

130135
@pytest.mark.asyncio
131-
async def test_list_indexes_returns_formatted_string_with_indexes() -> None:
136+
async def test_list_indexes_returns_indexes() -> None:
132137
index1 = create_test_index(name="index1", description="First index")
133138
index2 = create_test_index(name="index2", description="Second index")
134139

@@ -137,25 +142,33 @@ async def test_list_indexes_returns_formatted_string_with_indexes() -> None:
137142

138143
result = await list_indexes(client=client, workspace="test")
139144

140-
assert "index1" in result
141-
assert "index2" in result
142-
assert "First index" in result
143-
assert "Second index" in result
144-
assert "idx_123" in result
145+
assert isinstance(result, IndexList)
146+
assert len(result.data) == 2
147+
assert result.data[0].name == index1.name
148+
assert result.data[1].name == index2.name
145149

146150

147151
@pytest.mark.asyncio
148-
async def test_get_index_returns_formatted_string() -> None:
152+
async def test_list_indexes_returns_string_on_non_existant_workspace() -> None:
153+
resource = FakeIndexResource(list_exception=ResourceNotFoundError(message="Resource not found."))
154+
client = FakeClient(resource)
155+
156+
result = await list_indexes(client=client, workspace="test")
157+
158+
assert isinstance(result, str)
159+
assert result == "Error listing indexes. Error: Resource not found. (404)"
160+
161+
162+
@pytest.mark.asyncio
163+
async def test_get_index_returns_index() -> None:
149164
index = create_test_index(name="my_index", description="My special index")
150165
resource = FakeIndexResource(get_response=index)
151166
client = FakeClient(resource)
152167

153168
result = await get_index(client=client, workspace="test", index_name="my_index")
154169

155-
assert "my_index" in result
156-
assert "config: value" in result
157-
assert "idx_123" in result
158-
assert "My special index" in result
170+
assert isinstance(result, Index)
171+
assert result.name == "my_index"
159172

160173

161174
@pytest.mark.asyncio
@@ -169,7 +182,7 @@ async def test_get_index_returns_error_message_when_index_not_found() -> None:
169182

170183

171184
@pytest.mark.asyncio
172-
async def test_create_index_returns_success_message() -> None:
185+
async def test_create_index_returns_success_message_and_index() -> None:
173186
created_index = create_test_index(name="new_index")
174187
resource = FakeIndexResource(create_response=created_index)
175188
client = FakeClient(resource)
@@ -182,7 +195,11 @@ async def test_create_index_returns_success_message() -> None:
182195
description="New index description",
183196
)
184197

185-
assert "Index 'new_index' created successfully." == result
198+
assert isinstance(result, dict)
199+
assert isinstance(result.get("message"), str)
200+
index = result.get("index")
201+
assert isinstance(index, Index)
202+
assert index.name == "new_index"
186203

187204

188205
@pytest.mark.parametrize(
@@ -225,7 +242,12 @@ async def test_update_index_returns_success_message() -> None:
225242
yaml_configuration="new_config",
226243
)
227244

228-
assert "Index 'test_index' updated successfully." == result
245+
assert isinstance(result, dict)
246+
assert isinstance(result.get("message"), str)
247+
248+
index = result.get("index")
249+
assert isinstance(index, Index)
250+
assert index.name == "new_test_index"
229251

230252

231253
@pytest.mark.asyncio
@@ -355,11 +377,8 @@ async def test_deploy_index_returns_validation_errors() -> None:
355377

356378
result = await deploy_index(client=client, workspace="test", index_name="test_index")
357379

358-
assert "The provided pipeline configuration is invalid" in result
359-
assert "invalid_config" in result
360-
assert "Index configuration is invalid" in result
361-
assert "missing_dependency" in result
362-
assert "Required dependency not found" in result
380+
assert isinstance(result, PipelineValidationResult)
381+
assert result.errors == validation_errors
363382

364383

365384
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)