diff --git a/cmd/internal/imports.go b/cmd/internal/imports.go index 66aa0e3dc8bb..126f9451a0a0 100644 --- a/cmd/internal/imports.go +++ b/cmd/internal/imports.go @@ -83,6 +83,9 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent" + _ "github.com/googleapis/genai-toolbox/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo" + _ "github.com/googleapis/genai-toolbox/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents" _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupcontext" diff --git a/docs/en/integrations/cloudgda/prebuilt-configs/_index.md b/docs/en/integrations/cloudgda/prebuilt-configs/_index.md new file mode 100644 index 000000000000..eb9da2e99643 --- /dev/null +++ b/docs/en/integrations/cloudgda/prebuilt-configs/_index.md @@ -0,0 +1,5 @@ +--- +title: "Prebuilt Configs" +type: docs +description: "Prebuilt configurations for Conversational Analytics to perform natural language data analysis via Data Agents." +--- diff --git a/docs/en/integrations/cloudgda/prebuilt-configs/conversational-analytics-with-data-agent.md b/docs/en/integrations/cloudgda/prebuilt-configs/conversational-analytics-with-data-agent.md new file mode 100644 index 000000000000..ab94b2a8ec3e --- /dev/null +++ b/docs/en/integrations/cloudgda/prebuilt-configs/conversational-analytics-with-data-agent.md @@ -0,0 +1,28 @@ +--- +title: "Conversational Analytics with Data Agent" +type: docs +description: "Details of the Conversational Analytics with Data Agent prebuilt configuration." +--- + +## Conversational Analytics with Data Agent + +* `--prebuilt` value: `conversational-analytics-with-data-agent` +* **Environment Variables:** + * `CLOUD_GDA_PROJECT`: The GCP project ID. + * `CLOUD_GDA_LOCATION`: (Optional) The location of the data agent (e.g., `us` or `eu`). Defaults to `global`. + * `CLOUD_GDA_USE_CLIENT_OAUTH`: (Optional) If `true`, forwards the client's + OAuth access token for authentication. Defaults to `false`. + * `CLOUD_GDA_MAX_RESULTS`: (Optional) The maximum number of rows + to return. Defaults to `50`. +* **Permissions:** + * **Gemini Data Analytics Stateless Chat User (Beta)** (`roles/geminidataanalytics.dataAgentStatelessUser`) to interact with the data agent. + * **BigQuery User** (`roles/bigquery.user`) and **BigQuery Data Viewer** (`roles/bigquery.dataViewer`) on the underlying datasets/tables to allow the data agent to execute queries. +* **Tools:** + * `ask_data_agent`: Use this tool to perform natural language data analysis, + get insights, or answer complex questions using pre-configured data + sources via a specific Data Agent. For more information on + required roles, API setup, and IAM configuration, see the setup and + authentication section of the [Conversational Analytics API + documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview). + * `get_data_agent_info`: Retrieve details about a specific data agent. + * `list_accessible_data_agents`: List data agents that are accessible. diff --git a/docs/en/integrations/cloudgda/tools/conversational-analytics-ask-data-agent.md b/docs/en/integrations/cloudgda/tools/conversational-analytics-ask-data-agent.md new file mode 100644 index 000000000000..96f2f7e36481 --- /dev/null +++ b/docs/en/integrations/cloudgda/tools/conversational-analytics-ask-data-agent.md @@ -0,0 +1,64 @@ +--- +title: "conversational-analytics-ask-data-agent" +type: docs +weight: 1 +description: > + A "conversational-analytics-ask-data-agent" tool allows conversational interaction with a Conversational Analytics source. +aliases: +- /resources/tools/conversational-analytics-ask-data-agent +--- + +## About + +A `conversational-analytics-ask-data-agent` tool allows you to ask questions about +your data in natural language. + +This function takes a user's question (which can include conversational history +for context) and references to a specific BigQuery Data Agent, and sends them to a +stateless conversational API. + +The API uses a GenAI agent to understand the question, generate and execute SQL +queries and Python code, and formulate an answer. This function returns a +detailed, sequential log of this entire process, which includes any generated +SQL or Python code, the data retrieved, and the final text answer. + +**Note**: This tool requires additional setup in your project. Please refer to +the official Conversational Analytics API +documentation +for instructions. + +It's compatible with the following sources: + +- cloud-gemini-data-analytics + +`conversational-analytics-ask-data-agent` accepts the following parameters: + +- **`user_query_with_context`:** The question to ask the agent, potentially + including conversation history for context. +- **`data_agent_id`:** The ID of the data agent to ask. + +## Example + +```yaml +tools: + ask_data_agent: + kind: conversational-analytics-ask-data-agent + source: my-conversational-analytics-source + location: global + maxResults: 50 + description: | + Perform natural language data analysis and get insights by interacting + with a specific BigQuery Data Agent. This tool allows for conversational + queries and provides detailed responses based on the agent's configured + data sources. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "conversational-analytics-ask-data-agent". | +| source | string | true | Name of the source for chat. | +| description | string | true | Description of the tool that is passed to the LLM. | +| location | string | false | The Google Cloud location (default: "global"). | +| maxResults | integer | false | The maximum number of data rows to return in the tool's final response (default: 50). This only limits the amount of data included in the final tool return to prevent excessive token consumption, and does not affect the internal analytical process or intermediate steps. | \ No newline at end of file diff --git a/docs/en/integrations/cloudgda/tools/conversational-analytics-get-data-agent-info.md b/docs/en/integrations/cloudgda/tools/conversational-analytics-get-data-agent-info.md new file mode 100644 index 000000000000..62d89db229eb --- /dev/null +++ b/docs/en/integrations/cloudgda/tools/conversational-analytics-get-data-agent-info.md @@ -0,0 +1,43 @@ +--- +title: "conversational-analytics-get-data-agent-info" +type: docs +weight: 1 +description: > + A "conversational-analytics-get-data-agent-info" tool allows retrieving information about a specific Conversational Analytics data agent. +aliases: +- /resources/tools/conversational-analytics-get-data-agent-info +--- + +## About + +A `conversational-analytics-get-data-agent-info` tool allows you to retrieve +details about a specific data agent. + +It's compatible with the following sources: + +- cloud-gemini-data-analytics + +`conversational-analytics-get-data-agent-info` accepts the following parameters: + +- **`data_agent_id`:** The ID of the data agent to retrieve information for. + +## Example + +```yaml +tools: + get_agent_info: + kind: conversational-analytics-get-data-agent-info + source: my-conversational-analytics-source + location: global + description: | + Use this tool to get details about a specific data agent. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "conversational-analytics-get-data-agent-info". | +| source | string | true | Name of the source. | +| description | string | true | Description of the tool that is passed to the LLM. | +| location | string | false | The Google Cloud location (default: "global"). | \ No newline at end of file diff --git a/docs/en/integrations/cloudgda/tools/conversational-analytics-list-accessible-data-agents.md b/docs/en/integrations/cloudgda/tools/conversational-analytics-list-accessible-data-agents.md new file mode 100644 index 000000000000..da0e6750be28 --- /dev/null +++ b/docs/en/integrations/cloudgda/tools/conversational-analytics-list-accessible-data-agents.md @@ -0,0 +1,41 @@ +--- +title: "conversational-analytics-list-accessible-data-agents" +type: docs +weight: 1 +description: > + A "conversational-analytics-list-accessible-data-agents" tool allows listing accessible Conversational Analytics data agents. +aliases: +- /resources/tools/conversational-analytics-list-accessible-data-agents +--- + +## About + +A `conversational-analytics-list-accessible-data-agents` tool allows you to list +data agents that are accessible. + +It's compatible with the following sources: + +- cloud-gemini-data-analytics + +`conversational-analytics-list-accessible-data-agents` does not accept any parameters. + +## Example + +```yaml +tools: + list_agents: + kind: conversational-analytics-list-accessible-data-agents + source: my-conversational-analytics-source + location: global + description: | + Use this tool to list available data agents. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "conversational-analytics-list-accessible-data-agents". | +| source | string | true | Name of the source. | +| description | string | true | Description of the tool that is passed to the LLM. | +| location | string | false | The Google Cloud location (default: "global"). | \ No newline at end of file diff --git a/internal/prebuiltconfigs/prebuiltconfigs_test.go b/internal/prebuiltconfigs/prebuiltconfigs_test.go index 0fdfc8b4f6df..c62afb1a6061 100644 --- a/internal/prebuiltconfigs/prebuiltconfigs_test.go +++ b/internal/prebuiltconfigs/prebuiltconfigs_test.go @@ -26,6 +26,7 @@ var expectedToolSources = []string{ "alloydb-postgres-admin", "alloydb-postgres-observability", "alloydb-postgres", + "conversational-analytics-with-data-agent", "bigquery", "clickhouse", "cloud-healthcare", @@ -113,6 +114,7 @@ func TestGetPrebuiltTool(t *testing.T) { alloydb_observability_config := getOrFatal(t, "alloydb-postgres-observability") alloydb_config := getOrFatal(t, "alloydb-postgres") bigquery_config := getOrFatal(t, "bigquery") + conversational_analytics_config := getOrFatal(t, "conversational-analytics-with-data-agent") clickhouse_config := getOrFatal(t, "clickhouse") cloudsqlpg_observability_config := getOrFatal(t, "cloud-sql-postgres-observability") cloudsqlpg_config := getOrFatal(t, "cloud-sql-postgres") @@ -156,6 +158,9 @@ func TestGetPrebuiltTool(t *testing.T) { if len(bigquery_config) <= 0 { t.Fatalf("unexpected error: could not fetch bigquery prebuilt tools yaml") } + if len(conversational_analytics_config) <= 0 { + t.Fatalf("unexpected error: could not fetch bigquery conversational analytics prebuilt tools yaml") + } if len(clickhouse_config) <= 0 { t.Fatalf("unexpected error: could not fetch clickhouse prebuilt tools yaml") } diff --git a/internal/prebuiltconfigs/tools/conversational-analytics-with-data-agent.yaml b/internal/prebuiltconfigs/tools/conversational-analytics-with-data-agent.yaml new file mode 100644 index 000000000000..1d2e85fef7cf --- /dev/null +++ b/internal/prebuiltconfigs/tools/conversational-analytics-with-data-agent.yaml @@ -0,0 +1,53 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +sources: + conversational-analytics-source: + kind: "cloud-gemini-data-analytics" + projectId: ${CLOUD_GDA_PROJECT} + useClientOAuth: ${CLOUD_GDA_USE_CLIENT_OAUTH:false} + +tools: + list-accessible-data-agents: + kind: conversational-analytics-list-accessible-data-agents + source: conversational-analytics-source + location: ${CLOUD_GDA_LOCATION:global} + description: | + List all available Data Agents that can be used for + conversational analytics in the current project. + + get-data-agent-info: + kind: conversational-analytics-get-data-agent-info + source: conversational-analytics-source + location: ${CLOUD_GDA_LOCATION:global} + description: | + Retrieve detailed information about a specific Data Agent + using its ID. + + ask-data-agent: + kind: conversational-analytics-ask-data-agent + source: conversational-analytics-source + location: ${CLOUD_GDA_LOCATION:global} + maxResults: ${CLOUD_GDA_MAX_RESULTS:50} + description: | + Perform natural language data analysis and get insights by interacting + with a specific Data Agent. This tool allows for conversational + queries and provides detailed responses based on the agent's configured + data sources. + +toolsets: + conversational_analytics_tools: + - list-accessible-data-agents + - get-data-agent-info + - ask-data-agent diff --git a/internal/sources/cloudgda/cloud_gda.go b/internal/sources/cloudgda/cloud_gda.go index 4c977418c6a9..5073b2738a3b 100644 --- a/internal/sources/cloudgda/cloud_gda.go +++ b/internal/sources/cloudgda/cloud_gda.go @@ -24,10 +24,12 @@ import ( "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" + "golang.org/x/oauth2/google" "google.golang.org/api/option" ) const SourceType string = "cloud-gemini-data-analytics" +const CloudPlatformScope string = "https://www.googleapis.com/auth/cloud-platform" // NewDataChatClient can be overridden for testing. var NewDataChatClient = geminidataanalytics.NewDataChatClient @@ -103,6 +105,18 @@ func (s *Source) GetProjectID() string { return s.ProjectID } +func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) { + if scope == "" { + scope = CloudPlatformScope + } + + creds, err := google.FindDefaultCredentials(ctx, scope) + if err != nil { + return nil, fmt.Errorf("failed to find default credentials: %w", err) + } + return creds.TokenSource, nil +} + func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } diff --git a/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent/conversationalanalyticsaskdataagent.go b/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent/conversationalanalyticsaskdataagent.go new file mode 100644 index 000000000000..e55b7a4bfe80 --- /dev/null +++ b/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent/conversationalanalyticsaskdataagent.go @@ -0,0 +1,418 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conversationalanalyticsaskdataagent + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + cloudgdads "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "golang.org/x/oauth2" +) + +const resourceType string = "conversational-analytics-ask-data-agent" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) + GetProjectID() string + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &cloudgdads.Source{} + +var compatibleSources = [...]string{cloudgdads.SourceType} + +type BQTableReference struct { + ProjectID string `json:"projectId"` + DatasetID string `json:"datasetId"` + TableID string `json:"tableId"` +} + +// Structs for building the JSON payload +type UserMessage struct { + Text string `json:"text"` +} +type Message struct { + UserMessage UserMessage `json:"userMessage"` +} + +type DataAgentContext struct { + DataAgent string `json:"dataAgent"` +} + +type CAPayload struct { + Project string `json:"project"` + Messages []Message `json:"messages"` + DataAgentContext DataAgentContext `json:"dataAgentContext"` + ClientIdEnum string `json:"clientIdEnum"` +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Location string `yaml:"location"` + MaxResults int `yaml:"maxResults"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", resourceType, compatibleSources) + } + + if cfg.Location == "" { + cfg.Location = "global" + } + if cfg.MaxResults <= 0 { + cfg.MaxResults = 50 + } + + dataAgentIdDescription := `The ID of the data agent to ask.` + userQueryParameter := parameters.NewStringParameter("user_query_with_context", "The question to ask the agent, potentially including conversation history for context.") + dataAgentIdParameter := parameters.NewStringParameter("data_agent_id", dataAgentIdDescription) + params := parameters.Parameters{dataAgentIdParameter, userQueryParameter} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + var tokenStr string + + // Get credentials for the API call + if source.UseClientAuthorization() { + // Use client-side access token + if accessToken == "" { + return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil) + } + tokenStr, err = accessToken.ParseBearerToken() + if err != nil { + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) + } + } else { + // Get a token source for the Gemini Data Analytics API. + tokenSource, err := source.GoogleCloudTokenSourceWithScope(ctx, "") + if err != nil { + return nil, util.NewClientServerError("failed to get token source", http.StatusInternalServerError, err) + } + + // Use cloud-platform token source for Gemini Data Analytics API + if tokenSource == nil { + return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil) + } + token, err := tokenSource.Token() + if err != nil { + return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err) + } + tokenStr = token.AccessToken + } + + // Extract parameters from the map + mapParams := params.AsMap() + dataAgentId, _ := mapParams["data_agent_id"].(string) + userQuery, _ := mapParams["user_query_with_context"].(string) + + // Construct URL, headers, and payload + projectID := source.GetProjectID() + caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", projectID, t.Location) + + headers := map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", tokenStr), + "Content-Type": "application/json", + "X-Goog-API-Client": util.GDAClientID, + } + + dataAgentName := fmt.Sprintf("projects/%s/locations/%s/dataAgents/%s", projectID, t.Location, dataAgentId) + + payload := CAPayload{ + Project: fmt.Sprintf("projects/%s", projectID), + Messages: []Message{{UserMessage: UserMessage{Text: userQuery}}}, + DataAgentContext: DataAgentContext{ + DataAgent: dataAgentName, + }, + ClientIdEnum: util.GDAClientID, + } + + // Call the streaming API + response, err := getStream(caURL, payload, headers, t.MaxResults) + if err != nil { + return nil, util.NewAgentError("failed to get response from conversational analytics API", err) + } + + return response, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} + +func getStream(url string, payload CAPayload, headers map[string]string, maxRows int) (string, error) { + payloadBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + + client := &http.Client{Timeout: 330 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API returned non-200 status: %d %s", resp.StatusCode, string(body)) + } + + var messages []map[string]any + decoder := json.NewDecoder(resp.Body) + dataMsgIdx := -1 + + // The response is a JSON array, so we read the opening bracket. + if _, err := decoder.Token(); err != nil { + if err == io.EOF { + return "", nil // Empty response is valid + } + return "", fmt.Errorf("error reading start of json array: %w", err) + } + + for decoder.More() { + var rawMsg json.RawMessage + if err := decoder.Decode(&rawMsg); err != nil { + if err == io.EOF { + break + } + return "", fmt.Errorf("error decoding raw message: %w", err) + } + + var msg map[string]any + if err := json.Unmarshal(rawMsg, &msg); err != nil { + return "", fmt.Errorf("error unmarshaling raw message: %w", err) + } + + var processedMsg map[string]any + if dataResult := extractDataResult(msg); dataResult != nil { + // 1. If it's a data result, format it. + processedMsg = formatDataRetrieved(dataResult, maxRows) + if dataMsgIdx >= 0 { + // Replace previous data with a placeholder. Intermediate data results in a + // stream are redundant and consume unnecessary tokens. + messages[dataMsgIdx] = map[string]any{"Data Retrieved": "Intermediate result omitted"} + } + dataMsgIdx = len(messages) + } else if sm, ok := msg["systemMessage"].(map[string]any); ok { + // 2. If it's a system message, unwrap it. + processedMsg = sm + } else { + // 3. Otherwise (e.g. error), pass it through raw. + processedMsg = msg + } + + if processedMsg != nil { + messages = append(messages, processedMsg) + } + } + + var acc strings.Builder + for i, msg := range messages { + jsonBytes, err := json.Marshal(msg) + if err != nil { + return "", fmt.Errorf("error marshalling message: %w", err) + } + acc.Write(jsonBytes) + if i < len(messages)-1 { + acc.WriteString("\n") + } + } + + return acc.String(), nil +} + +// extractDataResult attempts to find the result.data deep inside the generic map. +func extractDataResult(msg map[string]any) map[string]any { + sm, ok := msg["systemMessage"].(map[string]any) + if !ok { + return nil + } + data, ok := sm["data"].(map[string]any) + if !ok { + return nil + } + result, ok := data["result"].(map[string]any) + if !ok { + return nil + } + if _, hasData := result["data"].([]any); hasData { + return result + } + return nil +} + +// formatDataRetrieved transforms the raw result map into the simplified Toolbox format. +func formatDataRetrieved(result map[string]any, maxRows int) map[string]any { + rawData, _ := result["data"].([]any) + + var fields []any + if schema, ok := result["schema"].(map[string]any); ok { + if f, ok := schema["fields"].([]any); ok { + fields = f + } + } + + var headers []string + for _, f := range fields { + if fm, ok := f.(map[string]any); ok { + if name, ok := fm["name"].(string); ok { + headers = append(headers, name) + } + } + } + + totalRows := len(rawData) + numToDisplay := totalRows + if numToDisplay > maxRows { + numToDisplay = maxRows + } + + var rows [][]any + for _, r := range rawData[:numToDisplay] { + if rm, ok := r.(map[string]any); ok { + var row []any + for _, h := range headers { + row = append(row, rm[h]) + } + rows = append(rows, row) + } + } + + summary := fmt.Sprintf("Showing all %d rows.", totalRows) + if totalRows > maxRows { + summary = fmt.Sprintf("Showing the first %d of %d total rows.", numToDisplay, totalRows) + } + + return map[string]any{ + "Data Retrieved": map[string]any{ + "headers": headers, + "rows": rows, + "summary": summary, + }, + } +} diff --git a/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent/conversationalanalyticsaskdataagent_test.go b/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent/conversationalanalyticsaskdataagent_test.go new file mode 100644 index 000000000000..ab5421327838 --- /dev/null +++ b/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent/conversationalanalyticsaskdataagent_test.go @@ -0,0 +1,68 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conversationalanalyticsaskdataagent_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/conversationalanalytics/conversationalanalyticsaskdataagent" +) + +func TestParseFromYamlConversationalAnalyticsAskDataAgent(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: example_tool + type: conversational-analytics-ask-data-agent + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": conversationalanalyticsaskdataagent.Config{ + Name: "example_tool", + Type: "conversational-analytics-ask-data-agent", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo/conversationalanalyticsgetdataagentinfo.go b/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo/conversationalanalyticsgetdataagentinfo.go new file mode 100644 index 000000000000..875e93e439cc --- /dev/null +++ b/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo/conversationalanalyticsgetdataagentinfo.go @@ -0,0 +1,224 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conversationalanalyticsgetdataagentinfo + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + cloudgdads "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "golang.org/x/oauth2" +) + +const resourceType string = "conversational-analytics-get-data-agent-info" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) + GetProjectID() string + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &cloudgdads.Source{} + +var compatibleSources = [...]string{cloudgdads.SourceType} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Location string `yaml:"location"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", resourceType, compatibleSources) + } + + if cfg.Location == "" { + cfg.Location = "global" + } + + dataAgentIdParameter := parameters.NewStringParameter("data_agent_id", "The ID of the data agent to retrieve info for.") + params := parameters.Parameters{dataAgentIdParameter} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + var tokenStr string + + // Get credentials for the API call + if source.UseClientAuthorization() { + // Use client-side access token + if accessToken == "" { + return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil) + } + tokenStr, err = accessToken.ParseBearerToken() + if err != nil { + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) + } + } else { + // Get a token source for the Gemini Data Analytics API. + tokenSource, err := source.GoogleCloudTokenSourceWithScope(ctx, "") + if err != nil { + return nil, util.NewClientServerError("failed to get token source", http.StatusInternalServerError, err) + } + + // Use cloud-platform token source for Gemini Data Analytics API + if tokenSource == nil { + return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil) + } + token, err := tokenSource.Token() + if err != nil { + return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err) + } + tokenStr = token.AccessToken + } + + // Extract parameters from the map + mapParams := params.AsMap() + dataAgentId, _ := mapParams["data_agent_id"].(string) + + // Construct URL + projectID := source.GetProjectID() + caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s/dataAgents/%s", projectID, t.Location, url.PathEscape(dataAgentId)) + + req, err := http.NewRequest("GET", caURL, nil) + if err != nil { + return nil, util.NewClientServerError("failed to create request", http.StatusInternalServerError, err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + req.Header.Set("X-Goog-API-Client", util.GDAClientID) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, util.NewClientServerError("failed to send request", http.StatusInternalServerError, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, util.NewAgentError(fmt.Sprintf("API returned non-200 status: %d %s", resp.StatusCode, string(body)), nil) + } + + var result map[string]any + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, util.NewClientServerError("failed to decode response", http.StatusInternalServerError, err) + } + + return result, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo/conversationalanalyticsgetdataagentinfo_test.go b/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo/conversationalanalyticsgetdataagentinfo_test.go new file mode 100644 index 000000000000..056671a2f15f --- /dev/null +++ b/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo/conversationalanalyticsgetdataagentinfo_test.go @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conversationalanalyticsgetdataagentinfo_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/conversationalanalytics/conversationalanalyticsgetdataagentinfo" +) + +func TestParseFromYamlConversationalAnalyticsGetDataAgent(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: example_tool + type: conversational-analytics-get-data-agent-info + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": conversationalanalyticsgetdataagentinfo.Config{ + Name: "example_tool", + Type: "conversational-analytics-get-data-agent-info", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + { + desc: "with auth required", + in: ` + kind: tool + name: example_tool + type: conversational-analytics-get-data-agent-info + source: my-instance + description: some description + authRequired: + - my-google-auth + `, + want: server.ToolConfigs{ + "example_tool": conversationalanalyticsgetdataagentinfo.Config{ + Name: "example_tool", + Type: "conversational-analytics-get-data-agent-info", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{ + "my-google-auth", + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents/conversationalanalyticslistaccessibledataagents.go b/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents/conversationalanalyticslistaccessibledataagents.go new file mode 100644 index 000000000000..d867e02f56f9 --- /dev/null +++ b/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents/conversationalanalyticslistaccessibledataagents.go @@ -0,0 +1,218 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conversationalanalyticslistaccessibledataagents + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + cloudgdads "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "golang.org/x/oauth2" +) + +const resourceType string = "conversational-analytics-list-accessible-data-agents" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) + GetProjectID() string + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &cloudgdads.Source{} + +var compatibleSources = [...]string{cloudgdads.SourceType} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Location string `yaml:"location"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", resourceType, compatibleSources) + } + + if cfg.Location == "" { + cfg.Location = "global" + } + + params := parameters.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + var tokenStr string + + // Get credentials for the API call + if source.UseClientAuthorization() { + // Use client-side access token + if accessToken == "" { + return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil) + } + tokenStr, err = accessToken.ParseBearerToken() + if err != nil { + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) + } + } else { + // Get a token source for the Gemini Data Analytics API. + tokenSource, err := source.GoogleCloudTokenSourceWithScope(ctx, "") + if err != nil { + return nil, util.NewClientServerError("failed to get token source", http.StatusInternalServerError, err) + } + + // Use cloud-platform token source for Gemini Data Analytics API + if tokenSource == nil { + return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil) + } + token, err := tokenSource.Token() + if err != nil { + return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err) + } + tokenStr = token.AccessToken + } + + // Construct URL + projectID := source.GetProjectID() + caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s/dataAgents:listAccessible", projectID, t.Location) + + req, err := http.NewRequest("GET", caURL, nil) + if err != nil { + return nil, util.NewClientServerError("failed to create request", http.StatusInternalServerError, err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + req.Header.Set("X-Goog-API-Client", util.GDAClientID) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, util.NewClientServerError("failed to send request", http.StatusInternalServerError, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, util.NewAgentError(fmt.Sprintf("API returned non-200 status: %d %s", resp.StatusCode, string(body)), nil) + } + + var result any + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, util.NewClientServerError("failed to decode response", http.StatusInternalServerError, err) + } + + return result, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents/conversationalanalyticslistaccessibledataagents_test.go b/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents/conversationalanalyticslistaccessibledataagents_test.go new file mode 100644 index 000000000000..a583f40a5611 --- /dev/null +++ b/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents/conversationalanalyticslistaccessibledataagents_test.go @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conversationalanalyticslistaccessibledataagents_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/conversationalanalytics/conversationalanalyticslistaccessibledataagents" +) + +func TestParseFromYamlConversationalAnalyticsListDataAgents(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: example_tool + type: conversational-analytics-list-accessible-data-agents + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": conversationalanalyticslistaccessibledataagents.Config{ + Name: "example_tool", + Type: "conversational-analytics-list-accessible-data-agents", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + { + desc: "with auth required", + in: ` + kind: tool + name: example_tool + type: conversational-analytics-list-accessible-data-agents + source: my-instance + description: some description + authRequired: + - my-google-auth + `, + want: server.ToolConfigs{ + "example_tool": conversationalanalyticslistaccessibledataagents.Config{ + Name: "example_tool", + Type: "conversational-analytics-list-accessible-data-agents", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{ + "my-google-auth", + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/cloudgda/cloud_gda_integration_test.go b/tests/cloudgda/cloud_gda_integration_test.go index fdfd8460dce8..31b395c33b1a 100644 --- a/tests/cloudgda/cloud_gda_integration_test.go +++ b/tests/cloudgda/cloud_gda_integration_test.go @@ -19,29 +19,46 @@ import ( "context" "encoding/json" "fmt" + "io" "net" "net/http" + "os" "regexp" "strings" "testing" "time" + bigqueryapi "cloud.google.com/go/bigquery" geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta" "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" + "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/sources" source "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" "github.com/googleapis/genai-toolbox/tests" + "golang.org/x/oauth2/google" + "google.golang.org/api/googleapi" + "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) var ( - cloudGdaToolType = "cloud-gemini-data-analytics-query" + cloudGdaToolType = "cloud-gemini-data-analytics-query" + CloudGDASourceType = "cloud-gemini-data-analytics" + CloudGdaProject = os.Getenv("CLOUD_GDA_PROJECT") ) +func getCloudGDAProject(t *testing.T) string { + if CloudGdaProject == "" { + t.Fatal("'CLOUD_GDA_PROJECT' not set") + } + return CloudGdaProject +} + type mockDataChatServer struct { geminidataanalyticspb.UnimplementedDataChatServiceServer t *testing.T @@ -204,3 +221,683 @@ func TestCloudGdaToolEndpoints(t *testing.T) { t.Errorf("MCP response does not contain expected query result: %s", respStr) } } + +// Copied over from bigquery_integration_test.go +func initBigQueryConnection(project string) (*bigqueryapi.Client, error) { + ctx := context.Background() + cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope) + if err != nil { + return nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err) + } + + client, err := bigqueryapi.NewClient(ctx, project, option.WithCredentials(cred)) + if err != nil { + return nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err) + } + return client, nil +} + +func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string) func(*testing.T) { + // Create dataset + dataset := client.Dataset(datasetName) + _, err := dataset.Metadata(ctx) + + if err != nil { + apiErr, ok := err.(*googleapi.Error) + if !ok || apiErr.Code != 404 { + t.Fatalf("Failed to check dataset %q existence: %v", datasetName, err) + } + metadataToCreate := &bigqueryapi.DatasetMetadata{Name: datasetName} + if err := dataset.Create(ctx, metadataToCreate); err != nil { + t.Fatalf("Failed to create dataset %q: %v", datasetName, err) + } + } + + // Create table + createJob, err := client.Query(createStatement).Run(ctx) + if err != nil { + t.Fatalf("Failed to start create table job for %s: %v", tableName, err) + } + createStatus, err := createJob.Wait(ctx) + if err != nil { + t.Fatalf("Failed to wait for create table job for %s: %v", tableName, err) + } + if err := createStatus.Err(); err != nil { + t.Fatalf("Create table job for %s failed: %v", tableName, err) + } + + if insertStatement != "" { + // Insert test data + insertQuery := client.Query(insertStatement) + insertJob, err := insertQuery.Run(ctx) + if err != nil { + t.Fatalf("Failed to start insert job for %s: %v", tableName, err) + } + insertStatus, err := insertJob.Wait(ctx) + if err != nil { + t.Fatalf("Failed to wait for insert job for %s: %v", tableName, err) + } + if err := insertStatus.Err(); err != nil { + t.Fatalf("Insert job for %s failed: %v", tableName, err) + } + } + + return func(t *testing.T) { + // tear down table + dropSQL := fmt.Sprintf("drop table %s", tableName) + dropJob, err := client.Query(dropSQL).Run(ctx) + if err != nil { + t.Errorf("Failed to start drop table job for %s: %v", tableName, err) + return + } + dropStatus, err := dropJob.Wait(ctx) + if err != nil { + t.Errorf("Failed to wait for drop table job for %s: %v", tableName, err) + return + } + if err := dropStatus.Err(); err != nil { + t.Errorf("Error dropping table %s: %v", tableName, err) + } + + // tear down dataset + datasetToTeardown := client.Dataset(datasetName) + tablesIterator := datasetToTeardown.Tables(ctx) + _, err = tablesIterator.Next() + + if err == iterator.Done { + if err := datasetToTeardown.Delete(ctx); err != nil { + t.Errorf("Failed to delete dataset %s: %v", datasetName, err) + } + } else if err != nil { + t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err) + } + } +} + +func setupDataAgent(t *testing.T, ctx context.Context, projectID, datasetID, tableID, dataAgentDisplayName string) (string, func(*testing.T)) { + t.Logf("Setting up data agent with ProjectID: %q, DatasetID: %q, TableID: %q, DisplayName: %q", projectID, datasetID, tableID, dataAgentDisplayName) + + accessToken, err := sources.GetIAMAccessToken(ctx) + if err != nil { + t.Fatalf("failed to get access token: %v", err) + } + + dataAgentId := "test" + strings.ReplaceAll(uuid.New().String(), "-", "") + parent := fmt.Sprintf("projects/%s/locations/global", projectID) + url := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/%s/dataAgents?dataAgentId=%s", parent, dataAgentId) + + requestBody := map[string]any{ + "displayName": dataAgentDisplayName, + "dataAnalyticsAgent": map[string]any{ + "publishedContext": map[string]any{ + "datasourceReferences": map[string]any{ + "bq": map[string]any{ + "tableReferences": []map[string]string{ + { + "projectId": projectID, + "datasetId": datasetID, + "tableId": tableID, + }, + }, + }, + }, + }, + }, + } + + bodyBytes, err := json.Marshal(requestBody) + if err != nil { + t.Fatalf("failed to marshal create data agent request: %v", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(bodyBytes)) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to create data agent: %v", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("failed to create data agent, status: %d, body: %s", resp.StatusCode, string(respBody)) + } + + var op map[string]any + if err := json.Unmarshal(respBody, &op); err != nil { + t.Fatalf("failed to unmarshal operation: %v", err) + } + + opName, ok := op["name"].(string) + if !ok { + t.Fatalf("operation response missing name: %s", string(respBody)) + } + + // Poll for operation completion + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + timeout := time.After(60 * time.Second) + + done := false + for !done { + select { + case <-ctx.Done(): + t.Fatalf("context cancelled while waiting for data agent creation") + case <-timeout: + t.Fatalf("timed out waiting for data agent creation") + case <-ticker.C: + opUrl := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/%s", opName) + opReq, _ := http.NewRequestWithContext(ctx, http.MethodGet, opUrl, nil) + opReq.Header.Set("Authorization", "Bearer "+accessToken) + opResp, err := client.Do(opReq) + if err != nil { + t.Logf("failed to poll operation: %v", err) + continue + } + opRespBody, _ := io.ReadAll(opResp.Body) + opResp.Body.Close() + + var pollOp map[string]any + if err := json.Unmarshal(opRespBody, &pollOp); err != nil { + t.Logf("failed to unmarshal polling response: %v", err) + continue + } + + if d, ok := pollOp["done"].(bool); ok && d { + if errVal, ok := pollOp["error"]; ok && errVal != nil { + t.Fatalf("data agent creation failed: %v", errVal) + } + done = true + } + } + } + + teardown := func(t *testing.T) { + agentName := fmt.Sprintf("%s/dataAgents/%s", parent, dataAgentId) + deleteUrl := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/%s", agentName) + delReq, _ := http.NewRequest(http.MethodDelete, deleteUrl, nil) + delReq.Header.Set("Authorization", "Bearer "+accessToken) + delResp, err := client.Do(delReq) + if err != nil { + t.Errorf("failed to delete data agent %s: %v", agentName, err) + return + } + defer delResp.Body.Close() + if delResp.StatusCode != http.StatusOK { + t.Errorf("failed to delete data agent %s, status: %d", agentName, delResp.StatusCode) + } + } + + return dataAgentId, teardown +} + +func TestCloudGDAConservationalAnalyticsTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + projectID := getCloudGDAProject(t) + client, err := initBigQueryConnection(projectID) + if err != nil { + t.Fatalf("unable to create BigQuery client: %s", err) + } + + // Setup dataset and table for Data Agent + datasetName := fmt.Sprintf("data_agent_test_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) + tableName := "test_table" + tableNameParam := fmt.Sprintf("`%s.%s.%s`", projectID, datasetName, tableName) + + createTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING)", tableNameParam) + teardownTable := setupBigQueryTable(t, ctx, client, createTableStmt, "", datasetName, tableNameParam) + defer teardownTable(t) + + // Create Data Agent + dataAgentDisplayName := fmt.Sprintf("test-agent-%s", strings.ReplaceAll(uuid.New().String(), "-", "")) + dataAgentID, teardownDataAgent := setupDataAgent(t, ctx, projectID, datasetName, tableName, dataAgentDisplayName) + defer teardownDataAgent(t) + + // Configure tools with cloud-gemini-data-analytics source + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": map[string]any{ + "type": "cloud-gemini-data-analytics", + "projectId": projectID, + }, + "my-client-auth-source": map[string]any{ + "type": "cloud-gemini-data-analytics", + "projectId": projectID, + "useClientOAuth": true, + }, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "kind": "google", + "clientId": tests.ClientId, + }, + }, + "tools": map[string]any{ + "my-list-accessible-data-agents-tool": map[string]any{ + "type": "conversational-analytics-list-accessible-data-agents", + "source": "my-instance", + "description": "Tool to list data agents.", + }, + "my-auth-list-accessible-data-agents-tool": map[string]any{ + "type": "conversational-analytics-list-accessible-data-agents", + "source": "my-instance", + "description": "Tool to list data agents with auth.", + "authRequired": []string{"my-google-auth"}, + }, + "my-client-auth-list-accessible-data-agents-tool": map[string]any{ + "type": "conversational-analytics-list-accessible-data-agents", + "source": "my-client-auth-source", + "description": "Tool to list data agents with client auth.", + }, + "my-get-data-agent-info-tool": map[string]any{ + "type": "conversational-analytics-get-data-agent-info", + "source": "my-instance", + "description": "Tool to get data agent info.", + }, + "my-auth-get-data-agent-info-tool": map[string]any{ + "type": "conversational-analytics-get-data-agent-info", + "source": "my-instance", + "description": "Tool to get data agent info with auth.", + "authRequired": []string{"my-google-auth"}, + }, + "my-client-auth-get-data-agent-info-tool": map[string]any{ + "type": "conversational-analytics-get-data-agent-info", + "source": "my-client-auth-source", + "description": "Tool to get data agent info with client auth.", + }, + "my-ask-data-agent-tool": map[string]any{ + "type": "conversational-analytics-ask-data-agent", + "source": "my-instance", + "description": "Tool to ask data agent.", + }, + "my-auth-ask-data-agent-tool": map[string]any{ + "type": "conversational-analytics-ask-data-agent", + "source": "my-instance", + "description": "Tool to ask data agent with auth.", + "authRequired": []string{"my-google-auth"}, + }, + "my-client-auth-ask-data-agent-tool": map[string]any{ + "type": "conversational-analytics-ask-data-agent", + "source": "my-client-auth-source", + "description": "Tool to ask data agent with client auth.", + }, + }, + } + + args := []string{"--enable-api"} + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + runListAccessibleDataAgentsInvokeTest(t, dataAgentDisplayName) + runGetDataAgentInfoInvokeTest(t, dataAgentID, dataAgentDisplayName) + runAskDataAgentInvokeTest(t, dataAgentID) +} + +func runListAccessibleDataAgentsInvokeTest(t *testing.T, dataAgentDisplayName string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-list-accessible-data-agents-tool", + api: "http://127.0.0.1:5000/api/tool/my-list-accessible-data-agents-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: dataAgentDisplayName, + isErr: false, + }, + { + name: "invoke my-auth-list-accessible-data-agents-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-accessible-data-agents-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: dataAgentDisplayName, + isErr: false, + }, + { + name: "invoke my-auth-list-accessible-data-agents-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-accessible-data-agents-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-list-accessible-data-agents-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-accessible-data-agents-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: dataAgentDisplayName, + isErr: false, + }, + { + name: "invoke my-client-auth-list-accessible-data-agents-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-accessible-data-agents-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + for k, v := range tc.requestHeader { + req.Header.Add(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + + var body map[string]interface{} + if err := json.Unmarshal(bodyBytes, &body); err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if !strings.Contains(got, tc.want) { + t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) + } + }) + } +} + +func runGetDataAgentInfoInvokeTest(t *testing.T, dataAgentName, dataAgentDisplayName string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-get-data-agent-info-tool", + api: "http://127.0.0.1:5000/api/tool/my-get-data-agent-info-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"data_agent_id": "%s"}`, dataAgentName))), + want: dataAgentDisplayName, + isErr: false, + }, + { + name: "invoke my-auth-get-data-agent-info-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-data-agent-info-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"data_agent_id": "%s"}`, dataAgentName))), + want: dataAgentDisplayName, + isErr: false, + }, + { + name: "invoke my-auth-get-data-agent-info-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-data-agent-info-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"data_agent_id": "%s"}`, dataAgentName))), + isErr: true, + }, + { + name: "invoke my-client-auth-get-data-agent-info-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-data-agent-info-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"data_agent_id": "%s"}`, dataAgentName))), + want: dataAgentDisplayName, + isErr: false, + }, + { + name: "invoke my-client-auth-get-data-agent-info-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-data-agent-info-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"data_agent_id": "%s"}`, dataAgentName))), + isErr: true, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + for k, v := range tc.requestHeader { + req.Header.Add(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if !strings.Contains(got, tc.want) { + t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) + } + }) + } +} + +func runAskDataAgentInvokeTest(t *testing.T, dataAgentID string) { + const maxRetries = 3 + const requestTimeout = 340 * time.Second + + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + dataAgentWant := `FINAL_RESPONSE` + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody string + want string + isErr bool + }{ + { + name: "invoke my-ask-data-agent-tool", + api: "http://127.0.0.1:5000/api/tool/my-ask-data-agent-tool/invoke", + requestHeader: map[string]string{}, + requestBody: fmt.Sprintf(`{"user_query_with_context": "What are the names in the table?", "data_agent_id": "%s"}`, dataAgentID), + want: dataAgentWant, + isErr: false, + }, + { + name: "invoke my-auth-ask-data-agent-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-ask-data-agent-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: fmt.Sprintf(`{"user_query_with_context": "What are the names in the table?", "data_agent_id": "%s"}`, dataAgentID), + want: dataAgentWant, + isErr: false, + }, + { + name: "invoke my-auth-ask-data-agent-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-ask-data-agent-tool/invoke", + requestHeader: map[string]string{}, + requestBody: fmt.Sprintf(`{"user_query_with_context": "What are the names in the table?", "data_agent_id": "%s"}`, dataAgentID), + isErr: true, + }, + { + name: "invoke my-client-auth-ask-data-agent-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-ask-data-agent-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: fmt.Sprintf(`{"user_query_with_context": "What are the names in the table?", "data_agent_id": "%s"}`, dataAgentID), + want: dataAgentWant, + isErr: false, + }, + { + name: "invoke my-client-auth-ask-data-agent-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-ask-data-agent-tool/invoke", + requestHeader: map[string]string{}, + requestBody: fmt.Sprintf(`{"user_query_with_context": "What are the names in the table?", "data_agent_id": "%s"}`, dataAgentID), + isErr: true, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + var resp *http.Response + var err error + bodyBytes := []byte(tc.requestBody) + + req, err := http.NewRequest(http.MethodPost, tc.api, nil) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Set("Content-type", "application/json") + for k, v := range tc.requestHeader { + req.Header.Add(k, v) + } + + for i := 0; i < maxRetries; i++ { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyBytes)), nil + } + reqWithCtx := req.WithContext(ctx) + + resp, err = http.DefaultClient.Do(reqWithCtx) + if err != nil { + // Retry on time out. + if os.IsTimeout(err) { + t.Logf("Request timed out (attempt %d/%d), retrying...", i+1, maxRetries) + time.Sleep(5 * time.Second) + continue + } + t.Fatalf("unable to send request: %s", err) + } + if resp.StatusCode == http.StatusServiceUnavailable { + t.Logf("Received 503 Service Unavailable (attempt %d/%d), retrying...", i+1, maxRetries) + time.Sleep(15 * time.Second) + continue + } + break + } + + if err != nil { + t.Fatalf("Request failed after %d retries: %v", maxRetries, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + if err = json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + wantPattern := regexp.MustCompile(tc.want) + if !wantPattern.MatchString(got) { + t.Fatalf("response did not match the expected pattern.\nFull response:\n%s", got) + } + }) + } +}