Skip to content

Commit 2e02973

Browse files
committed
feat: Compatibility with rerank API output fields
1 parent d74ae91 commit 2e02973

File tree

2 files changed

+184
-8
lines changed

2 files changed

+184
-8
lines changed

internal/models/rerank/reranker.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"strings"
8-
8+
99
"github.com/Tencent/WeKnora/internal/types"
1010
)
1111

@@ -26,17 +26,17 @@ type RankResult struct {
2626
Document DocumentInfo `json:"document"`
2727
RelevanceScore float64 `json:"relevance_score"`
2828
}
29-
//Handles the RelevanceScore field by checking if RelevanceScore exists first, otherwise falls back to Score field
29+
30+
// Handles the RelevanceScore field by checking if RelevanceScore exists first, otherwise falls back to Score field
3031
func (r *RankResult) UnmarshalJSON(data []byte) error {
3132

3233
var temp struct {
33-
Index int `json:"index"`
34-
Document DocumentInfo `json:"document"`
35-
RelevanceScore *float64 `json:"relevance_score"`
36-
Score *float64 `json:"score"`
34+
Index int `json:"index"`
35+
Document DocumentInfo `json:"document"`
36+
RelevanceScore *float64 `json:"relevance_score"`
37+
Score *float64 `json:"score"`
3738
}
3839

39-
4040
if err := json.Unmarshal(data, &temp); err != nil {
4141
return fmt.Errorf("failed to unmarshal rank result: %w", err)
4242
}
@@ -50,14 +50,34 @@ func (r *RankResult) UnmarshalJSON(data []byte) error {
5050
r.RelevanceScore = *temp.Score
5151
}
5252

53-
5453
return nil
5554
}
5655

5756
type DocumentInfo struct {
5857
Text string `json:"text"`
5958
}
6059

60+
// UnmarshalJSON handles both string and object formats for DocumentInfo
61+
func (d *DocumentInfo) UnmarshalJSON(data []byte) error {
62+
// First try to unmarshal as a string
63+
var text string
64+
if err := json.Unmarshal(data, &text); err == nil {
65+
d.Text = text
66+
return nil
67+
}
68+
69+
// If that fails, try to unmarshal as an object with text field
70+
var temp struct {
71+
Text string `json:"text"`
72+
}
73+
if err := json.Unmarshal(data, &temp); err != nil {
74+
return fmt.Errorf("failed to unmarshal DocumentInfo: %w", err)
75+
}
76+
77+
d.Text = temp.Text
78+
return nil
79+
}
80+
6181
type RerankerConfig struct {
6282
APIKey string
6383
BaseURL string
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package rerank
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
)
7+
8+
func TestRankResultUnmarshalJSON(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
input string
12+
expectedText string
13+
expectedIndex int
14+
expectedScore float64
15+
expectError bool
16+
}{
17+
{
18+
name: "document as string with relevance_score",
19+
input: `{"index": 0, "document": "This is a document", "relevance_score": 0.95}`,
20+
expectedText: "This is a document",
21+
expectedIndex: 0,
22+
expectedScore: 0.95,
23+
expectError: false,
24+
},
25+
{
26+
name: "document as object with relevance_score",
27+
input: `{"index": 1, "document": {"text": "This is a document"}, "relevance_score": 0.87}`,
28+
expectedText: "This is a document",
29+
expectedIndex: 1,
30+
expectedScore: 0.87,
31+
expectError: false,
32+
},
33+
{
34+
name: "document as string with score field",
35+
input: `{"index": 2, "document": "This is a document", "score": 0.92}`,
36+
expectedText: "This is a document",
37+
expectedIndex: 2,
38+
expectedScore: 0.92,
39+
expectError: false,
40+
},
41+
{
42+
name: "document as object with score field",
43+
input: `{"index": 3, "document": {"text": "This is a document"}, "score": 0.78}`,
44+
expectedText: "This is a document",
45+
expectedIndex: 3,
46+
expectedScore: 0.78,
47+
expectError: false,
48+
},
49+
{
50+
name: "document as string with both score fields - relevance_score takes priority",
51+
input: `{"index": 4, "document": "This is a document", "relevance_score": 0.95, "score": 0.80}`,
52+
expectedText: "This is a document",
53+
expectedIndex: 4,
54+
expectedScore: 0.95,
55+
expectError: false,
56+
},
57+
{
58+
name: "document as object with both score fields - relevance_score takes priority",
59+
input: `{"index": 5, "document": {"text": "This is a document"}, "relevance_score": 0.88, "score": 0.75}`,
60+
expectedText: "This is a document",
61+
expectedIndex: 5,
62+
expectedScore: 0.88,
63+
expectError: false,
64+
},
65+
{
66+
name: "document as string with no score fields",
67+
input: `{"index": 6, "document": "This is a document"}`,
68+
expectedText: "This is a document",
69+
expectedIndex: 6,
70+
expectedScore: 0.0,
71+
expectError: false,
72+
},
73+
{
74+
name: "document as object with no score fields",
75+
input: `{"index": 7, "document": {"text": "This is a document"}}`,
76+
expectedText: "This is a document",
77+
expectedIndex: 7,
78+
expectedScore: 0.0,
79+
expectError: false,
80+
},
81+
}
82+
83+
for _, tt := range tests {
84+
t.Run(tt.name, func(t *testing.T) {
85+
var result RankResult
86+
err := json.Unmarshal([]byte(tt.input), &result)
87+
88+
if tt.expectError {
89+
if err == nil {
90+
t.Errorf("Expected error but got none")
91+
}
92+
return
93+
}
94+
95+
if err != nil {
96+
t.Fatalf("Unmarshal failed: %v", err)
97+
}
98+
99+
if result.Document.Text != tt.expectedText {
100+
t.Errorf("Expected document text %q, got %q", tt.expectedText, result.Document.Text)
101+
}
102+
if result.Index != tt.expectedIndex {
103+
t.Errorf("Expected index %d, got %d", tt.expectedIndex, result.Index)
104+
}
105+
if result.RelevanceScore != tt.expectedScore {
106+
t.Errorf("Expected score %f, got %f", tt.expectedScore, result.RelevanceScore)
107+
}
108+
})
109+
}
110+
}
111+
112+
// TestDocumentInfoMarshalJSON tests that DocumentInfo can be marshaled back to JSON
113+
func TestDocumentInfoMarshalJSON(t *testing.T) {
114+
doc := DocumentInfo{Text: "Test document content"}
115+
116+
data, err := json.Marshal(doc)
117+
if err != nil {
118+
t.Fatalf("Marshal failed: %v", err)
119+
}
120+
121+
expected := `{"text":"Test document content"}`
122+
if string(data) != expected {
123+
t.Errorf("Expected %s, got %s", expected, string(data))
124+
}
125+
}
126+
127+
// TestRankResultMarshalJSON tests that RankResult can be marshaled back to JSON
128+
func TestRankResultMarshalJSON(t *testing.T) {
129+
result := RankResult{
130+
Index: 1,
131+
Document: DocumentInfo{Text: "Test document"},
132+
RelevanceScore: 0.95,
133+
}
134+
135+
data, err := json.Marshal(result)
136+
if err != nil {
137+
t.Fatalf("Marshal failed: %v", err)
138+
}
139+
140+
// Parse back to verify structure
141+
var parsed RankResult
142+
err = json.Unmarshal(data, &parsed)
143+
if err != nil {
144+
t.Fatalf("Round-trip unmarshal failed: %v", err)
145+
}
146+
147+
if parsed.Index != result.Index {
148+
t.Errorf("Index mismatch: expected %d, got %d", result.Index, parsed.Index)
149+
}
150+
if parsed.Document.Text != result.Document.Text {
151+
t.Errorf("Document text mismatch: expected %q, got %q", result.Document.Text, parsed.Document.Text)
152+
}
153+
if parsed.RelevanceScore != result.RelevanceScore {
154+
t.Errorf("Score mismatch: expected %f, got %f", result.RelevanceScore, parsed.RelevanceScore)
155+
}
156+
}

0 commit comments

Comments
 (0)