Skip to content

Commit e8f1a24

Browse files
authored
Feat:update check_embedding api (infiniflow#11254)
### What problem does this PR solve? pr: infiniflow#10854 change: update check_embedding api ### Type of change - [x] New Feature (non-breaking change which adds functionality)
1 parent 9084505 commit e8f1a24

2 files changed

Lines changed: 27 additions & 8 deletions

File tree

api/apps/kb_app.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import random
19+
import re
1920

2021
from flask import request
2122
from flask_login import login_required, current_user
@@ -847,8 +848,13 @@ def sample_random_chunks_with_vectors(
847848
"position_int": full_doc.get("position_int"),
848849
"top_int": full_doc.get("top_int"),
849850
"content_with_weight": full_doc.get("content_with_weight") or "",
851+
"question_kwd": full_doc.get("question_kwd") or []
850852
})
851853
return out
854+
855+
def _clean(s: str) -> str:
856+
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
857+
return s if s else "None"
852858
req = request.json
853859
kb_id = req.get("kb_id", "")
854860
embd_id = req.get("embd_id", "")
@@ -861,8 +867,10 @@ def sample_random_chunks_with_vectors(
861867

862868
results, eff_sims = [], []
863869
for ck in samples:
864-
txt = (ck.get("content_with_weight") or "").strip()
865-
if not txt:
870+
title = ck.get("doc_name") or "Title"
871+
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
872+
txt_in = _clean(txt_in)
873+
if not txt_in:
866874
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
867875
continue
868876

@@ -871,8 +879,16 @@ def sample_random_chunks_with_vectors(
871879
continue
872880

873881
try:
874-
qv, _ = emb_mdl.encode_queries(txt)
875-
sim = _cos_sim(qv, ck["vector"])
882+
v, _ = emb_mdl.encode([title, txt_in])
883+
sim_content = _cos_sim(v[1], ck["vector"])
884+
title_w = 0.1
885+
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
886+
sim_mix = _cos_sim(qv_mix, ck["vector"])
887+
sim = sim_content
888+
mode = "content_only"
889+
if sim_mix > sim:
890+
sim = sim_mix
891+
mode = "title+content"
876892
except Exception:
877893
return get_error_data_result(message="embedding failure")
878894

@@ -894,8 +910,9 @@ def sample_random_chunks_with_vectors(
894910
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
895911
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
896912
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
913+
"match_mode": mode,
897914
}
898-
if summary["avg_cos_sim"] > 0.99:
915+
if summary["avg_cos_sim"] > 0.9:
899916
return get_json_result(data={"summary": summary, "results": results})
900917
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
901918

rag/svr/task_executor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
442442
tk_count = 0
443443
if len(tts) == len(cnts):
444444
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
445-
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
445+
tts = np.tile(vts[0], (len(cnts), 1))
446446
tk_count += c
447447

448448
@timeout(60)
@@ -465,8 +465,10 @@ def batch_encode(txts):
465465
if not filename_embd_weight:
466466
filename_embd_weight = 0.1
467467
title_w = float(filename_embd_weight)
468-
vects = (title_w * tts + (1 - title_w) *
469-
cnts) if len(tts) == len(cnts) else cnts
468+
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
469+
vects = title_w * tts + (1 - title_w) * cnts
470+
else:
471+
vects = cnts
470472

471473
assert len(vects) == len(docs)
472474
vector_size = 0

0 commit comments

Comments
 (0)