Skip to content

Commit dcbce13

Browse files
committed
Add support to dump agent episode trajectories
1 parent 50662f7 commit dcbce13

File tree

12 files changed

+238
-1
lines changed

12 files changed

+238
-1
lines changed

gollm/azopenai.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ func (c *AzureOpenAIChat) Initialize(messages []*api.Message) error {
275275
return nil
276276
}
277277

278+
func (c *AzureOpenAIChat) SaveMessages(path string) error {
279+
return fmt.Errorf("SaveMessages is not implemented for AzureOpenAI")
280+
}
281+
278282
func (c *AzureOpenAIChat) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) {
279283
// TODO: Implement streaming
280284
response, err := c.Send(ctx, contents...)

gollm/bedrock.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,10 @@ func (c *bedrockChat) IsRetryableError(err error) bool {
527527
return DefaultIsRetryableError(err)
528528
}
529529

530+
func (c *bedrockChat) SaveMessages(path string) error {
531+
return fmt.Errorf("SaveMessages is not implemented for Bedrock")
532+
}
533+
530534
// bedrockResponse implements ChatResponse for regular (non-streaming) responses
531535
type bedrockResponse struct {
532536
output *bedrockruntime.ConverseOutput

gollm/factory.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,7 @@ func (rc *retryChat[C]) IsRetryableError(err error) bool {
346346
func (rc *retryChat[C]) Initialize(messages []*api.Message) error {
347347
return rc.underlying.Initialize(messages)
348348
}
349+
350+
func (rc *retryChat[C]) SaveMessages(path string) error {
351+
return rc.underlying.SaveMessages(path)
352+
}

gollm/gemini.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ func (c *GoogleAIClient) StartChat(systemPrompt string, model string) Chat {
277277
TopP: &topP,
278278
MaxOutputTokens: maxOutputTokens,
279279
ResponseMIMEType: "text/plain",
280+
ThinkingConfig: &genai.ThinkingConfig{IncludeThoughts: true},
280281
},
281282
history: []*genai.Content{},
282283
}
@@ -502,6 +503,183 @@ func (c *GeminiChat) Initialize(messages []*api.Message) error {
502503
return nil
503504
}
504505

506+
func (c *GeminiChat) SaveMessages(path string) error {
507+
klog.Infof("Saving messages to %s", path)
508+
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
509+
if err != nil {
510+
return fmt.Errorf("failed to open history file: %w", err)
511+
}
512+
defer f.Close()
513+
514+
// We want to save serialized version of the list of messages.
515+
// Since we are appending line by line, we just marshal the whole history as one JSON object per line.
516+
// c.history is []*genai.Content
517+
518+
bytes, err := json.Marshal(c.history)
519+
if err != nil {
520+
return fmt.Errorf("failed to marshal history: %w", err)
521+
}
522+
523+
// Unmarshal to generic map to perform sanitization (stripping thoughtSignature and coalescing)
524+
var rawHistory []map[string]interface{}
525+
if err := json.Unmarshal(bytes, &rawHistory); err != nil {
526+
return fmt.Errorf("failed to unmarshal history for cleaning: %w", err)
527+
}
528+
529+
// Recursively remove thoughtSignature
530+
removeKey(rawHistory, "thoughtSignature")
531+
532+
// Coalesce history to merge consecutive messages and parts
533+
cleanHistory := coalesceHistory(rawHistory)
534+
535+
cleanBytes, err := json.Marshal(cleanHistory)
536+
if err != nil {
537+
return fmt.Errorf("failed to marshal clean history: %w", err)
538+
}
539+
540+
if _, err := f.Write(cleanBytes); err != nil {
541+
return fmt.Errorf("failed to write history to file: %w", err)
542+
}
543+
if _, err := f.WriteString("\n"); err != nil {
544+
return fmt.Errorf("failed to write newline to file: %w", err)
545+
}
546+
return nil
547+
}
548+
549+
func coalesceHistory(history []map[string]interface{}) []map[string]interface{} {
550+
if len(history) == 0 {
551+
return history
552+
}
553+
554+
var coalesced []map[string]interface{}
555+
var currentMsg map[string]interface{}
556+
557+
for _, msg := range history {
558+
if currentMsg == nil {
559+
currentMsg = msg
560+
continue
561+
}
562+
563+
// Check if we can merge with currentMsg
564+
// We merge if the role is the same
565+
currRole, _ := currentMsg["role"].(string)
566+
nextRole, _ := msg["role"].(string)
567+
568+
if currRole == nextRole && currRole != "" {
569+
// Merge parts
570+
currParts, ok1 := currentMsg["parts"].([]interface{})
571+
nextParts, ok2 := msg["parts"].([]interface{})
572+
if ok1 && ok2 {
573+
currentMsg["parts"] = append(currParts, nextParts...)
574+
}
575+
} else {
576+
coalesced = append(coalesced, currentMsg)
577+
currentMsg = msg
578+
}
579+
}
580+
if currentMsg != nil {
581+
coalesced = append(coalesced, currentMsg)
582+
}
583+
584+
// Now coalesce parts within each message
585+
for _, msg := range coalesced {
586+
if parts, ok := msg["parts"].([]interface{}); ok {
587+
msg["parts"] = coalesceParts(parts)
588+
}
589+
}
590+
591+
return coalesced
592+
}
593+
594+
func coalesceParts(parts []interface{}) []interface{} {
595+
if len(parts) == 0 {
596+
return parts
597+
}
598+
599+
var coalesced []interface{}
600+
var currentPart map[string]interface{}
601+
602+
for _, p := range parts {
603+
part, ok := p.(map[string]interface{})
604+
if !ok {
605+
// specific part is not a map, just append current and this one
606+
if currentPart != nil {
607+
coalesced = append(coalesced, currentPart)
608+
currentPart = nil
609+
}
610+
coalesced = append(coalesced, p)
611+
continue
612+
}
613+
614+
if currentPart == nil {
615+
currentPart = part
616+
continue
617+
}
618+
619+
// Check if we can merge consecutive parts
620+
// We can merge if:
621+
// 1. Both are text (and thought status matches)
622+
// 2. We do NOT merge function calls usually as they are distinct items, but if they are fragmented...
623+
// The user only mentioned "combining consecutive parts of same type (thought: true or functionCall or text)".
624+
// Visual observation of example shows text fragments.
625+
626+
// Check types
627+
isText1, text1 := getText(currentPart)
628+
isText2, text2 := getText(part)
629+
630+
isThought1 := isThought(currentPart)
631+
isThought2 := isThought(part)
632+
633+
if isText1 && isText2 && isThought1 == isThought2 {
634+
// Merge text
635+
currentPart["text"] = text1 + text2
636+
continue
637+
}
638+
639+
// If not mergeable, append current and start new
640+
coalesced = append(coalesced, currentPart)
641+
currentPart = part
642+
}
643+
if currentPart != nil {
644+
coalesced = append(coalesced, currentPart)
645+
}
646+
647+
return coalesced
648+
}
649+
650+
func getText(part map[string]interface{}) (bool, string) {
651+
if t, ok := part["text"].(string); ok {
652+
return true, t
653+
}
654+
return false, ""
655+
}
656+
657+
func isThought(part map[string]interface{}) bool {
658+
// check for "thought": true
659+
if v, ok := part["thought"].(bool); ok {
660+
return v
661+
}
662+
return false
663+
}
664+
665+
func removeKey(v interface{}, key string) {
666+
switch v := v.(type) {
667+
case []map[string]interface{}:
668+
for _, m := range v {
669+
removeKey(m, key)
670+
}
671+
case map[string]interface{}:
672+
delete(v, key)
673+
for _, val := range v {
674+
removeKey(val, key)
675+
}
676+
case []interface{}:
677+
for _, val := range v {
678+
removeKey(val, key)
679+
}
680+
}
681+
}
682+
505683
func (c *GeminiChat) messageToContent(msg *api.Message) (*genai.Content, error) {
506684
var role string
507685
switch msg.Source {

gollm/grok.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@ func (cs *grokChatSession) Initialize(messages []*api.Message) error {
385385
return nil
386386
}
387387

388+
func (cs *grokChatSession) SaveMessages(path string) error {
389+
return fmt.Errorf("SaveMessages is not implemented for Grok")
390+
}
391+
388392
// --- Helper structs for ChatResponse interface ---
389393

390394
type grokChatResponse struct {

gollm/interfaces.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ type Chat interface {
6262

6363
// Initialize initializes the chat with a previous conversation history.
6464
Initialize(messages []*api.Message) error
65+
66+
// SaveMessages saves the conversation history to a file.
67+
SaveMessages(path string) error
6568
}
6669

6770
// CompletionRequest is a request to generate a completion for a given prompt.

gollm/llamacpp.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ func (c *LlamaCppChat) Initialize(messages []*api.Message) error {
300300
return nil
301301
}
302302

303+
func (c *LlamaCppChat) SaveMessages(path string) error {
304+
return fmt.Errorf("SaveMessages is not implemented for LlamaCpp")
305+
}
306+
303307
func ptrTo[T any](t T) *T {
304308
return &t
305309
}

gollm/ollama.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ func (c *OllamaChat) Initialize(messages []*kctlApi.Message) error {
214214
return nil
215215
}
216216

217+
func (c *OllamaChat) SaveMessages(path string) error {
218+
return fmt.Errorf("SaveMessages is not implemented for Ollama")
219+
}
220+
217221
type OllamaChatResponse struct {
218222
candidates []*OllamaCandidate
219223
ollamaResponse api.ChatResponse

gollm/openai.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,10 @@ func (cs *openAIChatSession) Initialize(messages []*api.Message) error {
469469
return nil
470470
}
471471

472+
func (cs *openAIChatSession) SaveMessages(path string) error {
473+
return fmt.Errorf("SaveMessages is not implemented for OpenAI")
474+
}
475+
472476
// Helper structs for ChatResponse interface
473477

474478
type openAIChatResponse struct {

gollm/openai_response.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ func (cs *openAIResponseChatSession) IsRetryableError(err error) bool {
145145
}
146146

147147
func (cs *openAIResponseChatSession) Initialize(messages []*api.Message) error {
148-
klog.Warning("chat history persistence is not supported for provider 'openai', using in-memory chat history")
148+
klog.Warning("chat history persistence is not supported for provider 'openai-responses', using in-memory chat history")
149149
return nil
150150
}
151151

152+
func (cs *openAIResponseChatSession) SaveMessages(path string) error {
153+
return fmt.Errorf("SaveMessages is not implemented for OpenAI Response API")
154+
}
155+
152156
// Helper structs for ChatResponse interface
153157
type openAIResponseChatResponse struct {
154158
resp *responses.Response

0 commit comments

Comments
 (0)