Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ import (
"time"
)

const maxChunkSize = 8 * 1024 * 1024 // 8 MB chunk size
const maxRetryCount = 3
const initialRetryDelay = time.Second
const delayMultiplier = 2
const (
maxChunkSize = 8 * 1024 * 1024 // 8 MB chunk size
maxRetryCount = 3
initialRetryDelay = time.Second
delayMultiplier = 2
)

type apiClient struct {
clientConfig *ClientConfig
Expand Down Expand Up @@ -75,7 +77,6 @@ func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiC

// sendRequest issues an API request and returns a map of the response contents.
func sendRequest(ctx context.Context, ac *apiClient, path string, method string, body map[string]any, httpOptions *HTTPOptions) (map[string]any, error) {

req, httpOptions, err := buildRequest(ctx, ac, path, body, method, httpOptions)
if err != nil {
return nil, err
Expand Down Expand Up @@ -435,7 +436,7 @@ func iterateResponseStream[R any](rs *responseStream[R], responseConverter func(
default:
var err error
if len(line) > 0 {
var respWithError = new(responseWithError)
respWithError := new(responseWithError)
// Stream chunk that doesn't matches error format.
if marshalErr := json.Unmarshal(line, respWithError); marshalErr != nil {
err = fmt.Errorf("iterateResponseStream: invalid stream chunk: %s:%s", string(prefix), string(data))
Expand Down Expand Up @@ -479,7 +480,7 @@ type responseWithError struct {
}

func newAPIError(resp *http.Response) error {
var respWithError = new(responseWithError)
respWithError := new(responseWithError)
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("newAPIError: error reading response body: %w. Response: %v", err, string(body))
Expand Down Expand Up @@ -561,7 +562,7 @@ func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL stri
var offset int64 = 0
var resp *http.Response
var respBody map[string]any
var uploadCommand = "upload"
uploadCommand := "upload"

buffer := make([]byte, maxChunkSize)
for {
Expand All @@ -574,7 +575,7 @@ func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL stri
} else if err != nil {
return nil, fmt.Errorf("Failed to read bytes from file at offset %d: %w. Bytes actually read: %d", offset, err, bytesRead)
}
for attempt := 0; attempt < maxRetryCount; attempt++ {
for attempt := range maxRetryCount {
patchedHTTPOptions, err := patchHTTPOptions(ac.clientConfig.HTTPOptions, *httpOptions)
if err != nil {
return nil, err
Expand Down Expand Up @@ -641,7 +642,7 @@ func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL stri
return nil, fmt.Errorf("Failed to upload file: Upload status is not finalized")
}

var response = new(File)
response := new(File)
err := mapToStruct(respBody["file"].(map[string]any), &response)
if err != nil {
return nil, err
Expand Down
11 changes: 5 additions & 6 deletions api_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,13 @@ func TestMapToStruct(t *testing.T) {
inputMap: map[string]any{
"role": "test-role",
"TokenIDs": []string{"123", "456"},
"Tokens": [][]byte{[]byte("token1"), []byte("token2")}},
"Tokens": [][]byte{[]byte("token1"), []byte("token2")},
},
wantValue: TokensInfo{
Role: "test-role",
TokenIDs: []int64{123, 456},
Tokens: [][]byte{[]byte("token1"), []byte("token2")}},
Tokens: [][]byte{[]byte("token1"), []byte("token2")},
},
},
{
name: "Citation",
Expand Down Expand Up @@ -705,7 +707,6 @@ func TestMapToStruct(t *testing.T) {
outputValue := reflect.New(reflect.TypeOf(tc.wantValue)).Interface()

err := mapToStruct(tc.inputMap, &outputValue)

if err != nil {
t.Fatalf("mapToStruct failed: %v", err)
}
Expand Down Expand Up @@ -1309,7 +1310,7 @@ func createTestFile(t *testing.T, size int64) (string, func()) {

buf := make([]byte, 1024*1024) // 1MB buffer
pattern := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()")
for i := 0; i < len(buf); i++ {
for i := range buf {
buf[i] = pattern[i%len(pattern)]
}

Expand Down Expand Up @@ -1504,7 +1505,6 @@ func TestUploadFile(t *testing.T) {
uploadURL := server.URL + "/upload"

uploadedFile, err := ac.uploadFile(ctx, fileReader, uploadURL, httpOpts)

if err != nil {
t.Fatalf("uploadFile failed: %v", err)
}
Expand All @@ -1530,7 +1530,6 @@ func TestUploadFile(t *testing.T) {
if uploadedFile.MIMEType != "text/plain" { // Matches mock server response
t.Errorf("uploadedFile.MIMEType mismatch: want 'text/plain', got '%s'", uploadedFile.MIMEType)
}

})
}
}
Expand Down
6 changes: 4 additions & 2 deletions base_url.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

package genai

var defaultBaseGeminiURL string = ""
var defaultBaseVertexURL string = ""
var (
defaultBaseGeminiURL string = ""
defaultBaseVertexURL string = ""
)

// BaseURLParameters are parameters for setting the base URLs for the Gemini API and Vertex AI API.
type BaseURLParameters struct {
Expand Down
3 changes: 2 additions & 1 deletion caches_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func TestCachesAll(t *testing.T) {
defer ts.Close()

// Create a client with the test server
client, err := NewClient(context.Background(), &ClientConfig{HTTPOptions: HTTPOptions{BaseURL: ts.URL},
client, err := NewClient(context.Background(), &ClientConfig{
HTTPOptions: HTTPOptions{BaseURL: ts.URL},
envVarProvider: func() map[string]string {
return map[string]string{
"GOOGLE_API_KEY": "test-api-key",
Expand Down
24 changes: 11 additions & 13 deletions chats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func TestValidateContent(t *testing.T) {
{"NilContent", nil, false},
{"EmptyParts", &Content{Parts: []*Part{}}, false},
{"NilPart", &Content{Parts: []*Part{nil}}, false},
{"EmptyTextPart", &Content{Parts: []*Part{&Part{Text: ""}}}, false},
{"ValidTextPart", &Content{Parts: []*Part{&Part{Text: "hello"}}}, true},
{"ValidFunctionCall", &Content{Parts: []*Part{&Part{FunctionCall: &FunctionCall{Name: "test"}}}}, true},
{"EmptyTextPart", &Content{Parts: []*Part{{Text: ""}}}, false},
{"ValidTextPart", &Content{Parts: []*Part{{Text: "hello"}}}, true},
{"ValidFunctionCall", &Content{Parts: []*Part{{FunctionCall: &FunctionCall{Name: "test"}}}}, true},
}

for _, tt := range tests {
Expand Down Expand Up @@ -176,7 +176,6 @@ func TestChatsUnitTest(t *testing.T) {
break
}
})

}

func TestChatsText(t *testing.T) {
Expand Down Expand Up @@ -316,16 +315,16 @@ func TestChatsHistory(t *testing.T) {
// Create a new Chat with handwritten history.
var config *GenerateContentConfig = &GenerateContentConfig{Temperature: Ptr[float32](0.5)}
history := []*Content{
&Content{
{
Role: "user",
Parts: []*Part{
&Part{Text: "What is 1 + 2?"},
{Text: "What is 1 + 2?"},
},
},
&Content{
{
Role: "model",
Parts: []*Part{
&Part{Text: "3"},
{Text: "3"},
},
},
}
Expand Down Expand Up @@ -721,10 +720,10 @@ data:{
}

var expectedResponses []*Content
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text1_candidate1"}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: " "}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text3_candidate1"}, &Part{Text: " additional text3_candidate1 "}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text4_candidate1"}, &Part{Text: " additional text4_candidate1"}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: "text1_candidate1"}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: " "}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: "text3_candidate1"}, {Text: " additional text3_candidate1 "}}})
expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: "text4_candidate1"}, {Text: " additional text4_candidate1"}}})

history := chat.History(false)
expectedUserMessage := "What is 1 + 2?"
Expand All @@ -738,6 +737,5 @@ data:{
}
}
}

})
}
Loading