diff --git a/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_mcp_integration_test.go b/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_mcp_integration_test.go new file mode 100644 index 000000000000..e5d4c16d48f9 --- /dev/null +++ b/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_mcp_integration_test.go @@ -0,0 +1,261 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlmysql_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/tests" + "google.golang.org/api/sqladmin/v1" +) + +const createInstanceToolTypeMCP = "cloud-sql-mysql-create-instance" + +type createInstanceTransportMCP struct { + transport http.RoundTripper + url *url.URL +} + +func (t *createInstanceTransportMCP) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.String(), "https://sqladmin.googleapis.com") { + req.URL.Scheme = t.url.Scheme + req.URL.Host = t.url.Host + } + return t.transport.RoundTrip(req) +} + +type masterHandlerMCP struct { + t *testing.T +} + +func (h *masterHandlerMCP) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.UserAgent(), "genai-toolbox/") { + h.t.Errorf("User-Agent header not found") + } + + var body sqladmin.DatabaseInstance + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + h.t.Fatalf("failed to decode request body: %v", err) + } + + instanceName := body.Name + if instanceName == "" { + http.Error(w, "missing instance name", http.StatusBadRequest) + return + } + + var expectedBody sqladmin.DatabaseInstance + var response any + var statusCode int + + switch instanceName { + case "instance1": + expectedBody = sqladmin.DatabaseInstance{ + Project: "p1", + Name: "instance1", + DatabaseVersion: "MYSQL_8_0", + RootPassword: "password123", + Settings: &sqladmin.Settings{ + AvailabilityType: "REGIONAL", + Edition: "ENTERPRISE_PLUS", + Tier: "db-perf-optimized-N-8", + DataDiskSizeGb: 250, + DataDiskType: "PD_SSD", + }, + } + response = map[string]any{"name": "op1", "status": "PENDING"} + statusCode = http.StatusOK + case "instance2": + expectedBody = sqladmin.DatabaseInstance{ + Project: "p2", + Name: "instance2", + DatabaseVersion: "MYSQL_8_4", + RootPassword: "password456", + Settings: &sqladmin.Settings{ + AvailabilityType: "ZONAL", + Edition: "ENTERPRISE_PLUS", + Tier: "db-perf-optimized-N-2", + DataDiskSizeGb: 100, + DataDiskType: "PD_SSD", + }, + } + response = map[string]any{"name": "op2", "status": "RUNNING"} + statusCode = http.StatusOK + default: + http.Error(w, fmt.Sprintf("unhandled instance name: %s", instanceName), http.StatusInternalServerError) + return + } + + if expectedBody.Project != body.Project { + h.t.Errorf("unexpected project: got %q, want %q", body.Project, expectedBody.Project) + } + if expectedBody.Name != body.Name { + h.t.Errorf("unexpected name: got %q, want %q", body.Name, expectedBody.Name) + } + if expectedBody.DatabaseVersion != body.DatabaseVersion { + h.t.Errorf("unexpected databaseVersion: got %q, want %q", body.DatabaseVersion, expectedBody.DatabaseVersion) + } + if expectedBody.RootPassword != body.RootPassword { + h.t.Errorf("unexpected rootPassword: got %q, want %q", body.RootPassword, expectedBody.RootPassword) + } + if diff := cmp.Diff(expectedBody.Settings, body.Settings); diff != "" { + h.t.Errorf("unexpected request body settings (-want +got):\n%s", diff) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func TestCreateInstanceToolEndpointsMCP(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + handler := &masterHandlerMCP{t: t} + server := httptest.NewServer(handler) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + originalTransport := http.DefaultClient.Transport + if originalTransport == nil { + originalTransport = http.DefaultTransport + } + http.DefaultClient.Transport = &createInstanceTransportMCP{ + transport: originalTransport, + url: serverURL, + } + t.Cleanup(func() { + http.DefaultClient.Transport = originalTransport + }) + + toolsFile := getCreateInstanceToolsConfigMCP() + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %v", err) + } + defer cleanup() + + waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second) + defer cancelWait() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %v", err) + } + + tcs := []struct { + name string + toolName string + body string + want string + expectError bool + }{ + { + name: "verify successful instance creation with production preset", + toolName: "create-instance-prod", + body: `{"project": "p1", "name": "instance1", "databaseVersion": "MYSQL_8_0", "rootPassword": "password123", "editionPreset": "Production"}`, + want: `{"name":"op1","status":"PENDING"}`, + expectError: false, + }, + { + name: "verify successful instance creation with development preset", + toolName: "create-instance-dev", + body: `{"project": "p2", "name": "instance2", "rootPassword": "password456", "editionPreset": "Development"}`, + want: `{"name":"op2","status":"RUNNING"}`, + expectError: false, + }, + { + name: "verify missing required parameter returns schema error", + toolName: "create-instance-prod", + body: `{"name": "instance1"}`, + want: `parameter "project" is required`, + expectError: true, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + var args map[string]any + if err := json.Unmarshal([]byte(tc.body), &args); err != nil { + t.Fatalf("failed to unmarshal body: %v", err) + } + + statusCode, mcpResp, err := tests.InvokeMCPTool(t, tc.toolName, args, nil) + if err != nil { + t.Fatalf("native error executing %s: %v", tc.toolName, err) + } + + if statusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d", statusCode) + } + + if tc.expectError { + tests.AssertMCPError(t, mcpResp, tc.want) + } else { + if mcpResp.Result.IsError { + t.Fatalf("expected success, got error result: %v", mcpResp.Result) + } + gotStr := mcpResp.Result.Content[0].Text + var got, want map[string]any + if err := json.Unmarshal([]byte(gotStr), &got); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &want); err != nil { + t.Fatalf("failed to unmarshal want: %v", err) + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("unexpected result (-want +got):\n%s", diff) + } + } + }) + } +} + +func getCreateInstanceToolsConfigMCP() map[string]any { + return map[string]any{ + "sources": map[string]any{ + "my-cloud-sql-source": map[string]any{ + "type": "cloud-sql-admin", + }, + }, + "tools": map[string]any{ + "create-instance-prod": map[string]any{ + "type": createInstanceToolTypeMCP, + "source": "my-cloud-sql-source", + }, + "create-instance-dev": map[string]any{ + "type": createInstanceToolTypeMCP, + "source": "my-cloud-sql-source", + }, + }, + } +} diff --git a/tests/cloudsqlmysql/cloud_sql_mysql_mcp_test.go b/tests/cloudsqlmysql/cloud_sql_mysql_mcp_test.go new file mode 100644 index 000000000000..49d71bd3a08a --- /dev/null +++ b/tests/cloudsqlmysql/cloud_sql_mysql_mcp_test.go @@ -0,0 +1,211 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlmysql + +import ( + "context" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/tests" +) + +func TestCloudSQLMySQLMCPListTools(t *testing.T) { + sourceConfig := getCloudSQLMySQLVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + pool, err := initCloudSQLMySQLConnectionPool(CloudSQLMySQLProject, CloudSQLMySQLRegion, CloudSQLMySQLInstance, "public", CloudSQLMySQLUser, CloudSQLMySQLPass, CloudSQLMySQLDatabase) + if err != nil { + t.Fatalf("unable to create Cloud SQL connection pool: %s", err) + } + + // cleanup test environment + tests.CleanupMySQLTables(t, ctx, pool) + + // create table name with UUID + tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + // set up data for param tool + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetMySQLAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMySQLToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile) + tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMySQLToolType, tmplSelectCombined, tmplSelectFilterCombined, "") + toolsFile = tests.AddMySQLPrebuiltToolConfig(t, toolsFile) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Second) + defer waitCancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Expected Manifest + expectedTools := tests.GetBaseMCPExpectedTools() + expectedTools = append(expectedTools, tests.GetExecuteSQLMCPExpectedTools()...) + expectedTools = append(expectedTools, tests.GetTemplateParamMCPExpectedTools()...) + expectedTools = append(expectedTools, []tests.MCPToolManifest{ + { + Name: "list_tables", + Description: "Lists tables in the database.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "output_format": map[string]any{"default": "detailed", "description": "Optional: Use 'simple' for names only or 'detailed' for full info.", "type": "string"}, + "table_names": map[string]any{"default": "", "description": "Optional: A comma-separated list of table names. If empty, details for all tables will be listed.", "type": "string"}, + }, + "required": []any{}, + }, + }, + { + Name: "list_active_queries", + Description: "Lists active queries in the database.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "limit": map[string]any{"default": float64(100), "description": "Optional: The maximum number of rows to return.", "type": "integer"}, + "min_duration_secs": map[string]any{"default": float64(0), "description": "Optional: Only show queries running for at least this long in seconds", "type": "integer"}, + }, + "required": []any{}, + }, + }, + { + Name: "list_tables_missing_unique_indexes", + Description: "Lists tables that do not have primary or unique indexes in the database.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "limit": map[string]any{"default": float64(50), "description": "(Optional) Max rows to return, default is 50", "type": "integer"}, + "table_schema": map[string]any{"default": "", "description": "(Optional) The database where the check is to be performed. Check all tables visible to the current user if not specified", "type": "string"}, + }, + "required": []any{}, + }, + }, + { + Name: "list_table_fragmentation", + Description: "Lists table fragmentation in the database.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "data_free_threshold_bytes": map[string]any{"default": float64(1), "description": "(Optional) Only show tables with at least this much free space in bytes. Default is 1", "type": "integer"}, + "limit": map[string]any{"default": float64(10), "description": "(Optional) Max rows to return, default is 10", "type": "integer"}, + "table_name": map[string]any{"default": "", "description": "(Optional) Name of the table to be checked. Check all tables visible to the current user if not specified.", "type": "string"}, + "table_schema": map[string]any{"default": "", "description": "(Optional) The database where fragmentation check is to be executed. Check all tables visible to the current user if not specified", "type": "string"}, + }, + "required": []any{}, + }, + }, + { + Name: "get_query_plan", + Description: "Gets the query plan for a SQL statement.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "sql_statement": map[string]any{"type": "string", "description": "The sql statement to explain."}, + }, + "required": []any{"sql_statement"}, + }, + }, + }...) + + t.Run("verify tools/list registry returns complete manifest", func(t *testing.T) { + tests.RunMCPToolsListMethod(t, expectedTools) + }) +} + +func TestCloudSQLMySQLMCPCallTool(t *testing.T) { + sourceConfig := getCloudSQLMySQLVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + pool, err := initCloudSQLMySQLConnectionPool(CloudSQLMySQLProject, CloudSQLMySQLRegion, CloudSQLMySQLInstance, "public", CloudSQLMySQLUser, CloudSQLMySQLPass, CloudSQLMySQLDatabase) + if err != nil { + t.Fatalf("unable to create Cloud SQL connection pool: %s", err) + } + + // cleanup test environment + tests.CleanupMySQLTables(t, ctx, pool) + + // create table name with UUID + tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + // set up data for param tool + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetMySQLAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMySQLToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile) + tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMySQLToolType, tmplSelectCombined, tmplSelectFilterCombined, "") + toolsFile = tests.AddMySQLPrebuiltToolConfig(t, toolsFile) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Second) + defer waitCancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want := tests.GetMySQLWants() + + tests.RunToolInvokeTest(t, select1Want, tests.DisableArrayTest(), tests.WithMCP(), tests.WithNullWant("[]")) + tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) + tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want, tests.WithMCPSql(), tests.WithExecuteCreateWant("[]"), tests.WithExecuteDropWant("[]"), tests.WithExecuteSelectEmptyWant("[]")) + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithMCPTemplate()) + + // Run specific MySQL tool tests over MCP + const expectedOwner = "'toolbox-identity'@'%'" + tests.RunMySQLListTablesTest(t, CloudSQLMySQLDatabase, tableNameParam, tableNameAuth, expectedOwner, tests.WithMCPExec()) + tests.RunMySQLListActiveQueriesTest(t, ctx, pool, tests.WithMCPExec()) + tests.RunMySQLGetQueryPlanTest(t, ctx, pool, CloudSQLMySQLDatabase, tableNameParam, tests.WithMCPExec()) +} diff --git a/tests/mcp_tool.go b/tests/mcp_tool.go index 6454db8d62ae..987209e15bed 100644 --- a/tests/mcp_tool.go +++ b/tests/mcp_tool.go @@ -136,6 +136,35 @@ func InvokeMCPTool(t *testing.T, toolName string, arguments map[string]any, requ return resp.StatusCode, &mcpResp, nil } +// getMCPResultText safely extracts the text from content blocks, unmarshaling them if they are valid JSON. +// +// TODO: For tests that need to strictly validate the exact schema or structure of the output, +// consider avoiding this helper and instead unmarshal the raw JSON directly into expected Go structs for comparison. +func getMCPResultText(t *testing.T, resp *MCPCallToolResponse) []any { + if len(resp.Result.Content) == 0 { + return []any{} + } + + var res []any + for _, content := range resp.Result.Content { + var item any + if err := json.Unmarshal([]byte(content.Text), &item); err != nil { + res = append(res, content.Text) + } else { + if slice, ok := item.([]any); ok { + res = append(res, slice...) + } else { + res = append(res, item) + } + } + + } + if res == nil { + return []any{} + } + return res +} + // GetMCPToolsList is a JSON-RPC harness that fetches the tools/list registry. func GetMCPToolsList(t *testing.T, requestHeader map[string]string) (int, []any, error) { headers := NewMCPRequestHeader(t, requestHeader) @@ -251,12 +280,11 @@ func RunMCPCustomToolCallMethod(t *testing.T, toolName string, arguments map[str if mcpResp.Result.IsError { t.Fatalf("%s returned error result: %v", toolName, mcpResp.Result) } - if len(mcpResp.Result.Content) == 0 { - t.Fatalf("%s returned empty content field", toolName) - } - got := mcpResp.Result.Content[0].Text - if !strings.Contains(got, want) { - t.Fatalf(`expected %q to contain %q`, got, want) + got := getMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(got) + gotStr := string(gotBytes) + if !strings.Contains(gotStr, want) { + t.Fatalf(`expected %q to contain %q`, gotStr, want) } } @@ -268,7 +296,7 @@ func RunMCPToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTes myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", myToolById4Want: "[{\"id\":4,\"name\":null}]", myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", - nullWant: "null", + nullWant: "[null]", supportOptionalNullParam: true, supportArrayParam: true, supportClientAuth: false, @@ -352,13 +380,194 @@ func RunMCPToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTes if mcpResp.Result.IsError { t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result) } - if len(mcpResp.Result.Content) == 0 { - t.Fatalf("%s returned empty content field", tc.toolName) - } - got := mcpResp.Result.Content[0].Text - if !strings.Contains(got, tc.wantResult) { - t.Fatalf(`expected %q to contain %q`, got, tc.wantResult) + got := getMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(got) + gotStr := string(gotBytes) + if !strings.Contains(gotStr, tc.wantResult) { + t.Fatalf(`expected %q to contain %q`, gotStr, tc.wantResult) } }) } } + +// GetBaseMCPExpectedTools returns the MCP manifests for the base tools loaded by GetToolsConfig. +func GetBaseMCPExpectedTools() []MCPToolManifest { + return []MCPToolManifest{ + { + Name: "my-simple-tool", + Description: "Simple tool to test end to end functionality.", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{}, "required": []any{}}, + }, + { + Name: "my-tool", + Description: "Tool to test invocation with params.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer", "description": "user ID"}, + "name": map[string]any{"type": "string", "description": "user name"}, + }, + "required": []any{"id", "name"}, + }, + }, + { + Name: "my-tool-by-id", + Description: "Tool to test invocation with params.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"id": map[string]any{"type": "integer", "description": "user ID"}}, + "required": []any{"id"}, + }, + }, + { + Name: "my-tool-by-name", + Description: "Tool to test invocation with params.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"name": map[string]any{"type": "string", "description": "user name"}}, + "required": []any{}, + }, + }, + { + Name: "my-array-tool", + Description: "Tool to test invocation with array params.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "idArray": map[string]any{"type": "array", "description": "ID array", "items": map[string]any{"type": "integer", "description": "ID"}}, + "nameArray": map[string]any{"type": "array", "description": "user name array", "items": map[string]any{"type": "string", "description": "user name"}}, + }, + "required": []any{"idArray", "nameArray"}, + }, + }, + { + Name: "my-auth-tool", + Description: "Tool to test authenticated parameters.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"email": map[string]any{"type": "string", "description": "user email"}}, + "required": []any{"email"}, + }, + }, + { + Name: "my-auth-required-tool", + Description: "Tool to test auth required invocation.", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{}, "required": []any{}}, + }, + { + Name: "my-fail-tool", + Description: "Tool to test statement with incorrect syntax.", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{}, "required": []any{}}, + }, + } +} + +// GetExecuteSQLMCPExpectedTools returns the MCP manifests for the tools loaded by AddExecuteSqlConfig. +func GetExecuteSQLMCPExpectedTools() []MCPToolManifest { + return []MCPToolManifest{ + { + Name: "my-exec-sql-tool", + Description: "Tool to execute sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"sql": map[string]any{"type": "string", "description": "The sql to execute."}}, + "required": []any{"sql"}, + }, + }, + { + Name: "my-auth-exec-sql-tool", + Description: "Tool to execute sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"sql": map[string]any{"type": "string", "description": "The sql to execute."}}, + "required": []any{"sql"}, + }, + }, + } +} + +// GetTemplateParamMCPExpectedTools returns the MCP manifests for the tools loaded by AddTemplateParamConfig. +func GetTemplateParamMCPExpectedTools() []MCPToolManifest { + return []MCPToolManifest{ + { + Name: "create-table-templateParams-tool", + Description: "Create table tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tableName": map[string]any{"type": "string", "description": "some description"}, + "columns": map[string]any{"type": "array", "description": "The columns to create", "items": map[string]any{"type": "string", "description": "A column name that will be created"}}, + }, + "required": []any{"tableName", "columns"}, + }, + }, + { + Name: "insert-table-templateParams-tool", + Description: "Insert tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tableName": map[string]any{"type": "string", "description": "some description"}, + "columns": map[string]any{"type": "array", "description": "The columns to insert into", "items": map[string]any{"type": "string", "description": "A column name that will be returned from the query."}}, + "values": map[string]any{"type": "string", "description": "The values to insert as a comma separated string"}, + }, + "required": []any{"tableName", "columns", "values"}, + }, + }, + { + Name: "select-templateParams-tool", + Description: "Create table tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"tableName": map[string]any{"type": "string", "description": "some description"}}, + "required": []any{"tableName"}, + }, + }, + { + Name: "select-templateParams-combined-tool", + Description: "Create table tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer", "description": "the id of the user"}, + "tableName": map[string]any{"type": "string", "description": "some description"}, + }, + "required": []any{"id", "tableName"}, + }, + }, + { + Name: "select-fields-templateParams-tool", + Description: "Create table tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tableName": map[string]any{"type": "string", "description": "some description"}, + "fields": map[string]any{"type": "array", "description": "The fields to select from", "items": map[string]any{"type": "string", "description": "A field that will be returned from the query."}}, + }, + "required": []any{"tableName", "fields"}, + }, + }, + { + Name: "select-filter-templateParams-combined-tool", + Description: "Create table tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string", "description": "the name of the user"}, + "tableName": map[string]any{"type": "string", "description": "some description"}, + "columnFilter": map[string]any{"type": "string", "description": "some description"}, + }, + "required": []any{"name", "tableName", "columnFilter"}, + }, + }, + { + Name: "drop-table-templateParams-tool", + Description: "Drop table tool with template parameters", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"tableName": map[string]any{"type": "string", "description": "some description"}}, + "required": []any{"tableName"}, + }, + }, + } +} diff --git a/tests/option.go b/tests/option.go index c7b892738074..228dc316319b 100644 --- a/tests/option.go +++ b/tests/option.go @@ -28,10 +28,18 @@ type InvokeTestConfig struct { supportArrayParam bool supportClientAuth bool supportSelect1Auth bool + IsMCP bool } type InvokeTestOption func(*InvokeTestConfig) +// WithMCP enables the MCP routing for standard Tool Invoke tests +func WithMCP() InvokeTestOption { + return func(c *InvokeTestConfig) { + c.IsMCP = true + } +} + // WithMyAuthToolWant represents the response value for my-auth-tool. // e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyAuthToolWant("custom")) func WithMyAuthToolWant(s string) InvokeTestOption { @@ -164,6 +172,7 @@ type ExecuteSqlTestConfig struct { createWant string dropWant string selectEmptyWant string + IsMCP bool } type ExecuteSqlOption func(*ExecuteSqlTestConfig) @@ -176,6 +185,13 @@ func WithSelect1Statement(s string) ExecuteSqlOption { } } +// WithMCPSql enables the MCP routing for ExecuteSql tests +func WithMCPSql() ExecuteSqlOption { + return func(c *ExecuteSqlTestConfig) { + c.IsMCP = true + } +} + // WithExecuteCreateWant represents the expected response for a CREATE TABLE statement. func WithExecuteCreateWant(s string) ExecuteSqlOption { return func(c *ExecuteSqlTestConfig) { @@ -215,10 +231,19 @@ type TemplateParameterTestConfig struct { supportDdl bool supportInsert bool supportSelectFields bool + IsMCP bool } type TemplateParamOption func(*TemplateParameterTestConfig) +// WithMCPTemplate flags the test harness to route the request through the local MCP server. +// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableName, tests.WithMCPTemplate()) +func WithMCPTemplate() TemplateParamOption { + return func(c *TemplateParameterTestConfig) { + c.IsMCP = true + } +} + // WithDdlWant represents the response value of ddl statements. // e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithDdlWant("custom")) func WithDdlWant(s string) TemplateParamOption { @@ -314,3 +339,21 @@ func DisableSelectFilterTest() TemplateParamOption { c.supportSelectFields = false } } + +/* Configurations for RunMySQL...Test() */ + +// ToolExecConfig holds the configuration for executing prebuilt tool tests. +type ToolExecConfig struct { + IsMCP bool +} + +// ToolExecOption is a functional option used to configure a ToolExecConfig. +type ToolExecOption func(*ToolExecConfig) + +// WithMCPExec flags the test harness to route the request through the local MCP server +// instead of the Native Toolbox REST API. +func WithMCPExec() ToolExecOption { + return func(c *ToolExecConfig) { + c.IsMCP = true + } +} diff --git a/tests/tool.go b/tests/tool.go index 3fc087c3e83b..acd5fb7b1041 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -264,154 +264,153 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp // Test tool invoke endpoint invokeTcs := []struct { name string - api string + toolName string enabled bool requestHeader map[string]string - requestBody io.Reader + args map[string]any wantStatusCode int wantBody string }{ { name: "invoke my-simple-tool", - api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke", + toolName: "my-simple-tool", enabled: configs.supportSelect1Want, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: select1Want, wantStatusCode: http.StatusOK, }, { name: "invoke my-tool", - api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", + toolName: "my-tool", enabled: true, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)), + args: map[string]any{"id": 3, "name": "Alice"}, wantBody: configs.myToolId3NameAliceWant, wantStatusCode: http.StatusOK, }, { name: "invoke my-tool-by-id with nil response", - api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke", + toolName: "my-tool-by-id", enabled: true, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)), + args: map[string]any{"id": 4}, wantBody: configs.myToolById4Want, wantStatusCode: http.StatusOK, }, { name: "invoke my-tool-by-name with nil response", - api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke", + toolName: "my-tool-by-name", enabled: configs.supportOptionalNullParam, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: configs.nullWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-tool without parameters", - api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", + toolName: "my-tool", enabled: true, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: `{"error":"parameter \"id\" is required"}`, wantStatusCode: http.StatusOK, }, { name: "Invoke my-tool with insufficient parameters", - api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", + toolName: "my-tool", enabled: true, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)), + args: map[string]any{"id": 1}, wantBody: `{"error":"parameter \"name\" is required"}`, wantStatusCode: http.StatusOK, }, { name: "invoke my-array-tool", - api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke", + toolName: "my-array-tool", enabled: configs.supportArrayParam, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)), + args: map[string]any{"idArray": []any{1, 2, 3}, "nameArray": []any{"Alice", "Sid", "RandomName"}, "cmdArray": []any{"HGETALL", "row3"}}, wantBody: configs.myArrayToolWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-tool with auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + toolName: "my-auth-tool", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: configs.myAuthToolWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-tool with invalid auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + toolName: "my-auth-tool", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-auth-tool without auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + toolName: "my-auth-tool", enabled: true, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-auth-required-tool with auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", + toolName: "my-auth-required-tool", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: select1Want, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-required-tool with invalid auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", + toolName: "my-auth-required-tool", enabled: true, requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-auth-required-tool without auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + toolName: "my-auth-tool", enabled: true, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-client-auth-tool with auth token", - api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke", + toolName: "my-client-auth-tool", enabled: configs.supportClientAuth, requestHeader: map[string]string{"Authorization": accessToken}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantBody: select1Want, wantStatusCode: http.StatusOK, }, { name: "Invoke my-client-auth-tool without auth token", - api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke", + toolName: "my-client-auth-tool", enabled: configs.supportClientAuth, requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantStatusCode: http.StatusUnauthorized, }, { - name: "Invoke my-client-auth-tool with invalid auth token", - api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke", + toolName: "my-client-auth-tool", enabled: configs.supportClientAuth, requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, wantStatusCode: http.StatusUnauthorized, }, } @@ -420,33 +419,125 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp if !tc.enabled { return } - // Send Tool invocation request - resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) - // Check status code - if resp.StatusCode != tc.wantStatusCode { - t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) - } + if configs.IsMCP { + // Invoke the tool via MCP protocol + mcpStatusCode, mcpResp, err := InvokeMCPTool(t, tc.toolName, tc.args, tc.requestHeader) + if err != nil { + t.Fatalf("native error executing %s: %s", tc.toolName, err) + } - // skip response body check - if tc.wantBody == "" { - return - } + // Check status code + wantStatus := tc.wantStatusCode + // MCP might return 200 OK for some error cases that REST returns 401 + if wantStatus == http.StatusUnauthorized && mcpStatusCode == http.StatusOK { + wantStatus = http.StatusOK + } + if mcpStatusCode != wantStatus { + t.Errorf("StatusCode mismatch: got %d, want %d", mcpStatusCode, wantStatus) + } - // Check response body - var body map[string]interface{} - err = json.Unmarshal(respBody, &body) - if err != nil { - t.Fatalf("error parsing response body: %s", err) - } + if tc.wantBody == "" { + return + } - got, ok := body["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } + // Extract error text if any + var errText string + if mcpResp.Error != nil { + errText = mcpResp.Error.Message + } else if mcpResp.Result.IsError { + for _, content := range mcpResp.Result.Content { + if content.Type == "text" { + errText += content.Text + } + } + } + + if errText != "" { + // We got an error! Check if we expected it. + var wantErrStr string + var wantJSON any + errWant := json.Unmarshal([]byte(tc.wantBody), &wantJSON) + if errWant == nil { + wantMap, okWant := wantJSON.(map[string]any) + if okWant { + wantErrStr, _ = wantMap["error"].(string) + } + } + if wantErrStr == "" { + wantErrStr = tc.wantBody + } + + if !strings.Contains(errText, wantErrStr) { + t.Fatalf("expected error text containing %q, got %q", wantErrStr, errText) + } + return // Success for this error test case + } + + // If no error found, but it's marked as error result, it's unexpected error without text + if mcpResp.Result.IsError { + t.Fatalf("%s returned error result without text: %v", tc.toolName, mcpResp.Result) + } + + gotObj := getMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(gotObj) + gotStr := string(gotBytes) + + if strings.HasPrefix(strings.TrimSpace(tc.wantBody), "[") || strings.HasPrefix(strings.TrimSpace(tc.wantBody), "{") { + // It looks like JSON, let's do JSON comparison + var gotJSON, wantJSON any + _ = json.Unmarshal([]byte(gotStr), &gotJSON) + _ = json.Unmarshal([]byte(tc.wantBody), &wantJSON) - if got != tc.wantBody { - t.Fatalf("unexpected value: got %q, want %q", got, tc.wantBody) + if diff := cmp.Diff(wantJSON, gotJSON); diff != "" { + t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, gotStr, tc.wantBody) + } + } else { + // Plain string, use strings.Contains as suggested by user + if !strings.Contains(gotStr, tc.wantBody) { + t.Fatalf(`expected %q to contain %q`, gotStr, tc.wantBody) + } + } + + } else { + // Legacy REST path + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + reqBytes, _ := json.Marshal(tc.args) + resp, respBody := RunRequest(t, http.MethodPost, api, bytes.NewBuffer(reqBytes), tc.requestHeader) + + if resp.StatusCode != tc.wantStatusCode { + t.Errorf("StatusCode mismatch: got %d, want %d", resp.StatusCode, tc.wantStatusCode) + } + + if tc.wantBody == "" { + return + } + + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body: %s", err) + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + // Legacy REST assertion + if got != tc.wantBody { + var gotJSON, wantJSON any + errGot := json.Unmarshal([]byte(got), &gotJSON) + errWant := json.Unmarshal([]byte(tc.wantBody), &wantJSON) + + if errGot == nil && errWant == nil { + if diff := cmp.Diff(wantJSON, gotJSON); diff != "" { + t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, got, tc.wantBody) + } + } else { + t.Fatalf("unexpected value: got %q, want %q", got, tc.wantBody) + } + } } }) } @@ -479,92 +570,98 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options selectOnlyNamesWant := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]" + parseJSONArray := func(s string) []any { + var res []any + _ = json.Unmarshal([]byte(s), &res) + return res + } + // Test tool invoke endpoint invokeTcs := []struct { name string enabled bool ddl bool insert bool - api string + toolName string requestHeader map[string]string - requestBody io.Reader + args map[string]any want string isErr bool }{ { name: "invoke create-table-templateParams-tool", ddl: true, - api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke", + toolName: "create-table-templateParams-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":%s}`, tableName, configs.createColArray))), + args: map[string]any{"tableName": tableName, "columns": parseJSONArray(configs.createColArray)}, want: configs.ddlWant, isErr: false, }, { name: "invoke insert-table-templateParams-tool", insert: true, - api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke", + toolName: "insert-table-templateParams-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"1, 'Alex', 21"}`, tableName))), + args: map[string]any{"tableName": tableName, "columns": []any{"id", "name", "age"}, "values": "1, 'Alex', 21"}, want: configs.insert1Want, isErr: false, }, { name: "invoke insert-table-templateParams-tool", insert: true, - api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke", + toolName: "insert-table-templateParams-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"2, 'Alice', 100"}`, tableName))), + args: map[string]any{"tableName": tableName, "columns": []any{"id", "name", "age"}, "values": "2, 'Alice', 100"}, want: configs.insert1Want, isErr: false, }, { name: "invoke select-templateParams-tool", - api: "http://127.0.0.1:5000/api/tool/select-templateParams-tool/invoke", + toolName: "select-templateParams-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))), + args: map[string]any{"tableName": tableName}, want: configs.selectAllWant, isErr: false, }, { name: "invoke select-templateParams-combined-tool", - api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke", + toolName: "select-templateParams-combined-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 1, "tableName": "%s"}`, tableName))), + args: map[string]any{"id": 1, "tableName": tableName}, want: configs.selectId1Want, isErr: false, }, { name: "invoke select-templateParams-combined-tool with no results", - api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke", + toolName: "select-templateParams-combined-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 999, "tableName": "%s"}`, tableName))), + args: map[string]any{"id": 999, "tableName": tableName}, want: configs.selectEmptyWant, isErr: false, }, { name: "invoke select-fields-templateParams-tool", enabled: configs.supportSelectFields, - api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke", + toolName: "select-fields-templateParams-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))), + args: map[string]any{"tableName": tableName, "fields": parseJSONArray(configs.nameFieldArray)}, want: selectOnlyNamesWant, isErr: false, }, { name: "invoke select-filter-templateParams-combined-tool", - api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke", + toolName: "select-filter-templateParams-combined-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))), + args: map[string]any{"name": "Alex", "tableName": tableName, "columnFilter": configs.nameColFilter}, want: configs.selectNameWant, isErr: false, }, { name: "invoke drop-table-templateParams-tool", ddl: true, - api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke", + toolName: "drop-table-templateParams-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))), + args: map[string]any{"tableName": tableName}, want: configs.ddlWant, isErr: false, }, @@ -578,30 +675,87 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl) // if test case is insert statement and source support insert test cases insertAllow := !tc.insert || (tc.insert && configs.supportInsert) + if ddlAllow && insertAllow { - // Send Tool invocation request - resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) - if resp.StatusCode != http.StatusOK { - if tc.isErr { - return + if configs.IsMCP { + toolName := tc.toolName + args := tc.args + if args == nil { + args = make(map[string]any) } - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) - } - // Check response body - var body map[string]interface{} - err := json.Unmarshal(respBody, &body) - if err != nil { - t.Fatalf("error parsing response body") - } + statusCode, mcpResp, err := InvokeMCPTool(t, toolName, args, tc.requestHeader) + if statusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d, error: %v", statusCode, err) + } - got, ok := body["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } + if mcpResp.Result.IsError { + if tc.isErr { + return + } + t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result) + } - if got != tc.want { - t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + gotObj := getMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(gotObj) + gotStr := string(gotBytes) + + if strings.HasPrefix(strings.TrimSpace(tc.want), "[") || strings.HasPrefix(strings.TrimSpace(tc.want), "{") { + // It looks like JSON, let's do JSON comparison + var gotJSON, wantJSON any + _ = json.Unmarshal([]byte(gotStr), &gotJSON) + _ = json.Unmarshal([]byte(tc.want), &wantJSON) + + if diff := cmp.Diff(wantJSON, gotJSON); diff != "" { + t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, gotStr, tc.want) + } + } else { + // Plain string, use strings.Contains as suggested by user + if !strings.Contains(gotStr, tc.want) { + t.Fatalf(`expected %q to contain %q`, gotStr, tc.want) + } + } + + } else { + // Legacy REST path + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + reqBytes, _ := json.Marshal(tc.args) + resp, respBody := RunRequest(t, http.MethodPost, api, bytes.NewBuffer(reqBytes), tc.requestHeader) + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) + } + + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + // Legacy REST assertion + if got != tc.want { + var gotJSON, wantJSON any + errGot := json.Unmarshal([]byte(got), &gotJSON) + errWant := json.Unmarshal([]byte(tc.want), &wantJSON) + + if errGot == nil && errWant == nil { + if diff := cmp.Diff(wantJSON, gotJSON); diff != "" { + t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, got, tc.want) + } + } else { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + } } } }) @@ -632,117 +786,176 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want // Test tool invoke endpoint invokeTcs := []struct { name string - api string + toolName string requestHeader map[string]string - requestBody io.Reader + args map[string]any want string isErr bool isAgentErr bool }{ { name: "invoke my-exec-sql-tool", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), + args: map[string]any{"sql": strings.Trim(configs.select1Statement, "\"")}, want: select1Want, isErr: false, }, { name: "invoke my-exec-sql-tool create table", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, createTableStatement))), + args: map[string]any{"sql": strings.Trim(createTableStatement, "\"")}, want: configs.createWant, isErr: false, }, { name: "invoke my-exec-sql-tool select table", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM t"}`)), + args: map[string]any{"sql": "SELECT * FROM t"}, want: configs.selectEmptyWant, isErr: false, }, { name: "invoke my-exec-sql-tool drop table", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"sql":"DROP TABLE t"}`)), + args: map[string]any{"sql": "DROP TABLE t"}, want: configs.dropWant, isErr: false, }, { name: "invoke my-exec-sql-tool without body", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{}`)), + args: map[string]any{}, isAgentErr: true, }, { name: "Invoke my-auth-exec-sql-tool with auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", + toolName: "my-auth-exec-sql-tool", requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), + args: map[string]any{"sql": strings.Trim(configs.select1Statement, "\"")}, isErr: false, want: select1Want, }, { name: "Invoke my-auth-exec-sql-tool with invalid auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", + toolName: "my-auth-exec-sql-tool", requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), + args: map[string]any{"sql": strings.Trim(configs.select1Statement, "\"")}, isErr: true, }, { name: "Invoke my-auth-exec-sql-tool without auth token", - api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", + toolName: "my-auth-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), + args: map[string]any{"sql": strings.Trim(configs.select1Statement, "\"")}, isErr: true, }, { name: "invoke my-exec-sql-tool with invalid SELECT SQL", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM non_existent_table"}`)), + args: map[string]any{"sql": "SELECT * FROM non_existent_table"}, isAgentErr: true, }, { name: "invoke my-exec-sql-tool with invalid ALTER SQL", - api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", + toolName: "my-exec-sql-tool", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"sql":"ALTER TALE t ALTER COLUMN id DROP NOT NULL"}`)), + args: map[string]any{"sql": "ALTER TALE t ALTER COLUMN id DROP NOT NULL"}, isAgentErr: true, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - // Send Tool invocation request - resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) - if resp.StatusCode != http.StatusOK { - if tc.isErr { + if configs.IsMCP { + toolName := tc.toolName + args := tc.args + if args == nil { + args = make(map[string]any) + } + + statusCode, mcpResp, err := InvokeMCPTool(t, toolName, args, tc.requestHeader) + if statusCode != http.StatusOK { + if tc.isErr || tc.isAgentErr { + return + } + t.Fatalf("response status code is not 200, got %d: %v", statusCode, err) + } + if tc.isAgentErr { return } - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) - } - if tc.isAgentErr { - return - } - // Check response body - var body map[string]interface{} - err = json.Unmarshal(respBody, &body) - if err != nil { - t.Fatalf("error parsing response body") - } + if mcpResp.Result.IsError { + if tc.isAgentErr || tc.isErr { + return + } + t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result) + } - got, ok := body["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } + gotObj := getMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(gotObj) + gotStr := string(gotBytes) - if got != tc.want { - t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + if strings.HasPrefix(strings.TrimSpace(tc.want), "[") || strings.HasPrefix(strings.TrimSpace(tc.want), "{") { + // It looks like JSON, let's do JSON comparison + var gotJSON, wantJSON any + _ = json.Unmarshal([]byte(gotStr), &gotJSON) + _ = json.Unmarshal([]byte(tc.want), &wantJSON) + + if diff := cmp.Diff(wantJSON, gotJSON); diff != "" { + t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, gotStr, tc.want) + } + } else { + // Plain string, use strings.Contains as suggested by user + if !strings.Contains(gotStr, tc.want) { + t.Fatalf(`expected %q to contain %q`, gotStr, tc.want) + } + } + + } else { + // Legacy REST path + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + reqBytes, _ := json.Marshal(tc.args) + resp, respBody := RunRequest(t, http.MethodPost, api, bytes.NewBuffer(reqBytes), tc.requestHeader) + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) + } + if tc.isAgentErr { + return + } + + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + // Legacy REST assertion + if got != tc.want { + var gotJSON, wantJSON any + errGot := json.Unmarshal([]byte(got), &gotJSON) + errWant := json.Unmarshal([]byte(tc.want), &wantJSON) + + if errGot == nil && errWant == nil { + if diff := cmp.Diff(wantJSON, gotJSON); diff != "" { + t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, got, tc.want) + } + } else { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + } } }) } @@ -1130,8 +1343,9 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user invokeTcs := []struct { name string - api string - requestBody io.Reader + toolName string + requestHeader map[string]string + args map[string]any wantStatusCode int want string isAllTables bool @@ -1139,73 +1353,75 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user }{ { name: "invoke list_tables all tables detailed output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)), + toolName: "list_tables", + args: map[string]any{"table_names": ""}, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), isAllTables: true, }, { name: "invoke list_tables all tables simple output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)), + toolName: "list_tables", + args: map[string]any{"table_names": "", "output_format": "simple"}, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)), isAllTables: true, }, { name: "invoke list_tables detailed output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))), + toolName: "list_tables", + args: map[string]any{"table_names": tableNameAuth}, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)), }, { name: "invoke list_tables simple output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))), + toolName: "list_tables", + args: map[string]any{"table_names": tableNameAuth, "output_format": "simple"}, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)), }, { name: "invoke list_tables with invalid output format", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)), + toolName: "list_tables", + args: map[string]any{"table_names": "", "output_format": "abcd"}, wantStatusCode: http.StatusOK, isAgentErr: true, }, { name: "invoke list_tables with malformed table_names parameter", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)), + toolName: "list_tables", + args: map[string]any{"table_names": 12345, "output_format": "detailed"}, wantStatusCode: http.StatusOK, isAgentErr: true, }, { name: "invoke list_tables with multiple table names", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))), + toolName: "list_tables", + args: map[string]any{"table_names": fmt.Sprintf("%s,%s", tableNameParam, tableNameAuth)}, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), }, { name: "invoke list_tables with non-existent table", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)), + toolName: "list_tables", + args: map[string]any{"table_names": "non_existent_table"}, wantStatusCode: http.StatusOK, want: `[]`, }, { name: "invoke list_tables with one existing and one non-existent table", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))), + toolName: "list_tables", + args: map[string]any{"table_names": fmt.Sprintf("%s,non_existent_table", tableNameParam)}, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - resp, respBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + reqBytes, _ := json.Marshal(tc.args) + resp, respBytes := RunRequest(t, http.MethodPost, api, bytes.NewBuffer(reqBytes), tc.requestHeader) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBytes)) } @@ -1895,7 +2111,8 @@ func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *p invokeTcs := []struct { name string - requestBody io.Reader + toolName string + args map[string]any clientSleepSecs int waitSecsBeforeCheck int wantStatusCode int @@ -1904,7 +2121,8 @@ func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *p // exclude background monitoring apps such as "wal_uploader" { name: "invoke list_active_queries when the system is idle", - requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`), + toolName: "list_active_queries", + args: map[string]any{"exclude_application_names": "wal_uploader"}, clientSleepSecs: 0, waitSecsBeforeCheck: 0, wantStatusCode: http.StatusOK, @@ -1912,7 +2130,8 @@ func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *p }, { name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold", - requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`), + toolName: "list_active_queries", + args: map[string]any{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}, clientSleepSecs: 1, waitSecsBeforeCheck: 1, wantStatusCode: http.StatusOK, @@ -1920,7 +2139,8 @@ func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *p }, { name: "invoke list_active_queries when 1 ongoing query should show up", - requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`), + toolName: "list_active_queries", + args: map[string]any{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}, clientSleepSecs: 10, waitSecsBeforeCheck: 5, wantStatusCode: http.StatusOK, @@ -1953,8 +2173,9 @@ func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *p time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) } - const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" - resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + reqBytes, _ := json.Marshal(tc.args) + resp, respBody := RunRequest(t, http.MethodPost, api, bytes.NewBuffer(reqBytes), nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } @@ -2732,7 +2953,12 @@ func RunPostgresListRolesTest(t *testing.T, ctx context.Context, pool *pgxpool.P } // RunMySQLListTablesTest run tests against the mysql-list-tables tool -func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth, expectedOwner string) { +func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth, expectedOwner string, opts ...ToolExecOption) { + config := &ToolExecConfig{} + for _, opt := range opts { + opt(config) + } + var ownerWant any if expectedOwner == "" { ownerWant = nil @@ -2845,31 +3071,59 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke" - resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) - if resp.StatusCode != tc.wantStatusCode { - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) - } - if tc.wantStatusCode != http.StatusOK { - return - } + var resultString string - var bodyWrapper struct { - Result json.RawMessage `json:"result"` - } - if err := json.Unmarshal(body, &bodyWrapper); err != nil { - t.Fatalf("error decoding response wrapper: %v", err) - } + if config.IsMCP { + reqBytes, _ := io.ReadAll(tc.requestBody) + var args map[string]any + if len(reqBytes) > 0 { + _ = json.Unmarshal(reqBytes, &args) + } + if args == nil { + args = make(map[string]any) + } - var resultString string - if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { - resultString = string(bodyWrapper.Result) + statusCode, mcpResp, err := InvokeMCPTool(t, "list_tables", args, nil) + if statusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, err: %v", statusCode, tc.wantStatusCode, err) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + gotObj := getMCPResultText(t, mcpResp) + if len(gotObj) == 0 { + resultString = "null" + } else { + gotBytes, _ := json.Marshal(gotObj) + resultString = string(gotBytes) + } + } else { + const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke" + resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(body, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } } var got any if tc.isSimple { var tables []tableInfo - if err := json.Unmarshal([]byte(resultString), &tables); err != nil { + if err := json.Unmarshal([]byte(resultString), &tables); err != nil && resultString != "null" { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } details := []map[string]any{} @@ -2883,7 +3137,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam got = details } else { var tables []tableInfo - if err := json.Unmarshal([]byte(resultString), &tables); err != nil { + if err := json.Unmarshal([]byte(resultString), &tables); err != nil && resultString != "null" { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } details := []objectDetails{} @@ -2923,8 +3177,12 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam } } -// RunMySQLListActiveQueriesTest run tests against the mysql-list-active-queries tests -func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.DB) { +func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.DB, opts ...ToolExecOption) { + config := &ToolExecConfig{} + for _, opt := range opts { + opt(config) + } + type queryListDetails struct { ProcessId any `json:"process_id"` Query string `json:"query"` @@ -3022,30 +3280,58 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql. time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) } - const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" - resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) - if resp.StatusCode != tc.wantStatusCode { - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) - } - if tc.wantStatusCode != http.StatusOK { - return - } + var resultString string - var bodyWrapper struct { - Result json.RawMessage `json:"result"` - } - if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { - t.Fatalf("error decoding response wrapper: %v", err) - } + if config.IsMCP { + reqBytes, _ := io.ReadAll(tc.requestBody) + var args map[string]any + if len(reqBytes) > 0 { + _ = json.Unmarshal(reqBytes, &args) + } + if args == nil { + args = make(map[string]any) + } - var resultString string - if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { - resultString = string(bodyWrapper.Result) + statusCode, mcpResp, err := InvokeMCPTool(t, "list_active_queries", args, nil) + if statusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, err: %v", statusCode, tc.wantStatusCode, err) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + gotObj := getMCPResultText(t, mcpResp) + if len(gotObj) == 0 { + resultString = "null" + } else { + gotBytes, _ := json.Marshal(gotObj) + resultString = string(gotBytes) + } + } else { + const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } } var got any var details []queryListDetails - if err := json.Unmarshal([]byte(resultString), &details); err != nil { + if err := json.Unmarshal([]byte(resultString), &details); err != nil && resultString != "null" { t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) } got = details @@ -3060,13 +3346,17 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql. wg.Wait() } -func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, pool *sql.DB, databaseName string) { +func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, pool *sql.DB, databaseName string, opts ...ToolExecOption) { + config := &ToolExecConfig{} + for _, opt := range opts { + opt(config) + } + type listDetails struct { TableSchema string `json:"table_schema"` TableName string `json:"table_name"` } - // bunch of wanted nonUniqueKeyTableName := "t03_non_unqiue_key_table" noKeyTableName := "t04_no_key_table" nonUniqueKeyTableWant := listDetails{ @@ -3216,13 +3506,11 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p } stmt.WriteString(")") - t.Logf("Creating table: %s", stmt.String()) if _, err := pool.ExecContext(ctx, stmt.String()); err != nil { t.Fatalf("failed executing %s: %v", stmt.String(), err) } return func() { - t.Logf("Dropping table: %s", tableName) if _, err := pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName)); err != nil { t.Errorf("failed to drop table %s: %v", tableName, err) } @@ -3243,30 +3531,58 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p cleanups = append(cleanups, cleanup) } - const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke" - resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) - if resp.StatusCode != tc.wantStatusCode { - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) - } - if tc.wantStatusCode != http.StatusOK { - return - } + var resultString string - var bodyWrapper struct { - Result json.RawMessage `json:"result"` - } - if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { - t.Fatalf("error decoding response wrapper: %v", err) - } + if config.IsMCP { + reqBytes, _ := io.ReadAll(tc.requestBody) + var args map[string]any + if len(reqBytes) > 0 { + _ = json.Unmarshal(reqBytes, &args) + } + if args == nil { + args = make(map[string]any) + } - var resultString string - if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { - resultString = string(bodyWrapper.Result) + statusCode, mcpResp, err := InvokeMCPTool(t, "list_tables_missing_unique_indexes", args, nil) + if statusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, err: %v", statusCode, tc.wantStatusCode, err) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + gotObj := getMCPResultText(t, mcpResp) + if len(gotObj) == 0 { + resultString = "null" + } else { + gotBytes, _ := json.Marshal(gotObj) + resultString = string(gotBytes) + } + } else { + const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke" + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } } var got any var details []listDetails - if err := json.Unmarshal([]byte(resultString), &details); err != nil { + if err := json.Unmarshal([]byte(resultString), &details); err != nil && resultString != "null" { t.Fatalf("failed to unmarshal nested listDetails string: %v", err) } got = details @@ -3280,7 +3596,12 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p } } -func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) { +func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string, opts ...ToolExecOption) { + config := &ToolExecConfig{} + for _, opt := range opts { + opt(config) + } + type tableFragmentationDetails struct { TableSchema string `json:"table_schema"` TableName string `json:"table_name"` @@ -3358,30 +3679,58 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke" - resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) - if resp.StatusCode != tc.wantStatusCode { - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) - } - if tc.wantStatusCode != http.StatusOK { - return - } + var resultString string - var bodyWrapper struct { - Result json.RawMessage `json:"result"` - } - if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { - t.Fatalf("error decoding response wrapper: %v", err) - } + if config.IsMCP { + reqBytes, _ := io.ReadAll(tc.requestBody) + var args map[string]any + if len(reqBytes) > 0 { + _ = json.Unmarshal(reqBytes, &args) + } + if args == nil { + args = make(map[string]any) + } - var resultString string - if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { - resultString = string(bodyWrapper.Result) + statusCode, mcpResp, err := InvokeMCPTool(t, "list_table_fragmentation", args, nil) + if statusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, err: %v", statusCode, tc.wantStatusCode, err) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + gotObj := getMCPResultText(t, mcpResp) + if len(gotObj) == 0 { + resultString = "null" + } else { + gotBytes, _ := json.Marshal(gotObj) + resultString = string(gotBytes) + } + } else { + const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke" + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } } var got any var details []tableFragmentationDetails - if err := json.Unmarshal([]byte(resultString), &details); err != nil { + if err := json.Unmarshal([]byte(resultString), &details); err != nil && resultString != "null" { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } got = details @@ -3395,8 +3744,12 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar } } -func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName, tableNameParam string) { - // Create a simple query to explain +func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName, tableNameParam string, opts ...ToolExecOption) { + config := &ToolExecConfig{} + for _, opt := range opts { + opt(config) + } + query := fmt.Sprintf("SELECT * FROM %s", tableNameParam) invokeTcs := []struct { @@ -3429,41 +3782,77 @@ func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, d for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - const api = "http://127.0.0.1:5000/api/tool/get_query_plan/invoke" - resp, respBytes := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) - if resp.StatusCode != tc.wantStatusCode { - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBytes)) - } - if tc.wantStatusCode != http.StatusOK { - return - } - - var bodyWrapper map[string]json.RawMessage + var resultString string - if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { - t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) - } + if config.IsMCP { + reqBytes, _ := io.ReadAll(tc.requestBody) + var args map[string]any + if len(reqBytes) > 0 { + _ = json.Unmarshal(reqBytes, &args) + } + if args == nil { + args = make(map[string]any) + } - resultJSON, ok := bodyWrapper["result"] - if !ok { - t.Fatal("unable to find 'result' in response body") - } + statusCode, mcpResp, err := InvokeMCPTool(t, "get_query_plan", args, nil) + if statusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, err: %v", statusCode, tc.wantStatusCode, err) + } + if tc.wantStatusCode != http.StatusOK { + return + } - var resultString string - if err := json.Unmarshal(resultJSON, &resultString); err != nil { - if string(resultJSON) == "null" { + gotObj := getMCPResultText(t, mcpResp) + if len(gotObj) == 0 { resultString = "null" } else { - t.Fatalf("'result' is not a JSON-encoded string: %s", err) + if len(gotObj) == 1 { + if str, ok := gotObj[0].(string); ok { + resultString = str + } else { + gotBytes, _ := json.Marshal(gotObj[0]) + resultString = string(gotBytes) + } + } else { + gotBytes, _ := json.Marshal(gotObj) + resultString = string(gotBytes) + } + } + } else { + const api = "http://127.0.0.1:5000/api/tool/get_query_plan/invoke" + resp, respBytes := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBytes)) + } + if tc.wantStatusCode != http.StatusOK { + return } - } - var got map[string]any - if err := json.Unmarshal([]byte(resultString), &got); err != nil { - t.Fatalf("failed to unmarshal actual result string: %v", err) + var bodyWrapper map[string]json.RawMessage + + if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { + t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) + } + + resultJSON, ok := bodyWrapper["result"] + if !ok { + t.Fatal("unable to find 'result' in response body") + } + + if err := json.Unmarshal(resultJSON, &resultString); err != nil { + if string(resultJSON) == "null" { + resultString = "null" + } else { + t.Fatalf("'result' is not a JSON-encoded string: %s", err) + } + } } if tc.checkResult != nil { + var got map[string]any + if err := json.Unmarshal([]byte(resultString), &got); err != nil && resultString != "null" { + t.Fatalf("failed to unmarshal actual result string: %v", err) + } tc.checkResult(t, got) } })