Skip to content

Commit 776248b

Browse files
committed
test(source/cloud-sql-mssql): create MCP integration tests
1 parent 6e1e8b1 commit 776248b

File tree

3 files changed

+542
-123
lines changed

3 files changed

+542
-123
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package cloudsqlmssql_test
16+
17+
import (
18+
"context"
19+
"encoding/json"
20+
"fmt"
21+
"net/http"
22+
"net/http/httptest"
23+
"net/url"
24+
"regexp"
25+
"strings"
26+
"testing"
27+
"time"
28+
29+
"github.com/google/go-cmp/cmp"
30+
"github.com/googleapis/genai-toolbox/internal/testutils"
31+
"github.com/googleapis/genai-toolbox/tests"
32+
"google.golang.org/api/sqladmin/v1"
33+
)
34+
35+
const createInstanceToolTypeMCP = "cloud-sql-mssql-create-instance"
36+
37+
type createInstanceTransportMCP struct {
38+
transport http.RoundTripper
39+
url *url.URL
40+
}
41+
42+
func (t *createInstanceTransportMCP) RoundTrip(req *http.Request) (*http.Response, error) {
43+
if strings.HasPrefix(req.URL.String(), "https://sqladmin.googleapis.com") {
44+
req.URL.Scheme = t.url.Scheme
45+
req.URL.Host = t.url.Host
46+
}
47+
return t.transport.RoundTrip(req)
48+
}
49+
50+
type masterHandlerMCP struct {
51+
t *testing.T
52+
}
53+
54+
func (h *masterHandlerMCP) ServeHTTP(w http.ResponseWriter, r *http.Request) {
55+
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
56+
h.t.Errorf("User-Agent header not found")
57+
}
58+
59+
var body sqladmin.DatabaseInstance
60+
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
61+
h.t.Fatalf("failed to decode request body: %v", err)
62+
}
63+
64+
instanceName := body.Name
65+
if instanceName == "" {
66+
http.Error(w, "missing instance name", http.StatusBadRequest)
67+
return
68+
}
69+
70+
var expectedBody sqladmin.DatabaseInstance
71+
var response any
72+
var statusCode int
73+
74+
switch instanceName {
75+
case "instance1":
76+
expectedBody = sqladmin.DatabaseInstance{
77+
Project: "p1",
78+
Name: "instance1",
79+
DatabaseVersion: "SQLSERVER_2022_ENTERPRISE",
80+
RootPassword: "password123",
81+
Settings: &sqladmin.Settings{
82+
AvailabilityType: "REGIONAL",
83+
Edition: "ENTERPRISE",
84+
Tier: "db-custom-4-26624",
85+
DataDiskSizeGb: 250,
86+
DataDiskType: "PD_SSD",
87+
},
88+
}
89+
response = map[string]any{"name": "op1", "status": "PENDING"}
90+
statusCode = http.StatusOK
91+
case "instance2":
92+
expectedBody = sqladmin.DatabaseInstance{
93+
Project: "p2",
94+
Name: "instance2",
95+
DatabaseVersion: "SQLSERVER_2022_STANDARD",
96+
RootPassword: "password456",
97+
Settings: &sqladmin.Settings{
98+
AvailabilityType: "ZONAL",
99+
Edition: "ENTERPRISE",
100+
Tier: "db-custom-2-8192",
101+
DataDiskSizeGb: 100,
102+
DataDiskType: "PD_SSD",
103+
},
104+
}
105+
response = map[string]any{"name": "op2", "status": "RUNNING"}
106+
statusCode = http.StatusOK
107+
default:
108+
http.Error(w, fmt.Sprintf("unhandled instance name: %s", instanceName), http.StatusInternalServerError)
109+
return
110+
}
111+
112+
if expectedBody.Project != body.Project {
113+
h.t.Errorf("unexpected project: got %q, want %q", body.Project, expectedBody.Project)
114+
}
115+
if expectedBody.Name != body.Name {
116+
h.t.Errorf("unexpected name: got %q, want %q", body.Name, expectedBody.Name)
117+
}
118+
if expectedBody.DatabaseVersion != body.DatabaseVersion {
119+
h.t.Errorf("unexpected databaseVersion: got %q, want %q", body.DatabaseVersion, expectedBody.DatabaseVersion)
120+
}
121+
if expectedBody.RootPassword != body.RootPassword {
122+
h.t.Errorf("unexpected rootPassword: got %q, want %q", body.RootPassword, expectedBody.RootPassword)
123+
}
124+
if diff := cmp.Diff(expectedBody.Settings, body.Settings); diff != "" {
125+
h.t.Errorf("unexpected request body settings (-want +got):\n%s", diff)
126+
}
127+
128+
w.Header().Set("Content-Type", "application/json")
129+
w.WriteHeader(statusCode)
130+
if err := json.NewEncoder(w).Encode(response); err != nil {
131+
http.Error(w, err.Error(), http.StatusInternalServerError)
132+
}
133+
}
134+
135+
func TestCreateInstanceToolEndpointsMCP(t *testing.T) {
136+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
137+
defer cancel()
138+
139+
handler := &masterHandlerMCP{t: t}
140+
server := httptest.NewServer(handler)
141+
defer server.Close()
142+
143+
serverURL, err := url.Parse(server.URL)
144+
if err != nil {
145+
t.Fatalf("failed to parse server URL: %v", err)
146+
}
147+
148+
originalTransport := http.DefaultClient.Transport
149+
if originalTransport == nil {
150+
originalTransport = http.DefaultTransport
151+
}
152+
http.DefaultClient.Transport = &createInstanceTransportMCP{
153+
transport: originalTransport,
154+
url: serverURL,
155+
}
156+
t.Cleanup(func() {
157+
http.DefaultClient.Transport = originalTransport
158+
})
159+
160+
toolsFile := getCreateInstanceToolsConfigMCP()
161+
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
162+
if err != nil {
163+
t.Fatalf("command initialization returned an error: %v", err)
164+
}
165+
defer cleanup()
166+
167+
waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second)
168+
defer cancelWait()
169+
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
170+
if err != nil {
171+
t.Logf("toolbox command logs: \n%s", out)
172+
t.Fatalf("toolbox didn't start successfully: %v", err)
173+
}
174+
175+
tcs := []struct {
176+
name string
177+
toolName string
178+
body string
179+
want string
180+
expectError bool
181+
}{
182+
{
183+
name: "verify successful instance creation with production preset",
184+
toolName: "create-instance-prod",
185+
body: `{"project": "p1", "name": "instance1", "databaseVersion": "SQLSERVER_2022_ENTERPRISE", "rootPassword": "password123", "editionPreset": "Production"}`,
186+
want: `{"name":"op1","status":"PENDING"}`,
187+
expectError: false,
188+
},
189+
{
190+
name: "verify successful instance creation with development preset",
191+
toolName: "create-instance-dev",
192+
body: `{"project": "p2", "name": "instance2", "rootPassword": "password456", "editionPreset": "Development"}`,
193+
want: `{"name":"op2","status":"RUNNING"}`,
194+
expectError: false,
195+
},
196+
{
197+
name: "verify missing required parameter returns schema error",
198+
toolName: "create-instance-prod",
199+
body: `{"name": "instance1"}`,
200+
want: `parameter "project" is required`,
201+
expectError: true,
202+
},
203+
}
204+
205+
for _, tc := range tcs {
206+
t.Run(tc.name, func(t *testing.T) {
207+
var args map[string]any
208+
if err := json.Unmarshal([]byte(tc.body), &args); err != nil {
209+
t.Fatalf("failed to unmarshal body: %v", err)
210+
}
211+
212+
statusCode, mcpResp, err := tests.InvokeMCPTool(t, tc.toolName, args, nil)
213+
if err != nil {
214+
t.Fatalf("native error executing %s: %v", tc.toolName, err)
215+
}
216+
217+
if statusCode != http.StatusOK {
218+
t.Fatalf("expected status 200, got %d", statusCode)
219+
}
220+
221+
if tc.expectError {
222+
tests.AssertMCPError(t, mcpResp, tc.want)
223+
} else {
224+
if mcpResp.Result.IsError {
225+
t.Fatalf("expected success, got error result: %v", mcpResp.Result)
226+
}
227+
gotStr := mcpResp.Result.Content[0].Text
228+
var got, want map[string]any
229+
if err := json.Unmarshal([]byte(gotStr), &got); err != nil {
230+
t.Fatalf("failed to unmarshal result: %v", err)
231+
}
232+
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
233+
t.Fatalf("failed to unmarshal want: %v", err)
234+
}
235+
if diff := cmp.Diff(want, got); diff != "" {
236+
t.Errorf("unexpected result (-want +got):\n%s", diff)
237+
}
238+
}
239+
})
240+
}
241+
}
242+
243+
func getCreateInstanceToolsConfigMCP() map[string]any {
244+
return map[string]any{
245+
"sources": map[string]any{
246+
"my-cloud-sql-source": map[string]any{
247+
"type": "cloud-sql-admin",
248+
},
249+
},
250+
"tools": map[string]any{
251+
"create-instance-prod": map[string]any{
252+
"type": createInstanceToolTypeMCP,
253+
"source": "my-cloud-sql-source",
254+
},
255+
"create-instance-dev": map[string]any{
256+
"type": createInstanceToolTypeMCP,
257+
"source": "my-cloud-sql-source",
258+
},
259+
},
260+
}
261+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package cloudsqlmssql
16+
17+
import (
18+
"context"
19+
"os"
20+
"regexp"
21+
"strings"
22+
"testing"
23+
"time"
24+
25+
"github.com/google/uuid"
26+
"github.com/googleapis/genai-toolbox/internal/testutils"
27+
"github.com/googleapis/genai-toolbox/tests"
28+
)
29+
30+
func TestCloudSQLMSSQLMCPToolEndpoints(t *testing.T) {
31+
if os.Getenv("CLOUD_SQL_MSSQL_PROJECT") == "" {
32+
t.Skip("Skipping Cloud SQL MSSQL MCP test because environment variables are not set")
33+
}
34+
35+
sourceConfig := getCloudSQLMSSQLVars(t)
36+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
37+
defer cancel()
38+
39+
db, err := initCloudSQLMSSQLConnection(CloudSQLMSSQLProject, CloudSQLMSSQLRegion, CloudSQLMSSQLInstance, "public", CloudSQLMSSQLUser, CloudSQLMSSQLPass, CloudSQLMSSQLDatabase)
40+
if err != nil {
41+
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
42+
}
43+
defer db.Close()
44+
45+
// cleanup test environment
46+
tests.CleanupMSSQLTables(t, ctx, db)
47+
48+
// create table name with UUID
49+
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
50+
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
51+
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
52+
53+
// set up data for param tool
54+
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMSSQLParamToolInfo(tableNameParam)
55+
teardownTable1 := tests.SetupMsSQLTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
56+
defer teardownTable1(t)
57+
58+
// set up data for auth tool
59+
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetMSSQLAuthToolInfo(tableNameAuth)
60+
teardownTable2 := tests.SetupMsSQLTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
61+
defer teardownTable2(t)
62+
63+
// Write config into a file and pass it to command
64+
toolsConfig := tests.GetToolsConfig(sourceConfig, CloudSQLMSSQLToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
65+
toolsConfig = tests.AddMSSQLExecuteSqlConfig(t, toolsConfig)
66+
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
67+
toolsConfig = tests.AddTemplateParamConfig(t, toolsConfig, CloudSQLMSSQLToolType, tmplSelectCombined, tmplSelectFilterCombined, "")
68+
toolsConfig = tests.AddMSSQLPrebuiltToolConfig(t, toolsConfig)
69+
70+
cmd, cleanup, err := tests.StartCmd(ctx, toolsConfig)
71+
if err != nil {
72+
t.Fatalf("command initialization returned an error: %s", err)
73+
}
74+
defer cleanup()
75+
76+
waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second)
77+
defer cancelWait()
78+
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
79+
if err != nil {
80+
t.Logf("toolbox command logs: \n%s", out)
81+
t.Fatalf("toolbox didn't start successfully: %s", err)
82+
}
83+
84+
// Get configs for tests
85+
_, mcpMyFailToolWant, _, mcpSelect1Want := tests.GetMSSQLWants()
86+
87+
// Verify the tools list manifest
88+
expectedTools := tests.GetBaseMCPExpectedTools()
89+
expectedTools = append(expectedTools, tests.GetExecuteSQLMCPExpectedTools()...)
90+
expectedTools = append(expectedTools, tests.GetTemplateParamMCPExpectedTools()...)
91+
expectedTools = append(expectedTools, tests.MCPToolManifest{
92+
Name: "list_tables",
93+
Description: "Lists tables in the database.",
94+
InputSchema: map[string]any{
95+
"type": "object",
96+
"properties": map[string]any{
97+
"table_names": map[string]any{
98+
"default": "",
99+
"description": "Optional: A comma-separated list of table names. If empty, details for all tables will be listed.",
100+
"type": "string",
101+
},
102+
"output_format": map[string]any{
103+
"default": "detailed",
104+
"description": "Optional: Use 'simple' for names only or 'detailed' for full info.",
105+
"type": "string",
106+
},
107+
},
108+
"required": []any{},
109+
},
110+
})
111+
112+
t.Run("verify tools/list registry returns complete manifest", func(t *testing.T) {
113+
tests.RunMCPToolsListMethod(t, expectedTools)
114+
})
115+
116+
// Run tests via MCP
117+
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
118+
119+
// Run specific MSSQL tool tests via MCP
120+
tests.RunMSSQLListTablesTest(t, tableNameParam, tableNameAuth, tests.WithMCPExec())
121+
}

0 commit comments

Comments
 (0)