Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions chats.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"
"iter"
"log"
"strings"
)

// Chats provides util functions for creating a new chat session.
Expand Down Expand Up @@ -56,12 +57,9 @@ func (c *Chats) Create(ctx context.Context, model string, config *GenerateConten
return chat, nil
}

func (c *Chat) recordHistory(ctx context.Context, inputContent *Content, outputContents []*Content) {
func (c *Chat) recordHistory(ctx context.Context, inputContent *Content, outputContent *Content) {
c.comprehensiveHistory = append(c.comprehensiveHistory, inputContent)

for _, outputContent := range outputContents {
c.comprehensiveHistory = append(c.comprehensiveHistory, copySanitizedModelContent(outputContent))
}
c.comprehensiveHistory = append(c.comprehensiveHistory, copySanitizedModelContent(outputContent))
}

// copySanitizedModelContent creates a (shallow) copy of modelContent with role set to
Expand Down Expand Up @@ -105,11 +103,9 @@ func (c *Chat) Send(ctx context.Context, parts ...*Part) (*GenerateContentRespon
}

// Record history. By default, use the first candidate for history.
var outputContents []*Content
if len(modelOutput.Candidates) > 0 && modelOutput.Candidates[0].Content != nil {
outputContents = append(outputContents, modelOutput.Candidates[0].Content)
c.recordHistory(ctx, inputContent, modelOutput.Candidates[0].Content)
}
c.recordHistory(ctx, inputContent, outputContents)

return modelOutput, err
}
Expand All @@ -136,7 +132,7 @@ func (c *Chat) SendStream(ctx context.Context, parts ...*Part) iter.Seq2[*Genera

// Return a new iterator that will yield the responses and record history with merged response.
return func(yield func(*GenerateContentResponse, error) bool) {
var outputContents []*Content
outputContent := &Content{}
for chunk, err := range response {
if err == io.EOF {
break
Expand All @@ -146,13 +142,54 @@ func (c *Chat) SendStream(ctx context.Context, parts ...*Part) iter.Seq2[*Genera
return
}
if len(chunk.Candidates) > 0 && chunk.Candidates[0].Content != nil {
outputContents = append(outputContents, chunk.Candidates[0].Content)
outputContent = joinContent(outputContent, chunk.Candidates[0].Content)
}
if !yield(chunk, nil) {
return
}
}
// Record history. By default, use the first candidate for history.
c.recordHistory(ctx, inputContent, outputContents)
c.recordHistory(ctx, inputContent, outputContent)
}
}

func joinContent(dest, src *Content) *Content {
if dest == nil {
return src
}
if src == nil {
return dest
}
// Assume roles are the same.
dest.Parts = joinParts(dest.Parts, src.Parts)
return dest
}

func joinParts(dest, src []*Part) []*Part {
return mergeTexts(append(dest, src...))
}

func mergeTexts(in []*Part) []*Part {
var out []*Part
i := 0
for i < len(in) {
if in[i].Text != "" {
texts := []string{in[i].Text}
var j int
for j = i + 1; j < len(in); j++ {
if in[j].Text != "" {
texts = append(texts, in[j].Text)
} else {
break
}
}
// j is just after the last Text.
out = append(out, NewPartFromText(strings.Join(texts, "")))
i = j
} else {
out = append(out, in[i])
i++
}
}
return out
}
47 changes: 28 additions & 19 deletions chats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"testing"

"cloud.google.com/go/auth"
"github.com/google/go-cmp/cmp"
)

func TestChatsUnitTest(t *testing.T) {
Expand Down Expand Up @@ -413,6 +414,8 @@ data:{
}

part := Part{Text: "What is 1 + 2?"}
expectedResponses := []string{"1 + ", "2", " = 3"}
i := 0

for result, err := range chat.SendMessageStream(ctx, part) {
if err != nil {
Expand All @@ -421,19 +424,25 @@ data:{
if result.Text() == "" {
t.Errorf("Response text should not be empty")
}
if result.Text() != expectedResponses[i] {
t.Errorf("Expected response to be %s, got %s", expectedResponses[i], result.Text())
}
i++
}

expectedResponses := []string{"1 + ", "2", " = 3"}
history := chat.History(false)
expectedUserMessage := "What is 1 + 2?"
if history[0].Parts[0].Text != expectedUserMessage {
t.Errorf("Expected history to start with %s, got %s", expectedUserMessage, history[0].Parts[0].Text)
}
for i, expectedResponse := range expectedResponses {
gotResponse := history[i+1].Parts[0].Text
if gotResponse != expectedResponse {
t.Errorf("Expected model response to be %s, got %s", expectedResponse, gotResponse)
}
// for i, expectedResponse := range expectedResponses {
// gotResponse := history[i+1].Parts[0].Text
// if gotResponse != expectedResponse {
// t.Errorf("Expected model response to be %s, got %s", expectedResponse, gotResponse)
// }
// }
if history[1].Parts[0].Text != "1 + 2 = 3" {
t.Errorf("Expected history to be %s, got %s", "1 + 2 = 3", history[1].Parts[0].Text)
}
})
}
Expand Down Expand Up @@ -502,29 +511,29 @@ data:{

part := Part{Text: "What is 1 + 2?"}

for _, err := range chat.SendMessageStream(ctx, part) {
var expectedContents []*Content
expectedContents = append(expectedContents, &Content{Role: "model", Parts: []*Part{&Part{Text: "text1_candidate1"}}})
expectedContents = append(expectedContents, &Content{Role: "model", Parts: []*Part{&Part{Text: " "}}})
expectedContents = append(expectedContents, &Content{Role: "model", Parts: []*Part{&Part{Text: "text3_candidate1"}, &Part{Text: " additional text3_candidate1 "}}})
expectedContents = append(expectedContents, &Content{Role: "model", Parts: []*Part{&Part{Text: "text4_candidate1"}, &Part{Text: " additional text4_candidate1"}}})
i := 0
for resp, err := range chat.SendMessageStream(ctx, part) {
if err != nil {
log.Fatal(err)
}
if diff := cmp.Diff(resp.Candidates[0].Content, expectedContents[i]); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
i++
}

var expectedResponses []*Content
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text1_candidate1"}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: " "}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text3_candidate1"}, &Part{Text: " additional text3_candidate1 "}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text4_candidate1"}, &Part{Text: " additional text4_candidate1"}}})

history := chat.History(false)
expectedUserMessage := "What is 1 + 2?"
if history[0].Parts[0].Text != expectedUserMessage {
t.Errorf("Expected history to start with %s, got %s", expectedUserMessage, history[0].Parts[0].Text)
}
for i, expectedResponse := range expectedResponses {
for j, expectedPart := range history[i+1].Parts {
if expectedPart.Text != expectedResponse.Parts[j].Text {
t.Errorf("Expected model response to be %s, got %s", expectedResponse.Parts[j].Text, part.Text)
}
}
if history[1].Parts[0].Text != "text1_candidate1 text3_candidate1 additional text3_candidate1 text4_candidate1 additional text4_candidate1" {
t.Errorf("Expected history to be %s, got %s", "text1_candidate1", history[1].Parts[0].Text)
}

})
Expand Down
Loading