Skip to content

Commit 02d9b4b

Browse files
author
Working On It
committed
fix: don't talk about carriers in CoT (casibase#1335)
1 parent 412156a commit 02d9b4b

File tree

3 files changed

+153
-2
lines changed

3 files changed

+153
-2
lines changed

controllers/carrier_writer.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2025 The Casibase Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package controllers
16+
17+
import (
18+
"bytes"
19+
)
20+
21+
type CarrierWriter struct {
22+
writerCleaner Cleaner
23+
messageBuf []byte
24+
}
25+
26+
func (w *CarrierWriter) Write(p []byte) (n int, err error) {
27+
var eventType string
28+
var data string
29+
30+
if bytes.HasPrefix(p, []byte("event: message")) {
31+
eventType = "message"
32+
prefix := []byte("event: message\ndata: ")
33+
suffix := []byte("\n\n")
34+
data = string(bytes.TrimSuffix(bytes.TrimPrefix(p, prefix), suffix))
35+
}
36+
37+
if eventType == "message" {
38+
w.messageBuf = append(w.messageBuf, []byte(data)...)
39+
}
40+
return len(p), nil
41+
}
42+
43+
func (w *CarrierWriter) MessageString() string {
44+
return string(w.messageBuf)
45+
}
46+
47+
func (w *CarrierWriter) Flush() {}

controllers/message_answer.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func (c *ApiController) GetMessageAnswer() {
205205
// fmt.Printf("Refined Question: [%s]\n", realQuestion)
206206
fmt.Printf("Answer: [")
207207

208-
if modelProvider.Type != "Dummy" {
208+
if modelProvider.Type != "Dummy" && !isReasonModel(modelProvider.SubType) {
209209
question, err = getQuestionWithCarriers(question, store.SuggestionCount, chat.NeedTitle)
210210
}
211211
if err != nil {
@@ -224,7 +224,11 @@ func (c *ApiController) GetMessageAnswer() {
224224
}
225225
modelResult, err = model.QueryTextWithTools(modelProviderObj, question, writer, history, store.Prompt, knowledge, agentInfo)
226226
} else {
227-
modelResult, err = modelProviderObj.QueryText(question, writer, history, store.Prompt, knowledge, nil)
227+
if isReasonModel(modelProvider.SubType) {
228+
modelResult, err = QueryCarrierText(question, writer, history, store.Prompt, knowledge, modelProviderObj, chat.NeedTitle, store.SuggestionCount)
229+
} else {
230+
modelResult, err = modelProviderObj.QueryText(question, writer, history, store.Prompt, knowledge, nil)
231+
}
228232
}
229233
if err != nil {
230234
if strings.Contains(err.Error(), "write tcp") {

controllers/message_carrier.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
package controllers
1616

1717
import (
18+
"fmt"
19+
"strings"
20+
"sync"
21+
1822
"github.com/casibase/casibase/carrier"
23+
"github.com/casibase/casibase/model"
1924
"github.com/casibase/casibase/object"
2025
)
2126

@@ -72,3 +77,98 @@ func parseAnswerWithCarriers(answer string, suggestionCount int, needTitle bool)
7277

7378
return parsedAnswer, suggestions, title, nil
7479
}
80+
81+
func isReasonModel(typ string) bool {
82+
typ = strings.ToLower(typ)
83+
if strings.Contains(typ, "r1") {
84+
return true
85+
} else if strings.Contains(typ, "reasoner") {
86+
return true
87+
}
88+
return false
89+
}
90+
91+
func getResultWithSuggestionsAndTitle(modelResult *model.ModelResult, writer *CarrierWriter, question string, modelProviderObj model.ModelProvider, needTitle bool, suggestionCount int) (*model.ModelResult, error) {
92+
var fullPrompt strings.Builder
93+
94+
fullPrompt.WriteString(fmt.Sprintf("User question: %s\n\n", question))
95+
if suggestionCount > 0 {
96+
divider := "|||"
97+
suggestionPrompt := fmt.Sprintf(`**Based on the user question, generate %d possible follow-up questions. No need to answer user quesion.
98+
They must:
99+
- Be in the same language as the original question.
100+
- Start with the separator "%s".
101+
- Be separated by "%s" without any other formatting or explanation.
102+
- Do not include any explanation, analysis, or answers—only output the %d questions.
103+
104+
`, suggestionCount, divider, divider, suggestionCount)
105+
fullPrompt.WriteString(suggestionPrompt)
106+
}
107+
if needTitle {
108+
fullPrompt.WriteString(`
109+
**Finally, generate a concise and meaningful title for the original question. No need to answer user quesion.
110+
111+
- The title must be in the same language.
112+
- The title must start with "=====" (five equals signs, no space).
113+
- Do not include the divider or title if a meaningful title cannot be generated.
114+
- Do NOT include any explanations or extra text—just output the title.`)
115+
}
116+
117+
carrierResult, err := modelProviderObj.QueryText(fullPrompt.String(), writer, nil, "", nil, nil)
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
return carrierResult, nil
123+
}
124+
125+
func QueryCarrierText(question string, writer *RefinedWriter, history []*model.RawMessage, prompt string, knowledge []*model.RawMessage, modelProviderObj model.ModelProvider, needTitle bool, suggestionCount int) (*model.ModelResult, error) {
126+
var (
127+
wg sync.WaitGroup
128+
mainErr error
129+
carrierErr error
130+
)
131+
132+
var modelResult *model.ModelResult
133+
134+
wg.Add(1)
135+
go func() {
136+
defer wg.Done()
137+
var err error
138+
modelResult, err = modelProviderObj.QueryText(question, writer, history, prompt, knowledge, nil)
139+
if err != nil {
140+
mainErr = err
141+
}
142+
}()
143+
144+
CarrierWriter := &CarrierWriter{*NewCleaner(6), []byte{}}
145+
var carrierResult *model.ModelResult
146+
147+
wg.Add(1)
148+
go func() {
149+
defer wg.Done()
150+
var err error
151+
carrierResult, err = getResultWithSuggestionsAndTitle(modelResult, CarrierWriter, question, modelProviderObj, needTitle, suggestionCount)
152+
if err != nil {
153+
carrierErr = err
154+
}
155+
}()
156+
157+
wg.Wait()
158+
159+
if mainErr != nil {
160+
return nil, mainErr
161+
}
162+
if carrierErr != nil {
163+
return nil, carrierErr
164+
}
165+
166+
modelResult.PromptTokenCount += carrierResult.PromptTokenCount
167+
modelResult.ResponseTokenCount += carrierResult.ResponseTokenCount
168+
modelResult.TotalPrice += carrierResult.TotalPrice
169+
modelResult.TotalTokenCount += carrierResult.TotalTokenCount
170+
171+
writer.Write(CarrierWriter.messageBuf)
172+
173+
return modelResult, nil
174+
}

0 commit comments

Comments
 (0)