Skip to content

Commit 7aef485

Browse files
Merge branch 'main' into looker-get-all-tests
2 parents 699037d + 8d05b4e commit 7aef485

12 files changed

Lines changed: 329 additions & 509 deletions

File tree

.hugo/hugo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
5151
# Add a new version block here before every release
5252
# The order of versions in this file is mirrored into the dropdown
5353

54+
[[params.versions]]
55+
version = "v0.28.0"
56+
url = "https://googleapis.github.io/genai-toolbox/v0.28.0/"
57+
5458
[[params.versions]]
5559
version = "v0.27.0"
5660
url = "https://googleapis.github.io/genai-toolbox/v0.27.0/"

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ require (
1212
cloud.google.com/go/dataplex v1.28.0
1313
cloud.google.com/go/dataproc/v2 v2.15.0
1414
cloud.google.com/go/firestore v1.20.0
15-
cloud.google.com/go/geminidataanalytics v0.3.0
15+
cloud.google.com/go/geminidataanalytics v0.5.0
1616
cloud.google.com/go/logging v1.13.1
1717
cloud.google.com/go/longrunning v0.7.0
1818
cloud.google.com/go/spanner v1.86.1

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
311311
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
312312
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
313313
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
314-
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
315-
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
314+
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
315+
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
316316
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
317317
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
318318
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=

internal/sources/cloudgda/cloud_gda.go

Lines changed: 38 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,23 @@
1414
package cloudgda
1515

1616
import (
17-
"bytes"
1817
"context"
19-
"encoding/json"
2018
"fmt"
21-
"io"
22-
"net/http"
2319

20+
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
21+
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
2422
"github.com/goccy/go-yaml"
2523
"github.com/googleapis/genai-toolbox/internal/sources"
2624
"github.com/googleapis/genai-toolbox/internal/util"
2725
"go.opentelemetry.io/otel/trace"
2826
"golang.org/x/oauth2"
29-
"golang.org/x/oauth2/google"
27+
"google.golang.org/api/option"
3028
)
3129

3230
const SourceType string = "cloud-gemini-data-analytics"
33-
const Endpoint string = "https://geminidataanalytics.googleapis.com"
31+
32+
// NewDataChatClient can be overridden for testing.
33+
var NewDataChatClient = geminidataanalytics.NewDataChatClient
3434

3535
// validate interface
3636
var _ sources.SourceConfig = Config{}
@@ -67,38 +67,27 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
6767
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
6868
}
6969

70-
var client *http.Client
71-
if r.UseClientOAuth {
72-
client = &http.Client{
73-
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
74-
}
75-
} else {
76-
// Use Application Default Credentials
77-
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
78-
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
79-
if err != nil {
80-
return nil, fmt.Errorf("failed to find default credentials: %w", err)
81-
}
82-
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
83-
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
84-
client = baseClient
85-
}
86-
8770
s := &Source{
8871
Config: r,
89-
Client: client,
90-
BaseURL: Endpoint,
9172
userAgent: ua,
9273
}
74+
75+
if !r.UseClientOAuth {
76+
client, err := NewDataChatClient(ctx, option.WithUserAgent(ua))
77+
if err != nil {
78+
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
79+
}
80+
s.Client = client
81+
}
82+
9383
return s, nil
9484
}
9585

9686
var _ sources.Source = &Source{}
9787

9888
type Source struct {
9989
Config
100-
Client *http.Client
101-
BaseURL string
90+
Client *geminidataanalytics.DataChatClient
10291
userAgent string
10392
}
10493

@@ -114,63 +103,36 @@ func (s *Source) GetProjectID() string {
114103
return s.ProjectID
115104
}
116105

117-
func (s *Source) GetBaseURL() string {
118-
return s.BaseURL
119-
}
120-
121-
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
122-
if s.UseClientOAuth {
123-
if accessToken == "" {
124-
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
125-
}
126-
token := &oauth2.Token{AccessToken: accessToken}
127-
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
128-
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
129-
return baseClient, nil
130-
}
131-
return s.Client, nil
132-
}
133-
134106
func (s *Source) UseClientAuthorization() bool {
135107
return s.UseClientOAuth
136108
}
137109

138-
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
139-
// The API endpoint itself always uses the "global" location.
140-
apiLocation := "global"
141-
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
142-
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
143-
144-
client, err := s.GetClient(ctx, tokenStr)
145-
if err != nil {
146-
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
147-
}
148-
149-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
150-
if err != nil {
151-
return nil, fmt.Errorf("failed to create request: %w", err)
152-
}
153-
req.Header.Set("Content-Type", "application/json")
110+
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
111+
if s.UseClientOAuth {
112+
if tokenStr == "" {
113+
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
114+
}
115+
token := &oauth2.Token{AccessToken: tokenStr}
116+
opts := []option.ClientOption{
117+
option.WithUserAgent(s.userAgent),
118+
option.WithTokenSource(oauth2.StaticTokenSource(token)),
119+
}
154120

155-
resp, err := client.Do(req)
156-
if err != nil {
157-
return nil, fmt.Errorf("failed to execute request: %w", err)
121+
client, err := NewDataChatClient(ctx, opts...)
122+
if err != nil {
123+
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
124+
}
125+
return client, func() { client.Close() }, nil
158126
}
159-
defer resp.Body.Close()
127+
return s.Client, func() {}, nil
128+
}
160129

161-
respBody, err := io.ReadAll(resp.Body)
130+
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
131+
client, cleanup, err := s.GetClient(ctx, tokenStr)
162132
if err != nil {
163-
return nil, fmt.Errorf("failed to read response body: %w", err)
164-
}
165-
166-
if resp.StatusCode != http.StatusOK {
167-
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
168-
}
169-
170-
var result map[string]any
171-
if err := json.Unmarshal(respBody, &result); err != nil {
172-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
133+
return nil, err
173134
}
135+
defer cleanup()
174136

175-
return result, nil
137+
return client.QueryData(ctx, req)
176138
}

internal/sources/cloudgda/cloud_gda_test.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,9 @@ func TestInitialize(t *testing.T) {
172172
if gdaSrc.Client == nil && !tc.wantClientOAuth {
173173
t.Fatal("expected non-nil HTTP client for ADC, got nil")
174174
}
175-
// When client OAuth is true, the source's client should be initialized with a base HTTP client
176-
// that includes the user agent round tripper, but not the OAuth token. The token-aware
177-
// client is created by GetClient.
178-
if gdaSrc.Client == nil && tc.wantClientOAuth {
179-
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
175+
// When client OAuth is true, the source's client should be nil.
176+
if gdaSrc.Client != nil && tc.wantClientOAuth {
177+
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
180178
}
181179

182180
// Test UseClientAuthorization method
@@ -186,15 +184,16 @@ func TestInitialize(t *testing.T) {
186184

187185
// Test GetClient with accessToken for client OAuth scenarios
188186
if tc.wantClientOAuth {
189-
client, err := gdaSrc.GetClient(ctx, "dummy-token")
187+
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
190188
if err != nil {
191189
t.Fatalf("GetClient with token failed: %v", err)
192190
}
191+
defer cleanup()
193192
if client == nil {
194193
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
195194
}
196195
// Ensure passing empty token with UseClientOAuth enabled returns error
197-
_, err = gdaSrc.GetClient(ctx, "")
196+
_, _, err = gdaSrc.GetClient(ctx, "")
198197
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
199198
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
200199
}

internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -482,30 +482,33 @@ func handleTextResponse(resp *TextResponse) map[string]any {
482482
}
483483

484484
func handleSchemaResponse(resp *SchemaResponse) map[string]any {
485+
res := make(map[string]any)
485486
if resp.Query != nil {
486-
return map[string]any{"Question": resp.Query.Question}
487+
res["Question"] = resp.Query.Question
487488
}
488489
if resp.Result != nil {
489490
var formattedSources []map[string]any
490491
for _, ds := range resp.Result.Datasources {
491492
formattedSources = append(formattedSources, formatDatasourceAsDict(&ds))
492493
}
493-
return map[string]any{"Schema Resolved": formattedSources}
494+
res["Schema Resolved"] = formattedSources
494495
}
495-
return nil
496+
if len(res) == 0 {
497+
return nil
498+
}
499+
return res
496500
}
497501

498502
func handleDataResponse(resp *DataResponse, maxRows int) map[string]any {
503+
res := make(map[string]any)
499504
if resp.Query != nil {
500-
return map[string]any{
501-
"Retrieval Query": map[string]any{
502-
"Query Name": resp.Query.Name,
503-
"Question": resp.Query.Question,
504-
},
505+
res["Retrieval Query"] = map[string]any{
506+
"Query Name": resp.Query.Name,
507+
"Question": resp.Query.Question,
505508
}
506509
}
507510
if resp.GeneratedSQL != "" {
508-
return map[string]any{"SQL Generated": resp.GeneratedSQL}
511+
res["SQL Generated"] = resp.GeneratedSQL
509512
}
510513
if resp.Result != nil {
511514
var headers []string
@@ -533,15 +536,16 @@ func handleDataResponse(resp *DataResponse, maxRows int) map[string]any {
533536
summary = fmt.Sprintf("Showing the first %d of %d total rows.", numRowsToDisplay, totalRows)
534537
}
535538

536-
return map[string]any{
537-
"Data Retrieved": map[string]any{
538-
"headers": headers,
539-
"rows": compactRows,
540-
"summary": summary,
541-
},
539+
res["Data Retrieved"] = map[string]any{
540+
"headers": headers,
541+
"rows": compactRows,
542+
"summary": summary,
542543
}
543544
}
544-
return nil
545+
if len(res) == 0 {
546+
return nil
547+
}
548+
return res
545549
}
546550

547551
func handleError(resp *ErrorResponse) map[string]any {
@@ -557,9 +561,17 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
557561
if newMessage == nil {
558562
return messages
559563
}
560-
if len(messages) > 0 {
561-
if _, ok := messages[len(messages)-1]["Data Retrieved"]; ok {
562-
messages = messages[:len(messages)-1]
564+
565+
if _, hasData := newMessage["Data Retrieved"]; hasData {
566+
// Only keep the last data result while preserving SQL and other metadata.
567+
for i := len(messages) - 1; i >= 0; i-- {
568+
if _, ok := messages[i]["Data Retrieved"]; ok {
569+
delete(messages[i], "Data Retrieved")
570+
if len(messages[i]) == 0 {
571+
messages = append(messages[:i], messages[i+1:]...)
572+
}
573+
break
574+
}
563575
}
564576
}
565577
return append(messages, newMessage)

0 commit comments

Comments
 (0)