Skip to content

Commit b24ae75

Browse files
KevinHuShclifftseng
authored andcommitted
Fix: parent-children chunking method. (infiniflow#11997)
### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
1 parent 0e600f3 commit b24ae75

File tree

10 files changed

+160
-57
lines changed

10 files changed

+160
-57
lines changed

api/apps/canvas_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ async def run():
147147
if cvs.canvas_category == CanvasCategory.DataFlow:
148148
task_id = get_uuid()
149149
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
150-
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
150+
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
151151
if not ok:
152152
return get_data_error_result(message=error_message)
153153
return get_json_result(data={"message_id": task_id})

api/apps/chunk_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ async def _retrieval():
386386
LLMBundle(kb.tenant_id, LLMType.CHAT))
387387
if ck["content_with_weight"]:
388388
ranks["chunks"].insert(0, ck)
389+
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
389390

390391
for c in ranks["chunks"]:
391392
c.pop("vector", None)

common/metadata_utils.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
from typing import Any, Callable
16+
import logging
17+
from typing import Any, Callable, Dict
18+
19+
import json_repair
1720

1821
from rag.prompts.generator import gen_meta_filter
1922

@@ -140,3 +143,63 @@ async def apply_meta_data_filter(
140143
doc_ids = ["-999"]
141144

142145
return doc_ids
146+
147+
148+
def update_metadata_to(metadata, meta):
149+
if not meta:
150+
return metadata
151+
if isinstance(meta, str):
152+
try:
153+
meta = json_repair.loads(meta)
154+
except Exception:
155+
logging.error("Meta data format error.")
156+
return metadata
157+
if not isinstance(meta, dict):
158+
return metadata
159+
for k, v in meta.items():
160+
if isinstance(v, list):
161+
v = [vv for vv in v if isinstance(vv, str)]
162+
if not v:
163+
continue
164+
if not isinstance(v, list) and not isinstance(v, str):
165+
continue
166+
if k not in metadata:
167+
metadata[k] = v
168+
continue
169+
if isinstance(metadata[k], list):
170+
if isinstance(v, list):
171+
metadata[k].extend(v)
172+
else:
173+
metadata[k].append(v)
174+
else:
175+
metadata[k] = v
176+
177+
return metadata
178+
179+
180+
def metadata_schema(metadata: list|None) -> Dict[str, Any]:
181+
if not metadata:
182+
return {}
183+
properties = {}
184+
185+
for item in metadata:
186+
key = item.get("key")
187+
if not key:
188+
continue
189+
190+
prop_schema = {
191+
"description": item.get("description", "")
192+
}
193+
if "enum" in item and item["enum"]:
194+
prop_schema["enum"] = item["enum"]
195+
prop_schema["type"] = "string"
196+
197+
properties[key] = prop_schema
198+
199+
json_schema = {
200+
"type": "object",
201+
"properties": properties,
202+
}
203+
204+
json_schema["additionalProperties"] = False
205+
return json_schema

graphrag/general/extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _chat(self, system, history, gen_conf={}, task_id=""):
7878
raise TaskCanceledException(f"Task {task_id} was cancelled")
7979

8080
try:
81-
response = self._llm.chat(system_msg[0]["content"], hist, conf)
81+
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
8282
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
8383
if response.find("**ERROR**") >= 0:
8484
raise Exception(response)

rag/app/naive.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
635635
"parser_config", {
636636
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
637637

638-
child_deli = re.findall(r"`([^`]+)`", parser_config.get("children_delimiter", ""))
639-
child_deli = sorted(set(child_deli), key=lambda x: -len(x))
640-
child_deli = "|".join(re.escape(t) for t in child_deli if t)
638+
child_deli = parser_config.get("children_delimiter", "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
639+
cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
640+
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
641+
if cust_child_deli:
642+
cust_child_deli = sorted(set(cust_child_deli), key=lambda x: -len(x))
643+
cust_child_deli = "|".join(re.escape(t) for t in cust_child_deli if t)
644+
child_deli += cust_child_deli
645+
641646
is_markdown = False
642647
table_context_size = max(0, int(parser_config.get("table_context_size", 0) or 0))
643648
image_context_size = max(0, int(parser_config.get("image_context_size", 0) or 0))

rag/flow/splitter/splitter.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,7 @@ async def _invoke(self, **kwargs):
6060
deli += f"`{d}`"
6161
else:
6262
deli += d
63-
child_deli = ""
64-
for d in self._param.children_delimiters:
65-
if len(d) > 1:
66-
child_deli += f"`{d}`"
67-
else:
68-
child_deli += d
69-
child_deli = [m.group(1) for m in re.finditer(r"`([^`]+)`", child_deli)]
70-
custom_pattern = "|".join(re.escape(t) for t in sorted(set(child_deli), key=len, reverse=True))
63+
custom_pattern = "|".join(re.escape(t) for t in sorted(set(self._param.children_delimiters), key=len, reverse=True))
7164

7265
self.set_output("output_format", "chunks")
7366
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")

rag/nlp/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,21 @@ def tokenize(d, txt, eng):
273273
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
274274

275275

276+
def split_with_pattern(d, pattern:str, content:str, eng) -> list:
277+
docs = []
278+
txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)]
279+
for j in range(0, len(txts), 2):
280+
txt = txts[j]
281+
if not txt:
282+
continue
283+
if j + 1 < len(txts):
284+
txt += txts[j+1]
285+
dd = copy.deepcopy(d)
286+
tokenize(dd, txt, eng)
287+
docs.append(dd)
288+
return docs
289+
290+
276291
def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=None):
277292
res = []
278293
# wrap up as es documents
@@ -293,10 +308,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=
293308

294309
if child_delimiters_pattern:
295310
d["mom_with_weight"] = ck
296-
for txt in re.split(r"(%s)" % child_delimiters_pattern, ck, flags=re.DOTALL):
297-
dd = copy.deepcopy(d)
298-
tokenize(dd, txt, eng)
299-
res.append(dd)
311+
res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng))
300312
continue
301313

302314
tokenize(d, ck, eng)
@@ -316,10 +328,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images, child_delimiters_patte
316328
add_positions(d, [[ii]*5])
317329
if child_delimiters_pattern:
318330
d["mom_with_weight"] = ck
319-
for txt in re.split(r"(%s)" % child_delimiters_pattern, ck, flags=re.DOTALL):
320-
dd = copy.deepcopy(d)
321-
tokenize(dd, txt, eng)
322-
res.append(dd)
331+
res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng))
323332
continue
324333
tokenize(d, ck, eng)
325334
res.append(d)

rag/prompts/generator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,3 +821,13 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
821821
except Exception as e:
822822
logging.exception(e)
823823
return []
824+
825+
826+
META_DATA = load_prompt("meta_data")
827+
async def gen_metadata(chat_mdl, schema:dict, content:str):
828+
template = PROMPT_JINJA_ENV.from_string(META_DATA)
829+
system_prompt = template.render(content=content, schema=schema)
830+
user_prompt = "Output: "
831+
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
832+
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
833+
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)

rag/prompts/meta_data.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Extract important structured information from the given content.
2+
Output ONLY a valid JSON string with no additional text.
3+
If no important structured information is found, output an empty JSON object: {}.
4+
5+
Important structured information structure as following:
6+
7+
{{ schema }}
8+
9+
---------------------------
10+
The given content as following:
11+
12+
{{ content }}
13+

rag/svr/task_executor.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323
import threading
2424
import time
2525

26-
import json_repair
27-
2826
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
2927
from api.db.services.knowledgebase_service import KnowledgebaseService
3028
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
3129
from common.connection_utils import timeout
30+
from common.metadata_utils import update_metadata_to, metadata_schema
3231
from rag.utils.base64_image import image2id
3332
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
3433
from common.log_utils import init_root_logger
3534
from common.config_utils import show_configs
3635
from graphrag.general.index import run_graphrag_for_kb
3736
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
38-
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text
37+
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \
38+
gen_metadata
3939
import logging
4040
import os
4141
from datetime import datetime
@@ -368,6 +368,45 @@ async def doc_question_proposal(chat_mdl, d, topn):
368368
raise
369369
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
370370

371+
if task["parser_config"].get("enable_metadata", False) and task["parser_config"].get("metadata"):
372+
st = timer()
373+
progress_callback(msg="Start to generate meta-data for every chunk ...")
374+
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
375+
376+
async def gen_metadata_task(chat_mdl, d):
377+
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata")
378+
if not cached:
379+
async with chat_limiter:
380+
cached = await gen_metadata(chat_mdl,
381+
metadata_schema(task["parser_config"]["metadata"]),
382+
d["content_with_weight"])
383+
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata")
384+
if cached:
385+
d["metadata_obj"] = cached
386+
tasks = []
387+
for d in docs:
388+
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
389+
try:
390+
await asyncio.gather(*tasks, return_exceptions=False)
391+
except Exception as e:
392+
logging.error("Error in doc_question_proposal", exc_info=e)
393+
for t in tasks:
394+
t.cancel()
395+
await asyncio.gather(*tasks, return_exceptions=True)
396+
raise
397+
metadata = {}
398+
for ck in cks:
399+
metadata = update_metadata_to(metadata, ck["metadata_obj"])
400+
del ck["metadata_obj"]
401+
if metadata:
402+
e, doc = DocumentService.get_by_id(task["doc_id"])
403+
if e:
404+
if isinstance(doc.meta_fields, str):
405+
doc.meta_fields = json.loads(doc.meta_fields)
406+
metadata = update_metadata_to(metadata, doc.meta_fields)
407+
DocumentService.update_by_id(task["doc_id"], {"meta_fields": metadata})
408+
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
409+
371410
if task["kb_parser_config"].get("tag_kb_ids", []):
372411
progress_callback(msg="Start to tag for every chunk ...")
373412
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
@@ -602,36 +641,6 @@ def batch_encode(txts):
602641

603642

604643
metadata = {}
605-
def dict_update(meta):
606-
nonlocal metadata
607-
if not meta:
608-
return
609-
if isinstance(meta, str):
610-
try:
611-
meta = json_repair.loads(meta)
612-
except Exception:
613-
logging.error("Meta data format error.")
614-
return
615-
if not isinstance(meta, dict):
616-
return
617-
for k, v in meta.items():
618-
if isinstance(v, list):
619-
v = [vv for vv in v if isinstance(vv, str)]
620-
if not v:
621-
continue
622-
if not isinstance(v, list) and not isinstance(v, str):
623-
continue
624-
if k not in metadata:
625-
metadata[k] = v
626-
continue
627-
if isinstance(metadata[k], list):
628-
if isinstance(v, list):
629-
metadata[k].extend(v)
630-
else:
631-
metadata[k].append(v)
632-
else:
633-
metadata[k] = v
634-
635644
for ck in chunks:
636645
ck["doc_id"] = doc_id
637646
ck["kb_id"] = [str(task["kb_id"])]
@@ -656,7 +665,7 @@ def dict_update(meta):
656665
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
657666
del ck["summary"]
658667
if "metadata" in ck:
659-
dict_update(ck["metadata"])
668+
metadata = update_metadata_to(metadata, ck["metadata"])
660669
del ck["metadata"]
661670
if "content_with_weight" not in ck:
662671
ck["content_with_weight"] = ck["text"]
@@ -670,7 +679,7 @@ def dict_update(meta):
670679
if e:
671680
if isinstance(doc.meta_fields, str):
672681
doc.meta_fields = json.loads(doc.meta_fields)
673-
dict_update(doc.meta_fields)
682+
metadata = update_metadata_to(metadata, doc.meta_fields)
674683
DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
675684

676685
start_ts = timer()

0 commit comments

Comments
 (0)