Skip to content

Commit 1473000

Browse files
authored
Implement retrieval_test in GO (infiniflow#14231)
### What problem does this PR solve? Implement retrieval_test in GO ### Type of change - [x] Refactoring
1 parent aadd9a3 commit 1473000

42 files changed

Lines changed: 4735 additions & 1522 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

api/apps/chunk_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ async def _retrieval():
157157
if ck["content_with_weight"]:
158158
ranks["chunks"].insert(0, ck)
159159
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
160+
ranks["total"] = len(ranks["chunks"])
160161

161162
for c in ranks["chunks"]:
162163
c.pop("vector", None)

conf/models/siliconflow.json

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"name": "SILICONFLOW",
3+
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,IMAGE2TEXT",
4+
"url": {
5+
"default": "https://api.siliconflow.cn/v1"
6+
},
7+
"url_suffix": {
8+
"chat": "chat/completions",
9+
"async_chat": "async/chat/completions",
10+
"async_result": "async-result",
11+
"embedding": "embedding",
12+
"rerank": "rerank"
13+
},
14+
"models": [
15+
{
16+
"name": "BAAI/bge-reranker-v2-m3",
17+
"max_tokens": 8192,
18+
"model_types": [
19+
"rerank"
20+
],
21+
"features": {}
22+
}
23+
]
24+
}
25+
26+

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ require (
88
github.com/aws/aws-sdk-go-v2/credentials v1.19.11
99
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4
1010
github.com/aws/smithy-go v1.24.2
11+
github.com/cespare/xxhash/v2 v2.3.0
1112
github.com/elastic/go-elasticsearch/v8 v8.19.1
1213
github.com/gin-gonic/gin v1.9.1
1314
github.com/google/uuid v1.6.0
@@ -43,7 +44,6 @@ require (
4344
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect
4445
github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect
4546
github.com/bytedance/sonic v1.9.1 // indirect
46-
github.com/cespare/xxhash/v2 v2.3.0 // indirect
4747
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
4848
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
4949
github.com/dustin/go-humanize v1.0.1 // indirect
@@ -106,4 +106,4 @@ require (
106106
gopkg.in/ini.v1 v1.67.0 // indirect
107107
)
108108

109-
replace github.com/infiniflow/infinity-go-sdk => github.com/infiniflow/infinity/go v0.0.0-20260331112649-9bcd52a3d364
109+
replace github.com/infiniflow/infinity-go-sdk => github.com/infiniflow/infinity/go v0.0.0-20260424025959-72028e662929

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
9898
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
9999
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
100100
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
101-
github.com/infiniflow/infinity/go v0.0.0-20260331112649-9bcd52a3d364 h1:0v5TjSirmCAUX3oaIV8Rd9d5B+kHPdymveETUU8OcC0=
102-
github.com/infiniflow/infinity/go v0.0.0-20260331112649-9bcd52a3d364/go.mod h1:hw3z5AwNFsGy1cdrE0Mfjot2y9jqVHTxBufUx9VzZ+0=
101+
github.com/infiniflow/infinity/go v0.0.0-20260424025959-72028e662929 h1:0M1BNouFVpnF12XEmF/42aR8CRU0bt/rMEVEsRUtSfQ=
102+
github.com/infiniflow/infinity/go v0.0.0-20260424025959-72028e662929/go.mod h1:hw3z5AwNFsGy1cdrE0Mfjot2y9jqVHTxBufUx9VzZ+0=
103103
github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a h1:Inib12UR9HAfBubrGNraPjKt/Cu8xPbTJbC50+0wP5U=
104104
github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a/go.mod h1:8N0Hlye5Lzw+H/yHWpZMkT0QLA+iOHG7KLdvAm95DZg=
105105
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=

internal/cli/user_parser.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1907,7 +1907,7 @@ func (p *Parser) parseInsertDatasetFromFile() (*Command, error) {
19071907
}
19081908

19091909
// Internal CLI for GO
1910-
// parseInsertMetadataFromFile parses: INSERT INTO METADATA FROM FILE "file_path"
1910+
// parseInsertMetadataFromFile parses: INSERT METADATA FROM FILE "file_path"
19111911
func (p *Parser) parseInsertMetadataFromFile() (*Command, error) {
19121912
p.nextToken() // consume METADATA
19131913

@@ -2617,6 +2617,7 @@ func (p *Parser) parseUpdateCommand() (*Command, error) {
26172617
return nil, fmt.Errorf("unknown UPDATE target: %s", p.curToken.Value)
26182618
}
26192619

2620+
// Internal CLI for GO
26202621
// parseUpdateChunk parses: UPDATE CHUNK 'chunk_id' OF DATASET 'dataset_name' SET '{"content": "..."}'
26212622
func (p *Parser) parseUpdateChunk() (*Command, error) {
26222623
p.nextToken() // consume CHUNK

internal/common/constants.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package common
2+
3+
const (
4+
// PAGERANK_FLD is the field name for pagerank score
5+
PAGERANK_FLD = "pagerank_fea"
6+
// TAG_FLD is the field name for tag features
7+
TAG_FLD = "tag_feas"
8+
)

internal/dao/tenant_llm.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package dao
1818

1919
import (
20+
"fmt"
2021
"ragflow/internal/entity"
2122
)
2223

@@ -28,6 +29,16 @@ func NewTenantLLMDAO() *TenantLLMDAO {
2829
return &TenantLLMDAO{}
2930
}
3031

32+
// GetByID get tenant LLM by primary key ID
33+
func (dao *TenantLLMDAO) GetByID(id int64) (*entity.TenantLLM, error) {
34+
var tenantLLM entity.TenantLLM
35+
err := DB.Where("id = ?", id).First(&tenantLLM).Error
36+
if err != nil {
37+
return nil, err
38+
}
39+
return &tenantLLM, nil
40+
}
41+
3142
// GetByTenantAndModelName get tenant LLM by tenant ID and model name
3243
func (dao *TenantLLMDAO) GetByTenantAndModelName(tenantID, providerName string, modelName string) (*entity.TenantLLM, error) {
3344
var tenantLLM entity.TenantLLM
@@ -38,6 +49,16 @@ func (dao *TenantLLMDAO) GetByTenantAndModelName(tenantID, providerName string,
3849
return &tenantLLM, nil
3950
}
4051

52+
// GetByTenantNameAndType get tenant LLM by tenant ID, model name, and model type
53+
func (dao *TenantLLMDAO) GetByTenantNameAndType(tenantID, modelName string, modelType entity.ModelType) (*entity.TenantLLM, error) {
54+
var tenantLLM entity.TenantLLM
55+
err := DB.Where("tenant_id = ? AND llm_name = ? AND model_type = ?", tenantID, modelName, modelType).First(&tenantLLM).Error
56+
if err != nil {
57+
return nil, err
58+
}
59+
return &tenantLLM, nil
60+
}
61+
4162
// GetByTenantAndType get tenant LLM by tenant ID and model type
4263
func (dao *TenantLLMDAO) GetByTenantAndType(tenantID string, modelType entity.ModelType) (*entity.TenantLLM, error) {
4364
var tenantLLM entity.TenantLLM
@@ -268,3 +289,50 @@ func (dao *TenantLLMDAO) GetByTenantIDLLMNameAndFactory(tenantID, llmName, facto
268289
}
269290
return &tenantLLM, nil
270291
}
292+
293+
// LookupTenantLLMByID looks up a TenantLLM record by ID and returns the record plus composite model name.
294+
func LookupTenantLLMByID(tenantLLMDao *TenantLLMDAO, id int64) (*entity.TenantLLM, string, error) {
295+
tenantLLM, err := tenantLLMDao.GetByID(id)
296+
if err != nil {
297+
return nil, "", fmt.Errorf("failed to get tenant_llm by id %d: %w", id, err)
298+
}
299+
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
300+
return nil, "", fmt.Errorf("tenant_llm record not found for id %d", id)
301+
}
302+
compositeName := fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
303+
return tenantLLM, compositeName, nil
304+
}
305+
306+
// LookupTenantLLMByName looks up a TenantLLM record by tenant name and model type.
307+
func LookupTenantLLMByName(tenantLLMDao *TenantLLMDAO, tenantID, name string, modelType entity.ModelType) (*entity.TenantLLM, string, error) {
308+
// Parse factory from name if present (e.g., "model@Factory")
309+
modelName, factory := splitModelNameAndFactory(name)
310+
311+
// If factory is found, use factory-based lookup
312+
if factory != "" {
313+
return LookupTenantLLMByFactory(tenantLLMDao, tenantID, factory, modelName, modelType)
314+
}
315+
316+
tenantLLM, err := tenantLLMDao.GetByTenantNameAndType(tenantID, modelName, modelType)
317+
if err != nil {
318+
return nil, "", fmt.Errorf("failed to get tenant_llm by name %s: %w", name, err)
319+
}
320+
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
321+
return nil, "", fmt.Errorf("tenant_llm record not found for name %s", name)
322+
}
323+
compositeName := fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
324+
return tenantLLM, compositeName, nil
325+
}
326+
327+
// LookupTenantLLMByFactory looks up a TenantLLM record by tenant, factory, and model name.
328+
func LookupTenantLLMByFactory(tenantLLMDao *TenantLLMDAO, tenantID, factory, name string, modelType entity.ModelType) (*entity.TenantLLM, string, error) {
329+
tenantLLM, err := tenantLLMDao.GetByTenantFactoryAndModelName(tenantID, factory, name)
330+
if err != nil {
331+
return nil, "", fmt.Errorf("failed to get tenant_llm by factory %s and name %s: %w", factory, name, err)
332+
}
333+
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
334+
return nil, "", fmt.Errorf("tenant_llm record not found for factory %s and name %s", factory, name)
335+
}
336+
compositeName := fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
337+
return tenantLLM, compositeName, nil
338+
}

internal/engine/elasticsearch/get.go

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,31 @@ package elasticsearch
1919
import (
2020
"context"
2121
"fmt"
22+
23+
"ragflow/internal/engine/types"
2224
)
2325

2426
// GetChunk gets a chunk by ID
2527
func (e *elasticsearchEngine) GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error) {
26-
// Build query to get the chunk by ID
27-
query := map[string]interface{}{
28-
"term": map[string]interface{}{
28+
// Build unified search request to get the chunk by ID
29+
searchReq := &types.SearchRequest{
30+
IndexNames: []string{indexName},
31+
Limit: 1,
32+
Offset: 0,
33+
Filter: map[string]interface{}{
2934
"id": chunkID,
3035
},
3136
}
3237

33-
searchReq := &SearchRequest{
34-
IndexNames: []string{indexName},
35-
Query: query,
36-
Size: 1,
37-
From: 0,
38-
}
39-
4038
// Execute search
41-
result, err := e.Search(ctx, searchReq)
39+
searchResp, err := e.Search(ctx, searchReq)
4240
if err != nil {
4341
return nil, fmt.Errorf("failed to search: %w", err)
4442
}
4543

46-
esResp, ok := result.(*SearchResponse)
47-
if !ok {
48-
return nil, fmt.Errorf("invalid search response type")
49-
}
50-
51-
if len(esResp.Hits.Hits) == 0 {
44+
if len(searchResp.Chunks) == 0 {
5245
return nil, nil
5346
}
5447

55-
return esResp.Hits.Hits[0].Source, nil
56-
}
48+
return searchResp.Chunks[0], nil
49+
}

0 commit comments

Comments
 (0)