Skip to content

Commit 968ad00

Browse files
committed
fix linting
1 parent d83171f commit 968ad00

File tree

5 files changed

+53
-28
lines changed

5 files changed

+53
-28
lines changed

examples/lightrag_gemini_demo_no_tiktoken.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ class _TokenizerConfig:
5151
"google/gemma3": _TokenizerConfig(
5252
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
5353
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
54-
)
55-
}
54+
),
55+
}
5656

57-
def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
57+
def __init__(
58+
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
59+
):
5860
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
5961
if "1.5" in model_name or "1.0" in model_name:
6062
# up to gemini 1.5 gemma2 is a comparable local tokenizer
@@ -77,7 +79,9 @@ def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional
7779
else:
7880
model_data = None
7981
if not model_data:
80-
model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
82+
model_data = self._load_from_url(
83+
file_url=file_url, expected_hash=expected_hash
84+
)
8185
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
8286

8387
tokenizer = spm.SentencePieceProcessor()
@@ -140,7 +144,7 @@ def _maybe_remove_file(file_path: Path) -> None:
140144

141145
# def encode(self, content: str) -> list[int]:
142146
# return self.tokenizer.encode(content)
143-
147+
144148
# def decode(self, tokens: list[int]) -> str:
145149
# return self.tokenizer.decode(tokens)
146150

@@ -187,7 +191,10 @@ async def initialize_rag():
187191
rag = LightRAG(
188192
working_dir=WORKING_DIR,
189193
# tiktoken_model_name="gpt-4o-mini",
190-
tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
194+
tokenizer=GemmaTokenizer(
195+
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
196+
model_name="gemini-2.0-flash",
197+
),
191198
llm_model_func=llm_model_func,
192199
embedding_func=EmbeddingFunc(
193200
embedding_dim=384,

lightrag/api/routers/ollama_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import asyncio
1111
from ascii_colors import trace_exception
1212
from lightrag import LightRAG, QueryParam
13-
from lightrag.utils import TiktokenTokenizer
13+
from lightrag.utils import TiktokenTokenizer
1414
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
1515
from fastapi import Depends
1616

lightrag/lightrag.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
from dataclasses import asdict, dataclass, field
88
from datetime import datetime
99
from functools import partial
10-
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
10+
from typing import (
11+
Any,
12+
AsyncIterator,
13+
Callable,
14+
Iterator,
15+
cast,
16+
final,
17+
Literal,
18+
Optional,
19+
List,
20+
Dict,
21+
)
1122

1223
from lightrag.kg import (
1324
STORAGES,
@@ -1139,11 +1150,7 @@ async def ainsert_custom_kg(
11391150
for chunk_data in custom_kg.get("chunks", []):
11401151
chunk_content = clean_text(chunk_data["content"])
11411152
source_id = chunk_data["source_id"]
1142-
tokens = len(
1143-
self.tokenizer.encode(
1144-
chunk_content
1145-
)
1146-
)
1153+
tokens = len(self.tokenizer.encode(chunk_content))
11471154
chunk_order_index = (
11481155
0
11491156
if "chunk_order_index" not in chunk_data.keys()

lightrag/operate.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def chunking_by_token_size(
8888
for index, start in enumerate(
8989
range(0, len(tokens), max_token_size - overlap_token_size)
9090
):
91-
chunk_content = tokenizer.decode(
92-
tokens[start : start + max_token_size]
93-
)
91+
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
9492
results.append(
9593
{
9694
"tokens": min(max_token_size, len(tokens) - start),
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
126124
if len(tokens) < summary_max_tokens: # No need for summary
127125
return description
128126
prompt_template = PROMPTS["summarize_entity_descriptions"]
129-
use_description = tokenizer.decode(
130-
tokens[:llm_max_tokens]
131-
)
127+
use_description = tokenizer.decode(tokens[:llm_max_tokens])
132128
context_base = dict(
133129
entity_name=entity_or_relation_name,
134130
description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -1380,10 +1376,15 @@ async def _get_node_data(
13801376
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
13811377
# get entitytext chunk
13821378
use_text_units = await _find_most_related_text_unit_from_entities(
1383-
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
1379+
node_datas,
1380+
query_param,
1381+
text_chunks_db,
1382+
knowledge_graph_inst,
13841383
)
13851384
use_relations = await _find_most_related_edges_from_entities(
1386-
node_datas, query_param, knowledge_graph_inst,
1385+
node_datas,
1386+
query_param,
1387+
knowledge_graph_inst,
13871388
)
13881389

13891390
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
@@ -1705,10 +1706,15 @@ async def _get_edge_data(
17051706
)
17061707
use_entities, use_text_units = await asyncio.gather(
17071708
_find_most_related_entities_from_relationships(
1708-
edge_datas, query_param, knowledge_graph_inst,
1709+
edge_datas,
1710+
query_param,
1711+
knowledge_graph_inst,
17091712
),
17101713
_find_related_text_unit_from_relationships(
1711-
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
1714+
edge_datas,
1715+
query_param,
1716+
text_chunks_db,
1717+
knowledge_graph_inst,
17121718
),
17131719
)
17141720
logger.info(

lightrag/utils.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass
1313
from functools import wraps
1414
from hashlib import md5
15-
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
15+
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
1616
import xml.etree.ElementTree as ET
1717
import numpy as np
1818
from lightrag.prompt import PROMPTS
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
311311
"""
312312
Defines the interface for a tokenizer, requiring encode and decode methods.
313313
"""
314+
314315
def encode(self, content: str) -> List[int]:
315316
"""Encodes a string into a list of tokens."""
316317
...
@@ -319,10 +320,12 @@ def decode(self, tokens: List[int]) -> str:
319320
"""Decodes a list of tokens into a string."""
320321
...
321322

323+
322324
class Tokenizer:
323325
"""
324326
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
325327
"""
328+
326329
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
327330
"""
328331
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
363366
"""
364367
A Tokenizer implementation using the tiktoken library.
365368
"""
369+
366370
def __init__(self, model_name: str = "gpt-4o-mini"):
367371
"""
368372
Initializes the TiktokenTokenizer with a specified model name.
@@ -385,9 +389,7 @@ def __init__(self, model_name: str = "gpt-4o-mini"):
385389
tokenizer = tiktoken.encoding_for_model(model_name)
386390
super().__init__(model_name=model_name, tokenizer=tokenizer)
387391
except KeyError:
388-
raise ValueError(
389-
f"Invalid model_name: {model_name}."
390-
)
392+
raise ValueError(f"Invalid model_name: {model_name}.")
391393

392394

393395
def pack_user_ass_to_openai_messages(*args: str):
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
424426

425427

426428
def truncate_list_by_token_size(
427-
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
429+
list_data: list[Any],
430+
key: Callable[[Any], str],
431+
max_token_size: int,
432+
tokenizer: Tokenizer,
428433
) -> list[int]:
429434
"""Truncate a list of data by token size"""
430435
if max_token_size <= 0:

0 commit comments

Comments
 (0)