diff --git a/.github/workflows/deploy_dev_docs_to_cf.yaml b/.github/workflows/deploy_dev_docs_to_cf.yaml index 830d8fffdc11..cf56f4ca88de 100644 --- a/.github/workflows/deploy_dev_docs_to_cf.yaml +++ b/.github/workflows/deploy_dev_docs_to_cf.yaml @@ -49,7 +49,7 @@ jobs: - name: Setup Hugo uses: peaceiris/actions-hugo@75d2e84710de30f6ff7268e08f310b60ef14033f # v3 with: - hugo-version: "0.145.0" + hugo-version: "0.160.0" extended: true - name: Setup Node diff --git a/.github/workflows/deploy_previous_version_docs_to_cf.yaml b/.github/workflows/deploy_previous_version_docs_to_cf.yaml index c41b125430c5..27d122979bc8 100644 --- a/.github/workflows/deploy_previous_version_docs_to_cf.yaml +++ b/.github/workflows/deploy_previous_version_docs_to_cf.yaml @@ -76,7 +76,7 @@ jobs: - name: Setup Hugo and Node uses: peaceiris/actions-hugo@75d2e84710de30f6ff7268e08f310b60ef14033f # v3 with: - hugo-version: "0.145.0" + hugo-version: "0.160.0" extended: true - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6 with: diff --git a/.github/workflows/deploy_versioned_docs_to_cf.yaml b/.github/workflows/deploy_versioned_docs_to_cf.yaml index 992b9222a962..1682413bf0d8 100644 --- a/.github/workflows/deploy_versioned_docs_to_cf.yaml +++ b/.github/workflows/deploy_versioned_docs_to_cf.yaml @@ -43,7 +43,7 @@ jobs: - name: Setup Hugo uses: peaceiris/actions-hugo@75d2e84710de30f6ff7268e08f310b60ef14033f # v3 with: - hugo-version: "0.145.0" + hugo-version: "0.160.0" extended: true - name: Setup Node diff --git a/.github/workflows/docs_preview_build_cf.yaml b/.github/workflows/docs_preview_build_cf.yaml index 9242db8097ef..ffddb831ace7 100644 --- a/.github/workflows/docs_preview_build_cf.yaml +++ b/.github/workflows/docs_preview_build_cf.yaml @@ -47,7 +47,7 @@ jobs: - name: Setup Hugo uses: peaceiris/actions-hugo@75d2e84710de30f6ff7268e08f310b60ef14033f # v3 with: - hugo-version: "0.145.0" + hugo-version: "0.160.0" extended: true - name: Setup Node diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index 6ab289149cb6..818e2271cb29 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -71,9 +71,17 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So err = pool.Ping(ctx) if err != nil { + pool.Close() return nil, fmt.Errorf("unable to connect successfully: %w", err) } + var res int + err = pool.QueryRow(ctx, "SELECT 1").Scan(&res) + if err != nil { + pool.Close() + return nil, fmt.Errorf("failed to execute 'SELECT 1' after connection: %w", err) + } + s := &Source{ Config: r, Pool: pool, diff --git a/tests/alloydb/alloydb_integration_test.go b/tests/alloydb/alloydb_integration_test.go index 5ff270bad531..27c11eb2fbc7 100644 --- a/tests/alloydb/alloydb_integration_test.go +++ b/tests/alloydb/alloydb_integration_test.go @@ -23,7 +23,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "os" "reflect" "regexp" "sort" @@ -37,115 +36,6 @@ import ( "github.com/googleapis/mcp-toolbox/tests" ) -var ( - AlloyDBProject = os.Getenv("ALLOYDB_PROJECT") - AlloyDBLocation = os.Getenv("ALLOYDB_REGION") - AlloyDBCluster = os.Getenv("ALLOYDB_CLUSTER") - AlloyDBInstance = os.Getenv("ALLOYDB_INSTANCE") - AlloyDBUser = os.Getenv("ALLOYDB_POSTGRES_USER") -) - -func getAlloyDBVars(t *testing.T) map[string]string { - if AlloyDBProject == "" { - t.Fatal("'ALLOYDB_PROJECT' not set") - } - if AlloyDBLocation == "" { - t.Fatal("'ALLOYDB_REGION' not set") - } - if AlloyDBCluster == "" { - t.Fatal("'ALLOYDB_CLUSTER' not set") - } - if AlloyDBInstance == "" { - t.Fatal("'ALLOYDB_INSTANCE' not set") - } - if AlloyDBUser == "" { - t.Fatal("'ALLOYDB_USER' not set") - } - return map[string]string{ - "project": AlloyDBProject, - "location": AlloyDBLocation, - "cluster": AlloyDBCluster, - "instance": AlloyDBInstance, - "user": AlloyDBUser, - } -} - -func getAlloyDBToolsConfig() map[string]any { - return map[string]any{ - "sources": map[string]any{ - "alloydb-admin-source": map[string]any{ - "type": "alloydb-admin", - }, - }, - "tools": map[string]any{ - // Tool for RunAlloyDBToolGetTest - "my-simple-tool": map[string]any{ - "type": "alloydb-list-clusters", - "source": "alloydb-admin-source", - "description": "Simple tool to test end to end functionality.", - }, - // Tool for MCP test - "my-param-tool": map[string]any{ - "type": "alloydb-list-clusters", - "source": "alloydb-admin-source", - "description": "Tool to list clusters", - }, - // Tool for MCP test that fails - "my-fail-tool": map[string]any{ - "type": "alloydb-list-clusters", - "source": "alloydb-admin-source", - "description": "Tool that will fail", - }, - // AlloyDB specific tools - "alloydb-list-clusters": map[string]any{ - "type": "alloydb-list-clusters", - "source": "alloydb-admin-source", - "description": "Lists all AlloyDB clusters in a given project and location.", - }, - "alloydb-list-users": map[string]any{ - "type": "alloydb-list-users", - "source": "alloydb-admin-source", - "description": "Lists all AlloyDB users within a specific cluster.", - }, - "alloydb-list-instances": map[string]any{ - "type": "alloydb-list-instances", - "source": "alloydb-admin-source", - "description": "Lists all AlloyDB instances within a specific cluster.", - }, - "alloydb-get-cluster": map[string]any{ - "type": "alloydb-get-cluster", - "source": "alloydb-admin-source", - "description": "Retrieves details of a specific AlloyDB cluster.", - }, - "alloydb-get-instance": map[string]any{ - "type": "alloydb-get-instance", - "source": "alloydb-admin-source", - "description": "Retrieves details of a specific AlloyDB instance.", - }, - "alloydb-get-user": map[string]any{ - "type": "alloydb-get-user", - "source": "alloydb-admin-source", - "description": "Retrieves details of a specific AlloyDB user.", - }, - "alloydb-create-cluster": map[string]any{ - "type": "alloydb-create-cluster", - "description": "create cluster", - "source": "alloydb-admin-source", - }, - "alloydb-create-instance": map[string]any{ - "type": "alloydb-create-instance", - "description": "create instance", - "source": "alloydb-admin-source", - }, - "alloydb-create-user": map[string]any{ - "type": "alloydb-create-user", - "description": "create user", - "source": "alloydb-admin-source", - }, - }, - } -} - func TestAlloyDBToolEndpoints(t *testing.T) { vars := getAlloyDBVars(t) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) diff --git a/tests/alloydb/alloydb_mcp_test.go b/tests/alloydb/alloydb_mcp_test.go index 200a61981672..9e195e0e5141 100644 --- a/tests/alloydb/alloydb_mcp_test.go +++ b/tests/alloydb/alloydb_mcp_test.go @@ -21,6 +21,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "reflect" "regexp" "sort" @@ -33,11 +34,138 @@ import ( "github.com/googleapis/mcp-toolbox/tests" ) +var ( + AlloyDBProject = os.Getenv("ALLOYDB_PROJECT") + AlloyDBLocation = os.Getenv("ALLOYDB_REGION") + AlloyDBCluster = os.Getenv("ALLOYDB_CLUSTER") + AlloyDBInstance = os.Getenv("ALLOYDB_INSTANCE") + AlloyDBUser = os.Getenv("ALLOYDB_POSTGRES_USER") +) + +func getAlloyDBVars(t *testing.T) map[string]string { + if AlloyDBProject == "" { + t.Fatal("'ALLOYDB_PROJECT' not set") + } + if AlloyDBLocation == "" { + t.Fatal("'ALLOYDB_REGION' not set") + } + if AlloyDBCluster == "" { + t.Fatal("'ALLOYDB_CLUSTER' not set") + } + if AlloyDBInstance == "" { + t.Fatal("'ALLOYDB_INSTANCE' not set") + } + if AlloyDBUser == "" { + t.Fatal("'ALLOYDB_USER' not set") + } + return map[string]string{ + "project": AlloyDBProject, + "location": AlloyDBLocation, + "cluster": AlloyDBCluster, + "instance": AlloyDBInstance, + "user": AlloyDBUser, + } +} + +func getAlloyDBToolsConfig() map[string]any { + return map[string]any{ + "sources": map[string]any{ + "alloydb-admin-source": map[string]any{ + "type": "alloydb-admin", + }, + }, + "tools": map[string]any{ + // Tool for RunAlloyDBToolGetTest + "my-simple-tool": map[string]any{ + "type": "alloydb-list-clusters", + "source": "alloydb-admin-source", + "description": "Simple tool to test end to end functionality.", + }, + // Tool for MCP test + "my-param-tool": map[string]any{ + "type": "alloydb-list-clusters", + "source": "alloydb-admin-source", + "description": "Tool to list clusters", + }, + // Tool for MCP test that fails + "my-fail-tool": map[string]any{ + "type": "alloydb-list-clusters", + "source": "alloydb-admin-source", + "description": "Tool that will fail", + }, + // AlloyDB specific tools + "alloydb-list-clusters": map[string]any{ + "type": "alloydb-list-clusters", + "source": "alloydb-admin-source", + "description": "Lists all AlloyDB clusters in a given project and location.", + }, + "alloydb-list-users": map[string]any{ + "type": "alloydb-list-users", + "source": "alloydb-admin-source", + "description": "Lists all AlloyDB users within a specific cluster.", + }, + "alloydb-list-instances": map[string]any{ + "type": "alloydb-list-instances", + "source": "alloydb-admin-source", + "description": "Lists all AlloyDB instances within a specific cluster.", + }, + "alloydb-get-cluster": map[string]any{ + "type": "alloydb-get-cluster", + "source": "alloydb-admin-source", + "description": "Retrieves details of a specific AlloyDB cluster.", + }, + "alloydb-get-instance": map[string]any{ + "type": "alloydb-get-instance", + "source": "alloydb-admin-source", + "description": "Retrieves details of a specific AlloyDB instance.", + }, + "alloydb-get-user": map[string]any{ + "type": "alloydb-get-user", + "source": "alloydb-admin-source", + "description": "Retrieves details of a specific AlloyDB user.", + }, + "alloydb-create-cluster": map[string]any{ + "type": "alloydb-create-cluster", + "description": "create cluster", + "source": "alloydb-admin-source", + }, + "alloydb-create-instance": map[string]any{ + "type": "alloydb-create-instance", + "description": "create instance", + "source": "alloydb-admin-source", + }, + "alloydb-create-user": map[string]any{ + "type": "alloydb-create-user", + "description": "create user", + "source": "alloydb-admin-source", + }, + }, + } +} + func TestAlloyDBListTools(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() - toolsFile := getAlloyDBToolsConfig() + toolsFile := map[string]any{ + "sources": map[string]any{ + "alloydb-admin-source": map[string]any{ + "type": "alloydb-admin", + }, + }, + "tools": map[string]any{ + "alloydb-list-clusters": map[string]any{ + "type": "alloydb-list-clusters", + "source": "alloydb-admin-source", + "description": "Lists all AlloyDB clusters in a given project and location.", + }, + "alloydb-list-users": map[string]any{ + "type": "alloydb-list-users", + "source": "alloydb-admin-source", + "description": "Lists all AlloyDB users within a specific cluster.", + }, + }, + } // Start the toolbox server cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) @@ -845,9 +973,8 @@ func TestAlloyDBCreateClusterMCP(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - args := []string{"--enable-api"} toolsFile := getAlloyDBToolsConfig() - cmd, cleanupCmd, err := tests.StartCmd(ctx, toolsFile, args...) + cmd, cleanupCmd, err := tests.StartCmd(ctx, toolsFile) if err != nil { t.Fatalf("command initialization returned an error: %v", err) } @@ -942,9 +1069,8 @@ func TestAlloyDBCreateInstanceMCP(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - args := []string{"--enable-api"} toolsFile := getAlloyDBToolsConfig() - cmd, cleanupCmd, err := tests.StartCmd(ctx, toolsFile, args...) + cmd, cleanupCmd, err := tests.StartCmd(ctx, toolsFile) if err != nil { t.Fatalf("command initialization returned an error: %v", err) } @@ -1049,9 +1175,8 @@ func TestAlloyDBCreateUserMCP(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - args := []string{"--enable-api"} toolsFile := getAlloyDBToolsConfig() - cmd, cleanupCmd, err := tests.StartCmd(ctx, toolsFile, args...) + cmd, cleanupCmd, err := tests.StartCmd(ctx, toolsFile) if err != nil { t.Fatalf("command initialization returned an error: %v", err) } diff --git a/tests/alloydbainl/alloydb_ai_nl_integration_test.go b/tests/alloydbainl/alloydb_ai_nl_integration_test.go index 42fc0e09c193..96a375f1181a 100644 --- a/tests/alloydbainl/alloydb_ai_nl_integration_test.go +++ b/tests/alloydbainl/alloydb_ai_nl_integration_test.go @@ -20,7 +20,6 @@ import ( "encoding/json" "io" "net/http" - "os" "reflect" "regexp" "strings" @@ -32,47 +31,6 @@ import ( "github.com/googleapis/mcp-toolbox/tests" ) -var ( - AlloyDBAINLSourceType = "alloydb-postgres" - AlloyDBAINLToolType = "alloydb-ai-nl" - AlloyDBAINLProject = os.Getenv("ALLOYDB_AI_NL_PROJECT") - AlloyDBAINLRegion = os.Getenv("ALLOYDB_AI_NL_REGION") - AlloyDBAINLCluster = os.Getenv("ALLOYDB_AI_NL_CLUSTER") - AlloyDBAINLInstance = os.Getenv("ALLOYDB_AI_NL_INSTANCE") - AlloyDBAINLDatabase = os.Getenv("ALLOYDB_AI_NL_DATABASE") - AlloyDBAINLUser = os.Getenv("ALLOYDB_AI_NL_USER") - AlloyDBAINLPass = os.Getenv("ALLOYDB_AI_NL_PASS") -) - -func getAlloyDBAINLVars(t *testing.T) map[string]any { - switch "" { - case AlloyDBAINLProject: - t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set") - case AlloyDBAINLRegion: - t.Fatal("'ALLOYDB_AI_NL_REGION' not set") - case AlloyDBAINLCluster: - t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set") - case AlloyDBAINLInstance: - t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set") - case AlloyDBAINLDatabase: - t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set") - case AlloyDBAINLUser: - t.Fatal("'ALLOYDB_AI_NL_USER' not set") - case AlloyDBAINLPass: - t.Fatal("'ALLOYDB_AI_NL_PASS' not set") - } - return map[string]any{ - "type": AlloyDBAINLSourceType, - "project": AlloyDBAINLProject, - "cluster": AlloyDBAINLCluster, - "instance": AlloyDBAINLInstance, - "region": AlloyDBAINLRegion, - "database": AlloyDBAINLDatabase, - "user": AlloyDBAINLUser, - "password": AlloyDBAINLPass, - } -} - func TestAlloyDBAINLToolEndpoints(t *testing.T) { sourceConfig := getAlloyDBAINLVars(t) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -277,59 +235,6 @@ func runAINLToolInvokeTest(t *testing.T) { } -func getAINLToolsConfig(sourceConfig map[string]any) map[string]any { - // Write config into a file and pass it to command - toolsFile := map[string]any{ - "sources": map[string]any{ - "my-instance": sourceConfig, - }, - "authServices": map[string]any{ - "my-google-auth": map[string]any{ - "type": "google", - "clientId": tests.ClientId, - }, - }, - "tools": map[string]any{ - "my-simple-tool": map[string]any{ - "type": AlloyDBAINLToolType, - "source": "my-instance", - "description": "Simple tool to test end to end functionality.", - "nlConfig": "my_nl_config", - }, - "my-auth-tool": map[string]any{ - "type": AlloyDBAINLToolType, - "source": "my-instance", - "description": "Tool to test authenticated parameters.", - "nlConfig": "my_nl_config", - "nlConfigParameters": []map[string]any{ - { - "name": "email", - "type": "string", - "description": "user email", - "authServices": []map[string]string{ - { - "name": "my-google-auth", - "field": "email", - }, - }, - }, - }, - }, - "my-auth-required-tool": map[string]any{ - "type": AlloyDBAINLToolType, - "source": "my-instance", - "description": "Tool to test auth required invocation.", - "nlConfig": "my_nl_config", - "authRequired": []string{ - "my-google-auth", - }, - }, - }, - } - - return toolsFile -} - func runAINLMCPToolCallMethod(t *testing.T) { sessionId := tests.RunInitialize(t, "2024-11-05") header := map[string]string{} diff --git a/tests/alloydbainl/alloydb_ai_nl_mcp_test.go b/tests/alloydbainl/alloydb_ai_nl_mcp_test.go new file mode 100644 index 000000000000..b25f8eedb8b5 --- /dev/null +++ b/tests/alloydbainl/alloydb_ai_nl_mcp_test.go @@ -0,0 +1,333 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package alloydbainl + +import ( + "context" + "net/http" + "os" + "regexp" + "testing" + "time" + + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/tests" +) + +func TestAlloyDBAINLListTools(t *testing.T) { + sourceConfig := getAlloyDBAINLVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + toolsFile := getAINLToolsConfig(sourceConfig) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second) + defer cancelWait() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Verify list of tools + expectedTools := []tests.MCPToolManifest{ + { + Name: "my-simple-tool", + Description: "Simple tool to test end to end functionality.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "question": map[string]any{ + "description": "The natural language question to ask.", + "type": "string", + }, + }, + "required": []any{"question"}, + }, + }, + { + Name: "my-auth-tool", + Description: "Tool to test authenticated parameters.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "question": map[string]any{ + "description": "The natural language question to ask.", + "type": "string", + }, + "email": map[string]any{ + "description": "user email", + "type": "string", + }, + }, + "required": []any{"question", "email"}, + }, + }, + { + Name: "my-auth-required-tool", + Description: "Tool to test auth required invocation.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "question": map[string]any{ + "description": "The natural language question to ask.", + "type": "string", + }, + }, + "required": []any{"question"}, + }, + }, + } + + tests.RunMCPToolsListMethod(t, expectedTools) +} + +func TestAlloyDBAINLCallTool(t *testing.T) { + sourceConfig := getAlloyDBAINLVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + toolsFile := getAINLToolsConfig(sourceConfig) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second) + defer cancelWait() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + idToken, err := tests.GetGoogleIdToken(t) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + invokeTcs := []struct { + name string + toolName string + args map[string]any + requestHeader map[string]string + want string + isErr bool + wantStatusCode int + }{ + { + name: "invoke my-simple-tool", + toolName: "my-simple-tool", + args: map[string]any{"question": "return the number 1"}, + want: "{\"execute_nl_query\":{\"?column?\":1}}", + isErr: false, + }, + { + name: "Invoke my-auth-tool with auth token", + toolName: "my-auth-tool", + args: map[string]any{"question": "can you show me the name of this user?"}, + requestHeader: map[string]string{"my-google-auth_token": idToken}, + want: "{\"execute_nl_query\":{\"name\":\"Alice\"}}", + isErr: false, + }, + { + name: "Invoke my-auth-tool with invalid auth token", + toolName: "my-auth-tool", + args: map[string]any{"question": "return the number 1"}, + requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + isErr: true, + }, + { + name: "Invoke my-auth-tool without auth token", + toolName: "my-auth-tool", + args: map[string]any{"question": "return the number 1"}, + isErr: true, + }, + { + name: "Invoke my-auth-required-tool with auth token", + toolName: "my-auth-required-tool", + args: map[string]any{"question": "return the number 1"}, + requestHeader: map[string]string{"my-google-auth_token": idToken}, + isErr: false, + want: "{\"execute_nl_query\":{\"?column?\":1}}", + }, + { + name: "Invoke my-auth-required-tool with invalid auth token", + toolName: "my-auth-required-tool", + args: map[string]any{"question": "return the number 1"}, + requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + isErr: true, + wantStatusCode: 401, + }, + { + name: "Invoke my-auth-required-tool without auth token", + toolName: "my-auth-required-tool", + args: map[string]any{"question": "return the number 1"}, + isErr: true, + wantStatusCode: 401, + }, + { + name: "Invoke invalid tool", + toolName: "foo", + args: map[string]any{}, + isErr: true, + }, + { + name: "Invoke my-auth-tool without parameters", + toolName: "my-auth-tool", + args: map[string]any{}, + isErr: true, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + statusCode, mcpResp, err := tests.InvokeMCPTool(t, tc.toolName, tc.args, tc.requestHeader) + if err != nil { + t.Fatalf("native error executing %s: %s", tc.toolName, err) + } + + expectedStatus := tc.wantStatusCode + if expectedStatus == 0 { + expectedStatus = http.StatusOK + } + if statusCode != expectedStatus { + t.Fatalf("expected status %d, got %d", expectedStatus, statusCode) + } + + if tc.isErr { + if mcpResp.Error == nil && !mcpResp.Result.IsError { + t.Fatalf("expected error result or JSON-RPC error, got success") + } + } else { + if mcpResp.Error != nil { + t.Fatalf("expected success, got JSON-RPC error: %v", mcpResp.Error) + } + if mcpResp.Result.IsError { + t.Fatalf("expected success result, got tool error: %v", mcpResp.Result) + } + if len(mcpResp.Result.Content) == 0 { + t.Fatalf("expected at least one content item, got none") + } + got := mcpResp.Result.Content[0].Text + if got != tc.want { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + } + }) + } +} + +var ( + AlloyDBAINLSourceType = "alloydb-postgres" + AlloyDBAINLToolType = "alloydb-ai-nl" + AlloyDBAINLProject = os.Getenv("ALLOYDB_AI_NL_PROJECT") + AlloyDBAINLRegion = os.Getenv("ALLOYDB_AI_NL_REGION") + AlloyDBAINLCluster = os.Getenv("ALLOYDB_AI_NL_CLUSTER") + AlloyDBAINLInstance = os.Getenv("ALLOYDB_AI_NL_INSTANCE") + AlloyDBAINLDatabase = os.Getenv("ALLOYDB_AI_NL_DATABASE") + AlloyDBAINLUser = os.Getenv("ALLOYDB_AI_NL_USER") + AlloyDBAINLPass = os.Getenv("ALLOYDB_AI_NL_PASS") +) + +func getAlloyDBAINLVars(t *testing.T) map[string]any { + switch "" { + case AlloyDBAINLProject: + t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set") + case AlloyDBAINLRegion: + t.Fatal("'ALLOYDB_AI_NL_REGION' not set") + case AlloyDBAINLCluster: + t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set") + case AlloyDBAINLInstance: + t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set") + case AlloyDBAINLDatabase: + t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set") + case AlloyDBAINLUser: + t.Fatal("'ALLOYDB_AI_NL_USER' not set") + case AlloyDBAINLPass: + t.Fatal("'ALLOYDB_AI_NL_PASS' not set") + } + return map[string]any{ + "type": AlloyDBAINLSourceType, + "project": AlloyDBAINLProject, + "cluster": AlloyDBAINLCluster, + "instance": AlloyDBAINLInstance, + "region": AlloyDBAINLRegion, + "database": AlloyDBAINLDatabase, + "user": AlloyDBAINLUser, + "password": AlloyDBAINLPass, + } +} + +func getAINLToolsConfig(sourceConfig map[string]any) map[string]any { + // Write config into a file and pass it to command + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "type": "google", + "clientId": tests.ClientId, + }, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "type": AlloyDBAINLToolType, + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + "nlConfig": "my_nl_config", + }, + "my-auth-tool": map[string]any{ + "type": AlloyDBAINLToolType, + "source": "my-instance", + "description": "Tool to test authenticated parameters.", + "nlConfig": "my_nl_config", + "nlConfigParameters": []map[string]any{ + { + "name": "email", + "type": "string", + "description": "user email", + "authServices": []map[string]string{ + { + "name": "my-google-auth", + "field": "email", + }, + }, + }, + }, + }, + "my-auth-required-tool": map[string]any{ + "type": AlloyDBAINLToolType, + "source": "my-instance", + "description": "Tool to test auth required invocation.", + "nlConfig": "my_nl_config", + "authRequired": []string{ + "my-google-auth", + }, + }, + }, + } + + return toolsFile +} diff --git a/tests/http/http_integration_test.go b/tests/http/http_integration_test.go index 7e1cc66316e5..84aedf7200a5 100644 --- a/tests/http/http_integration_test.go +++ b/tests/http/http_integration_test.go @@ -20,11 +20,9 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" - "reflect" "regexp" "strings" "testing" @@ -33,15 +31,9 @@ import ( "github.com/MicahParks/jwkset" "github.com/golang-jwt/jwt/v5" "github.com/googleapis/mcp-toolbox/internal/testutils" - "github.com/googleapis/mcp-toolbox/internal/util/parameters" "github.com/googleapis/mcp-toolbox/tests" ) -var ( - HttpSourceType = "http" - HttpToolType = "http" -) - func getHTTPSourceConfig(t *testing.T) map[string]any { idToken, err := tests.GetGoogleIdToken(t) if err != nil { @@ -55,252 +47,6 @@ func getHTTPSourceConfig(t *testing.T) map[string]any { } } -// handler function for the test server -func multiTool(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - path = strings.TrimPrefix(path, "/") // Remove leading slash - - switch path { - case "tool0": - handleTool0(w, r) - case "tool1": - handleTool1(w, r) - case "tool1id": - handleTool1Id(w, r) - case "tool1name": - handleTool1Name(w, r) - case "tool2": - handleTool2(w, r) - case "tool3": - handleTool3(w, r) - case "toolQueryTest": - handleQueryTest(w, r) - default: - http.NotFound(w, r) // Return 404 for unknown paths - } -} - -// handleQueryTest simply returns the raw query string it received so the test -// can verify it's formatted correctly. -func handleQueryTest(w http.ResponseWriter, r *http.Request) { - // expect GET method - if r.Method != http.MethodGet { - errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - w.WriteHeader(http.StatusOK) - enc := json.NewEncoder(w) - enc.SetEscapeHTML(false) - - err := enc.Encode(r.URL.RawQuery) - if err != nil { - http.Error(w, "Failed to write response", http.StatusInternalServerError) - return - } -} - -// handler function for the test server -func handleTool0(w http.ResponseWriter, r *http.Request) { - // expect POST method - if r.Method != http.MethodPost { - errorMessage := fmt.Sprintf("expected POST method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - w.WriteHeader(http.StatusOK) - response := "hello world" - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, "Failed to encode JSON", http.StatusInternalServerError) - return - } -} - -// handler function for the test server -func handleTool1(w http.ResponseWriter, r *http.Request) { - // expect GET method - if r.Method != http.MethodGet { - errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - // Parse request body - var requestBody map[string]interface{} - bodyBytes, readErr := io.ReadAll(r.Body) - if readErr != nil { - http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - err := json.Unmarshal(bodyBytes, &requestBody) - if err != nil { - errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - // Extract name - name, ok := requestBody["name"].(string) - if !ok || name == "" { - http.Error(w, "Bad Request: Missing or invalid name", http.StatusBadRequest) - return - } - - if name == "Alice" { - response := `[{"id":1,"name":"Alice"},{"id":3,"name":"Sid"}]` - _, err := w.Write([]byte(response)) - if err != nil { - http.Error(w, "Failed to write response", http.StatusInternalServerError) - } - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handler function for the test server -func handleTool1Id(w http.ResponseWriter, r *http.Request) { - // expect GET method - if r.Method != http.MethodGet { - errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - id := r.URL.Query().Get("id") - if id == "4" { - response := `[{"id":4,"name":null}]` - _, err := w.Write([]byte(response)) - if err != nil { - http.Error(w, "Failed to write response", http.StatusInternalServerError) - } - return - } - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handler function for the test server -func handleTool1Name(w http.ResponseWriter, r *http.Request) { - // expect GET method - if r.Method != http.MethodGet { - errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - if !r.URL.Query().Has("name") { - response := "null" - _, err := w.Write([]byte(response)) - if err != nil { - http.Error(w, "Failed to write response", http.StatusInternalServerError) - } - return - } - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handler function for the test server -func handleTool2(w http.ResponseWriter, r *http.Request) { - // expect GET method - if r.Method != http.MethodGet { - errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - email := r.URL.Query().Get("email") - if email != "" { - response := `[{"name":"Alice"}]` - _, err := w.Write([]byte(response)) - if err != nil { - http.Error(w, "Failed to write response", http.StatusInternalServerError) - } - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handler function for the test server -func handleTool3(w http.ResponseWriter, r *http.Request) { - // expect GET method - if r.Method != http.MethodGet { - errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - // Check request headers - expectedHeaders := map[string]string{ - "Content-Type": "application/json", - "X-Custom-Header": "example", - "X-Other-Header": "test", - } - for header, expectedValue := range expectedHeaders { - if r.Header.Get(header) != expectedValue { - errorMessage := fmt.Sprintf("Bad Request: Missing or incorrect header: %s", header) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - } - - // Check query parameters - expectedQueryParams := map[string][]string{ - "id": []string{"2", "1", "3"}, - "country": []string{"US"}, - } - query := r.URL.Query() - for param, expectedValueSlice := range expectedQueryParams { - values, ok := query[param] - if ok { - if !reflect.DeepEqual(expectedValueSlice, values) { - errorMessage := fmt.Sprintf("Bad Request: Incorrect query parameter: %s, actual: %s", param, query[param]) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - } else { - errorMessage := fmt.Sprintf("Bad Request: Missing query parameter: %s, actual: %s", param, query[param]) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - } - - // Parse request body - var requestBody map[string]interface{} - bodyBytes, readErr := io.ReadAll(r.Body) - if readErr != nil { - http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - err := json.Unmarshal(bodyBytes, &requestBody) - if err != nil { - errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes)) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - // Check request body - expectedBody := map[string]interface{}{ - "place": "zoo", - "animals": []any{"rabbit", "ostrich", "whale"}, - } - - if !reflect.DeepEqual(requestBody, expectedBody) { - errorMessage := fmt.Sprintf("Bad Request: Incorrect request body. Expected: %v, Got: %v", expectedBody, requestBody) - http.Error(w, errorMessage, http.StatusBadRequest) - return - } - - response := "hello world" - err = json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, "Failed to encode JSON", http.StatusInternalServerError) - return - } -} - func TestHttpToolEndpoints(t *testing.T) { // start a test server server := httptest.NewServer(http.HandlerFunc(multiTool)) @@ -597,152 +343,3 @@ func runAdvancedHTTPInvokeTest(t *testing.T) { }) } } - -// getHTTPToolsConfig returns a mock HTTP tool's config file -func getHTTPToolsConfig(sourceConfig map[string]any, toolType string, jwksURL string) map[string]any { - // Write config into a file and pass it to command - otherSourceConfig := make(map[string]any) - for k, v := range sourceConfig { - otherSourceConfig[k] = v - } - otherSourceConfig["headers"] = map[string]string{"X-Custom-Header": "unexpected", "Content-Type": "application/json"} - otherSourceConfig["queryParams"] = map[string]any{"id": 1, "name": "Sid"} - - clientID := tests.ClientId - if clientID == "" { - clientID = "test-client-id" - } - - toolsFile := map[string]any{ - "sources": map[string]any{ - "my-instance": sourceConfig, - "other-instance": otherSourceConfig, - }, - "authServices": map[string]any{ - "my-google-auth": map[string]any{ - "type": "google", - "clientId": clientID, - }, - "my-generic-auth": map[string]any{ - "type": "generic", - "audience": "test-audience", - "authorizationServer": jwksURL, - "scopesRequired": []string{"read:files"}, - }, - }, - "tools": map[string]any{ - "my-simple-tool": map[string]any{ - "type": toolType, - "path": "/tool0", - "method": "POST", - "source": "my-instance", - "requestBody": "{}", - "description": "Simple tool to test end to end functionality.", - }, - "my-tool": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "GET", - "path": "/tool1", - "description": "some description", - "queryParams": []parameters.Parameter{ - parameters.NewIntParameter("id", "user ID")}, - "bodyParams": []parameters.Parameter{parameters.NewStringParameter("name", "user name")}, - "requestBody": `{ -"age": 36, -"name": "{{.name}}" -} -`, - "headers": map[string]string{"Content-Type": "application/json"}, - }, - "my-tool-by-id": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "GET", - "path": "/tool1id", - "description": "some description", - "queryParams": []parameters.Parameter{ - parameters.NewIntParameter("id", "user ID")}, - "headers": map[string]string{"Content-Type": "application/json"}, - }, - "my-tool-by-name": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "GET", - "path": "/tool1name", - "description": "some description", - "queryParams": []parameters.Parameter{ - parameters.NewStringParameterWithRequired("name", "user name", false)}, - "headers": map[string]string{"Content-Type": "application/json"}, - }, - "my-query-param-tool": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "GET", - "path": "/toolQueryTest", - "description": "Tool to test optional query parameters.", - "queryParams": []parameters.Parameter{ - parameters.NewStringParameterWithRequired("reqId", "required ID", true), - parameters.NewStringParameterWithRequired("page", "optional page number", false), - parameters.NewStringParameterWithRequired("filter", "optional filter string", false), - }, - }, - "my-auth-tool": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "GET", - "path": "/tool2", - "description": "some description", - "requestBody": "{}", - "queryParams": []parameters.Parameter{ - parameters.NewStringParameterWithAuth("email", "some description", - []parameters.ParamAuthService{{Name: "my-google-auth", Field: "email"}}), - }, - }, - "my-auth-required-tool": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "POST", - "path": "/tool0", - "description": "some description", - "requestBody": "{}", - "authRequired": []string{"my-google-auth"}, - }, - "my-auth-required-generic-tool": map[string]any{ - "type": toolType, - "source": "my-instance", - "method": "POST", - "path": "/tool0", - "description": "some description", - "requestBody": "{}", - "authRequired": []string{"my-generic-auth"}, - }, - "my-advanced-tool": map[string]any{ - "type": toolType, - "source": "other-instance", - "method": "get", - "path": "/{{.path}}?id=2", - "description": "some description", - "headers": map[string]string{ - "X-Custom-Header": "example", - }, - "pathParams": []parameters.Parameter{ - ¶meters.StringParameter{ - CommonParameter: parameters.CommonParameter{Name: "path", Type: "string", Desc: "path param"}, - }, - }, - "queryParams": []parameters.Parameter{ - parameters.NewIntParameter("id", "user ID"), parameters.NewStringParameter("country", "country"), - }, - "requestBody": `{ - "place": "zoo", - "animals": {{json .animalArray }} - } - `, - "bodyParams": []parameters.Parameter{parameters.NewArrayParameter("animalArray", "animals in the zoo", parameters.NewStringParameter("animals", "desc"))}, - "headerParams": []parameters.Parameter{parameters.NewStringParameter("X-Other-Header", "custom header")}, - }, - }, - } - return toolsFile -} diff --git a/tests/http/http_mcp_test.go b/tests/http/http_mcp_test.go index dc94de5136af..cef039563ae5 100644 --- a/tests/http/http_mcp_test.go +++ b/tests/http/http_mcp_test.go @@ -19,18 +19,403 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "fmt" + "io" "net/http" "net/http/httptest" + "reflect" "regexp" + "strings" "testing" "time" "github.com/MicahParks/jwkset" "github.com/golang-jwt/jwt/v5" "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/internal/util/parameters" "github.com/googleapis/mcp-toolbox/tests" ) +var ( + HttpSourceType = "http" + HttpToolType = "http" +) + +// handler function for the test server +func multiTool(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + path = strings.TrimPrefix(path, "/") // Remove leading slash + + switch path { + case "tool0": + handleTool0(w, r) + case "tool1": + handleTool1(w, r) + case "tool1id": + handleTool1Id(w, r) + case "tool1name": + handleTool1Name(w, r) + case "tool2": + handleTool2(w, r) + case "tool3": + handleTool3(w, r) + case "toolQueryTest": + handleQueryTest(w, r) + default: + http.NotFound(w, r) // Return 404 for unknown paths + } +} + +// handleQueryTest simply returns the raw query string it received so the test +// can verify it's formatted correctly. +func handleQueryTest(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) + + err := enc.Encode(r.URL.RawQuery) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } +} + +func handleTool0(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + errorMessage := fmt.Sprintf("expected POST method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + response := "hello world" + err := json.NewEncoder(w).Encode(response) + if err != nil { + http.Error(w, "Failed to encode JSON", http.StatusInternalServerError) + return + } +} + +func handleTool1(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + var requestBody map[string]interface{} + bodyBytes, readErr := io.ReadAll(r.Body) + if readErr != nil { + http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + err := json.Unmarshal(bodyBytes, &requestBody) + if err != nil { + errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + name, ok := requestBody["name"].(string) + if !ok || name == "" { + http.Error(w, "Bad Request: Missing or invalid name", http.StatusBadRequest) + return + } + + if name == "Alice" { + response := `[{"id":1,"name":"Alice"},{"id":3,"name":"Sid"}]` + _, err := w.Write([]byte(response)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +func handleTool1Id(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + id := r.URL.Query().Get("id") + if id == "4" { + response := `[{"id":4,"name":null}]` + _, err := w.Write([]byte(response)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } + return + } + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +func handleTool1Name(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + if !r.URL.Query().Has("name") { + response := "null" + _, err := w.Write([]byte(response)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } + return + } + + http.Error(w, "Bad Request: Unexpected query parameter 'name'", http.StatusBadRequest) +} + +func handleTool2(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + email := r.URL.Query().Get("email") + if email != "" { + response := `[{"name":"Alice"}]` + _, err := w.Write([]byte(response)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +func handleTool3(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + "X-Custom-Header": "example", + "X-Other-Header": "test", + } + for header, expectedValue := range expectedHeaders { + if r.Header.Get(header) != expectedValue { + errorMessage := fmt.Sprintf("Bad Request: Missing or incorrect header: %s", header) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + } + + expectedQueryParams := map[string][]string{ + "id": []string{"2", "1", "3"}, + "country": []string{"US"}, + } + query := r.URL.Query() + for param, expectedValueSlice := range expectedQueryParams { + values, ok := query[param] + if ok { + if !reflect.DeepEqual(expectedValueSlice, values) { + errorMessage := fmt.Sprintf("Bad Request: Incorrect query parameter: %s, actual: %s", param, query[param]) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + } else { + errorMessage := fmt.Sprintf("Bad Request: Missing query parameter: %s, actual: %s", param, query[param]) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + } + + var requestBody map[string]interface{} + bodyBytes, readErr := io.ReadAll(r.Body) + if readErr != nil { + http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + err := json.Unmarshal(bodyBytes, &requestBody) + if err != nil { + errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes)) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + expectedBody := map[string]interface{}{ + "place": "zoo", + "animals": []any{"rabbit", "ostrich", "whale"}, + } + + if !reflect.DeepEqual(requestBody, expectedBody) { + errorMessage := fmt.Sprintf("Bad Request: Incorrect request body. Expected: %v, Got: %v", expectedBody, requestBody) + http.Error(w, errorMessage, http.StatusBadRequest) + return + } + + response := "hello world" + err = json.NewEncoder(w).Encode(response) + if err != nil { + http.Error(w, "Failed to encode JSON", http.StatusInternalServerError) + return + } +} + +func getHTTPToolsConfig(sourceConfig map[string]any, toolType string, jwksURL string) map[string]any { + otherSourceConfig := make(map[string]any) + for k, v := range sourceConfig { + otherSourceConfig[k] = v + } + otherSourceConfig["headers"] = map[string]string{"X-Custom-Header": "unexpected", "Content-Type": "application/json"} + otherSourceConfig["queryParams"] = map[string]any{"id": 1, "name": "Sid"} + + clientID := tests.ClientId + if clientID == "" { + clientID = "test-client-id" + } + + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + "other-instance": otherSourceConfig, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "type": "google", + "clientId": clientID, + }, + "my-generic-auth": map[string]any{ + "type": "generic", + "audience": "test-audience", + "authorizationServer": jwksURL, + "scopesRequired": []string{"read:files"}, + }, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "type": toolType, + "path": "/tool0", + "method": "POST", + "source": "my-instance", + "requestBody": "{}", + "description": "Simple tool to test end to end functionality.", + }, + "my-tool": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "GET", + "path": "/tool1", + "description": "some description", + "queryParams": []parameters.Parameter{ + parameters.NewIntParameter("id", "user ID")}, + "bodyParams": []parameters.Parameter{parameters.NewStringParameter("name", "user name")}, + "requestBody": `{ +"age": 36, +"name": "{{.name}}" +} +`, + "headers": map[string]string{"Content-Type": "application/json"}, + }, + "my-tool-by-id": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "GET", + "path": "/tool1id", + "description": "some description", + "queryParams": []parameters.Parameter{ + parameters.NewIntParameter("id", "user ID")}, + "headers": map[string]string{"Content-Type": "application/json"}, + }, + "my-tool-by-name": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "GET", + "path": "/tool1name", + "description": "some description", + "queryParams": []parameters.Parameter{ + parameters.NewStringParameterWithRequired("name", "user name", false)}, + "headers": map[string]string{"Content-Type": "application/json"}, + }, + "my-query-param-tool": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "GET", + "path": "/toolQueryTest", + "description": "Tool to test optional query parameters.", + "queryParams": []parameters.Parameter{ + parameters.NewStringParameterWithRequired("reqId", "required ID", true), + parameters.NewStringParameterWithRequired("page", "optional page number", false), + parameters.NewStringParameterWithRequired("filter", "optional filter string", false), + }, + }, + "my-auth-tool": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "GET", + "path": "/tool2", + "description": "some description", + "requestBody": "{}", + "queryParams": []parameters.Parameter{ + parameters.NewStringParameterWithAuth("email", "some description", + []parameters.ParamAuthService{{Name: "my-google-auth", Field: "email"}}), + }, + }, + "my-auth-required-tool": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "POST", + "path": "/tool0", + "description": "some description", + "requestBody": "{}", + "authRequired": []string{"my-google-auth"}, + }, + "my-auth-required-generic-tool": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "POST", + "path": "/tool0", + "description": "some description", + "requestBody": "{}", + "authRequired": []string{"my-generic-auth"}, + }, + "my-advanced-tool": map[string]any{ + "type": toolType, + "source": "other-instance", + "method": "get", + "path": "/{{.path}}?id=2", + "description": "some description", + "headers": map[string]string{ + "X-Custom-Header": "example", + }, + "pathParams": []parameters.Parameter{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{Name: "path", Type: "string", Desc: "path param"}, + }, + }, + "queryParams": []parameters.Parameter{ + parameters.NewIntParameter("id", "user ID"), parameters.NewStringParameter("country", "country"), + }, + "requestBody": `{ + "place": "zoo", + "animals": {{json .animalArray }} + } + `, + "bodyParams": []parameters.Parameter{parameters.NewArrayParameter("animalArray", "animals in the zoo", parameters.NewStringParameter("animals", "desc"))}, + "headerParams": []parameters.Parameter{parameters.NewStringParameter("X-Other-Header", "custom header")}, + }, + }, + } + return toolsFile +} + func getMCPHTTPSourceConfig(t *testing.T) map[string]any { idToken, err := tests.GetGoogleIdToken(t) if err != nil { @@ -86,7 +471,37 @@ func TestHTTPListTools(t *testing.T) { })) defer jwksServer.Close() - toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolType, jwksServer.URL) + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "type": HttpToolType, + "path": "/tool0", + "method": "POST", + "source": "my-instance", + "requestBody": "{}", + "description": "Simple tool to test end to end functionality.", + }, + "my-tool": map[string]any{ + "type": HttpToolType, + "source": "my-instance", + "method": "GET", + "path": "/tool1", + "description": "some description", + "queryParams": []parameters.Parameter{ + parameters.NewIntParameter("id", "user ID")}, + "bodyParams": []parameters.Parameter{parameters.NewStringParameter("name", "user name")}, + "requestBody": `{ +"age": 36, +"name": "{{.name}}" +} +`, + "headers": map[string]string{"Content-Type": "application/json"}, + }, + }, + } // Start the toolbox server. cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) diff --git a/tests/mcp_tool.go b/tests/mcp_tool.go index d76ab10f26f6..6454db8d62ae 100644 --- a/tests/mcp_tool.go +++ b/tests/mcp_tool.go @@ -18,6 +18,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "reflect" "strings" @@ -29,6 +30,78 @@ import ( v20251125 "github.com/googleapis/mcp-toolbox/internal/server/mcp/v20251125" ) +// RunRequest is a helper function to send HTTP requests and return the response +func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) { + req, err := http.NewRequest(method, url, body) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + + req.Header.Set("Content-type", "application/json") + + for k, v := range headers { + req.Header.Set(k, v) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read request body: %s", err) + } + + defer resp.Body.Close() + return resp, respBody +} + +// RunInitialize runs the initialize lifecycle for mcp to set up client-server connection +func RunInitialize(t *testing.T, protocolVersion string) string { + url := "http://127.0.0.1:5000/mcp" + + initializeRequestBody := map[string]any{ + "jsonrpc": "2.0", + "id": "mcp-initialize", + "method": "initialize", + "params": map[string]any{ + "protocolVersion": protocolVersion, + }, + } + reqMarshal, err := json.Marshal(initializeRequestBody) + if err != nil { + t.Fatalf("unexpected error during marshaling of body") + } + + resp, _ := RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqMarshal), nil) + if resp.StatusCode != 200 { + t.Fatalf("response status code is not 200") + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + sessionId := resp.Header.Get("Mcp-Session-Id") + + header := map[string]string{} + if sessionId != "" { + header["Mcp-Session-Id"] = sessionId + } + + initializeNotificationBody := map[string]any{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + notiMarshal, err := json.Marshal(initializeNotificationBody) + if err != nil { + t.Fatalf("unexpected error during marshaling of notifications body") + } + + _, _ = RunRequest(t, http.MethodPost, url, bytes.NewBuffer(notiMarshal), header) + return sessionId +} + // NewMCPRequestHeader takes custom headers and appends headers required for MCP. func NewMCPRequestHeader(t *testing.T, customHeaders map[string]string) map[string]string { headers := make(map[string]string) @@ -143,6 +216,10 @@ func RunMCPToolsListMethod(t *testing.T, expectedOutput []MCPToolManifest) { t.Fatalf("error unmarshalling tools into MCPToolManifest: %v", err) } + if len(actualTools) != len(expectedOutput) { + t.Fatalf("expected %d tools, got %d. Actual tools: %+v", len(expectedOutput), len(actualTools), actualTools) + } + for _, expected := range expectedOutput { found := false for _, actual := range actualTools { diff --git a/tests/tool.go b/tests/tool.go index 0b0c207d6f90..3fc087c3e83b 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -748,52 +748,6 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want } } -// RunInitialize runs the initialize lifecycle for mcp to set up client-server connection -func RunInitialize(t *testing.T, protocolVersion string) string { - url := "http://127.0.0.1:5000/mcp" - - initializeRequestBody := map[string]any{ - "jsonrpc": "2.0", - "id": "mcp-initialize", - "method": "initialize", - "params": map[string]any{ - "protocolVersion": protocolVersion, - }, - } - reqMarshal, err := json.Marshal(initializeRequestBody) - if err != nil { - t.Fatalf("unexpected error during marshaling of body") - } - - resp, _ := RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqMarshal), nil) - if resp.StatusCode != 200 { - t.Fatalf("response status code is not 200") - } - - if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { - t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) - } - - sessionId := resp.Header.Get("Mcp-Session-Id") - - header := map[string]string{} - if sessionId != "" { - header["Mcp-Session-Id"] = sessionId - } - - initializeNotificationBody := map[string]any{ - "jsonrpc": "2.0", - "method": "notifications/initialized", - } - notiMarshal, err := json.Marshal(initializeNotificationBody) - if err != nil { - t.Fatalf("unexpected error during marshaling of notifications body") - } - - _, _ = RunRequest(t, http.MethodPost, url, bytes.NewBuffer(notiMarshal), header) - return sessionId -} - // RunMCPToolCallMethod runs the tool/call for mcp endpoint func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, options ...McpTestOption) { // Resolve options @@ -4767,33 +4721,6 @@ func RunPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, pool } } -// RunRequest is a helper function to send HTTP requests and return the response -func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) { - // Send request - req, err := http.NewRequest(method, url, body) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - - req.Header.Set("Content-type", "application/json") - - for k, v := range headers { - req.Header.Set(k, v) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - respBody, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("unable to read request body: %s", err) - } - - defer resp.Body.Close() - return resp, respBody -} - func RunStatementToolsTest(t *testing.T, tools map[string]string) { for toolName, paramBody := range tools { t.Run(toolName, func(t *testing.T) {