11import json
22import asyncio
33import functools
4+ import concurrent .futures
5+ import time
6+ import threading
47
58from typing import Optional , List , Dict , Any , Union
69from api_service .protocols import APIClient
912from mcp_service .guardrails import ToolGuardrails
1013from workflow_generator .workflow_planner import WorkflowPlanner
1114from utils .logging_config import get_logger
15+ from models import Collection
1216
1317logger = get_logger (__name__ )
1418
@@ -95,6 +99,7 @@ async def _cleanup(self):
9599 except Exception as e :
96100 logger .error (f"Error closing API client: { e } " )
97101
102+ # TODO: All this processing should really be done outside of the os mcp service level - and we need to cache the results
98103 async def get_workflow_context (self ) -> str :
99104 """Get workflow context and initialise planner if needed"""
100105 try :
@@ -104,96 +109,107 @@ async def get_workflow_context(self) -> str:
104109
105110 collections_info = {}
106111 if cached_collections and hasattr (cached_collections , "collections" ):
107- collections_list = getattr (cached_collections , "collections" , [])
112+ collections_list : List [ Collection ] = getattr (cached_collections , "collections" , [])
108113 if collections_list and hasattr (collections_list , "__iter__" ):
109- for collection in collections_list :
114+
115+ async def fetch_collection_queryables (collection : Collection ) -> Dict [str , Any ]:
110116 try :
111- # TODO: This needs to be split into async tasks and run in parallel
112- # TODO: This also needs to be shifted into the api_client and cached like the rest of the data
113117 queryables_data = await self .api_client .make_request (
114118 "COLLECTION_QUERYABLES" , path_params = [collection .id ]
115119 )
116-
120+
117121 all_queryables = {}
118122 enum_queryables = {}
119123 properties = queryables_data .get ("properties" , {})
120-
124+
121125 for prop_name , prop_details in properties .items ():
122126 prop_type = prop_details .get ("type" , ["string" ])
123127 if isinstance (prop_type , list ):
124- main_type = (
125- prop_type [0 ] if prop_type else "string"
126- )
128+ main_type = prop_type [0 ] if prop_type else "string"
127129 is_nullable = "null" in prop_type
128130 else :
129131 main_type = prop_type
130132 is_nullable = False
131-
133+
132134 all_queryables [prop_name ] = {
133135 "type" : main_type ,
134136 "nullable" : is_nullable ,
135- "description" : prop_details .get (
136- "description" , ""
137- ),
137+ "description" : prop_details .get ("description" , "" ),
138138 "max_length" : prop_details .get ("maxLength" ),
139139 "format" : prop_details .get ("format" ),
140140 "pattern" : prop_details .get ("pattern" ),
141141 "minimum" : prop_details .get ("minimum" ),
142142 "maximum" : prop_details .get ("maximum" ),
143- "is_enum" : prop_details .get (
144- "enumeration" , False
145- ),
143+ "is_enum" : prop_details .get ("enumeration" , False )
146144 }
147-
148- if (
149- prop_details .get ("enumeration" )
150- and "enum" in prop_details
151- ):
145+
146+ if prop_details .get ("enumeration" ) and "enum" in prop_details :
152147 enum_queryables [prop_name ] = {
153148 "values" : prop_details ["enum" ],
154149 "type" : main_type ,
155150 "nullable" : is_nullable ,
156- "description" : prop_details .get (
157- "description" , ""
158- ),
159- "max_length" : prop_details .get ("maxLength" ),
151+ "description" : prop_details .get ("description" , "" ),
152+ "max_length" : prop_details .get ("maxLength" )
160153 }
161- all_queryables [prop_name ]["enum_values" ] = (
162- prop_details ["enum" ]
163- )
164-
154+ all_queryables [prop_name ]["enum_values" ] = prop_details ["enum" ]
155+
165156 all_queryables [prop_name ] = {
166- k : v
167- for k , v in all_queryables [prop_name ].items ()
157+ k : v for k , v in all_queryables [prop_name ].items ()
168158 if v is not None
169159 }
170-
171- collections_info [collection .id ] = {
172- "id" : collection .id ,
173- "title" : collection .title ,
174- "description" : collection .description ,
160+
161+ return {
162+ "collection" : collection ,
175163 "all_queryables" : all_queryables ,
176164 "enum_queryables" : enum_queryables ,
177- "has_enum_filters" : len (enum_queryables ) > 0 ,
178- "total_queryables" : len (all_queryables ),
179- "enum_count" : len (enum_queryables ),
180165 }
181-
166+
182167 except Exception as e :
183- logger .warning (
184- f"Failed to fetch queryables for { collection .id } : { e } "
185- )
186-
187- collections_info [collection .id ] = {
188- "id" : collection .id ,
189- "title" : collection .title ,
190- "description" : collection .description ,
168+ logger .warning (f"Failed to fetch queryables for { collection .id } : { e } " )
169+ return {
170+ "collection" : collection ,
191171 "all_queryables" : {},
192172 "enum_queryables" : {},
193- "has_enum_filters" : False ,
194- "total_queryables" : 0 ,
195- "enum_count" : 0 ,
196173 }
174+
175+ # This should reduce the network io bottleneck and speed it up!
176+ tasks = [fetch_collection_queryables (collection ) for collection in collections_list ]
177+ queryables_results = await asyncio .gather (* tasks )
178+
179+ logger .debug (f"Starting thread pool processing for { len (queryables_results )} results..." )
180+
181+ def process_collection_result (result ):
182+ collection = result ["collection" ]
183+ logger .debug (f"Processing collection { collection .id } in thread { threading .current_thread ().name } " )
184+ all_queryables = result ["all_queryables" ]
185+ enum_queryables = result ["enum_queryables" ]
186+
187+ return (collection .id , {
188+ "id" : collection .id ,
189+ "title" : collection .title ,
190+ "description" : collection .description ,
191+ "all_queryables" : all_queryables ,
192+ "enum_queryables" : enum_queryables ,
193+ "has_enum_filters" : len (enum_queryables ) > 0 ,
194+ "total_queryables" : len (all_queryables ),
195+ "enum_count" : len (enum_queryables )
196+ })
197+
198+ thread_start = time .time ()
199+
200+ # This should reduce the work required to process the queryables results
201+ # TODO: need to check this really does speed it up
202+ with concurrent .futures .ThreadPoolExecutor (max_workers = 4 ) as executor :
203+ logger .debug (f"Thread pool created with { executor ._max_workers } workers" )
204+ processed = await asyncio .get_event_loop ().run_in_executor (
205+ executor ,
206+ lambda : list (map (process_collection_result , queryables_results ))
207+ )
208+
209+ thread_end = time .time ()
210+ logger .debug (f"Thread pool processing completed in { thread_end - thread_start :.4f} s" )
211+
212+ collections_info = dict (processed )
197213
198214 self .workflow_planner = WorkflowPlanner (cached_spec , collections_info )
199215
@@ -205,31 +221,35 @@ async def get_workflow_context(self) -> str:
205221 "required_explanation" : {
206222 "1" : "Which collection you will use and why" ,
207223 "2" : "What specific filters you will apply (show the exact filter string)" ,
208- "3" : "What steps you will take" ,
224+ "3" : "What steps you will take"
209225 },
210226 "workflow_enforcement" : "Do not proceed with tool calls until you have clearly explained your plan to the user" ,
211- "example_planning" : "I will search the 'lus-fts-site-1' collection using the filter 'oslandusetertiarygroup = \" Cinema\" ' to find all cinema locations in your specified area." ,
227+ "example_planning" : "I will search the 'lus-fts-site-1' collection using the filter 'oslandusetertiarygroup = \" Cinema\" ' to find all cinema locations in your specified area."
212228 },
229+
213230 "available_collections" : context ["available_collections" ],
214231 "openapi_endpoints" : context ["openapi_endpoints" ],
232+
215233 "QUICK_FILTERING_GUIDE" : {
216234 "primary_tool" : "search_features" ,
217235 "key_parameter" : "filter" ,
218236 "enum_fields" : "Use exact values from collection's enum_queryables (e.g., 'Cinema', 'A Road')" ,
219237 "simple_fields" : "Use direct values (e.g., usrn = 12345678)" ,
220238 },
239+
221240 "COMMON_EXAMPLES" : {
222241 "cinema_search" : "search_features(collection_id='lus-fts-site-1', bbox='...', filter=\" oslandusetertiarygroup = 'Cinema'\" )" ,
223242 "a_road_search" : "search_features(collection_id='trn-ntwk-street-1', bbox='...', filter=\" roadclassification = 'A Road'\" )" ,
224243 "usrn_search" : "search_features(collection_id='trn-ntwk-street-1', filter='usrn = 12345678')" ,
225- "street_name" : "search_features(collection_id='trn-ntwk-street-1', filter=\" designatedname1_text LIKE '%high%'\" )" ,
244+ "street_name" : "search_features(collection_id='trn-ntwk-street-1', filter=\" designatedname1_text LIKE '%high%'\" )"
226245 },
246+
227247 "CRITICAL_RULES" : {
228248 "1" : "ALWAYS explain your plan first" ,
229249 "2" : "Use exact enum values from the specific collection's enum_queryables" ,
230250 "3" : "Use the 'filter' parameter for all filtering" ,
231- "4" : "Quote string values in single quotes" ,
232- },
251+ "4" : "Quote string values in single quotes"
252+ }
233253 }
234254 )
235255
@@ -244,7 +264,6 @@ def _require_workflow_context(self, func):
244264
245265 @functools .wraps (func )
246266 async def wrapper (* args , ** kwargs ):
247- # Only allow get_workflow_context when workflow_planner is None
248267 if self .workflow_planner is None :
249268 if func .__name__ == "get_workflow_context" :
250269 return await func (* args , ** kwargs )
0 commit comments