Skip to content

Commit 8ad3cb6

Browse files
authored
♻️ Performance optimization of the knowledge base file list query interface (#3025)
* ♻️ Performance optimization of the knowledge base file list query interface * ♻️ Performance optimization of the knowledge base file list query interface: ut
1 parent ee181cf commit 8ad3cb6

9 files changed

Lines changed: 507 additions & 320 deletions

backend/apps/data_process_app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,14 @@ async def get_index_tasks(index_name: str):
204204
205205
Returns tasks that are being processed or waiting to be processed
206206
"""
207+
import time
208+
start = time.time()
207209
try:
208-
return await service.get_index_tasks(index_name)
210+
result = await service.get_index_tasks(index_name)
211+
logger.info(f"[get_index_tasks] index={index_name}, tasks={len(result)}, duration={time.time()-start:.3f}s")
212+
return result
209213
except Exception as e:
214+
logger.error(f"[get_index_tasks] error: {e}")
210215
raise HTTPException(
211216
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e))
212217

backend/services/data_process_service.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self):
5454

5555
self._inspector = None
5656
self._inspector_last_time = 0
57-
self._inspector_ttl = 60 # Inspector cache time in seconds
57+
self._inspector_ttl = 300 # 5 minutes - inspector is expensive to create (ping all workers)
5858
self._inspector_lock = None
5959
self._inspector_lock = threading.Lock()
6060

@@ -105,7 +105,7 @@ async def stop(self):
105105
logger.info("Data processing service stopped")
106106

107107
def _get_celery_inspector(self):
108-
"""Get Celery inspector"""
108+
"""Get Celery inspector (cached for performance)"""
109109
with self._inspector_lock:
110110
now = time.time()
111111
if self._inspector and now - self._inspector_last_time < self._inspector_ttl:
@@ -117,9 +117,9 @@ def _get_celery_inspector(self):
117117
f"Celery broker URL is not configured properly, reconfiguring to {celery_app.conf.broker_url}")
118118
try:
119119
inspector = celery_app.control.inspect()
120-
inspector.ping()
121120
self._inspector = inspector
122121
self._inspector_last_time = now
122+
self._inspector_init_time = now
123123
return inspector
124124
except Exception as e:
125125
self._inspector = None
@@ -142,11 +142,9 @@ async def get_all_tasks(self, filter: bool = True) -> List[Dict[str, Any]]:
142142
all_tasks = []
143143
try:
144144
start_time = time.time()
145-
logger.debug(
146-
"Getting inspector to check for active and reserved tasks (concurrent)")
145+
inspector_start = time.time()
147146
inspector = self._get_celery_inspector()
148-
logger.debug(
149-
f"⏰ Inspector initialization took {time.time() - start_time}s")
147+
inspector_duration = time.time() - inspector_start
150148

151149
# Collect task IDs from different sources and keep runtime metadata
152150
task_ids = set()
@@ -171,18 +169,37 @@ def _normalize_runtime_meta(task: Dict[str, Any]) -> Dict[str, Any]:
171169
'original_filename': kwargs.get('original_filename', ''),
172170
}
173171

172+
celery_start = time.time()
173+
174+
# Use short timeout for inspector since workers can respond in ~0.1s
175+
# Default 1s timeout is unnecessary and causes delay
176+
short_timeout = 0.2
177+
174178
def get_active():
175-
return inspector.active()
179+
t = time.time()
180+
# Create fresh inspector with short timeout for each call
181+
short_inspector = celery_app.control.inspect(timeout=short_timeout)
182+
result = short_inspector.active()
183+
elapsed = time.time() - t
184+
logger.info(f"[get_all_tasks] inspector.active() took {elapsed:.3f}s")
185+
return result if result else {}
176186

177187
def get_reserved():
178-
return inspector.reserved()
188+
t = time.time()
189+
short_inspector = celery_app.control.inspect(timeout=short_timeout)
190+
result = short_inspector.reserved()
191+
elapsed = time.time() - t
192+
logger.info(f"[get_all_tasks] inspector.reserved() took {elapsed:.3f}s")
193+
return result if result else {}
194+
179195
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
180196
future_active = executor.submit(get_active)
181197
future_reserved = executor.submit(get_reserved)
182-
active_tasks_dict = future_active.result()
183-
reserved_tasks_dict = future_reserved.result()
184-
logger.debug(
185-
f"⏰ Get active and reserved tasks (concurrent) took {time.time() - start_time}s")
198+
active_tasks_dict = future_active.result(timeout=short_timeout + 0.5)
199+
reserved_tasks_dict = future_reserved.result(timeout=short_timeout + 0.5)
200+
celery_duration = time.time() - celery_start
201+
if celery_duration > 0.5:
202+
logger.warning(f"[get_all_tasks] Inspector took {celery_duration:.3f}s (expected <0.5s)")
186203
if active_tasks_dict:
187204
for worker, tasks in active_tasks_dict.items():
188205
for task in tasks:
@@ -199,23 +216,17 @@ def get_reserved():
199216
# Keep active metadata if already present
200217
runtime_task_meta.setdefault(task_id, _normalize_runtime_meta(task))
201218

202-
# Currently, we don't have scheduled tasks, so skip getting scheduled tasks here
203-
start_time = time.time()
204-
logger.debug("Getting task IDs from Redis backend")
205-
# Also get task IDs from Redis backend (covers completed/failed tasks within expiry)
219+
# Get task IDs from Redis backend (covers completed/failed tasks within expiry)
206220
try:
207221
redis_task_ids = get_all_task_ids_from_redis(self.redis_client)
208-
logger.debug(
209-
f"⏰ Get Redis task IDs took {time.time() - start_time}s")
210222
for task_id in redis_task_ids:
211-
# Add to the set, duplicates will be handled
212223
task_ids.add(task_id)
213224
except Exception as redis_error:
214225
logger.warning(
215226
f"Failed to query Redis for stored task IDs: {str(redis_error)}")
216-
logger.debug(
217-
f"Total unique task IDs collected (inspector + Redis): {len(task_ids)}")
227+
218228
task_id_list = list(task_ids)
229+
# Batch fetch all task info
219230
tasks = [get_task_info(task_id) for task_id in task_id_list]
220231
all_task_infos = await asyncio.gather(*tasks, return_exceptions=True)
221232
for idx, task_info in enumerate(all_task_infos):
@@ -243,7 +254,6 @@ def get_reserved():
243254
if not task_info.get('index_name'):
244255
continue
245256
all_tasks.append(task_info)
246-
logger.debug(f"Retrieved {len(all_tasks)} tasks.")
247257
except Exception as e:
248258
logger.error(f"Error retrieving all tasks: {str(e)}")
249259
all_tasks = []

backend/services/redis_service.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
import re
4-
from typing import Dict, Any, Optional, Tuple, Set
4+
from typing import Dict, Any, Optional, Tuple, Set, List
55

66
import redis
77

@@ -24,8 +24,8 @@ def client(self) -> redis.Redis:
2424
if not REDIS_URL:
2525
raise ValueError("REDIS_URL environment variable is not set")
2626
self._client = redis.from_url(
27-
REDIS_URL,
28-
socket_timeout=5,
27+
REDIS_URL,
28+
socket_timeout=5,
2929
socket_connect_timeout=5,
3030
decode_responses=True
3131
)
@@ -654,13 +654,13 @@ def save_error_info(self, task_id: str, error_reason: str, ttl_days: int = 30) -
654654
if not error_reason:
655655
logger.error(f"Cannot save error info for task {task_id}: error_reason is empty")
656656
return False
657-
657+
658658
ttl_seconds = ttl_days * 24 * 60 * 60
659659
reason_key = f"error:reason:{task_id}"
660660

661661
# Save error reason
662662
result = self.client.setex(reason_key, ttl_seconds, error_reason)
663-
663+
664664
if result:
665665
logger.info(f"Successfully saved error info to Redis for task {task_id}, key: {reason_key}")
666666
# Verify the save by reading it back
@@ -695,13 +695,13 @@ def save_progress_info(self, task_id: str, processed_chunks: int, total_chunks:
695695
if not task_id:
696696
logger.error("Cannot save progress info: task_id is empty")
697697
return False
698-
698+
699699
progress_key = f"progress:{task_id}"
700700
progress_data = {
701701
'processed_chunks': processed_chunks,
702702
'total_chunks': total_chunks
703703
}
704-
704+
705705
ttl_seconds = ttl_hours * 3600
706706
progress_json = json.dumps(progress_data)
707707
self.client.setex(
@@ -874,6 +874,79 @@ def get_error_info(self, task_id: str) -> Optional[str]:
874874
f"Failed to get error info for task {task_id}: {str(e)}")
875875
return None
876876

877+
def batch_get_progress_info(self, task_ids: List[str]) -> Dict[str, Optional[Dict[str, int]]]:
878+
"""
879+
Batch get progress information for multiple tasks in a single Redis call.
880+
881+
Args:
882+
task_ids: List of Celery task IDs
883+
884+
Returns:
885+
Dict mapping task_id to progress info dict, or None if not found
886+
"""
887+
if not task_ids:
888+
return {}
889+
890+
try:
891+
# Build list of keys
892+
progress_keys = [f"progress:{tid}" for tid in task_ids]
893+
# Use pipeline for batch operation
894+
pipe = self.client.pipeline()
895+
for key in progress_keys:
896+
pipe.get(key)
897+
results = pipe.execute()
898+
899+
# Build result dict
900+
result = {}
901+
for i, task_id in enumerate(task_ids):
902+
progress_data = results[i]
903+
if progress_data:
904+
try:
905+
if isinstance(progress_data, bytes):
906+
progress_data = progress_data.decode('utf-8')
907+
result[task_id] = json.loads(progress_data)
908+
except (json.JSONDecodeError, TypeError):
909+
result[task_id] = None
910+
else:
911+
result[task_id] = None
912+
return result
913+
except Exception as e:
914+
logger.warning(f"Failed to batch get progress info: {str(e)}")
915+
return {tid: None for tid in task_ids}
916+
917+
def batch_get_error_info(self, task_ids: List[str]) -> Dict[str, Optional[str]]:
918+
"""
919+
Batch get error information for multiple tasks in a single Redis call.
920+
921+
Args:
922+
task_ids: List of Celery task IDs
923+
924+
Returns:
925+
Dict mapping task_id to error reason string, or None if not found
926+
"""
927+
if not task_ids:
928+
return {}
929+
930+
try:
931+
# Build list of keys
932+
error_keys = [f"error:reason:{tid}" for tid in task_ids]
933+
# Use pipeline for batch operation
934+
pipe = self.client.pipeline()
935+
for key in error_keys:
936+
pipe.get(key)
937+
results = pipe.execute()
938+
939+
# Build result dict
940+
result = {}
941+
for i, task_id in enumerate(task_ids):
942+
reason = results[i]
943+
# With decode_responses=True, reason is already a string
944+
result[task_id] = reason if reason else None
945+
return result
946+
except Exception as e:
947+
logger.warning(f"Failed to batch get error info: {str(e)}")
948+
return {tid: None for tid in task_ids}
949+
877950
# Global Redis service instance
878951
_redis_service = None
879952

0 commit comments

Comments
 (0)