-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Expand file tree
/
Copy pathdify_retrieval.py
More file actions
197 lines (187 loc) · 7.21 KB
/
dify_retrieval.py
File metadata and controls
197 lines (187 loc) · 7.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from quart import jsonify, request
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common.metadata_utils import meta_filter, convert_conditions
from api.utils.api_utils import apikey_required, build_error_result, get_json_result, get_request_json, validate_request
from rag.app.tag import label_question
from common.constants import RetCode, LLMType
from common import settings
@manager.route('/dify/retrieval', methods=['GET']) # noqa: F821
async def retrieval_health_check():
"""Health check endpoint for Dify external knowledge base connectivity verification."""
return get_json_result(data=True)
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required
@validate_request("knowledge_id", "query")
async def retrieval(tenant_id):
"""
Dify-compatible retrieval API
---
tags:
- SDK
security:
- ApiKeyAuth: []
parameters:
- in: body
name: body
required: true
schema:
type: object
required:
- knowledge_id
- query
properties:
knowledge_id:
type: string
description: Knowledge base ID
query:
type: string
description: Query text
use_kg:
type: boolean
description: Whether to use knowledge graph
default: false
retrieval_setting:
type: object
description: Retrieval configuration
properties:
score_threshold:
type: number
description: Similarity threshold
default: 0.0
top_k:
type: integer
description: Number of results to return
default: 1024
metadata_condition:
type: object
description: Metadata filter condition
properties:
conditions:
type: array
items:
type: object
properties:
name:
type: string
description: Field name
comparison_operator:
type: string
description: Comparison operator
value:
type: string
description: Field value
responses:
200:
description: Retrieval succeeded
schema:
type: object
properties:
records:
type: array
items:
type: object
properties:
content:
type: string
description: Content text
score:
type: number
description: Similarity score
title:
type: string
description: Document title
metadata:
type: object
description: Metadata info
404:
description: Knowledge base or document not found
"""
req = await get_request_json()
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])
doc_ids = []
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
if kb.tenant_embd_id:
model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, model_config)
if metadata_condition:
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if not doc_ids and metadata_condition:
doc_ids = ["-999"]
ranks = await settings.retriever.retrieval(
question,
embd_mdl,
kb.tenant_id,
[kb_id],
page=1,
page_size=top,
similarity_threshold=similarity_threshold,
vector_similarity_weight=0.3,
top=top,
doc_ids=doc_ids,
rank_feature=label_question(question, [kb])
)
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], [tenant_id])
if use_kg:
model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, model_config))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
records = []
for c in ranks["chunks"]:
e, doc = DocumentService.get_by_id(c["doc_id"])
c.pop("vector", None)
meta = getattr(doc, 'meta_fields', {})
meta["doc_id"] = c["doc_id"]
# Dify expects metadata.document_id for external retrieval sources.
meta["document_id"] = c["doc_id"]
records.append({
"content": c["content_with_weight"],
"score": c["similarity"],
"title": c["docnm_kwd"],
"metadata": meta
})
return jsonify({"records": records})
except Exception as e:
if str(e).find("not_found") > 0:
return build_error_result(
message='No chunk found! Check the chunk status please!',
code=RetCode.NOT_FOUND
)
logging.exception(e)
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)