diff --git a/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml b/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml index 9dd1f378a70..81cfc3626c1 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml @@ -20,7 +20,7 @@ sources: instance: ${CLOUD_SQL_POSTGRES_INSTANCE} database: ${CLOUD_SQL_POSTGRES_DATABASE} user: ${CLOUD_SQL_POSTGRES_USER:} - password: ${CLOUD_SQL_POSTGRES_PASSWORD:} + password: ${CLOUD_SQL_POSTGRES_PASS:} ipType: ${CLOUD_SQL_POSTGRES_IP_TYPE:public} cloud-sql-admin-source: kind: cloud-sql-admin diff --git a/tests/alloydb/alloydb_mcp_test.go b/tests/alloydb/alloydb_mcp_test.go index 9e195e0e514..9eb7a85f556 100644 --- a/tests/alloydb/alloydb_mcp_test.go +++ b/tests/alloydb/alloydb_mcp_test.go @@ -227,7 +227,7 @@ func TestAlloyDBListTools(t *testing.T) { }, } - tests.RunMCPToolsListMethod(t, expectedTools) + tests.RunMCPToolsListMethod(t, ctx, expectedTools) } func TestAlloyDBCallTool(t *testing.T) { @@ -252,15 +252,15 @@ func TestAlloyDBCallTool(t *testing.T) { } // Run tool-specific invoke tests - runAlloyDBListClustersMCPTest(t, vars) - runAlloyDBListInstancesMCPTest(t, vars) - runAlloyDBListUsersMCPTest(t, vars) - runAlloyDBGetClusterMCPTest(t, vars) - runAlloyDBGetInstanceMCPTest(t, vars) - runAlloyDBGetUserMCPTest(t, vars) + runAlloyDBListClustersMCPTest(t, ctx, vars) + runAlloyDBListInstancesMCPTest(t, ctx, vars) + runAlloyDBListUsersMCPTest(t, ctx, vars) + runAlloyDBGetClusterMCPTest(t, ctx, vars) + runAlloyDBGetInstanceMCPTest(t, ctx, vars) + runAlloyDBGetUserMCPTest(t, ctx, vars) t.Run("MCP Invoke invalid tool", func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "non-existent-tool", map[string]any{}, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "non-existent-tool", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing %s: %s", "non-existent-tool", err) } @@ -271,7 +271,7 @@ func TestAlloyDBCallTool(t *testing.T) { }) } -func runAlloyDBListClustersMCPTest(t *testing.T, vars map[string]string) { +func runAlloyDBListClustersMCPTest(t *testing.T, ctx context.Context, vars map[string]string) { type ListClustersResponse struct { Clusters []struct { Name string `json:"name"` @@ -345,7 +345,7 @@ func runAlloyDBListClustersMCPTest(t *testing.T, vars map[string]string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-list-clusters", tc.args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-list-clusters", tc.args, nil) if err != nil { t.Fatalf("native error executing: %s", err) } @@ -382,7 +382,7 @@ func runAlloyDBListClustersMCPTest(t *testing.T, vars map[string]string) { } } -func runAlloyDBListInstancesMCPTest(t *testing.T, vars map[string]string) { +func runAlloyDBListInstancesMCPTest(t *testing.T, ctx context.Context, vars map[string]string) { type ListInstancesResponse struct { Instances []struct { Name string `json:"name"` @@ -460,7 +460,7 @@ func runAlloyDBListInstancesMCPTest(t *testing.T, vars map[string]string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-list-instances", tc.args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-list-instances", tc.args, nil) if err != nil { t.Fatalf("native error executing: %s", err) } @@ -497,7 +497,7 @@ func runAlloyDBListInstancesMCPTest(t *testing.T, vars map[string]string) { } } -func runAlloyDBListUsersMCPTest(t *testing.T, vars map[string]string) { +func runAlloyDBListUsersMCPTest(t *testing.T, ctx context.Context, vars map[string]string) { type UsersResponse struct { Users []struct { Name string `json:"name"` @@ -549,7 +549,7 @@ func runAlloyDBListUsersMCPTest(t *testing.T, vars map[string]string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-list-users", tc.args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-list-users", tc.args, nil) if err != nil { t.Fatalf("native error executing: %s", err) } @@ -585,7 +585,7 @@ func runAlloyDBListUsersMCPTest(t *testing.T, vars map[string]string) { } } -func runAlloyDBGetClusterMCPTest(t *testing.T, vars map[string]string) { +func runAlloyDBGetClusterMCPTest(t *testing.T, ctx context.Context, vars map[string]string) { invokeTcs := []struct { name string args map[string]any @@ -635,7 +635,7 @@ func runAlloyDBGetClusterMCPTest(t *testing.T, vars map[string]string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-get-cluster", tc.args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-get-cluster", tc.args, nil) if err != nil { t.Fatalf("native error executing: %s", err) } @@ -671,7 +671,7 @@ func runAlloyDBGetClusterMCPTest(t *testing.T, vars map[string]string) { } } -func runAlloyDBGetInstanceMCPTest(t *testing.T, vars map[string]string) { +func runAlloyDBGetInstanceMCPTest(t *testing.T, ctx context.Context, vars map[string]string) { invokeTcs := []struct { name string args map[string]any @@ -728,7 +728,7 @@ func runAlloyDBGetInstanceMCPTest(t *testing.T, vars map[string]string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-get-instance", tc.args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-get-instance", tc.args, nil) if err != nil { t.Fatalf("native error executing: %s", err) } @@ -764,7 +764,7 @@ func runAlloyDBGetInstanceMCPTest(t *testing.T, vars map[string]string) { } } -func runAlloyDBGetUserMCPTest(t *testing.T, vars map[string]string) { +func runAlloyDBGetUserMCPTest(t *testing.T, ctx context.Context, vars map[string]string) { invokeTcs := []struct { name string args map[string]any @@ -821,7 +821,7 @@ func runAlloyDBGetUserMCPTest(t *testing.T, vars map[string]string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-get-user", tc.args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-get-user", tc.args, nil) if err != nil { t.Fatalf("native error executing: %s", err) } @@ -1027,7 +1027,7 @@ func TestAlloyDBCreateClusterMCP(t *testing.T) { t.Fatalf("failed to unmarshal body: %v", err) } - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-create-cluster", args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-create-cluster", args, nil) if err != nil { t.Fatalf("native error executing %s: %s", "alloydb-create-cluster", err) } @@ -1133,7 +1133,7 @@ func TestAlloyDBCreateInstanceMCP(t *testing.T) { t.Fatalf("failed to unmarshal body: %v", err) } - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-create-instance", args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-create-instance", args, nil) if err != nil { t.Fatalf("native error executing %s: %s", "alloydb-create-instance", err) } @@ -1254,7 +1254,7 @@ func TestAlloyDBCreateUserMCP(t *testing.T) { t.Fatalf("failed to unmarshal body: %v", err) } - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "alloydb-create-user", args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "alloydb-create-user", args, nil) if err != nil { t.Fatalf("native error executing %s: %s", "alloydb-create-user", err) } diff --git a/tests/alloydb/alloydb_wait_for_operation_mcp_test.go b/tests/alloydb/alloydb_wait_for_operation_mcp_test.go index 0bace16b796..ada9510793f 100644 --- a/tests/alloydb/alloydb_wait_for_operation_mcp_test.go +++ b/tests/alloydb/alloydb_wait_for_operation_mcp_test.go @@ -172,7 +172,7 @@ func TestWaitToolEndpointsMCP(t *testing.T) { t.Fatalf("failed to unmarshal body: %v", err) } - statusCode, mcpResp, err := tests.InvokeMCPTool(t, tc.toolName, args, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, tc.toolName, args, nil) if err != nil { t.Fatalf("native error executing %s: %s", tc.toolName, err) } diff --git a/tests/alloydbainl/alloydb_ai_nl_mcp_test.go b/tests/alloydbainl/alloydb_ai_nl_mcp_test.go index b25f8eedb8b..bbb7c513e26 100644 --- a/tests/alloydbainl/alloydb_ai_nl_mcp_test.go +++ b/tests/alloydbainl/alloydb_ai_nl_mcp_test.go @@ -16,6 +16,7 @@ package alloydbainl import ( "context" + "encoding/json" "net/http" "os" "regexp" @@ -97,7 +98,7 @@ func TestAlloyDBAINLListTools(t *testing.T) { }, } - tests.RunMCPToolsListMethod(t, expectedTools) + tests.RunMCPToolsListMethod(t, ctx, expectedTools) } func TestAlloyDBAINLCallTool(t *testing.T) { @@ -202,7 +203,7 @@ func TestAlloyDBAINLCallTool(t *testing.T) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, tc.toolName, tc.args, tc.requestHeader) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, tc.toolName, tc.args, tc.requestHeader) if err != nil { t.Fatalf("native error executing %s: %s", tc.toolName, err) } @@ -226,12 +227,17 @@ func TestAlloyDBAINLCallTool(t *testing.T) { if mcpResp.Result.IsError { t.Fatalf("expected success result, got tool error: %v", mcpResp.Result) } - if len(mcpResp.Result.Content) == 0 { - t.Fatalf("expected at least one content item, got none") + got := tests.GetMCPResultText(t, mcpResp) + var gotStr string + if len(got) == 1 { + gotBytes, _ := json.Marshal(got[0]) + gotStr = string(gotBytes) + } else { + gotBytes, _ := json.Marshal(got) + gotStr = string(gotBytes) } - got := mcpResp.Result.Content[0].Text - if got != tc.want { - t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + if gotStr != tc.want { + t.Fatalf("unexpected value: got %q, want %q", gotStr, tc.want) } } }) diff --git a/tests/alloydbpg/alloydb_pg_mcp_test.go b/tests/alloydbpg/alloydb_pg_mcp_test.go index f84f3ada9bb..8e6c3778f9e 100644 --- a/tests/alloydbpg/alloydb_pg_mcp_test.go +++ b/tests/alloydbpg/alloydb_pg_mcp_test.go @@ -138,7 +138,7 @@ func TestAlloyDBPgListTools(t *testing.T) { // We expect standard Postgres tools to be listed // This is a subset check, full list validation can be added if needed - _, tools, err := tests.GetMCPToolsList(t, nil, ctx) + _, tools, err := tests.GetMCPToolsList(t, ctx, nil) if err != nil { t.Fatalf("failed to get tools list: %v", err) } @@ -184,8 +184,8 @@ func TestAlloyDBPgCallTool(t *testing.T) { tests.RunMCPPostgresListActiveQueriesTest(t, ctx, pool) tests.RunMCPPostgresListTablesTest(t, ctx, pool, AlloyDBPostgresUser) tests.RunMCPPostgresListQueryStatsTest(t, ctx, pool) - tests.RunMCPPostgresListAvailableExtensionsTest(t) - tests.RunMCPPostgresListInstalledExtensionsTest(t) + tests.RunMCPPostgresListAvailableExtensionsTest(t, ctx) + tests.RunMCPPostgresListInstalledExtensionsTest(t, ctx) tests.RunMCPPostgresDatabaseOverviewTest(t, ctx, pool) tests.RunMCPPostgresListTriggersTest(t, ctx, pool) tests.RunMCPPostgresListIndexesTest(t, ctx, pool) @@ -196,7 +196,7 @@ func TestAlloyDBPgCallTool(t *testing.T) { tests.RunMCPPostgresGetColumnCardinalityTest(t, ctx, pool) tests.RunMCPPostgresListTableStatsTest(t, ctx, pool) tests.RunMCPPostgresListPublicationTablesTest(t, ctx, pool) - tests.RunMCPPostgresListTableSpacesTest(t) + tests.RunMCPPostgresListTableSpacesTest(t, ctx) tests.RunMCPPostgresListPgSettingsTest(t, ctx, pool) tests.RunMCPPostgresListDatabaseStatsTest(t, ctx, pool) tests.RunMCPPostgresListRolesTest(t, ctx, pool) diff --git a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go index 57fc7739d26..2e4a46fe388 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go @@ -16,92 +16,16 @@ package cloudsqlpg import ( "context" - "fmt" - "net" - "os" "regexp" "strings" "testing" "time" - "cloud.google.com/go/cloudsqlconn" "github.com/google/uuid" "github.com/googleapis/mcp-toolbox/internal/testutils" "github.com/googleapis/mcp-toolbox/tests" - "github.com/jackc/pgx/v5/pgxpool" ) -var ( - CloudSQLPostgresSourceType = "cloud-sql-postgres" - CloudSQLPostgresToolType = "postgres-sql" - CloudSQLPostgresProject = os.Getenv("CLOUD_SQL_POSTGRES_PROJECT") - CloudSQLPostgresRegion = os.Getenv("CLOUD_SQL_POSTGRES_REGION") - CloudSQLPostgresInstance = os.Getenv("CLOUD_SQL_POSTGRES_INSTANCE") - CloudSQLPostgresDatabase = os.Getenv("CLOUD_SQL_POSTGRES_DATABASE") - CloudSQLPostgresUser = os.Getenv("CLOUD_SQL_POSTGRES_USER") - CloudSQLPostgresPass = os.Getenv("CLOUD_SQL_POSTGRES_PASS") -) - -func getCloudSQLPgVars(t *testing.T) map[string]any { - switch "" { - case CloudSQLPostgresProject: - t.Fatal("'CLOUD_SQL_POSTGRES_PROJECT' not set") - case CloudSQLPostgresRegion: - t.Fatal("'CLOUD_SQL_POSTGRES_REGION' not set") - case CloudSQLPostgresInstance: - t.Fatal("'CLOUD_SQL_POSTGRES_INSTANCE' not set") - case CloudSQLPostgresDatabase: - t.Fatal("'CLOUD_SQL_POSTGRES_DATABASE' not set") - case CloudSQLPostgresUser: - t.Fatal("'CLOUD_SQL_POSTGRES_USER' not set") - case CloudSQLPostgresPass: - t.Fatal("'CLOUD_SQL_POSTGRES_PASS' not set") - } - - return map[string]any{ - "type": CloudSQLPostgresSourceType, - "project": CloudSQLPostgresProject, - "instance": CloudSQLPostgresInstance, - "region": CloudSQLPostgresRegion, - "database": CloudSQLPostgresDatabase, - "user": CloudSQLPostgresUser, - "password": CloudSQLPostgresPass, - } -} - -// Copied over from cloud_sql_pg.go -func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) { - // Configure the driver to connect to the database - dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname) - config, err := pgxpool.ParseConfig(dsn) - if err != nil { - return nil, fmt.Errorf("unable to parse connection uri: %w", err) - } - - // Create a new dialer with options - dialOpts, err := tests.GetCloudSQLDialOpts(ip_type) - if err != nil { - return nil, err - } - d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...)) - if err != nil { - return nil, fmt.Errorf("unable to parse connection uri: %w", err) - } - - // Tell the driver to use the Cloud SQL Go Connector to create connections - i := fmt.Sprintf("%s:%s:%s", project, region, instance) - config.ConnConfig.DialFunc = func(ctx context.Context, _ string, instance string) (net.Conn, error) { - return d.Dial(ctx, i) - } - - // Interact with the driver directly as you normally would - pool, err := pgxpool.NewWithConfig(context.Background(), config) - if err != nil { - return nil, err - } - return pool, nil -} - func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { sourceConfig := getCloudSQLPgVars(t) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -109,7 +33,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { args := []string{"--enable-api"} - pool, err := initCloudSQLPgConnectionPool(CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase) + pool, err := initCloudSQLPgConnectionPool(ctx, CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase) if err != nil { t.Fatalf("unable to create Cloud SQL connection pool: %s", err) } diff --git a/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go b/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go new file mode 100644 index 00000000000..8c93fda8fa9 --- /dev/null +++ b/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go @@ -0,0 +1,199 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlpg + +// TODO: We may want to add tests for custom tools defined in cloud-sql-postgres.yaml +// in the future, rather than just testing the prebuilt tools. + +import ( + "context" + "fmt" + "net" + "os" + "regexp" + "strings" + "testing" + "time" + + "cloud.google.com/go/cloudsqlconn" + "github.com/google/uuid" + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/tests" + "github.com/jackc/pgx/v5/pgxpool" +) + +var ( + CloudSQLPostgresSourceType = "cloud-sql-postgres" + CloudSQLPostgresToolType = "postgres-sql" + CloudSQLPostgresProject = os.Getenv("CLOUD_SQL_POSTGRES_PROJECT") + CloudSQLPostgresRegion = os.Getenv("CLOUD_SQL_POSTGRES_REGION") + CloudSQLPostgresInstance = os.Getenv("CLOUD_SQL_POSTGRES_INSTANCE") + CloudSQLPostgresDatabase = os.Getenv("CLOUD_SQL_POSTGRES_DATABASE") + CloudSQLPostgresUser = os.Getenv("CLOUD_SQL_POSTGRES_USER") + CloudSQLPostgresPass = os.Getenv("CLOUD_SQL_POSTGRES_PASS") +) + +func getCloudSQLPgVars(t *testing.T) map[string]any { + switch "" { + case CloudSQLPostgresProject: + t.Fatal("'CLOUD_SQL_POSTGRES_PROJECT' not set") + case CloudSQLPostgresRegion: + t.Fatal("'CLOUD_SQL_POSTGRES_REGION' not set") + case CloudSQLPostgresInstance: + t.Fatal("'CLOUD_SQL_POSTGRES_INSTANCE' not set") + case CloudSQLPostgresDatabase: + t.Fatal("'CLOUD_SQL_POSTGRES_DATABASE' not set") + case CloudSQLPostgresUser: + t.Fatal("'CLOUD_SQL_POSTGRES_USER' not set") + case CloudSQLPostgresPass: + t.Fatal("'CLOUD_SQL_POSTGRES_PASS' not set") + } + + return map[string]any{ + "type": CloudSQLPostgresSourceType, + "project": CloudSQLPostgresProject, + "instance": CloudSQLPostgresInstance, + "region": CloudSQLPostgresRegion, + "database": CloudSQLPostgresDatabase, + "user": CloudSQLPostgresUser, + "password": CloudSQLPostgresPass, + } +} + +func initCloudSQLPgConnectionPool(ctx context.Context, project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) { + dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname) + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse connection uri: %w", err) + } + + // Create a new dialer with options + dialOpts, err := tests.GetCloudSQLDialOpts(ip_type) + if err != nil { + return nil, err + } + d, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithDefaultDialOptions(dialOpts...)) + if err != nil { + return nil, fmt.Errorf("unable to parse connection uri: %w", err) + } + + // Tell the driver to use the Cloud SQL Go Connector to create connections + i := fmt.Sprintf("%s:%s:%s", project, region, instance) + config.ConnConfig.DialFunc = func(ctx context.Context, _ string, instance string) (net.Conn, error) { + return d.Dial(ctx, i) + } + + // Interact with the driver directly as you normally would + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, err + } + return pool, nil +} + +func TestCloudSQLPgListTools(t *testing.T) { + getCloudSQLPgVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + args := []string{"--prebuilt", "cloud-sql-postgres"} + + cmd, cleanup, err := tests.StartCmd(ctx, map[string]any{}, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %v", err) + } + defer cleanup() + + waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second) + defer cancelWait() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %v", err) + } + + // We expect standard Postgres tools to be listed + _, tools, err := tests.GetMCPToolsList(t, ctx, nil) + if err != nil { + t.Fatalf("failed to get tools list: %v", err) + } + + if len(tools) == 0 { + t.Errorf("expected tools to be listed, got none") + } +} + +func TestCloudSQLPgCallTool(t *testing.T) { + getCloudSQLPgVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + pool, err := initCloudSQLPgConnectionPool(ctx, CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase) + if err != nil { + t.Fatalf("unable to create Cloud SQL connection pool: %s", err) + } + // Note: Don't defer pool.Close() here - the pool is only used for test setup/teardown. + // Closing it causes indefinite hangs due to background goroutines in the driver. + + uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "") + + args := []string{"--prebuilt", "cloud-sql-postgres"} + + cmd, cleanup, err := tests.StartCmd(ctx, map[string]any{}, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %v", err) + } + defer cleanup() + + waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second) + defer cancelWait() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %v", err) + } + + // Run shared Postgres tests + tests.RunMCPPostgresListViewsTest(t, ctx, pool) + tests.RunMCPPostgresListSchemasTest(t, ctx, pool, CloudSQLPostgresUser, uniqueID) + tests.RunMCPPostgresListActiveQueriesTest(t, ctx, pool) + tests.RunMCPPostgresListAvailableExtensionsTest(t, ctx) + tests.RunMCPPostgresListInstalledExtensionsTest(t, ctx) + tests.RunMCPPostgresDatabaseOverviewTest(t, ctx, pool) + tests.RunMCPPostgresListTriggersTest(t, ctx, pool) + tests.RunMCPPostgresListIndexesTest(t, ctx, pool) + tests.RunMCPPostgresListSequencesTest(t, ctx, pool) + tests.RunMCPPostgresLongRunningTransactionsTest(t, ctx, pool) + tests.RunMCPPostgresListLocksTest(t, ctx, pool) + tests.RunMCPPostgresReplicationStatsTest(t, ctx, pool) + tests.RunMCPPostgresGetColumnCardinalityTest(t, ctx, pool) + tests.RunMCPPostgresListTableStatsTest(t, ctx, pool) + tests.RunMCPPostgresListPublicationTablesTest(t, ctx, pool) + tests.RunMCPPostgresListTableSpacesTest(t, ctx) + tests.RunMCPPostgresListPgSettingsTest(t, ctx, pool) + tests.RunMCPPostgresListDatabaseStatsTest(t, ctx, pool) + tests.RunMCPPostgresListRolesTest(t, ctx, pool) + tests.RunMCPPostgresListStoredProcedureTest(t, ctx, pool) + + toolsToTest := map[string]string{ + "list_autovacuum_configurations": `{}`, + "list_memory_configurations": `{}`, + "list_top_bloated_tables": `{"limit": 10}`, + "list_replication_slots": `{}`, + "list_invalid_indexes": `{}`, + "get_query_plan": `{"query": "SELECT 1"}`, + } + tests.RunMCPStatementToolsTest(t, ctx, toolsToTest) +} diff --git a/tests/cloudsqlpg/cloud_sql_pg_vectorassist_test.go b/tests/cloudsqlpg/cloud_sql_pg_vectorassist_test.go index 0f2c2159314..8a7b0d555a8 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_vectorassist_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_vectorassist_test.go @@ -104,7 +104,7 @@ func TestVectorAssistIntegration(t *testing.T) { args := []string{"--enable-api"} - pool, err := initCloudSQLPgConnectionPool(CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase) + pool, err := initCloudSQLPgConnectionPool(ctx, CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase) if err != nil { t.Fatalf("unable to create Cloud SQL connection pool: %s", err) } diff --git a/tests/http/http_mcp_test.go b/tests/http/http_mcp_test.go index cef039563ae..b00f49f93aa 100644 --- a/tests/http/http_mcp_test.go +++ b/tests/http/http_mcp_test.go @@ -545,7 +545,7 @@ func TestHTTPListTools(t *testing.T) { }, } - tests.RunMCPToolsListMethod(t, expectedTools) + tests.RunMCPToolsListMethod(t, ctx, expectedTools) } func TestHTTPCallTool(t *testing.T) { @@ -608,24 +608,24 @@ func TestHTTPCallTool(t *testing.T) { } // Run Generic Auth Tests - runGenericAuthMCPInvokeTest(t, privateKey) + runGenericAuthMCPInvokeTest(t, ctx, privateKey) // Run Advanced Tool Tests - runAdvancedHTTPMCPInvokeTest(t) + runAdvancedHTTPMCPInvokeTest(t, ctx) // Run Query Parameter Tests - runQueryParamMCPInvokeTest(t) + runQueryParamMCPInvokeTest(t, ctx) // Use shared helper for standard database tools t.Run("use shared RunMCPToolInvokeTest", func(t *testing.T) { - tests.RunMCPToolInvokeTest(t, `"hello world"`, + tests.RunMCPToolInvokeTest(t, ctx, `"hello world"`, tests.WithMyToolId3NameAliceWant(`{"id":1,"name":"Alice"}`), tests.WithMyToolById4Want(`{"id":4,"name":null}`), ) }) } -func runGenericAuthMCPInvokeTest(t *testing.T, privateKey *rsa.PrivateKey) { +func runGenericAuthMCPInvokeTest(t *testing.T, ctx context.Context, privateKey *rsa.PrivateKey) { // Generic Auth Success t.Run("invoke generic auth tool with valid token", func(t *testing.T) { // Generate valid token @@ -642,7 +642,7 @@ func runGenericAuthMCPInvokeTest(t *testing.T, privateKey *rsa.PrivateKey) { } headers := map[string]string{"my-generic-auth_token": signedString} - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "my-auth-required-generic-tool", map[string]any{}, headers) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "my-auth-required-generic-tool", map[string]any{}, headers) if err != nil { t.Fatalf("native error executing %s: %s", "my-auth-required-generic-tool", err) } @@ -656,7 +656,7 @@ func runGenericAuthMCPInvokeTest(t *testing.T, privateKey *rsa.PrivateKey) { // Auth Failure: Invoke generic auth tool without token t.Run("invoke generic auth tool without token", func(t *testing.T) { - statusCode, _, err := tests.InvokeMCPTool(t, "my-auth-required-generic-tool", map[string]any{}, nil) + statusCode, _, err := tests.InvokeMCPTool(t, ctx, "my-auth-required-generic-tool", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing %s: %s", "my-auth-required-generic-tool", err) } @@ -666,25 +666,25 @@ func runGenericAuthMCPInvokeTest(t *testing.T, privateKey *rsa.PrivateKey) { }) } -func runQueryParamMCPInvokeTest(t *testing.T) { +func runQueryParamMCPInvokeTest(t *testing.T, ctx context.Context) { // Query Parameter Variations: Tests with optional parameters omitted or nil t.Run("invoke query-param-tool optional omitted", func(t *testing.T) { arguments := map[string]any{"reqId": "test1"} - tests.RunMCPCustomToolCallMethod(t, "my-query-param-tool", arguments, `"reqId=test1"`) + tests.RunMCPCustomToolCallMethod(t, ctx, "my-query-param-tool", arguments, `"reqId=test1"`) }) t.Run("invoke query-param-tool some optional nil", func(t *testing.T) { arguments := map[string]any{"reqId": "test2", "page": "5", "filter": nil} - tests.RunMCPCustomToolCallMethod(t, "my-query-param-tool", arguments, `"page=5\u0026reqId=test2"`) // 'filter' omitted! + tests.RunMCPCustomToolCallMethod(t, ctx, "my-query-param-tool", arguments, `"page=5\u0026reqId=test2"`) // 'filter' omitted! }) t.Run("invoke query-param-tool some optional absent", func(t *testing.T) { arguments := map[string]any{"reqId": "test2", "page": "5"} - tests.RunMCPCustomToolCallMethod(t, "my-query-param-tool", arguments, `"page=5\u0026reqId=test2"`) // 'filter' omitted! + tests.RunMCPCustomToolCallMethod(t, ctx, "my-query-param-tool", arguments, `"page=5\u0026reqId=test2"`) // 'filter' omitted! }) t.Run("invoke query-param-tool required param nil", func(t *testing.T) { - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "my-query-param-tool", map[string]any{"reqId": nil, "page": "1"}, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "my-query-param-tool", map[string]any{"reqId": nil, "page": "1"}, nil) if err != nil { t.Fatalf("native error executing %s: %s", "my-query-param-tool", err) } @@ -695,7 +695,7 @@ func runQueryParamMCPInvokeTest(t *testing.T) { }) } -func runAdvancedHTTPMCPInvokeTest(t *testing.T) { +func runAdvancedHTTPMCPInvokeTest(t *testing.T, ctx context.Context) { // Mock Server Error: Invoke tool with parameters that cause the mock server to return 400 t.Run("invoke my-advanced-tool with wrong params causing mock 400", func(t *testing.T) { arguments := map[string]any{ @@ -705,7 +705,7 @@ func runAdvancedHTTPMCPInvokeTest(t *testing.T) { "country": "US", "X-Other-Header": "test", } - statusCode, mcpResp, err := tests.InvokeMCPTool(t, "my-advanced-tool", arguments, nil) + statusCode, mcpResp, err := tests.InvokeMCPTool(t, ctx, "my-advanced-tool", arguments, nil) if err != nil { t.Fatalf("native error executing %s: %s", "my-advanced-tool", err) } @@ -724,6 +724,6 @@ func runAdvancedHTTPMCPInvokeTest(t *testing.T) { "country": "US", "X-Other-Header": "test", } - tests.RunMCPCustomToolCallMethod(t, "my-advanced-tool", arguments, `"hello world"`) + tests.RunMCPCustomToolCallMethod(t, ctx, "my-advanced-tool", arguments, `"hello world"`) }) } diff --git a/tests/mcp_tool.go b/tests/mcp_tool.go index e204c73878d..4affde143f1 100644 --- a/tests/mcp_tool.go +++ b/tests/mcp_tool.go @@ -24,7 +24,6 @@ import ( "fmt" "io" "net/http" - "reflect" "sort" "strconv" "strings" @@ -133,7 +132,7 @@ func NewMCPRequestHeader(t *testing.T, customHeaders map[string]string) map[stri } // InvokeMCPTool is a transparent, native JSON-RPC execution harness for tests. -func InvokeMCPTool(t *testing.T, toolName string, arguments map[string]any, requestHeader map[string]string, ctx ...context.Context) (int, *MCPCallToolResponse, error) { +func InvokeMCPTool(t *testing.T, ctx context.Context, toolName string, arguments map[string]any, requestHeader map[string]string) (int, *MCPCallToolResponse, error) { headers := NewMCPRequestHeader(t, requestHeader) req := NewMCPCallToolRequest(uuid.New().String(), toolName, arguments) @@ -141,9 +140,8 @@ func InvokeMCPTool(t *testing.T, toolName string, arguments map[string]any, requ if err != nil { t.Fatalf("error marshalling request body: %v", err) } - // TODO: We are using variadic ctx here to avoid breaking existing callers. - // Once all tests are updated to pass context, make it a regular first argument. - resp, respBody := RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBody), headers, ctx...) + + resp, respBody := RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBody), headers, ctx) var mcpResp MCPCallToolResponse if err := json.Unmarshal(respBody, &mcpResp); err != nil { @@ -156,12 +154,13 @@ func InvokeMCPTool(t *testing.T, toolName string, arguments map[string]any, requ return resp.StatusCode, &mcpResp, nil } -// getMCPResultText safely extracts the text from content blocks, unmarshaling them if they are valid JSON. +// GetMCPResultText safely extracts the text from content blocks, unmarshaling them if they are valid JSON. // // TODO: For tests that need to strictly validate the exact schema or structure of the output, // consider avoiding this helper and instead unmarshal the raw JSON directly into expected Go structs for comparison. -func getMCPResultText(t *testing.T, resp *MCPCallToolResponse) []any { +func GetMCPResultText(t *testing.T, resp *MCPCallToolResponse) []any { if len(resp.Result.Content) == 0 { + // Return an initialized empty slice instead of nil to avoid false failures with cmp.Diff return []any{} } @@ -180,13 +179,14 @@ func getMCPResultText(t *testing.T, resp *MCPCallToolResponse) []any { } if res == nil { + // Return an initialized empty slice instead of nil to avoid false failures with cmp.Diff return []any{} } return res } // GetMCPToolsList is a JSON-RPC harness that fetches the tools/list registry. -func GetMCPToolsList(t *testing.T, requestHeader map[string]string, ctx ...context.Context) (int, []any, error) { +func GetMCPToolsList(t *testing.T, ctx context.Context, requestHeader map[string]string) (int, []any, error) { headers := NewMCPRequestHeader(t, requestHeader) req := MCPListToolsRequest{ @@ -199,7 +199,7 @@ func GetMCPToolsList(t *testing.T, requestHeader map[string]string, ctx ...conte t.Fatalf("error marshalling tools/list request body: %v", err) } - resp, respBody := RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBody), headers, ctx...) + resp, respBody := RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBody), headers, ctx) var mcpResp jsonrpc.JSONRPCResponse if err := json.Unmarshal(respBody, &mcpResp); err != nil { @@ -244,9 +244,9 @@ func AssertMCPError(t *testing.T, mcpResp *MCPCallToolResponse, wantErrMsg strin } // RunMCPToolsListMethod calls tools/list and verifies that the returned tools match the expected list. -func RunMCPToolsListMethod(t *testing.T, expectedOutput []MCPToolManifest) { +func RunMCPToolsListMethod(t *testing.T, ctx context.Context, expectedOutput []MCPToolManifest) { t.Helper() - statusCodeList, toolsList, errList := GetMCPToolsList(t, nil) + statusCodeList, toolsList, errList := GetMCPToolsList(t, ctx, nil) if errList != nil { t.Fatalf("native error executing tools/list: %s", errList) } @@ -274,9 +274,8 @@ func RunMCPToolsListMethod(t *testing.T, expectedOutput []MCPToolManifest) { for _, actual := range actualTools { if actual.Name == expected.Name { found = true - // Use reflect.DeepEqual to check all fields (description, parameters, etc.) - if !reflect.DeepEqual(actual, expected) { - t.Fatalf("tool %s mismatch:\nwant: %+v\ngot: %+v", expected.Name, expected, actual) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Fatalf("tool %s mismatch (-want +got):\n%s", expected.Name, diff) } break } @@ -288,9 +287,9 @@ func RunMCPToolsListMethod(t *testing.T, expectedOutput []MCPToolManifest) { } // RunMCPCustomToolCallMethod invokes a tool and compares the result with expected output. -func RunMCPCustomToolCallMethod(t *testing.T, toolName string, arguments map[string]any, want string) { +func RunMCPCustomToolCallMethod(t *testing.T, ctx context.Context, toolName string, arguments map[string]any, want string) { t.Helper() - statusCode, mcpResp, err := InvokeMCPTool(t, toolName, arguments, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, toolName, arguments, nil) if err != nil { t.Fatalf("native error executing %s: %s", toolName, err) } @@ -300,7 +299,7 @@ func RunMCPCustomToolCallMethod(t *testing.T, toolName string, arguments map[str if mcpResp.Result.IsError { t.Fatalf("%s returned error result: %v", toolName, mcpResp.Result) } - got := getMCPResultText(t, mcpResp) + got := GetMCPResultText(t, mcpResp) gotBytes, _ := json.Marshal(got) gotStr := string(gotBytes) if !strings.Contains(gotStr, want) { @@ -310,7 +309,7 @@ func RunMCPCustomToolCallMethod(t *testing.T, toolName string, arguments map[str } // RunMCPToolInvokeTest runs the tool invoke test cases over MCP protocol. -func RunMCPToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOption) { +func RunMCPToolInvokeTest(t *testing.T, ctx context.Context, select1Want string, options ...InvokeTestOption) { t.Helper() // Resolve options using existing InvokeTestOption and InvokeTestConfig from option.go configs := &InvokeTestConfig{ @@ -387,7 +386,7 @@ func RunMCPToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTes if !tc.enabled { t.Skip("skipping disabled test case") } - statusCode, mcpResp, err := InvokeMCPTool(t, tc.toolName, tc.args, tc.headers) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, tc.toolName, tc.args, tc.headers) if err != nil { t.Fatalf("native error executing %s: %s", tc.toolName, err) } @@ -401,7 +400,7 @@ func RunMCPToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTes if mcpResp.Result.IsError { t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result) } - got := getMCPResultText(t, mcpResp) + got := GetMCPResultText(t, mcpResp) gotBytes, _ := json.Marshal(got) gotStr := string(gotBytes) if !strings.Contains(gotStr, tc.wantResult) { @@ -455,7 +454,7 @@ func RunMCPPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpoo } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_views", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_views", tc.args, nil) if err != nil { t.Fatalf("native error executing list_views: %s", err) } @@ -469,7 +468,7 @@ func RunMCPPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpoo t.Fatalf("list_views returned error result: %v", mcpResp.Result) } - got := getMCPResultText(t, mcpResp) + got := GetMCPResultText(t, mcpResp) var wantObj []any if err := json.Unmarshal([]byte(tc.want), &wantObj); err != nil { @@ -599,7 +598,7 @@ func RunMCPPostgresListTablesTest(t *testing.T, ctx context.Context, pool *pgxpo for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_tables", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_tables", tc.args, nil) if err != nil { t.Fatalf("native error executing list_tables: %s", err) } @@ -621,7 +620,7 @@ func RunMCPPostgresListTablesTest(t *testing.T, ctx context.Context, pool *pgxpo t.Fatalf("list_tables returned error result: %v", mcpResp.Result) } - got := getMCPResultText(t, mcpResp) + got := GetMCPResultText(t, mcpResp) var wantObj []any if err := json.Unmarshal([]byte(tc.want), &wantObj); err != nil { @@ -697,7 +696,7 @@ func RunMCPPostgresListQueryStatsTest(t *testing.T, ctx context.Context, pool *p for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_query_stats", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_query_stats", tc.args, nil) if err != nil { t.Fatalf("native error executing list_query_stats: %s", err) } @@ -712,7 +711,7 @@ func RunMCPPostgresListQueryStatsTest(t *testing.T, ctx context.Context, pool *p t.Fatalf("list_query_stats returned error result: %v", mcpResp.Result) } - got := getMCPResultText(t, mcpResp) + got := GetMCPResultText(t, mcpResp) // Verify that we got a list (even if empty) if got == nil { @@ -777,7 +776,7 @@ func RunMCPPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxp } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_schemas", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_schemas", tc.args, nil) if err != nil { t.Fatalf("native error executing list_schemas: %s", err) } @@ -790,7 +789,7 @@ func RunMCPPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxp if mcpResp.Result.IsError { t.Fatalf("list_schemas returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if tc.compareSubset { found := false @@ -893,7 +892,7 @@ func RunMCPPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) } - statusCode, mcpResp, err := InvokeMCPTool(t, "list_active_queries", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_active_queries", tc.args, nil) if err != nil { t.Fatalf("native error executing list_active_queries: %s", err) } @@ -907,7 +906,7 @@ func RunMCPPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool t.Fatalf("list_active_queries returned error result: %v", mcpResp.Result) } var details []queryListDetails - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) for _, item := range gotObj { if m, ok := item.(map[string]any); ok { if q, ok := m["query"].(string); ok { @@ -946,8 +945,8 @@ func RunMCPPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool } // RunMCPPostgresListAvailableExtensionsTest tests the list_available_extensions tool via MCP. -func RunMCPPostgresListAvailableExtensionsTest(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_available_extensions", map[string]any{}, nil) +func RunMCPPostgresListAvailableExtensionsTest(t *testing.T, ctx context.Context) { + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_available_extensions", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing list_available_extensions: %s", err) } @@ -960,8 +959,8 @@ func RunMCPPostgresListAvailableExtensionsTest(t *testing.T) { } // RunMCPPostgresListInstalledExtensionsTest tests the list_installed_extensions tool via MCP. -func RunMCPPostgresListInstalledExtensionsTest(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_installed_extensions", map[string]any{}, nil) +func RunMCPPostgresListInstalledExtensionsTest(t *testing.T, ctx context.Context) { + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_installed_extensions", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing list_installed_extensions: %s", err) } @@ -1087,7 +1086,7 @@ func RunMCPPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgx } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_triggers", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_triggers", tc.args, nil) if err != nil { t.Fatalf("native error executing list_triggers: %s", err) } @@ -1100,7 +1099,7 @@ func RunMCPPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgx if mcpResp.Result.IsError { t.Fatalf("list_triggers returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if tc.compareSubset { found := false @@ -1191,7 +1190,7 @@ func RunMCPPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pg } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_sequences", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_sequences", tc.args, nil) if err != nil { t.Fatalf("native error executing list_sequences: %s", err) } @@ -1204,7 +1203,7 @@ func RunMCPPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pg if mcpResp.Result.IsError { t.Fatalf("list_sequences returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) wantObj := []any{} for _, item := range tc.want { wantObj = append(wantObj, item) @@ -1325,7 +1324,7 @@ func RunMCPPostgresListIndexesTest(t *testing.T, ctx context.Context, pool *pgxp for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_indexes", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_indexes", tc.args, nil) if err != nil { t.Fatalf("native error executing list_indexes: %s", err) } @@ -1338,7 +1337,7 @@ func RunMCPPostgresListIndexesTest(t *testing.T, ctx context.Context, pool *pgxp if mcpResp.Result.IsError { t.Fatalf("list_indexes returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if tc.compareSubset { for _, wantIdx := range tc.want { @@ -1514,7 +1513,7 @@ func RunMCPPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, po } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_stored_procedure", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_stored_procedure", tc.args, nil) if err != nil { t.Fatalf("native error executing list_stored_procedure: %s", err) } @@ -1528,7 +1527,7 @@ func RunMCPPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, po t.Fatalf("list_stored_procedure returned error result: %v", mcpResp.Result) } var gotObj []storedProcedureDetails - got := getMCPResultText(t, mcpResp) + got := GetMCPResultText(t, mcpResp) for _, item := range got { if m, ok := item.(map[string]any); ok { proc := storedProcedureDetails{} @@ -1565,7 +1564,7 @@ func RunMCPPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, po // RunMCPPostgresDatabaseOverviewTest tests the database_overview tool via MCP. func RunMCPPostgresDatabaseOverviewTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { - statusCode, mcpResp, err := InvokeMCPTool(t, "database_overview", map[string]any{}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "database_overview", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing database_overview: %s", err) } @@ -1575,7 +1574,7 @@ func RunMCPPostgresDatabaseOverviewTest(t *testing.T, ctx context.Context, pool if mcpResp.Result.IsError { t.Fatalf("database_overview returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if len(gotObj) != 1 { t.Fatalf("Expected exactly one row in the result, got %d", len(gotObj)) @@ -1633,7 +1632,7 @@ func RunMCPPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpoo cleanup := CreateAndLockMCPPostgresTable(t, ctx, pool, "test_postgres_list_locks_table") defer cleanup() - statusCode, mcpResp, err := InvokeMCPTool(t, "list_locks", map[string]any{}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_locks", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing list_locks: %s", err) } @@ -1643,7 +1642,7 @@ func RunMCPPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpoo if mcpResp.Result.IsError { t.Fatalf("list_locks returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if len(gotObj) == 0 { t.Errorf("Expected to find locks, got none") @@ -1653,7 +1652,7 @@ func RunMCPPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpoo // RunMCPPostgresLongRunningTransactionsTest tests the long_running_transactions tool via MCP. func RunMCPPostgresLongRunningTransactionsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { - statusCode, mcpResp, err := InvokeMCPTool(t, "long_running_transactions", map[string]any{}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "long_running_transactions", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing long_running_transactions: %s", err) } @@ -1667,7 +1666,7 @@ func RunMCPPostgresLongRunningTransactionsTest(t *testing.T, ctx context.Context // RunMCPPostgresReplicationStatsTest tests the replication_stats tool via MCP. func RunMCPPostgresReplicationStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { - statusCode, mcpResp, err := InvokeMCPTool(t, "replication_stats", map[string]any{}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "replication_stats", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing replication_stats: %s", err) } @@ -1752,7 +1751,7 @@ func RunMCPPostgresGetColumnCardinalityTest(t *testing.T, ctx context.Context, p } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "get_column_cardinality", tc.args, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "get_column_cardinality", tc.args, nil) if err != nil { t.Fatalf("native error executing get_column_cardinality: %s", err) } @@ -1772,7 +1771,7 @@ func RunMCPPostgresGetColumnCardinalityTest(t *testing.T, ctx context.Context, p t.Logf("DEBUG: get_column_cardinality returned empty content as expected for non-existent table") return } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if tc.shouldHaveData { if len(gotObj) == 0 { @@ -1910,7 +1909,7 @@ func RunMCPPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *p for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_table_stats", tc.arguments, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_table_stats", tc.arguments, nil) if err != nil { t.Fatalf("native error executing list_table_stats: %s", err) } @@ -1924,7 +1923,7 @@ func RunMCPPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *p t.Fatalf("list_table_stats returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) // Verify expected data presence if tc.shouldHaveData { @@ -2019,7 +2018,7 @@ func RunMCPPostgresListPublicationTablesTest(t *testing.T, ctx context.Context, t.Fatalf("unable to fetch current user: %v", err) } - statusCode, mcpResp, err := InvokeMCPTool(t, "list_publication_tables", map[string]any{}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_publication_tables", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing list_publication_tables: %s", err) } @@ -2029,7 +2028,7 @@ func RunMCPPostgresListPublicationTablesTest(t *testing.T, ctx context.Context, if mcpResp.Result.IsError { t.Fatalf("list_publication_tables returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) found := false for _, rowObj := range gotObj { @@ -2049,8 +2048,8 @@ func RunMCPPostgresListPublicationTablesTest(t *testing.T, ctx context.Context, } // RunMCPPostgresListTableSpacesTest tests the list_tablespaces tool via MCP. -func RunMCPPostgresListTableSpacesTest(t *testing.T) { - statusCode, mcpResp, err := InvokeMCPTool(t, "list_tablespaces", map[string]any{}, nil) +func RunMCPPostgresListTableSpacesTest(t *testing.T, ctx context.Context) { + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_tablespaces", map[string]any{}, nil) if err != nil { t.Fatalf("native error executing list_tablespaces: %s", err) } @@ -2094,7 +2093,7 @@ func RunMCPPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *p "requires_restart": requiresRestart, } - statusCode, mcpResp, err := InvokeMCPTool(t, "list_pg_settings", map[string]any{"setting_name": targetSetting}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_pg_settings", map[string]any{"setting_name": targetSetting}, nil) if err != nil { t.Fatalf("native error executing list_pg_settings: %s", err) } @@ -2104,7 +2103,7 @@ func RunMCPPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *p if mcpResp.Result.IsError { t.Fatalf("list_pg_settings returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) if len(gotObj) != 1 { t.Fatalf("Expected exactly one row in the result, got %d", len(gotObj)) @@ -2150,7 +2149,7 @@ func RunMCPPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool cleanup1 := setUpMCPDatabase(t, ctx, pool, dbName1, dbOwner1) defer cleanup1() - statusCode, mcpResp, err := InvokeMCPTool(t, "list_database_stats", map[string]any{"database_name": dbName1}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_database_stats", map[string]any{"database_name": dbName1}, nil) if err != nil { t.Fatalf("native error executing list_database_stats: %s", err) } @@ -2160,7 +2159,7 @@ func RunMCPPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool if mcpResp.Result.IsError { t.Fatalf("list_database_stats returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) found := false for _, rowObj := range gotObj { @@ -2224,7 +2223,7 @@ func RunMCPPostgresListRolesTest(t *testing.T, ctx context.Context, pool *pgxpoo adminUser, _, _, cleanup := setupMCPPostgresRoles(t, ctx, pool) defer cleanup(t) - statusCode, mcpResp, err := InvokeMCPTool(t, "list_roles", map[string]any{"role_name": "test_role_"}, nil) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, "list_roles", map[string]any{"role_name": "test_role_"}, nil) if err != nil { t.Fatalf("native error executing list_roles: %s", err) } @@ -2234,7 +2233,7 @@ func RunMCPPostgresListRolesTest(t *testing.T, ctx context.Context, pool *pgxpoo if mcpResp.Result.IsError { t.Fatalf("list_roles returned error result: %v", mcpResp.Result) } - gotObj := getMCPResultText(t, mcpResp) + gotObj := GetMCPResultText(t, mcpResp) found := false for _, rowObj := range gotObj { @@ -2263,7 +2262,7 @@ func RunMCPStatementToolsTest(t *testing.T, ctx context.Context, tools map[strin t.Fatalf("failed to unmarshal paramBody: %v", err) } } - statusCode, mcpResp, err := InvokeMCPTool(t, toolName, args, nil, ctx) + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, toolName, args, nil) if err != nil { t.Fatalf("native error executing %s: %s", toolName, err) }