Skip to content

Commit a527512

Browse files
committed
feat: added in async io event loop for fetching collection query info and added threadpool executor to process the queryable results [2025-07-30]
1 parent 252e907 commit a527512

File tree

1 file changed

+76
-57
lines changed

1 file changed

+76
-57
lines changed

src/mcp_service/os_service.py

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import json
22
import asyncio
33
import functools
4+
import concurrent.futures
5+
import time
6+
import threading
47

58
from typing import Optional, List, Dict, Any, Union
69
from api_service.protocols import APIClient
@@ -9,6 +12,7 @@
912
from mcp_service.guardrails import ToolGuardrails
1013
from workflow_generator.workflow_planner import WorkflowPlanner
1114
from utils.logging_config import get_logger
15+
from models import Collection
1216

1317
logger = get_logger(__name__)
1418

@@ -95,6 +99,7 @@ async def _cleanup(self):
9599
except Exception as e:
96100
logger.error(f"Error closing API client: {e}")
97101

102+
# TODO: All this processing should really be done outside of the os mcp service level - and we need to cache the results
98103
async def get_workflow_context(self) -> str:
99104
"""Get workflow context and initialise planner if needed"""
100105
try:
@@ -104,96 +109,107 @@ async def get_workflow_context(self) -> str:
104109

105110
collections_info = {}
106111
if cached_collections and hasattr(cached_collections, "collections"):
107-
collections_list = getattr(cached_collections, "collections", [])
112+
collections_list: List[Collection] = getattr(cached_collections, "collections", [])
108113
if collections_list and hasattr(collections_list, "__iter__"):
109-
for collection in collections_list:
114+
115+
async def fetch_collection_queryables(collection: Collection) -> Dict[str, Any]:
110116
try:
111-
# TODO: This needs to be split into async tasks and run in parallel
112-
# TODO: This also needs to be shifted into the api_client and cached like the rest of the data
113117
queryables_data = await self.api_client.make_request(
114118
"COLLECTION_QUERYABLES", path_params=[collection.id]
115119
)
116-
120+
117121
all_queryables = {}
118122
enum_queryables = {}
119123
properties = queryables_data.get("properties", {})
120-
124+
121125
for prop_name, prop_details in properties.items():
122126
prop_type = prop_details.get("type", ["string"])
123127
if isinstance(prop_type, list):
124-
main_type = (
125-
prop_type[0] if prop_type else "string"
126-
)
128+
main_type = prop_type[0] if prop_type else "string"
127129
is_nullable = "null" in prop_type
128130
else:
129131
main_type = prop_type
130132
is_nullable = False
131-
133+
132134
all_queryables[prop_name] = {
133135
"type": main_type,
134136
"nullable": is_nullable,
135-
"description": prop_details.get(
136-
"description", ""
137-
),
137+
"description": prop_details.get("description", ""),
138138
"max_length": prop_details.get("maxLength"),
139139
"format": prop_details.get("format"),
140140
"pattern": prop_details.get("pattern"),
141141
"minimum": prop_details.get("minimum"),
142142
"maximum": prop_details.get("maximum"),
143-
"is_enum": prop_details.get(
144-
"enumeration", False
145-
),
143+
"is_enum": prop_details.get("enumeration", False)
146144
}
147-
148-
if (
149-
prop_details.get("enumeration")
150-
and "enum" in prop_details
151-
):
145+
146+
if prop_details.get("enumeration") and "enum" in prop_details:
152147
enum_queryables[prop_name] = {
153148
"values": prop_details["enum"],
154149
"type": main_type,
155150
"nullable": is_nullable,
156-
"description": prop_details.get(
157-
"description", ""
158-
),
159-
"max_length": prop_details.get("maxLength"),
151+
"description": prop_details.get("description", ""),
152+
"max_length": prop_details.get("maxLength")
160153
}
161-
all_queryables[prop_name]["enum_values"] = (
162-
prop_details["enum"]
163-
)
164-
154+
all_queryables[prop_name]["enum_values"] = prop_details["enum"]
155+
165156
all_queryables[prop_name] = {
166-
k: v
167-
for k, v in all_queryables[prop_name].items()
157+
k: v for k, v in all_queryables[prop_name].items()
168158
if v is not None
169159
}
170-
171-
collections_info[collection.id] = {
172-
"id": collection.id,
173-
"title": collection.title,
174-
"description": collection.description,
160+
161+
return {
162+
"collection": collection,
175163
"all_queryables": all_queryables,
176164
"enum_queryables": enum_queryables,
177-
"has_enum_filters": len(enum_queryables) > 0,
178-
"total_queryables": len(all_queryables),
179-
"enum_count": len(enum_queryables),
180165
}
181-
166+
182167
except Exception as e:
183-
logger.warning(
184-
f"Failed to fetch queryables for {collection.id}: {e}"
185-
)
186-
187-
collections_info[collection.id] = {
188-
"id": collection.id,
189-
"title": collection.title,
190-
"description": collection.description,
168+
logger.warning(f"Failed to fetch queryables for {collection.id}: {e}")
169+
return {
170+
"collection": collection,
191171
"all_queryables": {},
192172
"enum_queryables": {},
193-
"has_enum_filters": False,
194-
"total_queryables": 0,
195-
"enum_count": 0,
196173
}
174+
175+
# This should reduce the network io bottleneck and speed it up!
176+
tasks = [fetch_collection_queryables(collection) for collection in collections_list]
177+
queryables_results = await asyncio.gather(*tasks)
178+
179+
logger.debug(f"Starting thread pool processing for {len(queryables_results)} results...")
180+
181+
def process_collection_result(result):
182+
collection = result["collection"]
183+
logger.debug(f"Processing collection {collection.id} in thread {threading.current_thread().name}")
184+
all_queryables = result["all_queryables"]
185+
enum_queryables = result["enum_queryables"]
186+
187+
return (collection.id, {
188+
"id": collection.id,
189+
"title": collection.title,
190+
"description": collection.description,
191+
"all_queryables": all_queryables,
192+
"enum_queryables": enum_queryables,
193+
"has_enum_filters": len(enum_queryables) > 0,
194+
"total_queryables": len(all_queryables),
195+
"enum_count": len(enum_queryables)
196+
})
197+
198+
thread_start = time.time()
199+
200+
# This should reduce the work required to process the queryables results
201+
# TODO: need to check this really does speed it up
202+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
203+
logger.debug(f"Thread pool created with {executor._max_workers} workers")
204+
processed = await asyncio.get_event_loop().run_in_executor(
205+
executor,
206+
lambda: list(map(process_collection_result, queryables_results))
207+
)
208+
209+
thread_end = time.time()
210+
logger.debug(f"Thread pool processing completed in {thread_end - thread_start:.4f}s")
211+
212+
collections_info = dict(processed)
197213

198214
self.workflow_planner = WorkflowPlanner(cached_spec, collections_info)
199215

@@ -205,31 +221,35 @@ async def get_workflow_context(self) -> str:
205221
"required_explanation": {
206222
"1": "Which collection you will use and why",
207223
"2": "What specific filters you will apply (show the exact filter string)",
208-
"3": "What steps you will take",
224+
"3": "What steps you will take"
209225
},
210226
"workflow_enforcement": "Do not proceed with tool calls until you have clearly explained your plan to the user",
211-
"example_planning": "I will search the 'lus-fts-site-1' collection using the filter 'oslandusetertiarygroup = \"Cinema\"' to find all cinema locations in your specified area.",
227+
"example_planning": "I will search the 'lus-fts-site-1' collection using the filter 'oslandusetertiarygroup = \"Cinema\"' to find all cinema locations in your specified area."
212228
},
229+
213230
"available_collections": context["available_collections"],
214231
"openapi_endpoints": context["openapi_endpoints"],
232+
215233
"QUICK_FILTERING_GUIDE": {
216234
"primary_tool": "search_features",
217235
"key_parameter": "filter",
218236
"enum_fields": "Use exact values from collection's enum_queryables (e.g., 'Cinema', 'A Road')",
219237
"simple_fields": "Use direct values (e.g., usrn = 12345678)",
220238
},
239+
221240
"COMMON_EXAMPLES": {
222241
"cinema_search": "search_features(collection_id='lus-fts-site-1', bbox='...', filter=\"oslandusetertiarygroup = 'Cinema'\")",
223242
"a_road_search": "search_features(collection_id='trn-ntwk-street-1', bbox='...', filter=\"roadclassification = 'A Road'\")",
224243
"usrn_search": "search_features(collection_id='trn-ntwk-street-1', filter='usrn = 12345678')",
225-
"street_name": "search_features(collection_id='trn-ntwk-street-1', filter=\"designatedname1_text LIKE '%high%'\")",
244+
"street_name": "search_features(collection_id='trn-ntwk-street-1', filter=\"designatedname1_text LIKE '%high%'\")"
226245
},
246+
227247
"CRITICAL_RULES": {
228248
"1": "ALWAYS explain your plan first",
229249
"2": "Use exact enum values from the specific collection's enum_queryables",
230250
"3": "Use the 'filter' parameter for all filtering",
231-
"4": "Quote string values in single quotes",
232-
},
251+
"4": "Quote string values in single quotes"
252+
}
233253
}
234254
)
235255

@@ -244,7 +264,6 @@ def _require_workflow_context(self, func):
244264

245265
@functools.wraps(func)
246266
async def wrapper(*args, **kwargs):
247-
# Only allow get_workflow_context when workflow_planner is None
248267
if self.workflow_planner is None:
249268
if func.__name__ == "get_workflow_context":
250269
return await func(*args, **kwargs)

0 commit comments

Comments
 (0)