Skip to content

Commit c0f6939

Browse files
committed
Merge branch 'security/fix-sql-injection-postgres'
2 parents 012aaad + 813f4af commit c0f6939

1 file changed

Lines changed: 35 additions & 25 deletions

File tree

lightrag/kg/postgres_impl.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,10 +1843,11 @@ async def get_by_id(self, id: str) -> dict[str, Any] | None:
18431843
# Query by id
18441844
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
18451845
"""Get data by ids"""
1846-
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
1847-
ids=",".join([f"'{id}'" for id in ids])
1848-
)
1849-
params = {"workspace": self.workspace}
1846+
if not ids:
1847+
return []
1848+
1849+
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
1850+
params = {"workspace": self.workspace, "ids": ids}
18501851
results = await self.db.query(sql, list(params.values()), multirows=True)
18511852

18521853
def _order_results(
@@ -1949,11 +1950,12 @@ def _order_results(
19491950

19501951
async def filter_keys(self, keys: set[str]) -> set[str]:
19511952
"""Filter out duplicated content"""
1952-
sql = SQL_TEMPLATES["filter_keys"].format(
1953-
table_name=namespace_to_table_name(self.namespace),
1954-
ids=",".join([f"'{id}'" for id in keys]),
1955-
)
1956-
params = {"workspace": self.workspace}
1953+
if not keys:
1954+
return set()
1955+
1956+
table_name = namespace_to_table_name(self.namespace)
1957+
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
1958+
params = {"workspace": self.workspace, "ids": list(keys)}
19571959
try:
19581960
res = await self.db.query(sql, list(params.values()), multirows=True)
19591961
if res:
@@ -2532,11 +2534,12 @@ async def finalize(self):
25322534

25332535
async def filter_keys(self, keys: set[str]) -> set[str]:
25342536
"""Filter out duplicated content"""
2535-
sql = SQL_TEMPLATES["filter_keys"].format(
2536-
table_name=namespace_to_table_name(self.namespace),
2537-
ids=",".join([f"'{id}'" for id in keys]),
2538-
)
2539-
params = {"workspace": self.workspace}
2537+
if not keys:
2538+
return set()
2539+
2540+
table_name = namespace_to_table_name(self.namespace)
2541+
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
2542+
params = {"workspace": self.workspace, "ids": list(keys)}
25402543
try:
25412544
res = await self.db.query(sql, list(params.values()), multirows=True)
25422545
if res:
@@ -2849,34 +2852,41 @@ async def get_docs_paginated(
28492852
elif page_size > 200:
28502853
page_size = 200
28512854

2852-
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
2855+
# Whitelist validation for sort_field to prevent SQL injection
2856+
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
2857+
if sort_field not in allowed_sort_fields:
28532858
sort_field = "updated_at"
28542859

2860+
# Whitelist validation for sort_direction to prevent SQL injection
28552861
if sort_direction.lower() not in ["asc", "desc"]:
28562862
sort_direction = "desc"
2863+
else:
2864+
sort_direction = sort_direction.lower()
28572865

28582866
# Calculate offset
28592867
offset = (page - 1) * page_size
28602868

2861-
# Build WHERE clause
2862-
where_clause = "WHERE workspace=$1"
2869+
# Build parameterized query components
28632870
params = {"workspace": self.workspace}
28642871
param_count = 1
28652872

2873+
# Build WHERE clause with parameterized query
28662874
if status_filter is not None:
28672875
param_count += 1
2868-
where_clause += f" AND status=${param_count}"
2876+
where_clause = "WHERE workspace=$1 AND status=$2"
28692877
params["status"] = status_filter.value
2878+
else:
2879+
where_clause = "WHERE workspace=$1"
28702880

2871-
# Build ORDER BY clause
2881+
# Build ORDER BY clause using validated whitelist values
28722882
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
28732883

28742884
# Query for total count
28752885
count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}"
28762886
count_result = await self.db.query(count_sql, list(params.values()))
28772887
total_count = count_result["total"] if count_result else 0
28782888

2879-
# Query for paginated data
2889+
# Query for paginated data with parameterized LIMIT and OFFSET
28802890
data_sql = f"""
28812891
SELECT * FROM LIGHTRAG_DOC_STATUS
28822892
{where_clause}
@@ -4874,19 +4884,19 @@ def namespace_to_table_name(namespace: str) -> str:
48744884
""",
48754885
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
48764886
COALESCE(doc_name, '') as file_path
4877-
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
4887+
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
48784888
""",
48794889
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
48804890
chunk_order_index, full_doc_id, file_path,
48814891
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
48824892
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
48834893
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
4884-
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
4894+
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
48854895
""",
48864896
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
48874897
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
48884898
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
4889-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
4899+
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
48904900
""",
48914901
"get_by_id_full_entities": """SELECT id, entity_names, count,
48924902
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
@@ -4901,12 +4911,12 @@ def namespace_to_table_name(namespace: str) -> str:
49014911
"get_by_ids_full_entities": """SELECT id, entity_names, count,
49024912
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
49034913
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
4904-
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
4914+
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
49054915
""",
49064916
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
49074917
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
49084918
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
4909-
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
4919+
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
49104920
""",
49114921
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
49124922
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)

0 commit comments

Comments
 (0)