-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Expand file tree
/
Copy pathlightrag.py
More file actions
4451 lines (3887 loc) · 192 KB
/
lightrag.py
File metadata and controls
4451 lines (3887 loc) · 192 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import traceback
import asyncio
import configparser
import inspect
import os
import time
import warnings
from dataclasses import asdict, dataclass, field, replace
from datetime import datetime, timezone
from functools import partial
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
cast,
final,
Literal,
Optional,
List,
Dict,
Union,
)
from lightrag.prompt import PROMPTS
from lightrag.exceptions import PipelineCancelledException
from lightrag.constants import (
DEFAULT_MAX_GLEANING,
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS,
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_COSINE_THRESHOLD,
DEFAULT_RELATED_CHUNK_NUMBER,
DEFAULT_KG_CHUNK_PICK_METHOD,
DEFAULT_MIN_RERANK_SCORE,
DEFAULT_SUMMARY_MAX_TOKENS,
DEFAULT_SUMMARY_CONTEXT_SIZE,
DEFAULT_SUMMARY_LENGTH_RECOMMENDED,
DEFAULT_MAX_EXTRACT_INPUT_TOKENS,
DEFAULT_MAX_ASYNC,
DEFAULT_MAX_PARALLEL_INSERT,
DEFAULT_MAX_GRAPH_NODES,
DEFAULT_MAX_SOURCE_IDS_PER_ENTITY,
DEFAULT_MAX_SOURCE_IDS_PER_RELATION,
DEFAULT_ENTITY_TYPES,
DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
DEFAULT_SOURCE_IDS_LIMIT_METHOD,
DEFAULT_MAX_FILE_PATHS,
DEFAULT_FILE_PATH_MORE_PLACEHOLDER,
)
from lightrag.utils import get_env_value
from lightrag.kg import (
STORAGES,
verify_storage_implementation,
)
from lightrag.kg.shared_storage import (
get_namespace_data,
get_data_init_lock,
get_default_workspace,
set_default_workspace,
get_namespace_lock,
)
from lightrag.base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
QueryParam,
StorageNameSpace,
StoragesStatus,
DeletionResult,
OllamaServerInfos,
QueryResult,
)
from lightrag.namespace import NameSpace
from lightrag.operate import (
chunking_by_token_size,
extract_entities,
merge_nodes_and_edges,
kg_query,
naive_query,
rebuild_knowledge_from_chunks,
)
from lightrag.constants import GRAPH_FIELD_SEP
from lightrag.utils import (
Tokenizer,
TiktokenTokenizer,
EmbeddingFunc,
always_get_an_event_loop,
compute_mdhash_id,
lazy_external_import,
priority_limit_async_func_call,
get_content_summary,
sanitize_text_for_encoding,
check_storage_env_vars,
generate_track_id,
convert_to_user_format,
logger,
subtract_source_ids,
make_relation_chunk_key,
normalize_source_ids_limit_method,
)
from lightrag.types import KnowledgeGraph
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
# TODO: TO REMOVE @Yannick
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def _chunk_fields_from_status_doc(
status_doc: "DocProcessingStatus",
) -> tuple[list[str], int]:
"""Return (chunks_list, chunks_count) preserved from a status document.
Filters out any non-string or empty chunk IDs. When chunks_count is
absent or invalid, it is inferred from the length of chunks_list.
"""
chunks_list: list[str] = []
if isinstance(status_doc.chunks_list, list):
chunks_list = [
chunk_id
for chunk_id in status_doc.chunks_list
if isinstance(chunk_id, str) and chunk_id
]
if isinstance(status_doc.chunks_count, int) and status_doc.chunks_count >= 0:
return chunks_list, status_doc.chunks_count
return chunks_list, len(chunks_list)
def _normalize_string_list(raw_values: Any, context: str = "") -> list[str]:
"""Return a list of non-empty strings from raw_values.
Non-string elements are dropped and logged as warnings. If raw_values is
not a list, an empty list is returned.
"""
if not isinstance(raw_values, list):
return []
result = []
for i, value in enumerate(raw_values):
if isinstance(value, str) and value:
result.append(value)
else:
logger.warning(
"Non-string element dropped from list%s at index %d: %r",
f" ({context})" if context else "",
i,
value,
)
return result
@final
@dataclass
class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
# Directory
# ---
working_dir: str = field(default="./rag_storage")
"""Directory where cache and temporary files are stored."""
# Storage
# ---
kv_storage: str = field(default="JsonKVStorage")
"""Storage backend for key-value data."""
vector_storage: str = field(default="NanoVectorDBStorage")
"""Storage backend for vector embeddings."""
graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Workspace
# ---
workspace: str = field(default_factory=lambda: os.getenv("WORKSPACE", ""))
"""Workspace for data isolation. Defaults to empty string if WORKSPACE environment variable is not set."""
# ---
# TODO: Deprecated, use setup_logger in utils.py instead
log_level: int | None = field(default=None)
log_file_path: str | None = field(default=None)
# Query parameters
# ---
top_k: int = field(default=get_env_value("TOP_K", DEFAULT_TOP_K, int))
"""Number of entities/relations to retrieve for each query."""
chunk_top_k: int = field(
default=get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
)
"""Maximum number of chunks in context."""
max_entity_tokens: int = field(
default=get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int)
)
"""Maximum number of tokens for entity in context."""
max_relation_tokens: int = field(
default=get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int)
)
"""Maximum number of tokens for relation in context."""
max_total_tokens: int = field(
default=get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int)
)
"""Maximum total tokens in context (including system prompt, entities, relations and chunks)."""
cosine_threshold: int = field(
default=get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, int)
)
"""Cosine threshold of vector DB retrieval for entities, relations and chunks."""
related_chunk_number: int = field(
default=get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int)
)
"""Number of related chunks to grab from single entity or relation."""
kg_chunk_pick_method: str = field(
default=get_env_value("KG_CHUNK_PICK_METHOD", DEFAULT_KG_CHUNK_PICK_METHOD, str)
)
"""Method for selecting text chunks: 'WEIGHT' for weight-based selection, 'VECTOR' for embedding similarity-based selection."""
# Entity extraction
# ---
entity_extract_max_gleaning: int = field(
default=get_env_value("MAX_GLEANING", DEFAULT_MAX_GLEANING, int)
)
"""Maximum number of entity extraction attempts for ambiguous content."""
max_extract_input_tokens: int = field(
default=get_env_value(
"MAX_EXTRACT_INPUT_TOKENS", DEFAULT_MAX_EXTRACT_INPUT_TOKENS, int
)
)
"""Maximum tokens allowed for entity extraction input context."""
force_llm_summary_on_merge: int = field(
default=get_env_value(
"FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int
)
)
# Text chunking
# ---
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = field(
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
)
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tokenizer: Optional[Tokenizer] = field(default=None)
"""
A function that returns a Tokenizer instance.
If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
If both are None, the default TiktokenTokenizer is used.
"""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
chunking_func: Callable[
[
Tokenizer,
str,
Optional[str],
bool,
int,
int,
],
Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function can be either synchronous or asynchronous.
The function should take the following parameters:
- `tokenizer`: A Tokenizer instance to use for tokenization.
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `chunk_token_size`: The maximum number of tokens per chunk.
The function should return a list of dictionaries (or an awaitable that resolves to a list),
where each dictionary contains the following keys:
- `tokens` (int): The number of tokens in the chunk.
- `content` (str): The text content of the chunk.
- `chunk_order_index` (int): Zero-based index indicating the chunk's order in the document.
Defaults to `chunking_by_token_size` if not specified.
"""
# Embedding
# ---
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
embedding_token_limit: int | None = field(default=None, init=False)
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
"""Batch size for embedding computations."""
embedding_func_max_async: int = field(
default=int(os.getenv("EMBEDDING_FUNC_MAX_ASYNC", 8))
)
"""Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
default_embedding_timeout: int = field(
default=int(os.getenv("EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT))
)
# LLM Configuration
# ---
llm_model_func: Callable[..., object] | None = field(default=None)
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = field(default="gpt-4o-mini")
"""Name of the LLM model used for generating responses."""
summary_max_tokens: int = field(
default=int(os.getenv("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS))
)
"""Maximum tokens allowed for entity/relation description."""
summary_context_size: int = field(
default=int(os.getenv("SUMMARY_CONTEXT_SIZE", DEFAULT_SUMMARY_CONTEXT_SIZE))
)
"""Maximum number of tokens allowed per LLM response."""
summary_length_recommended: int = field(
default=int(
os.getenv("SUMMARY_LENGTH_RECOMMENDED", DEFAULT_SUMMARY_LENGTH_RECOMMENDED)
)
)
"""Recommended length of LLM summary output."""
llm_model_max_async: int = field(
default=int(os.getenv("MAX_ASYNC", DEFAULT_MAX_ASYNC))
)
"""Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
default_llm_timeout: int = field(
default=int(os.getenv("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT))
)
# Rerank Configuration
# ---
rerank_model_func: Callable[..., object] | None = field(default=None)
"""Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional."""
min_rerank_score: float = field(
default=get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float)
)
"""Minimum rerank score threshold for filtering chunks after reranking."""
# Storage
# ---
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
enable_llm_cache: bool = field(default=True)
"""Enables caching for LLM responses to avoid redundant computations."""
enable_llm_cache_for_entity_extract: bool = field(default=True)
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions
# ---
max_parallel_insert: int = field(
default=int(os.getenv("MAX_PARALLEL_INSERT", DEFAULT_MAX_PARALLEL_INSERT))
)
"""Maximum number of parallel insert operations."""
max_graph_nodes: int = field(
default=get_env_value("MAX_GRAPH_NODES", DEFAULT_MAX_GRAPH_NODES, int)
)
"""Maximum number of graph nodes to return in knowledge graph queries."""
max_source_ids_per_entity: int = field(
default=get_env_value(
"MAX_SOURCE_IDS_PER_ENTITY", DEFAULT_MAX_SOURCE_IDS_PER_ENTITY, int
)
)
"""Maximum number of source (chunk) ids in entity Grpah + VDB."""
max_source_ids_per_relation: int = field(
default=get_env_value(
"MAX_SOURCE_IDS_PER_RELATION",
DEFAULT_MAX_SOURCE_IDS_PER_RELATION,
int,
)
)
"""Maximum number of source (chunk) ids in relation Graph + VDB."""
source_ids_limit_method: str = field(
default_factory=lambda: normalize_source_ids_limit_method(
get_env_value(
"SOURCE_IDS_LIMIT_METHOD",
DEFAULT_SOURCE_IDS_LIMIT_METHOD,
str,
)
)
)
"""Strategy for enforcing source_id limits: IGNORE_NEW or FIFO."""
max_file_paths: int = field(
default=get_env_value("MAX_FILE_PATHS", DEFAULT_MAX_FILE_PATHS, int)
)
"""Maximum number of file paths to store in entity/relation file_path field."""
file_path_more_placeholder: str = field(default=DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
"""Placeholder text when file paths exceed max_file_paths limit."""
addon_params: dict[str, Any] = field(
default_factory=lambda: {
"language": get_env_value(
"SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE, str
),
"entity_types": get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list),
}
)
# Storages Management
# ---
# TODO: Deprecated (LightRAG will never initialize storage automatically on creation,and finalize should be call before destroying)
auto_manage_storages_states: bool = field(default=False)
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
cosine_better_than_threshold: float = field(
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
)
ollama_server_infos: Optional[OllamaServerInfos] = field(default=None)
"""Configuration for Ollama server information."""
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
from lightrag.kg.shared_storage import (
initialize_share_data,
)
# Handle deprecated parameters
if self.log_level is not None:
warnings.warn(
"WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead",
UserWarning,
stacklevel=2,
)
if self.log_file_path is not None:
warnings.warn(
"WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead",
UserWarning,
stacklevel=2,
)
# Remove these attributes to prevent their use
if hasattr(self, "log_level"):
delattr(self, "log_level")
if hasattr(self, "log_file_path"):
delattr(self, "log_file_path")
initialize_share_data()
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# Verify storage implementation compatibility and environment variables
storage_configs = [
("KV_STORAGE", self.kv_storage),
("VECTOR_STORAGE", self.vector_storage),
("GRAPH_STORAGE", self.graph_storage),
("DOC_STATUS_STORAGE", self.doc_status_storage),
]
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
verify_storage_implementation(storage_type, storage_name)
# Check environment variables
check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields
self.vector_db_storage_cls_kwargs = {
"cosine_better_than_threshold": self.cosine_better_than_threshold,
**self.vector_db_storage_cls_kwargs,
}
# Init Tokenizer
# Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
if self.tokenizer is None:
if self.tiktoken_model_name:
self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
else:
self.tokenizer = TiktokenTokenizer()
# Initialize ollama_server_infos if not provided
if self.ollama_server_infos is None:
self.ollama_server_infos = OllamaServerInfos()
# Validate config
if self.force_llm_summary_on_merge < 3:
logger.warning(
f"force_llm_summary_on_merge should be at least 3, got {self.force_llm_summary_on_merge}"
)
if self.summary_context_size > self.max_total_tokens:
logger.warning(
f"summary_context_size({self.summary_context_size}) should no greater than max_total_tokens({self.max_total_tokens})"
)
if self.summary_length_recommended > self.summary_max_tokens:
logger.warning(
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
)
# Init Embedding
# Step 1: Capture embedding_func and max_token_size before applying rate_limit decorator
original_embedding_func = self.embedding_func
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
logger.debug(
f"Captured embedding max_token_size: {embedding_max_token_size}"
)
self.embedding_token_limit = embedding_max_token_size
# Fix global_config now
global_config = asdict(self)
# Restore original EmbeddingFunc object (asdict converts it to dict)
global_config["embedding_func"] = original_embedding_func
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Step 2: Apply priority wrapper decorator to EmbeddingFunc's inner func
# Create a NEW EmbeddingFunc instance with the wrapped func to avoid mutating the caller's object
# This ensures _generate_collection_suffix can still access attributes (model_name, embedding_dim)
# while preventing side effects when the same EmbeddingFunc is reused across multiple LightRAG instances
if self.embedding_func is not None:
wrapped_func = priority_limit_async_func_call(
self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,
queue_name="Embedding func",
)(self.embedding_func.func)
# Use dataclasses.replace() to create a new instance, leaving the original unchanged
self.embedding_func = replace(self.embedding_func, func=wrapped_func)
# Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
) # type: ignore
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
) # type: ignore
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
) # type: ignore
self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
workspace=self.workspace,
global_config=global_config,
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_TEXT_CHUNKS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_FULL_DOCS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.full_entities: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_FULL_ENTITIES,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.full_relations: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_FULL_RELATIONS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.entity_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_ENTITY_CHUNKS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.relation_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_RELATION_CHUNKS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=NameSpace.VECTOR_STORE_ENTITIES,
workspace=self.workspace,
embedding_func=self.embedding_func,
meta_fields={"entity_name", "source_id", "content", "file_path"},
)
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=NameSpace.VECTOR_STORE_RELATIONSHIPS,
workspace=self.workspace,
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id", "source_id", "content", "file_path"},
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=NameSpace.VECTOR_STORE_CHUNKS,
workspace=self.workspace,
embedding_func=self.embedding_func,
meta_fields={"full_doc_id", "content", "file_path"},
)
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=NameSpace.DOC_STATUS,
workspace=self.workspace,
global_config=global_config,
embedding_func=None,
)
# Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache
# Get timeout from LLM model kwargs for dynamic timeout calculation
self.llm_model_func = priority_limit_async_func_call(
self.llm_model_max_async,
llm_timeout=self.default_llm_timeout,
queue_name="LLM func",
)(
partial(
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
)
self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self):
"""Storage initialization must be called one by one to prevent deadlock"""
if self._storages_status == StoragesStatus.CREATED:
# Set the first initialized workspace will set the default workspace
# Allows namespace operation without specifying workspace for backward compatibility
default_workspace = get_default_workspace()
if default_workspace is None:
set_default_workspace(self.workspace)
elif default_workspace != self.workspace:
logger.info(
f"Creating LightRAG instance with workspace='{self.workspace}' "
f"while default workspace is set to '{default_workspace}'"
)
# Auto-initialize pipeline_status for this workspace
from lightrag.kg.shared_storage import initialize_pipeline_status
await initialize_pipeline_status(workspace=self.workspace)
for storage in (
self.full_docs,
self.text_chunks,
self.full_entities,
self.full_relations,
self.entity_chunks,
self.relation_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
# logger.debug(f"Initializing storage: {storage}")
await storage.initialize()
self._storages_status = StoragesStatus.INITIALIZED
logger.debug("All storage types initialized")
async def finalize_storages(self):
"""Asynchronously finalize the storages with improved error handling"""
if self._storages_status == StoragesStatus.INITIALIZED:
storages = [
("full_docs", self.full_docs),
("text_chunks", self.text_chunks),
("full_entities", self.full_entities),
("full_relations", self.full_relations),
("entity_chunks", self.entity_chunks),
("relation_chunks", self.relation_chunks),
("entities_vdb", self.entities_vdb),
("relationships_vdb", self.relationships_vdb),
("chunks_vdb", self.chunks_vdb),
("chunk_entity_relation_graph", self.chunk_entity_relation_graph),
("llm_response_cache", self.llm_response_cache),
("doc_status", self.doc_status),
]
# Finalize each storage individually to ensure one failure doesn't prevent others from closing
successful_finalizations = []
failed_finalizations = []
for storage_name, storage in storages:
if storage:
try:
await storage.finalize()
successful_finalizations.append(storage_name)
logger.debug(f"Successfully finalized {storage_name}")
except Exception as e:
error_msg = f"Failed to finalize {storage_name}: {e}"
logger.error(error_msg)
failed_finalizations.append(storage_name)
# Log summary of finalization results
if successful_finalizations:
logger.info(
f"Successfully finalized {len(successful_finalizations)} storages"
)
if failed_finalizations:
logger.error(
f"Failed to finalize {len(failed_finalizations)} storages: {', '.join(failed_finalizations)}"
)
else:
logger.debug("All storages finalized successfully")
self._storages_status = StoragesStatus.FINALIZED
async def check_and_migrate_data(self):
"""Check if data migration is needed and perform migration if necessary"""
async with get_data_init_lock():
try:
# Check if migration is needed:
# 1. chunk_entity_relation_graph has entities and relations (count > 0)
# 2. full_entities and full_relations are empty
# Get all entity labels from graph
all_entity_labels = (
await self.chunk_entity_relation_graph.get_all_labels()
)
if not all_entity_labels:
logger.debug("No entities found in graph, skipping migration check")
return
try:
# Initialize chunk tracking storage after migration
await self._migrate_chunk_tracking_storage()
except Exception as e:
logger.error(f"Error during chunk_tracking migration: {e}")
raise e
# Check if full_entities and full_relations are empty
# Get all processed documents to check their entity/relation data
try:
processed_docs = await self.doc_status.get_docs_by_status(
DocStatus.PROCESSED
)
if not processed_docs:
logger.debug("No processed documents found, skipping migration")
return
# Check first few documents to see if they have full_entities/full_relations data
migration_needed = True
checked_count = 0
max_check = min(5, len(processed_docs)) # Check up to 5 documents
for doc_id in list(processed_docs.keys())[:max_check]:
checked_count += 1
entity_data = await self.full_entities.get_by_id(doc_id)
relation_data = await self.full_relations.get_by_id(doc_id)
if entity_data or relation_data:
migration_needed = False
break
if not migration_needed:
logger.debug(
"Full entities/relations data already exists, no migration needed"
)
return
logger.info(
f"Data migration needed: found {len(all_entity_labels)} entities in graph but no full_entities/full_relations data"
)
# Perform migration
await self._migrate_entity_relation_data(processed_docs)
except Exception as e:
logger.error(f"Error during migration check: {e}")
raise e
except Exception as e:
logger.error(f"Error in data migration check: {e}")
raise e
async def _migrate_entity_relation_data(self, processed_docs: dict):
"""Migrate existing entity and relation data to full_entities and full_relations storage"""
logger.info(f"Starting data migration for {len(processed_docs)} documents")
# Create mapping from chunk_id to doc_id
chunk_to_doc = {}
for doc_id, doc_status in processed_docs.items():
chunk_ids = (
doc_status.chunks_list
if hasattr(doc_status, "chunks_list") and doc_status.chunks_list
else []
)
for chunk_id in chunk_ids:
chunk_to_doc[chunk_id] = doc_id
# Initialize document entity and relation mappings
doc_entities = {} # doc_id -> set of entity_names
doc_relations = {} # doc_id -> set of relation_pairs (as tuples)
# Get all nodes and edges from graph
all_nodes = await self.chunk_entity_relation_graph.get_all_nodes()
all_edges = await self.chunk_entity_relation_graph.get_all_edges()
# Process all nodes once
for node in all_nodes:
if "source_id" in node:
entity_id = node.get("entity_id") or node.get("id")
if not entity_id:
continue
# Get chunk IDs from source_id
source_ids = node["source_id"].split(GRAPH_FIELD_SEP)
# Find which documents this entity belongs to
for chunk_id in source_ids:
doc_id = chunk_to_doc.get(chunk_id)
if doc_id:
if doc_id not in doc_entities:
doc_entities[doc_id] = set()
doc_entities[doc_id].add(entity_id)
# Process all edges once
for edge in all_edges:
if "source_id" in edge:
src = edge.get("source")
tgt = edge.get("target")
if not src or not tgt:
continue
# Get chunk IDs from source_id
source_ids = edge["source_id"].split(GRAPH_FIELD_SEP)
# Find which documents this relation belongs to
for chunk_id in source_ids:
doc_id = chunk_to_doc.get(chunk_id)
if doc_id:
if doc_id not in doc_relations:
doc_relations[doc_id] = set()
# Use tuple for set operations, convert to list later
doc_relations[doc_id].add(tuple(sorted((src, tgt))))
# Store the results in full_entities and full_relations
migration_count = 0
# Store entities
if doc_entities:
entities_data = {}
for doc_id, entity_set in doc_entities.items():
entities_data[doc_id] = {
"entity_names": list(entity_set),
"count": len(entity_set),
}
await self.full_entities.upsert(entities_data)
# Store relations
if doc_relations:
relations_data = {}
for doc_id, relation_set in doc_relations.items():
# Convert tuples back to lists
relations_data[doc_id] = {
"relation_pairs": [list(pair) for pair in relation_set],
"count": len(relation_set),
}
await self.full_relations.upsert(relations_data)
migration_count = len(
set(list(doc_entities.keys()) + list(doc_relations.keys()))
)
# Persist the migrated data
await self.full_entities.index_done_callback()
await self.full_relations.index_done_callback()
logger.info(
f"Data migration completed: migrated {migration_count} documents with entities/relations"
)
async def _migrate_chunk_tracking_storage(self) -> None:
"""Ensure entity/relation chunk tracking KV stores exist and are seeded."""
if not self.entity_chunks or not self.relation_chunks:
return
need_entity_migration = False
need_relation_migration = False
try:
need_entity_migration = await self.entity_chunks.is_empty()
except Exception as exc: # pragma: no cover - defensive logging
logger.error(f"Failed to check entity chunks storage: {exc}")
raise exc
try:
need_relation_migration = await self.relation_chunks.is_empty()
except Exception as exc: # pragma: no cover - defensive logging
logger.error(f"Failed to check relation chunks storage: {exc}")
raise exc
if not need_entity_migration and not need_relation_migration:
return