|
| 1 | +// Copyright 2026 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +package alloydbainl |
| 16 | + |
| 17 | +import ( |
| 18 | + "context" |
| 19 | + "net/http" |
| 20 | + "os" |
| 21 | + "regexp" |
| 22 | + "testing" |
| 23 | + "time" |
| 24 | + |
| 25 | + "github.com/googleapis/genai-toolbox/internal/testutils" |
| 26 | + "github.com/googleapis/genai-toolbox/tests" |
| 27 | +) |
| 28 | + |
| 29 | +func TestAlloyDBAINLToolEndpointsMCP(t *testing.T) { |
| 30 | + sourceConfig := getAlloyDBAINLVars(t) |
| 31 | + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) |
| 32 | + defer cancel() |
| 33 | + |
| 34 | + args := []string{"--enable-api"} |
| 35 | + |
| 36 | + toolsFile := getAINLToolsConfig(sourceConfig) |
| 37 | + |
| 38 | + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) |
| 39 | + if err != nil { |
| 40 | + t.Fatalf("command initialization returned an error: %s", err) |
| 41 | + } |
| 42 | + defer cleanup() |
| 43 | + |
| 44 | + waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second) |
| 45 | + defer cancelWait() |
| 46 | + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) |
| 47 | + if err != nil { |
| 48 | + t.Logf("toolbox command logs: \n%s", out) |
| 49 | + t.Fatalf("toolbox didn't start successfully: %s", err) |
| 50 | + } |
| 51 | + |
| 52 | + runAINLToolGetMCPTest(t) |
| 53 | + runAINLToolInvokeMCPTest(t) |
| 54 | +} |
| 55 | + |
| 56 | +func runAINLToolGetMCPTest(t *testing.T) { |
| 57 | + t.Run("list tools via MCP", func(t *testing.T) { |
| 58 | + statusCode, toolsList, err := tests.GetMCPToolsList(t, nil) |
| 59 | + if err != nil { |
| 60 | + t.Fatalf("native error executing tools/list: %s", err) |
| 61 | + } |
| 62 | + if statusCode != http.StatusOK { |
| 63 | + t.Fatalf("expected status 200, got %d", statusCode) |
| 64 | + } |
| 65 | + |
| 66 | + // Verify that my-simple-tool is in the list |
| 67 | + found := false |
| 68 | + for _, tool := range toolsList { |
| 69 | + toolMap, ok := tool.(map[string]any) |
| 70 | + if !ok { |
| 71 | + continue |
| 72 | + } |
| 73 | + if toolMap["name"] == "my-simple-tool" { |
| 74 | + found = true |
| 75 | + break |
| 76 | + } |
| 77 | + } |
| 78 | + if !found { |
| 79 | + t.Errorf("expected tool 'my-simple-tool' not found in list") |
| 80 | + } |
| 81 | + }) |
| 82 | +} |
| 83 | + |
| 84 | +func runAINLToolInvokeMCPTest(t *testing.T) { |
| 85 | + idToken, err := tests.GetGoogleIdToken(tests.ClientId) |
| 86 | + if err != nil { |
| 87 | + t.Fatalf("error getting Google ID token: %s", err) |
| 88 | + } |
| 89 | + |
| 90 | + invokeTcs := []struct { |
| 91 | + name string |
| 92 | + toolName string |
| 93 | + args map[string]any |
| 94 | + requestHeader map[string]string |
| 95 | + want string |
| 96 | + isErr bool |
| 97 | + }{ |
| 98 | + { |
| 99 | + name: "invoke my-simple-tool", |
| 100 | + toolName: "my-simple-tool", |
| 101 | + args: map[string]any{"question": "return the number 1"}, |
| 102 | + want: "[{\"execute_nl_query\":{\"?column?\":1}}]", |
| 103 | + isErr: false, |
| 104 | + }, |
| 105 | + { |
| 106 | + name: "Invoke my-auth-tool with auth token", |
| 107 | + toolName: "my-auth-tool", |
| 108 | + args: map[string]any{"question": "can you show me the name of this user?"}, |
| 109 | + requestHeader: map[string]string{"my-google-auth_token": idToken}, |
| 110 | + want: "[{\"execute_nl_query\":{\"name\":\"Alice\"}}]", |
| 111 | + isErr: false, |
| 112 | + }, |
| 113 | + { |
| 114 | + name: "Invoke my-auth-tool with invalid auth token", |
| 115 | + toolName: "my-auth-tool", |
| 116 | + args: map[string]any{"question": "return the number 1"}, |
| 117 | + requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, |
| 118 | + isErr: true, |
| 119 | + }, |
| 120 | + { |
| 121 | + name: "Invoke my-auth-tool without auth token", |
| 122 | + toolName: "my-auth-tool", |
| 123 | + args: map[string]any{"question": "return the number 1"}, |
| 124 | + isErr: true, |
| 125 | + }, |
| 126 | + { |
| 127 | + name: "Invoke my-auth-required-tool with auth token", |
| 128 | + toolName: "my-auth-required-tool", |
| 129 | + args: map[string]any{"question": "return the number 1"}, |
| 130 | + requestHeader: map[string]string{"my-google-auth_token": idToken}, |
| 131 | + isErr: false, |
| 132 | + want: "[{\"execute_nl_query\":{\"?column?\":1}}]", |
| 133 | + }, |
| 134 | + { |
| 135 | + name: "Invoke my-auth-required-tool with invalid auth token", |
| 136 | + toolName: "my-auth-required-tool", |
| 137 | + args: map[string]any{"question": "return the number 1"}, |
| 138 | + requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, |
| 139 | + isErr: true, |
| 140 | + }, |
| 141 | + { |
| 142 | + name: "Invoke my-auth-required-tool without auth token", |
| 143 | + toolName: "my-auth-required-tool", |
| 144 | + args: map[string]any{"question": "return the number 1"}, |
| 145 | + isErr: true, |
| 146 | + }, |
| 147 | + } |
| 148 | + |
| 149 | + for _, tc := range invokeTcs { |
| 150 | + t.Run(tc.name, func(t *testing.T) { |
| 151 | + statusCode, mcpResp, err := tests.InvokeMCPTool(t, tc.toolName, tc.args, tc.requestHeader) |
| 152 | + if err != nil { |
| 153 | + t.Fatalf("native error executing %s: %s", tc.toolName, err) |
| 154 | + } |
| 155 | + |
| 156 | + if statusCode != http.StatusOK { |
| 157 | + t.Fatalf("expected status 200, got %d", statusCode) |
| 158 | + } |
| 159 | + |
| 160 | + if tc.isErr { |
| 161 | + if !mcpResp.Result.IsError { |
| 162 | + t.Fatalf("expected error result, got success") |
| 163 | + } |
| 164 | + } else { |
| 165 | + if mcpResp.Result.IsError { |
| 166 | + t.Fatalf("expected success result, got error: %v", mcpResp.Result) |
| 167 | + } |
| 168 | + got := mcpResp.Result.Content[0].Text |
| 169 | + if got != tc.want { |
| 170 | + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) |
| 171 | + } |
| 172 | + } |
| 173 | + }) |
| 174 | + } |
| 175 | +} |
| 176 | + |
| 177 | +var ( |
| 178 | + AlloyDBAINLSourceType = "alloydb-postgres" |
| 179 | + AlloyDBAINLToolType = "alloydb-ai-nl" |
| 180 | + AlloyDBAINLProject = os.Getenv("ALLOYDB_AI_NL_PROJECT") |
| 181 | + AlloyDBAINLRegion = os.Getenv("ALLOYDB_AI_NL_REGION") |
| 182 | + AlloyDBAINLCluster = os.Getenv("ALLOYDB_AI_NL_CLUSTER") |
| 183 | + AlloyDBAINLInstance = os.Getenv("ALLOYDB_AI_NL_INSTANCE") |
| 184 | + AlloyDBAINLDatabase = os.Getenv("ALLOYDB_AI_NL_DATABASE") |
| 185 | + AlloyDBAINLUser = os.Getenv("ALLOYDB_AI_NL_USER") |
| 186 | + AlloyDBAINLPass = os.Getenv("ALLOYDB_AI_NL_PASS") |
| 187 | +) |
| 188 | + |
| 189 | +func getAlloyDBAINLVars(t *testing.T) map[string]any { |
| 190 | + switch "" { |
| 191 | + case AlloyDBAINLProject: |
| 192 | + t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set") |
| 193 | + case AlloyDBAINLRegion: |
| 194 | + t.Fatal("'ALLOYDB_AI_NL_REGION' not set") |
| 195 | + case AlloyDBAINLCluster: |
| 196 | + t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set") |
| 197 | + case AlloyDBAINLInstance: |
| 198 | + t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set") |
| 199 | + case AlloyDBAINLDatabase: |
| 200 | + t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set") |
| 201 | + case AlloyDBAINLUser: |
| 202 | + t.Fatal("'ALLOYDB_AI_NL_USER' not set") |
| 203 | + case AlloyDBAINLPass: |
| 204 | + t.Fatal("'ALLOYDB_AI_NL_PASS' not set") |
| 205 | + } |
| 206 | + return map[string]any{ |
| 207 | + "type": AlloyDBAINLSourceType, |
| 208 | + "project": AlloyDBAINLProject, |
| 209 | + "cluster": AlloyDBAINLCluster, |
| 210 | + "instance": AlloyDBAINLInstance, |
| 211 | + "region": AlloyDBAINLRegion, |
| 212 | + "database": AlloyDBAINLDatabase, |
| 213 | + "user": AlloyDBAINLUser, |
| 214 | + "password": AlloyDBAINLPass, |
| 215 | + } |
| 216 | +} |
| 217 | + |
| 218 | +func getAINLToolsConfig(sourceConfig map[string]any) map[string]any { |
| 219 | + // Write config into a file and pass it to command |
| 220 | + toolsFile := map[string]any{ |
| 221 | + "sources": map[string]any{ |
| 222 | + "my-instance": sourceConfig, |
| 223 | + }, |
| 224 | + "authServices": map[string]any{ |
| 225 | + "my-google-auth": map[string]any{ |
| 226 | + "type": "google", |
| 227 | + "clientId": tests.ClientId, |
| 228 | + }, |
| 229 | + }, |
| 230 | + "tools": map[string]any{ |
| 231 | + "my-simple-tool": map[string]any{ |
| 232 | + "type": AlloyDBAINLToolType, |
| 233 | + "source": "my-instance", |
| 234 | + "description": "Simple tool to test end to end functionality.", |
| 235 | + "nlConfig": "my_nl_config", |
| 236 | + }, |
| 237 | + "my-auth-tool": map[string]any{ |
| 238 | + "type": AlloyDBAINLToolType, |
| 239 | + "source": "my-instance", |
| 240 | + "description": "Tool to test authenticated parameters.", |
| 241 | + "nlConfig": "my_nl_config", |
| 242 | + "nlConfigParameters": []map[string]any{ |
| 243 | + { |
| 244 | + "name": "email", |
| 245 | + "type": "string", |
| 246 | + "description": "user email", |
| 247 | + "authServices": []map[string]string{ |
| 248 | + { |
| 249 | + "name": "my-google-auth", |
| 250 | + "field": "email", |
| 251 | + }, |
| 252 | + }, |
| 253 | + }, |
| 254 | + }, |
| 255 | + }, |
| 256 | + "my-auth-required-tool": map[string]any{ |
| 257 | + "type": AlloyDBAINLToolType, |
| 258 | + "source": "my-instance", |
| 259 | + "description": "Tool to test auth required invocation.", |
| 260 | + "nlConfig": "my_nl_config", |
| 261 | + "authRequired": []string{ |
| 262 | + "my-google-auth", |
| 263 | + }, |
| 264 | + }, |
| 265 | + }, |
| 266 | + } |
| 267 | + |
| 268 | + return toolsFile |
| 269 | +} |
0 commit comments