Skip to content

Commit 4dd7579

Browse files
authored
Merge pull request #2792 from danielaskdd/fix/cypher-injection-workspace-label
fix(api): sanitize workspace from CLI args and HTTP headers to prevent injection
2 parents 8f51a1c + 0ba5c76 commit 4dd7579

10 files changed

+305
-72
lines changed

examples/modalprocessors_example.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@
1919

2020

2121
def get_llm_model_func(api_key: str, base_url: str = None):
22-
return (
23-
lambda prompt,
24-
system_prompt=None,
25-
history_messages=[],
26-
**kwargs: openai_complete_if_cache(
22+
return lambda prompt, system_prompt=None, history_messages=[], **kwargs: (
23+
openai_complete_if_cache(
2724
"gpt-4o-mini",
2825
prompt,
2926
system_prompt=system_prompt,
@@ -41,41 +38,45 @@ def get_vision_model_func(api_key: str, base_url: str = None):
4138
system_prompt=None,
4239
history_messages=[],
4340
image_data=None,
44-
**kwargs: openai_complete_if_cache(
45-
"gpt-4o",
46-
"",
47-
system_prompt=None,
48-
history_messages=[],
49-
messages=[
50-
{"role": "system", "content": system_prompt} if system_prompt else None,
51-
{
52-
"role": "user",
53-
"content": [
54-
{"type": "text", "text": prompt},
55-
{
56-
"type": "image_url",
57-
"image_url": {
58-
"url": f"data:image/jpeg;base64,{image_data}"
41+
**kwargs: (
42+
openai_complete_if_cache(
43+
"gpt-4o",
44+
"",
45+
system_prompt=None,
46+
history_messages=[],
47+
messages=[
48+
{"role": "system", "content": system_prompt}
49+
if system_prompt
50+
else None,
51+
{
52+
"role": "user",
53+
"content": [
54+
{"type": "text", "text": prompt},
55+
{
56+
"type": "image_url",
57+
"image_url": {
58+
"url": f"data:image/jpeg;base64,{image_data}"
59+
},
5960
},
60-
},
61-
],
62-
}
63-
if image_data
64-
else {"role": "user", "content": prompt},
65-
],
66-
api_key=api_key,
67-
base_url=base_url,
68-
**kwargs,
69-
)
70-
if image_data
71-
else openai_complete_if_cache(
72-
"gpt-4o-mini",
73-
prompt,
74-
system_prompt=system_prompt,
75-
history_messages=history_messages,
76-
api_key=api_key,
77-
base_url=base_url,
78-
**kwargs,
61+
],
62+
}
63+
if image_data
64+
else {"role": "user", "content": prompt},
65+
],
66+
api_key=api_key,
67+
base_url=base_url,
68+
**kwargs,
69+
)
70+
if image_data
71+
else openai_complete_if_cache(
72+
"gpt-4o-mini",
73+
prompt,
74+
system_prompt=system_prompt,
75+
history_messages=history_messages,
76+
api_key=api_key,
77+
base_url=base_url,
78+
**kwargs,
79+
)
7980
)
8081
)
8182

@@ -178,14 +179,16 @@ async def initialize_rag(api_key: str, base_url: str = None):
178179
llm_model_func=lambda prompt,
179180
system_prompt=None,
180181
history_messages=[],
181-
**kwargs: openai_complete_if_cache(
182-
"gpt-4o-mini",
183-
prompt,
184-
system_prompt=system_prompt,
185-
history_messages=history_messages,
186-
api_key=api_key,
187-
base_url=base_url,
188-
**kwargs,
182+
**kwargs: (
183+
openai_complete_if_cache(
184+
"gpt-4o-mini",
185+
prompt,
186+
system_prompt=system_prompt,
187+
history_messages=history_messages,
188+
api_key=api_key,
189+
base_url=base_url,
190+
**kwargs,
191+
)
189192
),
190193
)
191194

lightrag/api/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import os
6+
import re
67
import argparse
78
import logging
89
from dotenv import load_dotenv
@@ -461,6 +462,17 @@ def parse_args() -> argparse.Namespace:
461462
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
462463
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
463464

465+
# Sanitize workspace: only alphanumeric characters and underscores are allowed
466+
if args.workspace:
467+
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", args.workspace)
468+
if sanitized != args.workspace:
469+
logging.warning(
470+
f"Workspace name '{args.workspace}' contains invalid characters. "
471+
f"It has been sanitized to '{sanitized}'. "
472+
"Only alphanumeric characters and underscores are allowed."
473+
)
474+
args.workspace = sanitized
475+
464476
return args
465477

466478

lightrag/api/lightrag_server.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_swagger_ui_oauth2_redirect_html,
1111
)
1212
import os
13+
import re
1314
import logging
1415
import logging.config
1516
import sys
@@ -478,6 +479,14 @@ def get_workspace_from_request(request: Request) -> str | None:
478479

479480
if not workspace:
480481
workspace = None
482+
else:
483+
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", workspace)
484+
if sanitized != workspace:
485+
logger.warning(
486+
f"Workspace header '{workspace}' contains invalid characters. "
487+
f"Sanitized to '{sanitized}'."
488+
)
489+
workspace = sanitized
481490

482491
return workspace
483492

lightrag/kg/memgraph_impl.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,18 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None):
5959
self._driver = None
6060

6161
def _get_workspace_label(self) -> str:
62-
"""Return workspace label (guaranteed non-empty during initialization)"""
63-
return self.workspace
62+
"""Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries.
63+
64+
Escapes backticks by doubling them to prevent Cypher injection
65+
via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping
66+
for all other characters. The returned value is intended to be used
67+
inside backticks (for example, MATCH (n:`{label}`)) and is not
68+
validated as a standalone unquoted identifier.
69+
"""
70+
workspace = self.workspace.strip()
71+
if not workspace:
72+
return "base"
73+
return workspace.replace("`", "``")
6474

6575
async def initialize(self):
6676
async with get_data_init_lock():

lightrag/kg/neo4j_impl.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,18 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None):
9191
self._driver = None
9292

9393
def _get_workspace_label(self) -> str:
94-
"""Return workspace label (guaranteed non-empty during initialization)"""
95-
return self.workspace
94+
"""Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries.
95+
96+
Escapes backticks by doubling them to prevent Cypher injection
97+
via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping
98+
for all other characters. The returned value is intended to be used
99+
inside backticks (for example, MATCH (n:`{label}`)) and is not
100+
validated as a standalone unquoted identifier.
101+
"""
102+
workspace = self.workspace.strip()
103+
if not workspace:
104+
return "base"
105+
return workspace.replace("`", "``")
96106

97107
def _normalize_index_suffix(self, workspace_label: str) -> str:
98108
"""Normalize workspace label for safe use in index names."""

tests/test_aquery_data_endpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def print_query_results(data: Dict[str, Any]):
605605
file_path = entity.get("file_path", "Unknown source")
606606
reference_id = entity.get("reference_id", "No reference")
607607

608-
print(f" {i+1}. {entity_name} ({entity_type})")
608+
print(f" {i + 1}. {entity_name} ({entity_type})")
609609
print(
610610
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
611611
)
@@ -624,7 +624,7 @@ def print_query_results(data: Dict[str, Any]):
624624
file_path = rel.get("file_path", "Unknown source")
625625
reference_id = rel.get("reference_id", "No reference")
626626

627-
print(f" {i+1}. {src}{tgt}")
627+
print(f" {i + 1}. {src}{tgt}")
628628
print(f" Keywords: {keywords}")
629629
print(
630630
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
@@ -642,7 +642,7 @@ def print_query_results(data: Dict[str, Any]):
642642
chunk_id = chunk.get("chunk_id", "Unknown ID")
643643
reference_id = chunk.get("reference_id", "No reference")
644644

645-
print(f" {i+1}. Text chunk ID: {chunk_id}")
645+
print(f" {i + 1}. Text chunk ID: {chunk_id}")
646646
print(f" Source: {file_path}")
647647
print(f" Reference ID: {reference_id}")
648648
print(
@@ -656,7 +656,7 @@ def print_query_results(data: Dict[str, Any]):
656656
for i, ref in enumerate(references):
657657
reference_id = ref.get("reference_id", "Unknown ID")
658658
file_path = ref.get("file_path", "Unknown source")
659-
print(f" {i+1}. Reference ID: {reference_id}")
659+
print(f" {i + 1}. Reference ID: {reference_id}")
660660
print(f" File Path: {file_path}")
661661
print()
662662

tests/test_lightrag_ollama_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ async def run_concurrent_requests():
714714

715715
for i, result in enumerate(results):
716716
if isinstance(result, Exception):
717-
error_messages.append(f"Request {i+1} failed: {str(result)}")
717+
error_messages.append(f"Request {i + 1} failed: {str(result)}")
718718
else:
719719
success_results.append((i + 1, result))
720720

tests/test_qdrant_migration.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func)
8080

8181
# Setup mocks for migration scenario
8282
# 1. New collection does not exist, only legacy exists
83-
mock_qdrant_client.collection_exists.side_effect = (
84-
lambda name: name == legacy_collection
83+
mock_qdrant_client.collection_exists.side_effect = lambda name: (
84+
name == legacy_collection
8585
)
8686

8787
# 2. Legacy collection exists and has data
@@ -173,8 +173,8 @@ async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_fun
173173
)
174174

175175
# Only new collection exists (no legacy collection found)
176-
mock_qdrant_client.collection_exists.side_effect = (
177-
lambda name: name == storage.final_namespace
176+
mock_qdrant_client.collection_exists.side_effect = lambda name: (
177+
name == storage.final_namespace
178178
)
179179

180180
# Initialize
@@ -285,8 +285,8 @@ async def test_scenario_2_legacy_upgrade_migration(
285285
new_collection = storage.final_namespace
286286

287287
# Case 4: Only legacy collection exists
288-
mock_qdrant_client.collection_exists.side_effect = (
289-
lambda name: name == legacy_collection
288+
mock_qdrant_client.collection_exists.side_effect = lambda name: (
289+
name == legacy_collection
290290
)
291291

292292
# Mock legacy collection info with 1536d vectors
@@ -454,10 +454,13 @@ async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embeddin
454454
new_collection = storage.final_namespace
455455

456456
# Mock: Both collections exist
457-
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
458-
legacy_collection,
459-
new_collection,
460-
]
457+
mock_qdrant_client.collection_exists.side_effect = lambda name: (
458+
name
459+
in [
460+
legacy_collection,
461+
new_collection,
462+
]
463+
)
461464

462465
# Mock: Legacy collection is empty (0 records)
463466
def count_mock(collection_name, exact=True, count_filter=None):
@@ -520,10 +523,13 @@ async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_
520523
new_collection = storage.final_namespace
521524

522525
# Mock: Both collections exist
523-
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
524-
legacy_collection,
525-
new_collection,
526-
]
526+
mock_qdrant_client.collection_exists.side_effect = lambda name: (
527+
name
528+
in [
529+
legacy_collection,
530+
new_collection,
531+
]
532+
)
527533

528534
# Mock: Legacy collection has data (50 records)
529535
def count_mock(collection_name, exact=True, count_filter=None):

tests/test_workspace_isolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def test_lock_mechanism(stress_test_mode, parallel_workers):
222222
# Support stress testing with configurable number of workers
223223
num_workers = parallel_workers if stress_test_mode else 3
224224
parallel_workload = [
225-
(f"ws_{chr(97+i)}", f"ws_{chr(97+i)}", "test_namespace")
225+
(f"ws_{chr(97 + i)}", f"ws_{chr(97 + i)}", "test_namespace")
226226
for i in range(num_workers)
227227
]
228228

@@ -491,7 +491,7 @@ async def use_shared_lock(coroutine_id):
491491

492492
print("✅ PASSED: NamespaceLock Concurrent Reuse")
493493
print(
494-
f" Same NamespaceLock instance used successfully in {expected_entries//2} concurrent coroutines"
494+
f" Same NamespaceLock instance used successfully in {expected_entries // 2} concurrent coroutines"
495495
)
496496

497497

0 commit comments

Comments
 (0)