1414package cloudgda
1515
1616import (
17- "bytes"
1817 "context"
19- "encoding/json"
2018 "fmt"
21- "io"
22- "net/http"
2319
20+ geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
21+ "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
2422 "github.com/goccy/go-yaml"
2523 "github.com/googleapis/genai-toolbox/internal/sources"
2624 "github.com/googleapis/genai-toolbox/internal/util"
2725 "go.opentelemetry.io/otel/trace"
2826 "golang.org/x/oauth2"
29- "golang.org/x/oauth2/google "
27+ "google. golang.org/api/option "
3028)
3129
3230const SourceType string = "cloud-gemini-data-analytics"
33- const Endpoint string = "https://geminidataanalytics.googleapis.com"
31+
32+ // NewDataChatClient can be overridden for testing.
33+ var NewDataChatClient = geminidataanalytics .NewDataChatClient
3434
3535// validate interface
3636var _ sources.SourceConfig = Config {}
@@ -67,38 +67,27 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
6767 return nil , fmt .Errorf ("error in User Agent retrieval: %s" , err )
6868 }
6969
70- var client * http.Client
71- if r .UseClientOAuth {
72- client = & http.Client {
73- Transport : util .NewUserAgentRoundTripper (ua , http .DefaultTransport ),
74- }
75- } else {
76- // Use Application Default Credentials
77- // Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
78- creds , err := google .FindDefaultCredentials (ctx , "https://www.googleapis.com/auth/cloud-platform" )
79- if err != nil {
80- return nil , fmt .Errorf ("failed to find default credentials: %w" , err )
81- }
82- baseClient := oauth2 .NewClient (ctx , creds .TokenSource )
83- baseClient .Transport = util .NewUserAgentRoundTripper (ua , baseClient .Transport )
84- client = baseClient
85- }
86-
8770 s := & Source {
8871 Config : r ,
89- Client : client ,
90- BaseURL : Endpoint ,
9172 userAgent : ua ,
9273 }
74+
75+ if ! r .UseClientOAuth {
76+ client , err := NewDataChatClient (ctx , option .WithUserAgent (ua ))
77+ if err != nil {
78+ return nil , fmt .Errorf ("failed to create DataChatClient: %w" , err )
79+ }
80+ s .Client = client
81+ }
82+
9383 return s , nil
9484}
9585
9686var _ sources.Source = & Source {}
9787
9888type Source struct {
9989 Config
100- Client * http.Client
101- BaseURL string
90+ Client * geminidataanalytics.DataChatClient
10291 userAgent string
10392}
10493
@@ -114,63 +103,36 @@ func (s *Source) GetProjectID() string {
114103 return s .ProjectID
115104}
116105
117- func (s * Source ) GetBaseURL () string {
118- return s .BaseURL
119- }
120-
121- func (s * Source ) GetClient (ctx context.Context , accessToken string ) (* http.Client , error ) {
122- if s .UseClientOAuth {
123- if accessToken == "" {
124- return nil , fmt .Errorf ("client-side OAuth is enabled but no access token was provided" )
125- }
126- token := & oauth2.Token {AccessToken : accessToken }
127- baseClient := oauth2 .NewClient (ctx , oauth2 .StaticTokenSource (token ))
128- baseClient .Transport = util .NewUserAgentRoundTripper (s .userAgent , baseClient .Transport )
129- return baseClient , nil
130- }
131- return s .Client , nil
132- }
133-
134106func (s * Source ) UseClientAuthorization () bool {
135107 return s .UseClientOAuth
136108}
137109
138- func (s * Source ) RunQuery (ctx context.Context , tokenStr string , bodyBytes []byte ) (any , error ) {
139- // The API endpoint itself always uses the "global" location.
140- apiLocation := "global"
141- apiParent := fmt .Sprintf ("projects/%s/locations/%s" , s .GetProjectID (), apiLocation )
142- apiURL := fmt .Sprintf ("%s/v1beta/%s:queryData" , s .GetBaseURL (), apiParent )
143-
144- client , err := s .GetClient (ctx , tokenStr )
145- if err != nil {
146- return nil , fmt .Errorf ("failed to get HTTP client: %w" , err )
147- }
148-
149- req , err := http .NewRequestWithContext (ctx , http .MethodPost , apiURL , bytes .NewBuffer (bodyBytes ))
150- if err != nil {
151- return nil , fmt .Errorf ("failed to create request: %w" , err )
152- }
153- req .Header .Set ("Content-Type" , "application/json" )
110+ func (s * Source ) GetClient (ctx context.Context , tokenStr string ) (* geminidataanalytics.DataChatClient , func (), error ) {
111+ if s .UseClientOAuth {
112+ if tokenStr == "" {
113+ return nil , nil , fmt .Errorf ("client-side OAuth is enabled but no access token was provided" )
114+ }
115+ token := & oauth2.Token {AccessToken : tokenStr }
116+ opts := []option.ClientOption {
117+ option .WithUserAgent (s .userAgent ),
118+ option .WithTokenSource (oauth2 .StaticTokenSource (token )),
119+ }
154120
155- resp , err := client .Do (req )
156- if err != nil {
157- return nil , fmt .Errorf ("failed to execute request: %w" , err )
121+ client , err := NewDataChatClient (ctx , opts ... )
122+ if err != nil {
123+ return nil , nil , fmt .Errorf ("failed to create per-request DataChatClient: %w" , err )
124+ }
125+ return client , func () { client .Close () }, nil
158126 }
159- defer resp .Body .Close ()
127+ return s .Client , func () {}, nil
128+ }
160129
161- respBody , err := io .ReadAll (resp .Body )
130+ func (s * Source ) RunQuery (ctx context.Context , tokenStr string , req * geminidataanalyticspb.QueryDataRequest ) (* geminidataanalyticspb.QueryDataResponse , error ) {
131+ client , cleanup , err := s .GetClient (ctx , tokenStr )
162132 if err != nil {
163- return nil , fmt .Errorf ("failed to read response body: %w" , err )
164- }
165-
166- if resp .StatusCode != http .StatusOK {
167- return nil , fmt .Errorf ("API request failed with status %d: %s" , resp .StatusCode , string (respBody ))
168- }
169-
170- var result map [string ]any
171- if err := json .Unmarshal (respBody , & result ); err != nil {
172- return nil , fmt .Errorf ("failed to unmarshal response: %w" , err )
133+ return nil , err
173134 }
135+ defer cleanup ()
174136
175- return result , nil
137+ return client . QueryData ( ctx , req )
176138}
0 commit comments