From 5def18a60e48955e28bcf15bc8581ad2afb974e9 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Wed, 2 Apr 2025 18:07:21 +0200 Subject: [PATCH 01/10] feat: MCP implementation [IDE-1099] (#815) --- application/server/server.go | 29 +- application/server/server_test.go | 39 -- domain/ide/command/execute_mcp_test.go | 105 ---- internal/mcp/llm_binding.go | 87 ++- internal/mcp/llm_binding_smoke_test.go | 114 ---- internal/mcp/llm_binding_test.go | 11 +- internal/mcp/options.go | 20 +- internal/mcp/scan_tool.go | 152 +++-- internal/mcp/scan_tool_test.go | 830 +++++++++++++++++++++++++ internal/mcp/snyk_tools.json | 216 +++++++ internal/mcp/utils.go | 108 ++++ mcp_extension/main.go | 111 ++++ mcp_extension/main_test.go | 39 ++ 13 files changed, 1483 insertions(+), 378 deletions(-) delete mode 100644 domain/ide/command/execute_mcp_test.go delete mode 100644 internal/mcp/llm_binding_smoke_test.go create mode 100644 internal/mcp/scan_tool_test.go create mode 100644 internal/mcp/snyk_tools.json create mode 100644 internal/mcp/utils.go create mode 100644 mcp_extension/main.go create mode 100644 mcp_extension/main_test.go diff --git a/application/server/server.go b/application/server/server.go index 06885179d..9bd10db50 100644 --- a/application/server/server.go +++ b/application/server/server.go @@ -28,7 +28,6 @@ import ( "github.com/snyk/snyk-ls/domain/snyk" "github.com/snyk/snyk-ls/domain/snyk/persistence" - mcp2 "github.com/snyk/snyk-ls/internal/mcp" "github.com/snyk/snyk-ls/internal/storedconfig" "github.com/snyk/snyk-ls/domain/snyk/scanner" @@ -78,25 +77,6 @@ func Start(c *config.Config) { di.Init() initHandlers(srv, handlers, c) - // start mcp server - logger.Info().Msg("Starting up MCP Server...") - var mcpServer *mcp2.McpLLMBinding - go func() { - mcpServer = mcp2.NewMcpLLMBinding(c, mcp2.WithScanner(di.Scanner()), mcp2.WithLogger(c.Logger())) - err := mcpServer.Start() - if err != nil { - c.Logger().Err(err).Msg("failed to start mcp server") - } - }() - - // shutdown mcp server once the lsp returns from wait status - defer func() { - if mcpServer != nil { - logger.Info().Msg("Shutting down MCP Server...") - mcpServer.Shutdown(context.Background()) - } - }() - logger.Info().Msg("Starting up Language Server...") srv = srv.Start(channel.Header("")(os.Stdin, os.Stdout)) status := srv.WaitStatus() @@ -477,14 +457,7 @@ func initializedHandler(c *config.Config, srv *jrpc2.Server) handler.Func { ) logger.Info().Msg(msg) } - defer func() { - // delay sending the mcp server URL - for c.GetMCPServerURL() == nil { - // wait until the server URL is available - time.Sleep(500 * time.Millisecond) - } - di.Notifier().Send(types.McpServerURLParams{URL: c.GetMCPServerURL().String()}) - }() + return nil, nil }) } diff --git a/application/server/server_test.go b/application/server/server_test.go index 151ded2d0..140154df0 100644 --- a/application/server/server_test.go +++ b/application/server/server_test.go @@ -18,7 +18,6 @@ package server import ( "context" - "net/url" "os" "os/exec" "path/filepath" @@ -238,44 +237,6 @@ func Test_initialized_shouldCheckRequiredProtocolVersion(t *testing.T) { "did not receive callback because of wrong protocol version") } -func Test_initialized_shouldSendMcpServerAddress(t *testing.T) { - c := testutil.UnitTest(t) - loc, jsonRpcRecorder := setupServer(t, c) - - params := types.InitializeParams{ - InitializationOptions: types.Settings{RequiredProtocolVersion: config.LsProtocolVersion}, - } - - rsp, err := loc.Client.Call(ctx, "initialize", params) - require.NoError(t, err) - var result types.InitializeResult - err = rsp.UnmarshalResult(&result) - require.NoError(t, err) - - testURL, err := url.Parse("http://localhost:1234") - require.NoError(t, err) - - c.SetMCPServerURL(testURL) - - _, err = loc.Client.Call(ctx, "initialized", params) - require.NoError(t, err) - require.Eventuallyf(t, func() bool { - n := jsonRpcRecorder.FindNotificationsByMethod("$/snyk.mcpServerURL") - if n == nil { - return false - } - if len(n) > 1 { - t.Fatal("can't succeed anymore, too many notifications ", n) - } - - var param types.McpServerURLParams - err = n[0].UnmarshalParams(¶m) - require.NoError(t, err) - return param.URL == testURL.String() - }, time.Minute*5, time.Millisecond, - "did not receive mcp server url") -} - func Test_initialize_shouldSupportAllCommands(t *testing.T) { c := testutil.UnitTest(t) loc, _ := setupServer(t, c) diff --git a/domain/ide/command/execute_mcp_test.go b/domain/ide/command/execute_mcp_test.go deleted file mode 100644 index ac173ed1c..000000000 --- a/domain/ide/command/execute_mcp_test.go +++ /dev/null @@ -1,105 +0,0 @@ -//go:build !race -// +build !race - -/* - * © 2024 Snyk Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package command - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/snyk/snyk-ls/domain/ide/hover" - "github.com/snyk/snyk-ls/domain/ide/workspace" - "github.com/snyk/snyk-ls/domain/scanstates" - "github.com/snyk/snyk-ls/domain/snyk/persistence" - "github.com/snyk/snyk-ls/domain/snyk/scanner" - "github.com/snyk/snyk-ls/internal/mcp" - noti "github.com/snyk/snyk-ls/internal/notification" - "github.com/snyk/snyk-ls/internal/observability/performance" - "github.com/snyk/snyk-ls/internal/testutil" - "github.com/snyk/snyk-ls/internal/types" -) - -func Test_executeMcpCallCommand(t *testing.T) { - c := testutil.UnitTest(t) - c.SetAutomaticScanning(false) - - // start mcp server - sc := scanner.NewTestScanner() - scanNotifier := scanner.NewMockScanNotifier() - hoverService := hover.NewFakeHoverService() - notifier := noti.NewMockNotifier() - emitter := scanstates.NewSummaryEmitter(c, notifier) - scanStateAggregator := scanstates.NewScanStateAggregator(c, emitter) - scanPersister := persistence.NewNopScanPersister() - w := workspace.New(c, performance.NewInstrumentor(), sc, hoverService, scanNotifier, notifier, scanPersister, scanStateAggregator) - - f := workspace.NewFolder( - c, - "testPath", - "test", - sc, - hoverService, - scanNotifier, - notifier, - scanPersister, - scanStateAggregator, - ) - - w.AddFolder(f) - c.SetWorkspace(w) - - mcpBinding := mcp.NewMcpLLMBinding(c, mcp.WithLogger(c.Logger()), mcp.WithScanner(sc)) - go func() { - _ = mcpBinding.Start() - }() - - t.Cleanup(func() { - timeout, cancelFunc := context.WithTimeout(context.Background(), time.Second) - mcpBinding.Shutdown(timeout) - defer cancelFunc() - }) - - // wait for mcp server to start - assert.Eventually(t, func() bool { - return mcpBinding.Started() - }, time.Minute, time.Millisecond) - - // create command - command := executeMcpCallCommand{ - command: types.CommandData{ - Title: "Execute Snyk Scan", - CommandId: types.ExecuteMCPToolCall, - Arguments: []any{mcp.SnykScanWorkspaceScan}, - }, - notifier: noti.NewMockNotifier(), - logger: c.Logger(), - baseURL: c.GetMCPServerURL().String(), - } - - // execute command - _, err := command.Execute(context.Background()) - require.NoError(t, err) - require.Eventuallyf(t, func() bool { - return sc.Calls() > 0 - }, time.Minute, time.Millisecond, "should have called the scanner") -} diff --git a/internal/mcp/llm_binding.go b/internal/mcp/llm_binding.go index 193f4589e..06788e458 100644 --- a/internal/mcp/llm_binding.go +++ b/internal/mcp/llm_binding.go @@ -27,29 +27,29 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/pkg/errors" "github.com/rs/zerolog" + "github.com/snyk/go-application-framework/pkg/workflow" +) - "github.com/snyk/snyk-ls/application/config" - "github.com/snyk/snyk-ls/internal/types" +const ( + SseTransportType string = "sse" + StdioTransportType string = "stdio" ) // McpLLMBinding is an implementation of a mcp server that allows interaction between // a given SnykLLMBinding and a CommandService. type McpLLMBinding struct { - c *config.Config - scanner types.Scanner - logger *zerolog.Logger - mcpServer *server.MCPServer - sseServer *server.SSEServer - baseURL *url.URL - forwardingResultProcessor types.ScanResultProcessor - mutex sync.Mutex - started bool + logger *zerolog.Logger + mcpServer *server.MCPServer + sseServer *server.SSEServer + baseURL *url.URL + mutex sync.RWMutex + started bool + cliPath string } -func NewMcpLLMBinding(c *config.Config, opts ...McpOption) *McpLLMBinding { +func NewMcpLLMBinding(opts ...Option) *McpLLMBinding { logger := zerolog.Nop() mcpServerImpl := &McpLLMBinding{ - c: c, logger: &logger, } @@ -70,30 +70,60 @@ func defaultURL() *url.URL { } // Start starts the MCP server. It blocks until the server is stopped via Shutdown. -func (m *McpLLMBinding) Start() error { - // protect critical assignments with mutex - m.mutex.Lock() +func (m *McpLLMBinding) Start(invocationContext workflow.InvocationContext) error { + runTimeInfo := invocationContext.GetRuntimeInfo() + version := "" + if runTimeInfo != nil { + version = runTimeInfo.GetVersion() + } m.mcpServer = server.NewMCPServer( "Snyk MCP Server", - config.Version, + version, server.WithLogging(), server.WithResourceCapabilities(true, true), server.WithPromptCapabilities(true), ) - err := m.addSnykScanTool() + err := m.addSnykTools(invocationContext) if err != nil { - m.mutex.Unlock() return err } + transportType := invocationContext.GetConfiguration().GetString("transport") + if transportType == StdioTransportType { + return m.HandleStdioServer() + } else if transportType == SseTransportType { + return m.HandleSseServer() + } else { + return fmt.Errorf("invalid transport type: %s", transportType) + } +} + +func (m *McpLLMBinding) HandleStdioServer() error { + m.mutex.Lock() + m.started = true + m.mutex.Unlock() + + err := server.ServeStdio(m.mcpServer) + + if err != nil { + m.logger.Error().Err(err).Msg("Error starting MCP Stdio server") + return err + } + + return nil +} + +func (m *McpLLMBinding) HandleSseServer() error { // listen on default url/port if none was configured if m.baseURL == nil { m.baseURL = defaultURL() } m.sseServer = server.NewSSEServer(m.mcpServer, m.baseURL.String()) - m.mutex.Unlock() + + //nolint:forbidigo // stdio stream isn't started yet + fmt.Printf("Starting with base URL %s\n", m.baseURL.String()) m.logger.Info().Str("baseURL", m.baseURL.String()).Msg("starting") go func() { @@ -106,10 +136,10 @@ func (m *McpLLMBinding) Start() error { m.mutex.Lock() m.logger.Info().Str("baseURL", m.baseURL.String()).Msg("started") m.started = true - m.c.SetMCPServerURL(m.baseURL) m.mutex.Unlock() }() - err = m.sseServer.Start(m.baseURL.Host) + + err := m.sseServer.Start(m.baseURL.Host) if err != nil { // expect http.ErrServerClosed when shutting down if !errors.Is(err, http.ErrServerClosed) { @@ -124,14 +154,17 @@ func (m *McpLLMBinding) Shutdown(ctx context.Context) { m.mutex.Lock() defer m.mutex.Unlock() - err := m.sseServer.Shutdown(ctx) - if err != nil { - m.logger.Error().Err(err).Msg("Error shutting down MCP SSE server") + if m.sseServer != nil { + err := m.sseServer.Shutdown(ctx) + if err != nil { + m.logger.Error().Err(err).Msg("Error shutting down MCP SSE server") + } } } func (m *McpLLMBinding) Started() bool { - m.mutex.Lock() - defer m.mutex.Unlock() + m.mutex.RLock() + defer m.mutex.RUnlock() + return m.started } diff --git a/internal/mcp/llm_binding_smoke_test.go b/internal/mcp/llm_binding_smoke_test.go deleted file mode 100644 index bea283253..000000000 --- a/internal/mcp/llm_binding_smoke_test.go +++ /dev/null @@ -1,114 +0,0 @@ -//go:build !race -// +build !race - -/* - * © 2025 Snyk Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mcp - -import ( - "context" - "testing" - "time" - - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/mcp" - "github.com/stretchr/testify/assert" - - "github.com/snyk/snyk-ls/domain/ide/hover" - "github.com/snyk/snyk-ls/domain/ide/workspace" - "github.com/snyk/snyk-ls/domain/scanstates" - "github.com/snyk/snyk-ls/domain/snyk/persistence" - "github.com/snyk/snyk-ls/domain/snyk/scanner" - noti "github.com/snyk/snyk-ls/internal/notification" - "github.com/snyk/snyk-ls/internal/observability/performance" - "github.com/snyk/snyk-ls/internal/testutil" -) - -// Test_WorkspaceScan does not run in race mode, due to races in the underlying framework -func Test_WorkspaceScan(t *testing.T) { - c := testutil.SmokeTest(t, false) - sc := scanner.NewTestScanner() - scanNotifier := scanner.NewMockScanNotifier() - hoverService := hover.NewFakeHoverService() - notifier := noti.NewMockNotifier() - emitter := scanstates.NewSummaryEmitter(c, notifier) - scanStateAggregator := scanstates.NewScanStateAggregator(c, emitter) - scanPersister := persistence.NewNopScanPersister() - w := workspace.New(c, performance.NewInstrumentor(), sc, hoverService, scanNotifier, notifier, scanPersister, scanStateAggregator) - - f := workspace.NewFolder( - c, - "testPath", - "test", - sc, - hoverService, - scanNotifier, - notifier, - scanPersister, - scanStateAggregator, - ) - - w.AddFolder(f) - - c.SetWorkspace(w) - server := NewMcpLLMBinding(c, WithScanner(sc), WithLogger(c.Logger())) - - go func() { - err := server.Start() - assert.NoError(t, err) - }() - - assert.Eventually(t, func() bool { - server.mutex.Lock() - defer server.mutex.Unlock() - portInUse := isPortInUse(server.baseURL) - return portInUse && server.baseURL == c.GetMCPServerURL() - }, time.Minute, time.Second) - - clientEndpoint := server.baseURL.String() + "/sse" - - mcpClient, err := client.NewSSEMCPClient(clientEndpoint) - assert.NoError(t, err) - defer mcpClient.Close() - - // start - err = mcpClient.Start(context.Background()) - assert.NoError(t, err) - - // initialize - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "example-client", - Version: "1.0.0", - } - - _, err = mcpClient.Initialize(context.Background(), initRequest) - assert.NoError(t, err) - - toolsRequest := mcp.ListToolsRequest{} - tools, err := mcpClient.ListTools(context.Background(), toolsRequest) - assert.NoError(t, err) - assert.Len(t, tools.Tools, 1) - - scanRequest := mcp.CallToolRequest{} - scanRequest.Params.Name = SnykScanWorkspaceScan - - result, err := mcpClient.CallTool(context.Background(), scanRequest) - assert.NoError(t, err) - assert.NotNil(t, result) -} diff --git a/internal/mcp/llm_binding_test.go b/internal/mcp/llm_binding_test.go index 2d0dce314..e060c98d9 100644 --- a/internal/mcp/llm_binding_test.go +++ b/internal/mcp/llm_binding_test.go @@ -20,26 +20,19 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/snyk/snyk-ls/domain/snyk/scanner" - "github.com/snyk/snyk-ls/internal/testutil" ) func TestNewMcpServer(t *testing.T) { - c := testutil.UnitTest(t) - mcpServer := NewMcpLLMBinding(c) + mcpServer := NewMcpLLMBinding() assert.NotNil(t, mcpServer) assert.NotNil(t, mcpServer.logger) } func TestNewMcpServerWithOptions(t *testing.T) { - c := testutil.UnitTest(t) baseURL, _ := url.Parse("http://test:8080") - s := scanner.NewTestScanner() - mcpServer := NewMcpLLMBinding(c, WithScanner(s), WithBaseURL(baseURL)) + mcpServer := NewMcpLLMBinding(WithBaseURL(baseURL)) - assert.Equal(t, s, mcpServer.scanner) assert.Equal(t, baseURL, mcpServer.baseURL) } diff --git a/internal/mcp/options.go b/internal/mcp/options.go index c4313c70b..b0e0e232b 100644 --- a/internal/mcp/options.go +++ b/internal/mcp/options.go @@ -20,33 +20,25 @@ import ( "net/url" "github.com/rs/zerolog" - - "github.com/snyk/snyk-ls/internal/types" ) -type McpOption func(server *McpLLMBinding) - -func WithScanner(scanner types.Scanner) McpOption { - return func(server *McpLLMBinding) { - server.scanner = scanner - } -} +type Option func(server *McpLLMBinding) -func WithLogger(logger *zerolog.Logger) McpOption { +func WithLogger(logger *zerolog.Logger) Option { return func(server *McpLLMBinding) { l := logger.With().Str("component", "mcp").Logger() server.logger = &l } } -func WithBaseURL(baseURL *url.URL) func(server *McpLLMBinding) { +func WithCliPath(cliPath string) Option { return func(server *McpLLMBinding) { - server.baseURL = baseURL + server.cliPath = cliPath } } -func WithScanResultProcessor(proc types.ScanResultProcessor) func(server *McpLLMBinding) { +func WithBaseURL(baseURL *url.URL) func(server *McpLLMBinding) { return func(server *McpLLMBinding) { - server.forwardingResultProcessor = proc + server.baseURL = baseURL } } diff --git a/internal/mcp/scan_tool.go b/internal/mcp/scan_tool.go index d0efd7dc0..247f377f7 100644 --- a/internal/mcp/scan_tool.go +++ b/internal/mcp/scan_tool.go @@ -1,5 +1,5 @@ /* - * © 2025 Snyk Limited + * 2025 Snyk Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,68 +18,136 @@ package mcp import ( "context" + _ "embed" + "encoding/json" + "fmt" + "os/exec" "github.com/mark3labs/mcp-go/mcp" + "github.com/snyk/go-application-framework/pkg/workflow" +) - ctx2 "github.com/snyk/snyk-ls/internal/context" - "github.com/snyk/snyk-ls/internal/types" +// Tool name constants to maintain backward compatibility +const ( + SnykScaTest = "snyk_sca_test" + SnykCodeTest = "snyk_code_test" + SnykVersion = "snyk_version" + SnykAuth = "snyk_auth" + SnykAuthStatus = "snyk_auth_status" + SnykLogout = "snyk_logout" ) -const SnykScanWorkspaceScan = types.WorkspaceScanCommand +type SnykMcpToolsDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Command []string `json:"command"` + StandardParams []string `json:"standardParams"` + Params []SnykMcpToolParameter `json:"params"` +} + +type SnykMcpToolParameter struct { + Name string `json:"name"` + Type string `json:"type"` + IsRequired bool `json:"isRequired"` + Description string `json:"description"` +} -func (m *McpLLMBinding) addSnykScanTool() error { - tool := mcp.NewTool(SnykScanWorkspaceScan, - mcp.WithDescription("Perform Snyk scans on current workspace"), - ) +//go:embed snyk_tools.json +var snykToolsJson string - m.mcpServer.AddTool(tool, m.snykWorkSpaceScanHandler()) +type SnykMcpTools struct { + Tools []SnykMcpToolsDefinition `json:"tools"` +} + +func loadMcpToolsFromJson() (*SnykMcpTools, error) { + var config SnykMcpTools + if err := json.Unmarshal([]byte(snykToolsJson), &config); err != nil { + return nil, fmt.Errorf("failed to parse config file: %w", err) + } + + return &config, nil +} + +func (m *McpLLMBinding) addSnykTools(invocationCtx workflow.InvocationContext) error { + config, err := loadMcpToolsFromJson() + if err != nil || config == nil { + m.logger.Err(err).Msg("Failed to load Snyk tools configuration") + return err + } + + for _, toolDef := range config.Tools { + tool := createToolFromDefinition(&toolDef) + switch toolDef.Name { + case SnykLogout: + m.mcpServer.AddTool(tool, m.snykLogoutHandler(invocationCtx, toolDef)) + default: + m.mcpServer.AddTool(tool, m.defaultHandler(invocationCtx, toolDef)) + } + } return nil } -func (m *McpLLMBinding) snykWorkSpaceScanHandler() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +// runSnyk runs a Snyk command and returns the result +func (m *McpLLMBinding) runSnyk(ctx context.Context, invocationCtx workflow.InvocationContext, workingDir string, cmd []string) (string, error) { + command := exec.CommandContext(ctx, cmd[0], cmd[1:]...) + + if workingDir != "" { + command.Dir = workingDir + } + + command.Stderr = invocationCtx.GetEnhancedLogger() + res, err := command.Output() + + resAsString := string(res) + if err != nil { + m.logger.Err(err).Msg("Failed to execute command") + } + return resAsString, nil +} + +// defaultHandler creates a generic handler for Snyk commands that applies standard parameters +func (m *McpLLMBinding) defaultHandler(invocationCtx workflow.InvocationContext, toolDef SnykMcpToolsDefinition) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - w := m.c.Workspace() - trusted, _ := w.GetFolderTrust() + params, workingDir := extractParamsFromRequestArgs(toolDef, request.Params.Arguments) - callToolResult := &mcp.CallToolResult{ - Content: make([]interface{}, 0), + // Apply standard parameters from tool definition + // e.g. all_projects and json + for _, paramName := range toolDef.StandardParams { + cliParamName := convertToCliParam(paramName) + params[cliParamName] = true } - resultProcessor := func(ctx context.Context, data types.ScanData) { - // add the scan results to the call tool response - // in the future, this could be a rendered markdown/html template - callToolResult.Content = append(callToolResult.Content, data) - if data.Err != nil { - callToolResult.IsError = true - } - - // standard processing for the folder - scanResultProcessor := folderScanResultProcessor(w, data.Path) - if scanResultProcessor != nil { - scanResultProcessor(ctx, data) - } - - // forward to forwarding processor - if m.forwardingResultProcessor != nil { - m.forwardingResultProcessor(ctx, data) - } + // Handle regular commands + if len(toolDef.Command) == 0 { + return nil, fmt.Errorf("empty command in tool definition for %s", toolDef.Name) } - enrichedContext := ctx2.NewContextWithScanSource(ctx, ctx2.LLM) - for _, folder := range trusted { - m.scanner.Scan(enrichedContext, folder.Path(), resultProcessor, folder.Path()) + args := buildArgs(m.cliPath, toolDef.Command, params) + + // Add working directory if specified + if workingDir != "" { + args = append(args, workingDir) } - return callToolResult, nil + // Run the command + output, err := m.runSnyk(ctx, invocationCtx, workingDir, args) + if err != nil { + return nil, err + } + return mcp.NewToolResultText(output), nil } } -func folderScanResultProcessor(w types.Workspace, path types.FilePath) types.ScanResultProcessor { - folder := w.GetFolderContaining(path) - if folder == nil { - return nil +func (m *McpLLMBinding) snykLogoutHandler(invocationCtx workflow.InvocationContext, _ SnykMcpToolsDefinition) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Special handling for logout which needs multiple commands + params := []string{m.cliPath, "config", "unset", "INTERNAL_OAUTH_TOKEN_STORAGE"} + _, _ = m.runSnyk(ctx, invocationCtx, "", params) + + params = []string{m.cliPath, "config", "unset", "token"} + _, _ = m.runSnyk(ctx, invocationCtx, "", params) + + return mcp.NewToolResultText("Successfully logged out"), nil } - scanResultProcessor := folder.ScanResultProcessor() - return scanResultProcessor } diff --git a/internal/mcp/scan_tool_test.go b/internal/mcp/scan_tool_test.go new file mode 100644 index 000000000..1bf204ca0 --- /dev/null +++ b/internal/mcp/scan_tool_test.go @@ -0,0 +1,830 @@ +/* +* 2025 Snyk Limited +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. + */ + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/golang/mock/gomock" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/rs/zerolog" + "github.com/snyk/go-application-framework/pkg/configuration" + "github.com/snyk/go-application-framework/pkg/mocks" + "github.com/snyk/go-application-framework/pkg/workflow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testFixture struct { + t *testing.T + mockEngine *mocks.MockEngine + binding *McpLLMBinding + snykCliPath string + invocationContext *mocks.MockInvocationContext + tools *SnykMcpTools +} + +func SetupEngineMock(t *testing.T) (*mocks.MockEngine, configuration.Configuration) { + t.Helper() + ctrl := gomock.NewController(t) + mockEngine := mocks.NewMockEngine(ctrl) + engineConfig := configuration.NewWithOpts(configuration.WithAutomaticEnv()) + mockEngine.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes() + return mockEngine, engineConfig +} + +func setupTestFixture(t *testing.T) *testFixture { + t.Helper() + engine, engineConfig := SetupEngineMock(t) + logger := zerolog.New(io.Discard) + + mockctl := gomock.NewController(t) + storage := mocks.NewMockStorage(mockctl) + engineConfig.SetStorage(storage) + + invocationCtx := mocks.NewMockInvocationContext(mockctl) + invocationCtx.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes() + invocationCtx.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes() + + // Snyk CLI mock + tempDir := t.TempDir() + snykCliPath := filepath.Join(tempDir, "snyk") + if runtime.GOOS == "windows" { + snykCliPath += ".bat" + } + + // Create a default mock CLI that just echoes the command + defaultMockResponse := "{\"ok\": true}" + createMockSnykCli(t, snykCliPath, defaultMockResponse) + + // Create the binding + binding := NewMcpLLMBinding(WithCliPath(snykCliPath), WithLogger(invocationCtx.GetEnhancedLogger())) + binding.mcpServer = server.NewMCPServer("Snyk", "1.1.1") + tools, err := loadMcpToolsFromJson() + assert.NoError(t, err) + return &testFixture{ + t: t, + mockEngine: engine, + binding: binding, + snykCliPath: snykCliPath, + invocationContext: invocationCtx, + tools: tools, + } +} + +func (f *testFixture) mockCliOutput(output string) { + createMockSnykCli(f.t, f.snykCliPath, output) +} + +func getToolWithName(t *testing.T, tools *SnykMcpTools, toolName string) *SnykMcpToolsDefinition { + t.Helper() + for _, tool := range tools.Tools { + if tool.Name == toolName { + return &tool + } + } + return nil +} + +func TestMcpSnykToolRegistration(t *testing.T) { + fixture := setupTestFixture(t) + err := fixture.binding.addSnykTools(fixture.invocationContext) + assert.NoError(t, err) +} + +func TestSnykTestHandler(t *testing.T) { + // Setup + fixture := setupTestFixture(t) + + // Configure mock CLI to return a specific JSON response + mockOutput := `{ok": false,"vulnerabilities": [{"id": "SNYK-JS-ACORN-559469","title": "Regular Expression Denial of Service (ReDoS)","severity":"high","packageName": "acorn"},{"id": "SNYK-JS-TUNNELAGENT-1572284","title": "Uninitialized Memory Exposure","severity": "medium","packageName": "tunnel-agent"}],"dependencyCount": 42,"packageManager": "npm"}` + fixture.mockCliOutput(mockOutput) + tool := getToolWithName(t, fixture.tools, SnykScaTest) + assert.NotNil(t, tool) + // Create the handler + handler := fixture.binding.defaultHandler(fixture.invocationContext, *tool) + + tmpDir := t.TempDir() + // Define test cases + testCases := []struct { + name string + args map[string]interface{} + expectedParams []string + }{ + { + name: "Basic SCA Test", + args: map[string]interface{}{ + "path": tmpDir, + "all_projects": true, + "json": true, + }, + expectedParams: []string{"--all-projects", "--json"}, + }, + { + name: "Test with Organization", + args: map[string]interface{}{ + "path": tmpDir, + "all_projects": true, + "json": true, + "org": "my-snyk-org", + }, + expectedParams: []string{"--all-projects", "--json", "--org=my-snyk-org"}, + }, + { + name: "Test with Severity Threshold", + args: map[string]interface{}{ + "path": tmpDir, + "all_projects": false, + "json": true, + "severity_threshold": "high", + }, + expectedParams: []string{"--json", "--severity-threshold=high"}, + }, + { + name: "Test with Multiple Options", + args: map[string]interface{}{ + "path": tmpDir, + "all_projects": true, + "json": true, + "severity_threshold": "medium", + "dev": true, + "skip_unresolved": true, + "prune_repeated_subdependencies": true, + "fail_on": "upgradable", + }, + expectedParams: []string{ + "--all-projects", "--json", "--severity-threshold=medium", + "--dev", "--skip-unresolved", "--prune-repeated-subdependencies", + "--fail-on=upgradable", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + requestObj := map[string]interface{}{ + "params": map[string]interface{}{ + "arguments": tc.args, + }, + } + requestJSON, err := json.Marshal(requestObj) + assert.NoError(t, err, "Failed to marshal request to JSON") + + // Parse the JSON string to CallToolRequest + var request mcp.CallToolRequest + err = json.Unmarshal(requestJSON, &request) + assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest") + + result, err := handler(context.Background(), request) + + assert.NoError(t, err) + assert.NotNil(t, result) + + textContent, ok := result.Content[0].(mcp.TextContent) + assert.True(t, ok) + content := strings.TrimSpace(textContent.Text) + assert.Contains(t, content, "ok") + assert.Contains(t, content, "vulnerabilities") + assert.Contains(t, content, "dependencyCount") + assert.Contains(t, content, "packageManager") + }) + } +} + +func TestSnykCodeTestHandler(t *testing.T) { + // Setup + fixture := setupTestFixture(t) + + // Configure mock CLI + mockJsonResponse := `{"ok":false,"issues":[],"filesAnalyzed":10}` + fixture.mockCliOutput(mockJsonResponse) + + // Get the tool definition + toolDef := getToolWithName(t, fixture.tools, SnykCodeTest) + + // Create the handler + handler := fixture.binding.defaultHandler(fixture.invocationContext, *toolDef) + tmpDir := t.TempDir() + // Test cases with various combinations of arguments + testCases := []struct { + name string + args map[string]interface{} + }{ + { + name: "Basic Test", + args: map[string]interface{}{ + "path": tmpDir, + }, + }, + { + name: "Test with Custom File", + args: map[string]interface{}{ + "path": tmpDir, + "file": "specific_file.js", + }, + }, + { + name: "Test with Severity Threshold", + args: map[string]interface{}{ + "path": tmpDir, + "severity_threshold": "high", + }, + }, + { + name: "Test with Organization", + args: map[string]interface{}{ + "path": tmpDir, + "org": "my-snyk-org", + }, + }, + { + name: "Test with All Options", + args: map[string]interface{}{ + "path": tmpDir, + "file": "specific_file.js", + "severity_threshold": "high", + "org": "my-snyk-org", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + requestObj := map[string]interface{}{ + "params": map[string]interface{}{ + "arguments": tc.args, + }, + } + requestJSON, err := json.Marshal(requestObj) + assert.NoError(t, err, "Failed to marshal request to JSON") + + var request mcp.CallToolRequest + err = json.Unmarshal(requestJSON, &request) + assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest") + + result, err := handler(context.Background(), request) + + assert.NoError(t, err) + assert.NotNil(t, result) + textContent, ok := result.Content[0].(mcp.TextContent) + assert.True(t, ok) + content := strings.TrimSpace(textContent.Text) + assert.Contains(t, content, "ok") + assert.Contains(t, content, "issues") + assert.Contains(t, content, "filesAnalyzed") + }) + } +} + +func TestBasicSnykCommands(t *testing.T) { + // Setup + fixture := setupTestFixture(t) + + testCases := []struct { + name string + handlerFunc func(invocationCtx workflow.InvocationContext, toolDefinition SnykMcpToolsDefinition) func(ctx context.Context, arguments mcp.CallToolRequest) (*mcp.CallToolResult, error) + mockResponse string + expectedCmd string + command []string + }{ + { + name: "Version Command", + handlerFunc: fixture.binding.defaultHandler, + command: []string{"--version"}, + mockResponse: `{"client":{"version":"1.1192.0"}}`, + expectedCmd: "version", + }, + { + name: "Auth Status Command", + handlerFunc: fixture.binding.defaultHandler, + command: []string{"whoami", "--experimental"}, + mockResponse: `{"authenticated":true,"username":"user@example.com"}`, + expectedCmd: "auth", + }, + { + name: "Logout Command", + handlerFunc: fixture.binding.snykLogoutHandler, + command: []string{"--version"}, + mockResponse: `Successfully logged out`, + expectedCmd: "logout", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Configure mock CLI + fixture.mockCliOutput(tc.mockResponse) + + // Create the handler + handler := tc.handlerFunc(fixture.invocationContext, SnykMcpToolsDefinition{Command: tc.command}) + + // Create an empty request object as JSON string + requestObj := map[string]interface{}{ + "params": map[string]interface{}{ + "arguments": map[string]interface{}{}, + }, + } + requestJSON, err := json.Marshal(requestObj) + assert.NoError(t, err, "Failed to marshal request to JSON") + + // Parse the JSON string to CallToolRequest + var request mcp.CallToolRequest + err = json.Unmarshal(requestJSON, &request) + assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest") + + // Call the handler + result, err := handler(context.Background(), request) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, result) + textContent, ok := result.Content[0].(mcp.TextContent) + assert.True(t, ok) + assert.Equal(t, tc.mockResponse, strings.TrimSpace(textContent.Text)) + }) + } +} + +func TestAuthHandler(t *testing.T) { + // Setup + fixture := setupTestFixture(t) + + // Configure mock CLI + mockAuthResponse := "Authenticated Successfully" + fixture.mockCliOutput(mockAuthResponse) + + // Create the handler + handler := fixture.binding.defaultHandler(fixture.invocationContext, SnykMcpToolsDefinition{Command: []string{"auth"}}) + + requestObj := map[string]interface{}{ + "params": map[string]interface{}{ + "arguments": map[string]interface{}{}, + }, + } + requestJSON, err := json.Marshal(requestObj) + assert.NoError(t, err, "Failed to marshal request to JSON") + + var request mcp.CallToolRequest + err = json.Unmarshal(requestJSON, &request) + assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest") + + result, err := handler(context.Background(), request) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, result) + textContent, ok := result.Content[0].(mcp.TextContent) + assert.True(t, ok) + assert.Equal(t, mockAuthResponse, strings.TrimSpace(textContent.Text)) +} + +func TestGetSnykToolsConfig(t *testing.T) { + config, err := loadMcpToolsFromJson() + + assert.NoError(t, err) + assert.NotNil(t, config) + assert.NotEmpty(t, config.Tools) + + toolNames := map[string]bool{ + SnykScaTest: false, + SnykCodeTest: false, + SnykVersion: false, + SnykAuth: false, + SnykAuthStatus: false, + SnykLogout: false, + } + + for _, tool := range config.Tools { + toolNames[tool.Name] = true + } + + for name, found := range toolNames { + assert.True(t, found, "Tool %s not found in configuration", name) + } +} + +func TestCreateToolFromDefinition(t *testing.T) { + testCases := []struct { + name string + toolDefinition SnykMcpToolsDefinition + expectedName string + }{ + { + name: "Simple Tool", + toolDefinition: SnykMcpToolsDefinition{ + Name: "test_tool", + Description: "Test tool description", + Command: []string{"test"}, + Params: []SnykMcpToolParameter{}, + }, + expectedName: "test_tool", + }, + { + name: "Tool with String Params", + toolDefinition: SnykMcpToolsDefinition{ + Name: "string_param_tool", + Description: "Tool with string params", + Command: []string{"test"}, + Params: []SnykMcpToolParameter{ + { + Name: "param1", + Type: "string", + IsRequired: true, + Description: "Required string param", + }, + { + Name: "param2", + Type: "string", + IsRequired: false, + Description: "Optional string param", + }, + }, + }, + expectedName: "string_param_tool", + }, + { + name: "Tool with Boolean Params", + toolDefinition: SnykMcpToolsDefinition{ + Name: "bool_param_tool", + Description: "Tool with boolean params", + Command: []string{"test"}, + Params: []SnykMcpToolParameter{ + { + Name: "flag1", + Type: "boolean", + IsRequired: true, + Description: "Required boolean param", + }, + { + Name: "flag2", + Type: "boolean", + IsRequired: false, + Description: "Optional boolean param", + }, + }, + }, + expectedName: "bool_param_tool", + }, + { + name: "Tool with Mixed Params", + toolDefinition: SnykMcpToolsDefinition{ + Name: "mixed_param_tool", + Description: "Tool with mixed params", + Command: []string{"test"}, + Params: []SnykMcpToolParameter{ + { + Name: "str_param", + Type: "string", + IsRequired: true, + Description: "Required string param", + }, + { + Name: "bool_flag", + Type: "boolean", + IsRequired: false, + Description: "Optional boolean param", + }, + }, + }, + expectedName: "mixed_param_tool", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tool := createToolFromDefinition(&tc.toolDefinition) + + assert.NotNil(t, tool) + assert.Equal(t, tc.expectedName, tool.Name) + }) + } +} + +func TestExtractParamsFromRequest(t *testing.T) { + testCases := []struct { + name string + toolDef SnykMcpToolsDefinition + arguments map[string]interface{} + expectedParamCount int + expectedWorkingDir string + expectedParams map[string]interface{} + }{ + { + name: "Empty Request", + toolDef: SnykMcpToolsDefinition{ + Name: "test_tool", + Params: []SnykMcpToolParameter{}, + }, + arguments: map[string]interface{}{}, + expectedParamCount: 0, + expectedWorkingDir: "", + expectedParams: map[string]interface{}{}, + }, + { + name: "String Parameters", + toolDef: SnykMcpToolsDefinition{ + Name: "string_tool", + Params: []SnykMcpToolParameter{ + { + Name: "org", + Type: "string", + }, + { + Name: "path", + Type: "string", + }, + }, + }, + arguments: map[string]interface{}{ + "org": "my-org", + "path": "/test/path", + }, + expectedParamCount: 2, + expectedWorkingDir: "/test/path", + expectedParams: map[string]interface{}{ + "org": "my-org", + "path": "/test/path", + }, + }, + { + name: "Boolean Parameters", + toolDef: SnykMcpToolsDefinition{ + Name: "bool_tool", + Params: []SnykMcpToolParameter{ + { + Name: "json", + Type: "boolean", + }, + { + Name: "all_projects", + Type: "boolean", + }, + }, + }, + arguments: map[string]interface{}{ + "json": true, + "all_projects": true, + }, + expectedParamCount: 2, + expectedWorkingDir: "", + expectedParams: map[string]interface{}{ + "json": true, + "all-projects": true, + }, + }, + { + name: "Mixed Parameters", + toolDef: SnykMcpToolsDefinition{ + Name: "mixed_tool", + Params: []SnykMcpToolParameter{ + { + Name: "path", + Type: "string", + }, + { + Name: "json", + Type: "boolean", + }, + { + Name: "severity_threshold", + Type: "string", + }, + }, + }, + arguments: map[string]interface{}{ + "path": "/test/path", + "json": true, + "severity_threshold": "high", + }, + expectedParamCount: 3, + expectedWorkingDir: "/test/path", + expectedParams: map[string]interface{}{ + "path": "/test/path", + "json": true, + "severity-threshold": "high", + }, + }, + { + name: "Empty String Parameters", + toolDef: SnykMcpToolsDefinition{ + Name: "empty_string_tool", + Params: []SnykMcpToolParameter{ + { + Name: "org", + Type: "string", + }, + }, + }, + arguments: map[string]interface{}{ + "org": "", + }, + expectedParamCount: 0, + expectedWorkingDir: "", + expectedParams: map[string]interface{}{}, + }, + { + name: "False Boolean Parameters", + toolDef: SnykMcpToolsDefinition{ + Name: "false_bool_tool", + Params: []SnykMcpToolParameter{ + { + Name: "json", + Type: "boolean", + }, + }, + }, + arguments: map[string]interface{}{ + "json": false, + }, + expectedParamCount: 0, + expectedWorkingDir: "", + expectedParams: map[string]interface{}{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + params, workingDir := extractParamsFromRequestArgs(tc.toolDef, tc.arguments) + + assert.Equal(t, tc.expectedWorkingDir, workingDir) + + assert.Equal(t, len(tc.expectedParams), len(params)) + + for key, expectedValue := range tc.expectedParams { + switch key { + case "path": + continue + default: + expectedKey := strings.ReplaceAll(key, "_", "-") + actualValue, ok := params[expectedKey] + assert.True(t, ok, "Parameter %s not found", expectedKey) + assert.Equal(t, expectedValue, actualValue) + } + } + }) + } +} + +func TestBuildArgs(t *testing.T) { + testCases := []struct { + name string + cliPath string + command []string + params map[string]interface{} + expected []string + }{ + { + name: "No Parameters", + cliPath: "snyk", + command: []string{"test"}, + params: map[string]interface{}{}, + expected: []string{"snyk", "test"}, + }, + { + name: "String Parameters", + cliPath: "snyk", + command: []string{"test"}, + params: map[string]interface{}{ + "org": "my-org", + "file": "package.json", + }, + expected: []string{"snyk", "test", "--org=my-org", "--file=package.json"}, + }, + { + name: "Boolean Parameters", + cliPath: "snyk", + command: []string{"test"}, + params: map[string]interface{}{ + "json": true, + "all-projects": true, + }, + expected: []string{"snyk", "test", "--json", "--all-projects"}, + }, + { + name: "Mixed Parameters", + cliPath: "snyk", + command: []string{"test"}, + params: map[string]interface{}{ + "org": "my-org", + "json": true, + "all-projects": true, + }, + expected: []string{"snyk", "test", "--org=my-org", "--all-projects", "--json"}, + }, + { + name: "Empty String Parameters", + cliPath: "snyk", + command: []string{"test"}, + params: map[string]interface{}{ + "org": "", + }, + expected: []string{"snyk", "test"}, + }, + { + name: "False Boolean Parameters", + cliPath: "snyk", + command: []string{"test"}, + params: map[string]interface{}{ + "json": false, + }, + expected: []string{"snyk", "test"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + args := buildArgs(tc.cliPath, tc.command, tc.params) + for _, arg := range args { + assert.Contains(t, tc.expected, arg) + } + }) + } +} + +func TestRunSnyk(t *testing.T) { + fixture := setupTestFixture(t) + + ctx := context.Background() + + testCases := []struct { + name string + mockOutput string + command []string + workingDir string + expectError bool + }{ + { + name: "Successful Command", + mockOutput: "Command executed successfully", + command: []string{fixture.snykCliPath, "test"}, + workingDir: "", + expectError: false, + }, + { + name: "Command with Working Directory", + mockOutput: "Command executed successfully", + command: []string{fixture.snykCliPath, "test"}, + workingDir: t.TempDir(), + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fixture.mockCliOutput(tc.mockOutput) + + output, err := fixture.binding.runSnyk(ctx, fixture.invocationContext, tc.workingDir, tc.command) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.mockOutput, strings.TrimSpace(output)) + } + }) + } +} + +func createMockSnykCli(t *testing.T, path, output string) { + t.Helper() + + var script string + + if runtime.GOOS == "windows" { + script = fmt.Sprintf(`@echo off +echo %s +exit /b 0 +`, output) + } else { + script = fmt.Sprintf(`#!/bin/sh +echo '%s' +exit 0 +`, output) + } + + err := os.WriteFile(path, []byte(script), 0755) + require.NoError(t, err) +} diff --git a/internal/mcp/snyk_tools.json b/internal/mcp/snyk_tools.json new file mode 100644 index 000000000..1e6a2de49 --- /dev/null +++ b/internal/mcp/snyk_tools.json @@ -0,0 +1,216 @@ +{ + "tools": [ + { + "name": "snyk_sca_test", + "description": "Run a SCA test on project dependencies to detect known vulnerabilities. Use this to scan open-source packages in supported ecosystems like npm, Maven, etc. Supports monorepo scanning via `--all-projects`. Outputs vulnerability data in JSON if enabled.", + "command": ["test"], + "standardParams": ["all_projects", "json"], + "params": [ + { + "name": "path", + "type": "string", + "isRequired": true, + "description": "Path to the project to test (default is the absolute path of the current directory, formatted according to the operating system's conventions)." + }, + { + "name": "all_projects", + "type": "boolean", + "isRequired": false, + "description": "Scan all projects in the specified directory. (Default is true)." + }, + { + "name": "json", + "type": "boolean", + "isRequired": false, + "description": "Output results in JSON format. (Default is true)." + }, + { + "name": "severity_threshold", + "type": "string", + "isRequired": false, + "description": "Only report vulnerabilities of the specified level or higher (low, medium, high, critical). (Default is empty)" + }, + { + "name": "org", + "type": "string", + "isRequired": false, + "description": "Specify the organization under which to run the test. (Default is empty)." + }, + { + "name": "dev", + "type": "boolean", + "isRequired": false, + "description": "Include development dependencies. (Default is false)" + }, + { + "name": "skip_unresolved", + "type": "boolean", + "isRequired": false, + "description": "Skip testing of unresolved packages. (Default is false)" + }, + { + "name": "prune_repeated_subdependencies", + "type": "boolean", + "isRequired": false, + "description": "Prune repeated sub-dependencies. (Default is false)." + }, + { + "name": "fail_on", + "type": "string", + "isRequired": false, + "description": "Specify the failure criteria (all, upgradable, patchable). (Default is all)." + }, + { + "name": "file", + "type": "string", + "isRequired": false, + "description": "Specify a package file to test. (Default is empty)" + }, + { + "name": "fail_fast", + "type": "boolean", + "isRequired": false, + "description": "Use with --all-projects to interrupt scans when errors occur. (Default is false)" + }, + { + "name": "detection_depth", + "type": "string", + "isRequired": false, + "description": "Use with --all-projects to indicate how many subdirectories to search. (Default is empty)" + }, + { + "name": "exclude", + "type": "string", + "isRequired": false, + "description": "Use with --all-projects to exclude directory names and file names. (Default is empty)" + }, + { + "name": "print_deps", + "type": "boolean", + "isRequired": false, + "description": "Print the dependency tree before sending it for analysis. (Default is false)" + }, + { + "name": "remote_repo_url", + "type": "string", + "isRequired": false, + "description": "Set or override the remote URL for the repository to monitor. (Default is empty)" + }, + { + "name": "package_manager", + "type": "string", + "isRequired": false, + "description": "Specify the name of the package manager when the filename is not standard. (Default is empty)" + }, + { + "name": "unmanaged", + "type": "boolean", + "isRequired": false, + "description": "For C++ only, scan all files for known open source dependencies. (Default is false)" + }, + { + "name": "ignore_policy", + "type": "boolean", + "isRequired": false, + "description": "Ignore all set policies, the current policy in the .snyk file, Org level ignores, and the project policy. (Default is false)" + }, + { + "name": "trust_policies", + "type": "boolean", + "isRequired": false, + "description": "Apply and use ignore rules from the Snyk policies in your dependencies. (Default is false)" + }, + { + "name": "show_vulnerable_paths", + "type": "string", + "isRequired": false, + "description": "Display the dependency paths (none|some|all). (Default: none)." + }, + { + "name": "project_name", + "type": "string", + "isRequired": false, + "description": "Specify a custom Snyk project name. (Default is empty)" + }, + { + "name": "target_reference", + "type": "string", + "isRequired": false, + "description": "Specify a reference that differentiates this project, for example, a branch name. (Default is empty)" + }, + { + "name": "policy_path", + "type": "string", + "isRequired": false, + "description": "Manually pass a path to a .snyk policy file. (Default is empty)" + } + ] + }, + { + "name": "snyk_code_test", + "description": "Run a static application security test (SAST) on your source code to detect security issues like SQL injection, XSS, and hardcoded secrets. Designed to catch issues early in the development cycle.", + "command": ["code", "test"], + "standardParams": ["json"], + "params": [ + { + "name": "path", + "type": "string", + "isRequired": true, + "description": "Path to the project to test (default is the absolute path of the current directory, formatted according to the operating system's conventions)." + }, + { + "name": "file", + "type": "string", + "isRequired": false, + "description": "Specific file to scan (default: empty)." + }, + { + "name": "json", + "type": "boolean", + "isRequired": false, + "description": "Output results in JSON format. (default: true)" + }, + { + "name": "severity_threshold", + "type": "string", + "isRequired": false, + "description": "Only report vulnerabilities of the specified level or higher (low, medium, high). (default: empty)" + }, + { + "name": "org", + "type": "string", + "isRequired": false, + "description": "Specify the organization under which to run the test. (default: empty)" + } + ] + }, + { + "name": "snyk_version", + "description": "Get Snyk CLI version", + "command": ["--version"], + "standardParams": [], + "params": [] + }, + { + "name": "snyk_auth", + "description": "Authenticate with Snyk", + "command": ["auth"], + "standardParams": [], + "params": [] + }, + { + "name": "snyk_auth_status", + "description": "Check Snyk authentication status", + "command": ["whoami", "--experimental"], + "standardParams": [], + "params": [] + }, + { + "name": "snyk_logout", + "description": "Log out from Snyk", + "command": ["logout"], + "standardParams": [], + "params": [] + } + ] +} diff --git a/internal/mcp/utils.go b/internal/mcp/utils.go new file mode 100644 index 000000000..c1447a3d8 --- /dev/null +++ b/internal/mcp/utils.go @@ -0,0 +1,108 @@ +/* + * © 2025 Snyk Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "strings" + + "github.com/mark3labs/mcp-go/mcp" +) + +// buildArgs builds command-line arguments for Snyk CLI based on parameters +func buildArgs(cliPath string, command []string, params map[string]interface{}) []string { + args := []string{cliPath} + args = append(args, command...) + + // Add params as command-line flags + for key, value := range params { + switch v := value.(type) { + case bool: + if v { + args = append(args, "--"+key) + } + case string: + if v != "" { + args = append(args, "--"+key+"="+v) + } + } + } + + return args +} + +// createToolFromDefinition creates an MCP tool from a Snyk tool definition +func createToolFromDefinition(toolDef *SnykMcpToolsDefinition) mcp.Tool { + opts := []mcp.ToolOption{mcp.WithDescription(toolDef.Description)} + for _, param := range toolDef.Params { + if param.Type == "string" { + if param.IsRequired { + opts = append(opts, mcp.WithString(param.Name, mcp.Required(), mcp.Description(param.Description))) + } else { + opts = append(opts, mcp.WithString(param.Name, mcp.Description(param.Description))) + } + } else if param.Type == "boolean" { + if param.IsRequired { + opts = append(opts, mcp.WithBoolean(param.Name, mcp.Required(), mcp.Description(param.Description))) + } else { + opts = append(opts, mcp.WithBoolean(param.Name, mcp.Description(param.Description))) + } + } + } + + return mcp.NewTool(toolDef.Name, opts...) +} + +// extractParamsFromRequestArgs extracts parameters from the arguments based on the tool definition +func extractParamsFromRequestArgs(toolDef SnykMcpToolsDefinition, arguments map[string]interface{}) (map[string]interface{}, string) { + params := make(map[string]interface{}) + var workingDir string + + for _, paramDef := range toolDef.Params { + val, ok := arguments[paramDef.Name] + if !ok { + continue + } + + // Store path separately to use as working directory + if paramDef.Name == "path" { + if pathStr, ok := val.(string); ok { + workingDir = pathStr + } + } + + // Convert parameter name from snake_case to kebab-case for CLI arguments + cliParamName := strings.ReplaceAll(paramDef.Name, "_", "-") + + // Cast the value based on parameter type + if paramDef.Type == "string" { + if strVal, ok := val.(string); ok && strVal != "" { + params[cliParamName] = strVal + } + } else if paramDef.Type == "boolean" { + if boolVal, ok := val.(bool); ok && boolVal { + params[cliParamName] = true + } + } + } + + return params, workingDir +} + +// convertToCliParam Convert parameter name from snake_case to kebab-case for CLI arguments +func convertToCliParam(cliParam string) string { + return strings.ReplaceAll(cliParam, "_", "-") +} diff --git a/mcp_extension/main.go b/mcp_extension/main.go new file mode 100644 index 000000000..3f38c2ea7 --- /dev/null +++ b/mcp_extension/main.go @@ -0,0 +1,111 @@ +/* + * © 2023 Snyk Limited All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp_extension + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/snyk/go-application-framework/pkg/configuration" + + "github.com/snyk/snyk-ls/application/entrypoint" + "github.com/snyk/snyk-ls/internal/mcp" + + "github.com/spf13/pflag" + + "github.com/snyk/go-application-framework/pkg/workflow" +) + +var WORKFLOWID_MCP = workflow.NewWorkflowIdentifier("mcp") + +func Init(engine workflow.Engine) error { + flags := pflag.NewFlagSet("mcp", pflag.ContinueOnError) + + flags.StringP("transport", "t", "sse", "sets transport to ") + flags.Bool(configuration.FLAG_EXPERIMENTAL, false, "enable experimental mcp command") + + cfg := workflow.ConfigurationOptionsFromFlagset(flags) + entry, _ := engine.Register(WORKFLOWID_MCP, cfg, mcpWorkflow) + entry.SetVisibility(false) + + return nil +} + +func mcpWorkflow( + invocation workflow.InvocationContext, + _ []workflow.Data, +) (output []workflow.Data, err error) { + defer entrypoint.OnPanicRecover() + + config := invocation.GetConfiguration() + + // only run if experimental flag is set + if !config.GetBool(configuration.FLAG_EXPERIMENTAL) { + return nil, fmt.Errorf("set `--experimental` flag to enable mcp command") + } + + output = []workflow.Data{} + logger := invocation.GetEnhancedLogger() + + cliPath, err := getCliPath(invocation) + if err != nil { + logger.Err(err).Msg("Failed to set cli path") + return output, err + } + logger.Trace().Interface("environment", os.Environ()).Msg("start environment") + mcpStart(invocation, cliPath) + + return output, nil +} + +func mcpStart(invocationContext workflow.InvocationContext, cliPath string) { + mcpServer := mcp.NewMcpLLMBinding(mcp.WithLogger(invocationContext.GetEnhancedLogger()), mcp.WithCliPath(cliPath)) + logger := invocationContext.GetEnhancedLogger() + + // start mcp server + //nolint:forbidigo // stdio stream isn't started yet + fmt.Println("Starting up MCP Server...") + err := mcpServer.Start(invocationContext) + + if err != nil { + logger.Err(err).Msg("failed to start mcp server") + } + defer func() { + logger.Info().Msg("Shutting down MCP Server...") + mcpServer.Shutdown(context.Background()) + }() +} + +func getCliPath(ctx workflow.InvocationContext) (string, error) { + logger := ctx.GetEnhancedLogger() + exePath, err := os.Executable() + if err != nil { + logger.Err(err).Msg("Failed to get executable path") + return "", err + } + resolvedPath, err := filepath.EvalSymlinks(exePath) + + if err != nil { + logger.Err(err).Msg("Failed to eval symlink from path") + return "", err + } else { + // Set Cli path to current process path + return resolvedPath, nil + } +} diff --git a/mcp_extension/main_test.go b/mcp_extension/main_test.go new file mode 100644 index 000000000..773f2bb54 --- /dev/null +++ b/mcp_extension/main_test.go @@ -0,0 +1,39 @@ +package mcp_extension + +import ( + "testing" + "time" + + "github.com/snyk/go-application-framework/pkg/configuration" + "github.com/stretchr/testify/assert" + + "github.com/snyk/go-application-framework/pkg/app" +) + +func Test_ExtensionEntryPoint(t *testing.T) { + expectedTransportType := "stdio" + engine := app.CreateAppEngineWithOptions() + + engineConfig := configuration.NewWithOpts( + configuration.WithAutomaticEnv(), + ) + engineConfig.Set("transport", expectedTransportType) + engineConfig.Set(configuration.FLAG_EXPERIMENTAL, true) + + //register extension under test + err := Init(engine) + assert.Nil(t, err) + + go func() { + err = engine.Init() + assert.Nil(t, err) + + data, err := engine.InvokeWithConfig(WORKFLOWID_MCP, engineConfig) + assert.Nil(t, err) + assert.Empty(t, data) + }() + + assert.Eventuallyf(t, func() bool { + return expectedTransportType == engineConfig.GetString("transport") && engineConfig.GetBool(configuration.FLAG_EXPERIMENTAL) + }, time.Minute, time.Millisecond, "open browser was not called") +} From e190de4875e4dea33e948ac1ae3c6365b28f3cc1 Mon Sep 17 00:00:00 2001 From: Bastian Doetsch Date: Thu, 3 Apr 2025 12:44:11 +0200 Subject: [PATCH 02/10] fix: add analytics to mcp wrapper (#818) --- internal/mcp/llm_binding.go | 20 ++++++++++++++++++++ internal/mcp/llm_binding_test.go | 24 ++++++++++++++++++++++++ internal/mcp/scan_tool.go | 7 ++++++- internal/mcp/scan_tool_test.go | 2 ++ mcp_extension/main.go | 8 ++++++++ 5 files changed, 60 insertions(+), 1 deletion(-) diff --git a/internal/mcp/llm_binding.go b/internal/mcp/llm_binding.go index 06788e458..89d41c3b5 100644 --- a/internal/mcp/llm_binding.go +++ b/internal/mcp/llm_binding.go @@ -21,12 +21,15 @@ import ( "fmt" "net/http" "net/url" + "os" + "strings" "sync" "time" "github.com/mark3labs/mcp-go/server" "github.com/pkg/errors" "github.com/rs/zerolog" + "github.com/snyk/go-application-framework/pkg/configuration" "github.com/snyk/go-application-framework/pkg/workflow" ) @@ -168,3 +171,20 @@ func (m *McpLLMBinding) Started() bool { return m.started } + +func (m *McpLLMBinding) expandedEnv(version string) []string { + environ := os.Environ() + var expandedEnv = []string{} + for _, v := range environ { + if strings.HasPrefix(strings.ToLower(v), strings.ToLower(configuration.INTEGRATION_NAME)) { + continue + } + if strings.HasPrefix(strings.ToLower(v), strings.ToLower(configuration.INTEGRATION_VERSION)) { + continue + } + expandedEnv = append(expandedEnv, v) + } + expandedEnv = append(expandedEnv, configuration.INTEGRATION_NAME+"=MCP") + expandedEnv = append(expandedEnv, fmt.Sprintf("%s=%s", configuration.INTEGRATION_VERSION, version)) + return expandedEnv +} diff --git a/internal/mcp/llm_binding_test.go b/internal/mcp/llm_binding_test.go index e060c98d9..2d4e415dc 100644 --- a/internal/mcp/llm_binding_test.go +++ b/internal/mcp/llm_binding_test.go @@ -17,8 +17,11 @@ package mcp import ( "net/url" + "os" + "strings" "testing" + "github.com/snyk/go-application-framework/pkg/configuration" "github.com/stretchr/testify/assert" ) @@ -42,3 +45,24 @@ func TestDefaultURL(t *testing.T) { assert.Equal(t, "http", u.Scheme) assert.Contains(t, u.Host, DefaultHost) } + +func TestExpandedEnv(t *testing.T) { + t.Setenv(configuration.INTEGRATION_NAME, "abc") + t.Setenv(configuration.INTEGRATION_VERSION, "abc") + binding := NewMcpLLMBinding() + + env := binding.expandedEnv("1.x.1") + + for _, s := range os.Environ() { + if strings.HasPrefix(s, configuration.INTEGRATION_NAME) { + continue + } + if strings.HasPrefix(s, configuration.INTEGRATION_VERSION) { + continue + } + assert.Contains(t, env, s) + } + + assert.Contains(t, env, configuration.INTEGRATION_NAME+"=MCP") + assert.Contains(t, env, configuration.INTEGRATION_VERSION+"=1.x.1") +} diff --git a/internal/mcp/scan_tool.go b/internal/mcp/scan_tool.go index 247f377f7..f7a8939ce 100644 --- a/internal/mcp/scan_tool.go +++ b/internal/mcp/scan_tool.go @@ -95,7 +95,12 @@ func (m *McpLLMBinding) runSnyk(ctx context.Context, invocationCtx workflow.Invo if workingDir != "" { command.Dir = workingDir } - + runtimeInfo := invocationCtx.GetRuntimeInfo() + if runtimeInfo != nil { + command.Env = m.expandedEnv(runtimeInfo.GetVersion()) + } else { + command.Env = m.expandedEnv("unknown") + } command.Stderr = invocationCtx.GetEnhancedLogger() res, err := command.Output() diff --git a/internal/mcp/scan_tool_test.go b/internal/mcp/scan_tool_test.go index 1bf204ca0..255bdd0e3 100644 --- a/internal/mcp/scan_tool_test.go +++ b/internal/mcp/scan_tool_test.go @@ -33,6 +33,7 @@ import ( "github.com/rs/zerolog" "github.com/snyk/go-application-framework/pkg/configuration" "github.com/snyk/go-application-framework/pkg/mocks" + "github.com/snyk/go-application-framework/pkg/runtimeinfo" "github.com/snyk/go-application-framework/pkg/workflow" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -68,6 +69,7 @@ func setupTestFixture(t *testing.T) *testFixture { invocationCtx := mocks.NewMockInvocationContext(mockctl) invocationCtx.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes() invocationCtx.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes() + invocationCtx.EXPECT().GetRuntimeInfo().Return(runtimeinfo.New(runtimeinfo.WithName("hurz"), runtimeinfo.WithVersion("1000.8.3"))).AnyTimes() // Snyk CLI mock tempDir := t.TempDir() diff --git a/mcp_extension/main.go b/mcp_extension/main.go index 3f38c2ea7..d9da67f8f 100644 --- a/mcp_extension/main.go +++ b/mcp_extension/main.go @@ -54,6 +54,14 @@ func mcpWorkflow( defer entrypoint.OnPanicRecover() config := invocation.GetConfiguration() + config.Set(configuration.INTEGRATION_NAME, "MCP") + + runtimeInfo := invocation.GetRuntimeInfo() + if runtimeInfo != nil { + config.Set(configuration.INTEGRATION_VERSION, runtimeInfo.GetVersion()) + } else { + config.Set(configuration.INTEGRATION_VERSION, "unknown") + } // only run if experimental flag is set if !config.GetBool(configuration.FLAG_EXPERIMENTAL) { From 0448899709544e62a253ace27b8fe0bbc08ccd70 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Thu, 3 Apr 2025 15:21:04 +0200 Subject: [PATCH 03/10] fix: disable mcp json output (#819) --- internal/mcp/llm_binding.go | 3 +-- internal/mcp/snyk_tools.json | 16 ++-------------- mcp_extension/main.go | 2 -- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/internal/mcp/llm_binding.go b/internal/mcp/llm_binding.go index 89d41c3b5..d7c993875 100644 --- a/internal/mcp/llm_binding.go +++ b/internal/mcp/llm_binding.go @@ -125,8 +125,7 @@ func (m *McpLLMBinding) HandleSseServer() error { m.sseServer = server.NewSSEServer(m.mcpServer, m.baseURL.String()) - //nolint:forbidigo // stdio stream isn't started yet - fmt.Printf("Starting with base URL %s\n", m.baseURL.String()) + _, _ = fmt.Fprintf(os.Stderr, "Starting with base URL %s\n", m.baseURL.String()) m.logger.Info().Str("baseURL", m.baseURL.String()).Msg("starting") go func() { diff --git a/internal/mcp/snyk_tools.json b/internal/mcp/snyk_tools.json index 1e6a2de49..481f353f3 100644 --- a/internal/mcp/snyk_tools.json +++ b/internal/mcp/snyk_tools.json @@ -4,7 +4,7 @@ "name": "snyk_sca_test", "description": "Run a SCA test on project dependencies to detect known vulnerabilities. Use this to scan open-source packages in supported ecosystems like npm, Maven, etc. Supports monorepo scanning via `--all-projects`. Outputs vulnerability data in JSON if enabled.", "command": ["test"], - "standardParams": ["all_projects", "json"], + "standardParams": ["all_projects"], "params": [ { "name": "path", @@ -18,12 +18,6 @@ "isRequired": false, "description": "Scan all projects in the specified directory. (Default is true)." }, - { - "name": "json", - "type": "boolean", - "isRequired": false, - "description": "Output results in JSON format. (Default is true)." - }, { "name": "severity_threshold", "type": "string", @@ -150,7 +144,7 @@ "name": "snyk_code_test", "description": "Run a static application security test (SAST) on your source code to detect security issues like SQL injection, XSS, and hardcoded secrets. Designed to catch issues early in the development cycle.", "command": ["code", "test"], - "standardParams": ["json"], + "standardParams": [], "params": [ { "name": "path", @@ -164,12 +158,6 @@ "isRequired": false, "description": "Specific file to scan (default: empty)." }, - { - "name": "json", - "type": "boolean", - "isRequired": false, - "description": "Output results in JSON format. (default: true)" - }, { "name": "severity_threshold", "type": "string", diff --git a/mcp_extension/main.go b/mcp_extension/main.go index d9da67f8f..dea81d27e 100644 --- a/mcp_extension/main.go +++ b/mcp_extension/main.go @@ -87,8 +87,6 @@ func mcpStart(invocationContext workflow.InvocationContext, cliPath string) { logger := invocationContext.GetEnhancedLogger() // start mcp server - //nolint:forbidigo // stdio stream isn't started yet - fmt.Println("Starting up MCP Server...") err := mcpServer.Start(invocationContext) if err != nil { From 6c421be3dea66c72d00bb1f7574b5c6a3d4bdff9 Mon Sep 17 00:00:00 2001 From: Bastian Doetsch Date: Fri, 4 Apr 2025 16:35:35 +0200 Subject: [PATCH 04/10] feat(mcp): update MCP tools and configuration (#822) * feat(mcp): update MCP tools and configuration * docs: updated licenses --------- Co-authored-by: bastiandoetsch --- go.mod | 3 ++- go.sum | 4 +++ internal/mcp/llm_binding.go | 2 +- internal/mcp/snyk_tools.json | 6 ----- .../yosida95/uritemplate/v3/LICENSE | 25 +++++++++++++++++++ 5 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 licenses/github.com/yosida95/uritemplate/v3/LICENSE diff --git a/go.mod b/go.mod index b19e48ae6..ef9e18974 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gosimple/hashdir v1.0.2 github.com/hexops/gotextdiff v1.0.3 - github.com/mark3labs/mcp-go v0.8.4 + github.com/mark3labs/mcp-go v0.18.0 github.com/otiai10/copy v1.14.1 github.com/pact-foundation/pact-go v1.10.0 github.com/pingcap/errors v0.11.4 @@ -122,6 +122,7 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect golang.org/x/crypto v0.35.0 // indirect golang.org/x/sys v0.30.0 // indirect diff --git a/go.sum b/go.sum index 593adecd6..a2dfac0f7 100644 --- a/go.sum +++ b/go.sum @@ -218,6 +218,8 @@ github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4 github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.8.4 h1:/VxjJ0+4oN2eYLuAgVzixrYNfrmwJnV38EfPIX3VbPE= github.com/mark3labs/mcp-go v0.8.4/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= +github.com/mark3labs/mcp-go v0.18.0 h1:YuhgIVjNlTG2ZOwmrkORWyPTp0dz1opPEqvsPtySXao= +github.com/mark3labs/mcp-go v0.18.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -395,6 +397,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= diff --git a/internal/mcp/llm_binding.go b/internal/mcp/llm_binding.go index d7c993875..d72f3060d 100644 --- a/internal/mcp/llm_binding.go +++ b/internal/mcp/llm_binding.go @@ -123,7 +123,7 @@ func (m *McpLLMBinding) HandleSseServer() error { m.baseURL = defaultURL() } - m.sseServer = server.NewSSEServer(m.mcpServer, m.baseURL.String()) + m.sseServer = server.NewSSEServer(m.mcpServer, server.WithBaseURL(m.baseURL.String())) _, _ = fmt.Fprintf(os.Stderr, "Starting with base URL %s\n", m.baseURL.String()) diff --git a/internal/mcp/snyk_tools.json b/internal/mcp/snyk_tools.json index 481f353f3..c1b2c0f5d 100644 --- a/internal/mcp/snyk_tools.json +++ b/internal/mcp/snyk_tools.json @@ -78,12 +78,6 @@ "isRequired": false, "description": "Use with --all-projects to exclude directory names and file names. (Default is empty)" }, - { - "name": "print_deps", - "type": "boolean", - "isRequired": false, - "description": "Print the dependency tree before sending it for analysis. (Default is false)" - }, { "name": "remote_repo_url", "type": "string", diff --git a/licenses/github.com/yosida95/uritemplate/v3/LICENSE b/licenses/github.com/yosida95/uritemplate/v3/LICENSE new file mode 100644 index 000000000..79e8f8757 --- /dev/null +++ b/licenses/github.com/yosida95/uritemplate/v3/LICENSE @@ -0,0 +1,25 @@ +Copyright (C) 2016, Kohei YOSHIDA . All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From ba1f43e5b4bf3d3f6480fd7304b57fae01780064 Mon Sep 17 00:00:00 2001 From: Zdroba <72391375+DariusZdroba@users.noreply.github.com> Date: Tue, 8 Apr 2025 12:19:39 +0300 Subject: [PATCH 05/10] fix: update code-client-go version and LS usage [IDE-1098] (#816) Co-authored-by: Abdelrahman Shawki Hassan Co-authored-by: Andrew Robinson Hodges --- go.mod | 6 +++--- go.sum | 8 ++++---- infrastructure/code/ai_fix_handler.go | 15 +++++---------- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index ef9e18974..dd3ff2387 100644 --- a/go.mod +++ b/go.mod @@ -27,8 +27,8 @@ require ( github.com/rs/zerolog v1.33.0 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/shirou/gopsutil v3.21.11+incompatible - github.com/snyk/code-client-go v1.16.2 - github.com/snyk/go-application-framework v0.0.0-20250307155453-ce7aaf72fe7d + github.com/snyk/code-client-go v1.18.0 + github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd github.com/spf13/pflag v1.0.6 github.com/stretchr/testify v1.10.0 @@ -41,7 +41,6 @@ require ( golang.org/x/oauth2 v0.27.0 golang.org/x/sync v0.11.0 gopkg.in/ini.v1 v1.67.0 - gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -130,6 +129,7 @@ require ( golang.org/x/tools v0.30.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) //replace github.com/snyk/go-application-framework => ../go-application-framework diff --git a/go.sum b/go.sum index a2dfac0f7..a6d33deba 100644 --- a/go.sum +++ b/go.sum @@ -320,12 +320,12 @@ github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMT github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= -github.com/snyk/code-client-go v1.16.2 h1:07vMMpl+qUWYz8IJf+13CTL3lRSjKN1TdBk9UDhqxvI= -github.com/snyk/code-client-go v1.16.2/go.mod h1:WH6lNkJc785hfXmwhixxWHix3O6z+1zwz40oK8vl/zg= +github.com/snyk/code-client-go v1.18.0 h1:y+yVCiB2gt6SKau3O2etCb+lSpvARHMp5eDRdp9awY0= +github.com/snyk/code-client-go v1.18.0/go.mod h1:WH6lNkJc785hfXmwhixxWHix3O6z+1zwz40oK8vl/zg= github.com/snyk/error-catalog-golang-public v0.0.0-20250218074309-307ad7b38a60 h1:iB6z2BhBpfN9p0/dEZfwWvs7fpdZk3loooAih8yspS8= github.com/snyk/error-catalog-golang-public v0.0.0-20250218074309-307ad7b38a60/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4= -github.com/snyk/go-application-framework v0.0.0-20250307155453-ce7aaf72fe7d h1:eC0V150YUGvO3mEwXBMELQXzWGHFngxD1Y5fAtzWQC0= -github.com/snyk/go-application-framework v0.0.0-20250307155453-ce7aaf72fe7d/go.mod h1:oWN7a1ud3u5y8HxW+Qdroy+ofEEA8s0MwVAtb8qq/v4= +github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f h1:1EPrRhLQ5Bo0SmIqoAU38Et1Bv2klCbyfgLmVJfUyvM= +github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f/go.mod h1:A7oFVjMjNukzsMeiIWXEXjCrAf2ARvoK4aQOm9e3E/Y= github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530 h1:s9PHNkL6ueYRiAKNfd8OVxlUOqU3qY0VDbgCD1f6WQY= github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530/go.mod h1:88KbbvGYlmLgee4OcQ19yr0bNpXpOr2kciOthaSzCAg= github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd h1:Dq5WSzWsP1TbVi10zPWBI5LKEBDg4Y1OhWEph1wr5WQ= diff --git a/infrastructure/code/ai_fix_handler.go b/infrastructure/code/ai_fix_handler.go index dfa2b0b15..94b81e033 100644 --- a/infrastructure/code/ai_fix_handler.go +++ b/infrastructure/code/ai_fix_handler.go @@ -39,11 +39,10 @@ type AiFixHandler struct { type AiStatus string const ( - AiFixNotStarted AiStatus = "NOT_STARTED" - AiFixInProgress AiStatus = "IN_PROGRESS" - AiFixSuccess AiStatus = "SUCCESS" - AiFixError AiStatus = "ERROR" - shouldRunExplain = false + AiFixNotStarted AiStatus = "NOT_STARTED" + AiFixInProgress AiStatus = "IN_PROGRESS" + AiFixSuccess AiStatus = "SUCCESS" + AiFixError AiStatus = "ERROR" ) const ( ExplainApiVersion string = "2024-10-15" @@ -75,9 +74,6 @@ func (fixHandler *AiFixHandler) GetResults(fixId string) (filePath string, diff } func (fixHandler *AiFixHandler) EnrichWithExplain(ctx context.Context, c *config.Config, issue types.Issue, suggestions []AutofixUnifiedDiffSuggestion) { - if !shouldRunExplain { - return - } logger := c.Logger().With().Str("method", "EnrichWithExplain").Logger() if ctx.Err() != nil { logger.Debug().Msgf("EnrichWithExplain context canceled") @@ -94,12 +90,11 @@ func (fixHandler *AiFixHandler) EnrichWithExplain(ctx context.Context, c *config deepCodeLLMBinding := llm.NewDeepcodeLLMBinding( llm.WithLogger(c.Logger()), llm.WithOutputFormat(llm.HTML), - llm.WithEndpoint(getExplainEndpoint(c)), llm.WithHTTPClient(func() codeClientHTTP.HTTPClient { return c.Engine().GetNetworkAccess().GetHttpClient() }), ) - explanations, err := deepCodeLLMBinding.ExplainWithOptions(contextWithCancel, llm.ExplainOptions{RuleKey: issue.GetID(), Diffs: diffs}) + explanations, err := deepCodeLLMBinding.ExplainWithOptions(contextWithCancel, llm.ExplainOptions{RuleKey: issue.GetID(), Diffs: diffs, Endpoint: getExplainEndpoint(c)}) if err != nil { logger.Error().Err(err).Msgf("Failed to explain with explain for issue %s", issue.GetID()) return From 4ba82533164f8495c5527077c17c449b9303af54 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 8 Apr 2025 16:02:03 +0200 Subject: [PATCH 06/10] fix: disable explain (#826) --- infrastructure/code/ai_fix_handler.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/infrastructure/code/ai_fix_handler.go b/infrastructure/code/ai_fix_handler.go index 94b81e033..ca297e79d 100644 --- a/infrastructure/code/ai_fix_handler.go +++ b/infrastructure/code/ai_fix_handler.go @@ -39,10 +39,11 @@ type AiFixHandler struct { type AiStatus string const ( - AiFixNotStarted AiStatus = "NOT_STARTED" - AiFixInProgress AiStatus = "IN_PROGRESS" - AiFixSuccess AiStatus = "SUCCESS" - AiFixError AiStatus = "ERROR" + AiFixNotStarted AiStatus = "NOT_STARTED" + AiFixInProgress AiStatus = "IN_PROGRESS" + AiFixSuccess AiStatus = "SUCCESS" + AiFixError AiStatus = "ERROR" + shouldRunExplain = false ) const ( ExplainApiVersion string = "2024-10-15" @@ -74,6 +75,9 @@ func (fixHandler *AiFixHandler) GetResults(fixId string) (filePath string, diff } func (fixHandler *AiFixHandler) EnrichWithExplain(ctx context.Context, c *config.Config, issue types.Issue, suggestions []AutofixUnifiedDiffSuggestion) { + if !shouldRunExplain { + return + } logger := c.Logger().With().Str("method", "EnrichWithExplain").Logger() if ctx.Err() != nil { logger.Debug().Msgf("EnrichWithExplain context canceled") From b936074201df29e13f0fedf82559e7516f45dcd9 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 15 Apr 2025 12:47:18 +0200 Subject: [PATCH 07/10] chore: update code-client-go --- go.mod | 4 ++-- go.sum | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index dd3ff2387..29e0eff67 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/rs/zerolog v1.33.0 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/shirou/gopsutil v3.21.11+incompatible - github.com/snyk/code-client-go v1.18.0 + github.com/snyk/code-client-go v1.20.1 github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd github.com/spf13/pflag v1.0.6 @@ -41,6 +41,7 @@ require ( golang.org/x/oauth2 v0.27.0 golang.org/x/sync v0.11.0 gopkg.in/ini.v1 v1.67.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -129,7 +130,6 @@ require ( golang.org/x/tools v0.30.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) //replace github.com/snyk/go-application-framework => ../go-application-framework diff --git a/go.sum b/go.sum index a6d33deba..7369fb6e7 100644 --- a/go.sum +++ b/go.sum @@ -216,8 +216,6 @@ github.com/magiconair/properties v1.8.6 h1:5ibWZ6iY0NctNGWo87LalDlEZ6R41TqbbDamh github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mark3labs/mcp-go v0.8.4 h1:/VxjJ0+4oN2eYLuAgVzixrYNfrmwJnV38EfPIX3VbPE= -github.com/mark3labs/mcp-go v0.8.4/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= github.com/mark3labs/mcp-go v0.18.0 h1:YuhgIVjNlTG2ZOwmrkORWyPTp0dz1opPEqvsPtySXao= github.com/mark3labs/mcp-go v0.18.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= @@ -320,8 +318,8 @@ github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMT github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= -github.com/snyk/code-client-go v1.18.0 h1:y+yVCiB2gt6SKau3O2etCb+lSpvARHMp5eDRdp9awY0= -github.com/snyk/code-client-go v1.18.0/go.mod h1:WH6lNkJc785hfXmwhixxWHix3O6z+1zwz40oK8vl/zg= +github.com/snyk/code-client-go v1.20.1 h1:38nEGzrQIh/aVLjR99jiTUQM0sL9SQAvhMfZGmd9G0w= +github.com/snyk/code-client-go v1.20.1/go.mod h1:WH6lNkJc785hfXmwhixxWHix3O6z+1zwz40oK8vl/zg= github.com/snyk/error-catalog-golang-public v0.0.0-20250218074309-307ad7b38a60 h1:iB6z2BhBpfN9p0/dEZfwWvs7fpdZk3loooAih8yspS8= github.com/snyk/error-catalog-golang-public v0.0.0-20250218074309-307ad7b38a60/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4= github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f h1:1EPrRhLQ5Bo0SmIqoAU38Et1Bv2klCbyfgLmVJfUyvM= From f850f5a10c8b3d9ac29971db723b77dfd474fb09 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 15 Apr 2025 12:56:21 +0200 Subject: [PATCH 08/10] fix: send dummy mcpserver notification --- application/server/server.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/application/server/server.go b/application/server/server.go index 9bd10db50..825ecadc5 100644 --- a/application/server/server.go +++ b/application/server/server.go @@ -457,7 +457,8 @@ func initializedHandler(c *config.Config, srv *jrpc2.Server) handler.Func { ) logger.Info().Msg(msg) } - + // this change is to avoid breaking current stable VS Code. URL value is not used. + di.Notifier().Send(types.McpServerURLParams{URL: "http://127.0.0.1:7695"}) return nil, nil }) } From d8ed9c3a2009c483ef68850ae408f1ae3ebf78fd Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 15 Apr 2025 18:06:29 +0200 Subject: [PATCH 09/10] fix: check request origin --- internal/mcp/llm_binding.go | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/internal/mcp/llm_binding.go b/internal/mcp/llm_binding.go index d72f3060d..7f9c979fe 100644 --- a/internal/mcp/llm_binding.go +++ b/internal/mcp/llm_binding.go @@ -141,7 +141,13 @@ func (m *McpLLMBinding) HandleSseServer() error { m.mutex.Unlock() }() - err := m.sseServer.Start(m.baseURL.Host) + srv := &http.Server{ + Addr: m.baseURL.Host, + Handler: middleware(m.sseServer), + } + + err := srv.ListenAndServe() + if err != nil { // expect http.ErrServerClosed when shutting down if !errors.Is(err, http.ErrServerClosed) { @@ -152,6 +158,35 @@ func (m *McpLLMBinding) HandleSseServer() error { return nil } +var allowedHostnames = map[string]bool{ + "localhost": true, + "127.0.0.1": true, + "::1": true, +} + +func middleware(sseServer *server.SSEServer) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + originHeader := r.Header.Get("Origin") + isValidOrigin := originHeader == "" + + if originHeader != "" { + parsedOrigin, err := url.Parse(originHeader) + if err == nil { + requestHost := parsedOrigin.Hostname() + if _, allowed := allowedHostnames[requestHost]; allowed { + isValidOrigin = true + } + } + } + + if isValidOrigin { + sseServer.ServeHTTP(w, r) + } else { + http.Error(w, "Forbidden: Access restricted to localhost origins", http.StatusForbidden) + } + }) +} + func (m *McpLLMBinding) Shutdown(ctx context.Context) { m.mutex.Lock() defer m.mutex.Unlock() From 13d9ea0f5abf46829ebbfed4e1de2227f0c27bf3 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Thu, 10 Apr 2025 12:36:43 +0200 Subject: [PATCH 10/10] fix: add scansource to workspace scan command (#831) --- application/server/execute_command_test.go | 19 ++++++++++++++++ domain/ide/command/workspace_scan.go | 25 ++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/application/server/execute_command_test.go b/application/server/execute_command_test.go index 6ee6be1d6..7f9b7f692 100644 --- a/application/server/execute_command_test.go +++ b/application/server/execute_command_test.go @@ -122,6 +122,25 @@ func Test_executeWorkspaceScanCommand_shouldAskForTrust(t *testing.T) { }, 2*time.Second, time.Millisecond) } +func Test_executeWorkspaceScanCommand_shouldAcceptScanSourceParam(t *testing.T) { + c := testutil.UnitTest(t) + loc, jsonRPCRecorder := setupServerWithCustomDI(t, c, false) + + s := &scanner.TestScanner{} + c.Workspace().AddFolder(workspace.NewFolder(c, "dummy", "dummy", s, di.HoverService(), di.ScanNotifier(), di.Notifier(), di.ScanPersister(), di.ScanStateAggregator())) + // explicitly enable folder trust which is disabled by default in tests + config.CurrentConfig().SetTrustedFolderFeatureEnabled(true) + + params := lsp.ExecuteCommandParams{Command: types.WorkspaceScanCommand, Arguments: []any{"LLM"}} + _, err := loc.Client.Call(ctx, "workspace/executeCommand", params) + if err != nil { + t.Fatal(err) + } + assert.Eventually(t, func() bool { + return s.Calls() == 0 && checkTrustMessageRequest(jsonRPCRecorder, c) + }, 2*time.Second, time.Millisecond) +} + func Test_loginCommand_StartsAuthentication(t *testing.T) { c := testutil.UnitTest(t) loc, jsonRPCRecorder := setupServer(t, c) diff --git a/domain/ide/command/workspace_scan.go b/domain/ide/command/workspace_scan.go index 7bb82d1e4..b372ba8c4 100644 --- a/domain/ide/command/workspace_scan.go +++ b/domain/ide/command/workspace_scan.go @@ -20,6 +20,7 @@ import ( "context" "github.com/snyk/snyk-ls/application/config" + context2 "github.com/snyk/snyk-ls/internal/context" "github.com/snyk/snyk-ls/internal/types" ) @@ -36,7 +37,27 @@ func (cmd *workspaceScanCommand) Command() types.CommandData { func (cmd *workspaceScanCommand) Execute(ctx context.Context) (any, error) { w := cmd.c.Workspace() w.Clear() - w.ScanWorkspace(ctx) - HandleUntrustedFolders(ctx, cmd.c, cmd.srv) + args := cmd.command.Arguments + enrichedCtx := cmd.enrichContextWithScanSource(ctx, args) + w.ScanWorkspace(enrichedCtx) + HandleUntrustedFolders(enrichedCtx, cmd.c, cmd.srv) return nil, nil } + +func (cmd *workspaceScanCommand) enrichContextWithScanSource(ctx context.Context, args []any) context.Context { + if len(args) == 0 { + return ctx + } + + sc, ok := args[0].(string) + if !ok { + return ctx + } + + if sc != context2.IDE.String() && sc != context2.LLM.String() { + return ctx + } + + scanSource := context2.ScanSource(sc) + return context2.NewContextWithScanSource(ctx, scanSource) +}