diff --git a/application/server/execute_command_test.go b/application/server/execute_command_test.go
index 6ee6be1d6..7f9b7f692 100644
--- a/application/server/execute_command_test.go
+++ b/application/server/execute_command_test.go
@@ -122,6 +122,25 @@ func Test_executeWorkspaceScanCommand_shouldAskForTrust(t *testing.T) {
}, 2*time.Second, time.Millisecond)
}
+func Test_executeWorkspaceScanCommand_shouldAcceptScanSourceParam(t *testing.T) {
+ c := testutil.UnitTest(t)
+ loc, jsonRPCRecorder := setupServerWithCustomDI(t, c, false)
+
+ s := &scanner.TestScanner{}
+ c.Workspace().AddFolder(workspace.NewFolder(c, "dummy", "dummy", s, di.HoverService(), di.ScanNotifier(), di.Notifier(), di.ScanPersister(), di.ScanStateAggregator()))
+ // explicitly enable folder trust which is disabled by default in tests
+ config.CurrentConfig().SetTrustedFolderFeatureEnabled(true)
+
+ params := lsp.ExecuteCommandParams{Command: types.WorkspaceScanCommand, Arguments: []any{"LLM"}}
+ _, err := loc.Client.Call(ctx, "workspace/executeCommand", params)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Eventually(t, func() bool {
+ return s.Calls() == 0 && checkTrustMessageRequest(jsonRPCRecorder, c)
+ }, 2*time.Second, time.Millisecond)
+}
+
func Test_loginCommand_StartsAuthentication(t *testing.T) {
c := testutil.UnitTest(t)
loc, jsonRPCRecorder := setupServer(t, c)
diff --git a/application/server/server.go b/application/server/server.go
index 06885179d..825ecadc5 100644
--- a/application/server/server.go
+++ b/application/server/server.go
@@ -28,7 +28,6 @@ import (
"github.com/snyk/snyk-ls/domain/snyk"
"github.com/snyk/snyk-ls/domain/snyk/persistence"
- mcp2 "github.com/snyk/snyk-ls/internal/mcp"
"github.com/snyk/snyk-ls/internal/storedconfig"
"github.com/snyk/snyk-ls/domain/snyk/scanner"
@@ -78,25 +77,6 @@ func Start(c *config.Config) {
di.Init()
initHandlers(srv, handlers, c)
- // start mcp server
- logger.Info().Msg("Starting up MCP Server...")
- var mcpServer *mcp2.McpLLMBinding
- go func() {
- mcpServer = mcp2.NewMcpLLMBinding(c, mcp2.WithScanner(di.Scanner()), mcp2.WithLogger(c.Logger()))
- err := mcpServer.Start()
- if err != nil {
- c.Logger().Err(err).Msg("failed to start mcp server")
- }
- }()
-
- // shutdown mcp server once the lsp returns from wait status
- defer func() {
- if mcpServer != nil {
- logger.Info().Msg("Shutting down MCP Server...")
- mcpServer.Shutdown(context.Background())
- }
- }()
-
logger.Info().Msg("Starting up Language Server...")
srv = srv.Start(channel.Header("")(os.Stdin, os.Stdout))
status := srv.WaitStatus()
@@ -477,14 +457,8 @@ func initializedHandler(c *config.Config, srv *jrpc2.Server) handler.Func {
)
logger.Info().Msg(msg)
}
- defer func() {
- // delay sending the mcp server URL
- for c.GetMCPServerURL() == nil {
- // wait until the server URL is available
- time.Sleep(500 * time.Millisecond)
- }
- di.Notifier().Send(types.McpServerURLParams{URL: c.GetMCPServerURL().String()})
- }()
+ // this change is to avoid breaking current stable VS Code. URL value is not used.
+ di.Notifier().Send(types.McpServerURLParams{URL: "http://127.0.0.1:7695"})
return nil, nil
})
}
diff --git a/application/server/server_test.go b/application/server/server_test.go
index 151ded2d0..140154df0 100644
--- a/application/server/server_test.go
+++ b/application/server/server_test.go
@@ -18,7 +18,6 @@ package server
import (
"context"
- "net/url"
"os"
"os/exec"
"path/filepath"
@@ -238,44 +237,6 @@ func Test_initialized_shouldCheckRequiredProtocolVersion(t *testing.T) {
"did not receive callback because of wrong protocol version")
}
-func Test_initialized_shouldSendMcpServerAddress(t *testing.T) {
- c := testutil.UnitTest(t)
- loc, jsonRpcRecorder := setupServer(t, c)
-
- params := types.InitializeParams{
- InitializationOptions: types.Settings{RequiredProtocolVersion: config.LsProtocolVersion},
- }
-
- rsp, err := loc.Client.Call(ctx, "initialize", params)
- require.NoError(t, err)
- var result types.InitializeResult
- err = rsp.UnmarshalResult(&result)
- require.NoError(t, err)
-
- testURL, err := url.Parse("http://localhost:1234")
- require.NoError(t, err)
-
- c.SetMCPServerURL(testURL)
-
- _, err = loc.Client.Call(ctx, "initialized", params)
- require.NoError(t, err)
- require.Eventuallyf(t, func() bool {
- n := jsonRpcRecorder.FindNotificationsByMethod("$/snyk.mcpServerURL")
- if n == nil {
- return false
- }
- if len(n) > 1 {
- t.Fatal("can't succeed anymore, too many notifications ", n)
- }
-
- var param types.McpServerURLParams
- err = n[0].UnmarshalParams(¶m)
- require.NoError(t, err)
- return param.URL == testURL.String()
- }, time.Minute*5, time.Millisecond,
- "did not receive mcp server url")
-}
-
func Test_initialize_shouldSupportAllCommands(t *testing.T) {
c := testutil.UnitTest(t)
loc, _ := setupServer(t, c)
diff --git a/domain/ide/command/execute_mcp_test.go b/domain/ide/command/execute_mcp_test.go
deleted file mode 100644
index ac173ed1c..000000000
--- a/domain/ide/command/execute_mcp_test.go
+++ /dev/null
@@ -1,105 +0,0 @@
-//go:build !race
-// +build !race
-
-/*
- * © 2024 Snyk Limited
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package command
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "github.com/snyk/snyk-ls/domain/ide/hover"
- "github.com/snyk/snyk-ls/domain/ide/workspace"
- "github.com/snyk/snyk-ls/domain/scanstates"
- "github.com/snyk/snyk-ls/domain/snyk/persistence"
- "github.com/snyk/snyk-ls/domain/snyk/scanner"
- "github.com/snyk/snyk-ls/internal/mcp"
- noti "github.com/snyk/snyk-ls/internal/notification"
- "github.com/snyk/snyk-ls/internal/observability/performance"
- "github.com/snyk/snyk-ls/internal/testutil"
- "github.com/snyk/snyk-ls/internal/types"
-)
-
-func Test_executeMcpCallCommand(t *testing.T) {
- c := testutil.UnitTest(t)
- c.SetAutomaticScanning(false)
-
- // start mcp server
- sc := scanner.NewTestScanner()
- scanNotifier := scanner.NewMockScanNotifier()
- hoverService := hover.NewFakeHoverService()
- notifier := noti.NewMockNotifier()
- emitter := scanstates.NewSummaryEmitter(c, notifier)
- scanStateAggregator := scanstates.NewScanStateAggregator(c, emitter)
- scanPersister := persistence.NewNopScanPersister()
- w := workspace.New(c, performance.NewInstrumentor(), sc, hoverService, scanNotifier, notifier, scanPersister, scanStateAggregator)
-
- f := workspace.NewFolder(
- c,
- "testPath",
- "test",
- sc,
- hoverService,
- scanNotifier,
- notifier,
- scanPersister,
- scanStateAggregator,
- )
-
- w.AddFolder(f)
- c.SetWorkspace(w)
-
- mcpBinding := mcp.NewMcpLLMBinding(c, mcp.WithLogger(c.Logger()), mcp.WithScanner(sc))
- go func() {
- _ = mcpBinding.Start()
- }()
-
- t.Cleanup(func() {
- timeout, cancelFunc := context.WithTimeout(context.Background(), time.Second)
- mcpBinding.Shutdown(timeout)
- defer cancelFunc()
- })
-
- // wait for mcp server to start
- assert.Eventually(t, func() bool {
- return mcpBinding.Started()
- }, time.Minute, time.Millisecond)
-
- // create command
- command := executeMcpCallCommand{
- command: types.CommandData{
- Title: "Execute Snyk Scan",
- CommandId: types.ExecuteMCPToolCall,
- Arguments: []any{mcp.SnykScanWorkspaceScan},
- },
- notifier: noti.NewMockNotifier(),
- logger: c.Logger(),
- baseURL: c.GetMCPServerURL().String(),
- }
-
- // execute command
- _, err := command.Execute(context.Background())
- require.NoError(t, err)
- require.Eventuallyf(t, func() bool {
- return sc.Calls() > 0
- }, time.Minute, time.Millisecond, "should have called the scanner")
-}
diff --git a/domain/ide/command/workspace_scan.go b/domain/ide/command/workspace_scan.go
index 7bb82d1e4..b372ba8c4 100644
--- a/domain/ide/command/workspace_scan.go
+++ b/domain/ide/command/workspace_scan.go
@@ -20,6 +20,7 @@ import (
"context"
"github.com/snyk/snyk-ls/application/config"
+ context2 "github.com/snyk/snyk-ls/internal/context"
"github.com/snyk/snyk-ls/internal/types"
)
@@ -36,7 +37,27 @@ func (cmd *workspaceScanCommand) Command() types.CommandData {
func (cmd *workspaceScanCommand) Execute(ctx context.Context) (any, error) {
w := cmd.c.Workspace()
w.Clear()
- w.ScanWorkspace(ctx)
- HandleUntrustedFolders(ctx, cmd.c, cmd.srv)
+ args := cmd.command.Arguments
+ enrichedCtx := cmd.enrichContextWithScanSource(ctx, args)
+ w.ScanWorkspace(enrichedCtx)
+ HandleUntrustedFolders(enrichedCtx, cmd.c, cmd.srv)
return nil, nil
}
+
+func (cmd *workspaceScanCommand) enrichContextWithScanSource(ctx context.Context, args []any) context.Context {
+ if len(args) == 0 {
+ return ctx
+ }
+
+ sc, ok := args[0].(string)
+ if !ok {
+ return ctx
+ }
+
+ if sc != context2.IDE.String() && sc != context2.LLM.String() {
+ return ctx
+ }
+
+ scanSource := context2.ScanSource(sc)
+ return context2.NewContextWithScanSource(ctx, scanSource)
+}
diff --git a/go.mod b/go.mod
index b19e48ae6..29e0eff67 100644
--- a/go.mod
+++ b/go.mod
@@ -16,7 +16,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gosimple/hashdir v1.0.2
github.com/hexops/gotextdiff v1.0.3
- github.com/mark3labs/mcp-go v0.8.4
+ github.com/mark3labs/mcp-go v0.18.0
github.com/otiai10/copy v1.14.1
github.com/pact-foundation/pact-go v1.10.0
github.com/pingcap/errors v0.11.4
@@ -27,8 +27,8 @@ require (
github.com/rs/zerolog v1.33.0
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
github.com/shirou/gopsutil v3.21.11+incompatible
- github.com/snyk/code-client-go v1.16.2
- github.com/snyk/go-application-framework v0.0.0-20250307155453-ce7aaf72fe7d
+ github.com/snyk/code-client-go v1.20.1
+ github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f
github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd
github.com/spf13/pflag v1.0.6
github.com/stretchr/testify v1.10.0
@@ -122,6 +122,7 @@ require (
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
+ github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/crypto v0.35.0 // indirect
golang.org/x/sys v0.30.0 // indirect
diff --git a/go.sum b/go.sum
index 593adecd6..7369fb6e7 100644
--- a/go.sum
+++ b/go.sum
@@ -216,8 +216,8 @@ github.com/magiconair/properties v1.8.6 h1:5ibWZ6iY0NctNGWo87LalDlEZ6R41TqbbDamh
github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
-github.com/mark3labs/mcp-go v0.8.4 h1:/VxjJ0+4oN2eYLuAgVzixrYNfrmwJnV38EfPIX3VbPE=
-github.com/mark3labs/mcp-go v0.8.4/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE=
+github.com/mark3labs/mcp-go v0.18.0 h1:YuhgIVjNlTG2ZOwmrkORWyPTp0dz1opPEqvsPtySXao=
+github.com/mark3labs/mcp-go v0.18.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE=
github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo=
github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
@@ -318,12 +318,12 @@ github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMT
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8=
github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY=
-github.com/snyk/code-client-go v1.16.2 h1:07vMMpl+qUWYz8IJf+13CTL3lRSjKN1TdBk9UDhqxvI=
-github.com/snyk/code-client-go v1.16.2/go.mod h1:WH6lNkJc785hfXmwhixxWHix3O6z+1zwz40oK8vl/zg=
+github.com/snyk/code-client-go v1.20.1 h1:38nEGzrQIh/aVLjR99jiTUQM0sL9SQAvhMfZGmd9G0w=
+github.com/snyk/code-client-go v1.20.1/go.mod h1:WH6lNkJc785hfXmwhixxWHix3O6z+1zwz40oK8vl/zg=
github.com/snyk/error-catalog-golang-public v0.0.0-20250218074309-307ad7b38a60 h1:iB6z2BhBpfN9p0/dEZfwWvs7fpdZk3loooAih8yspS8=
github.com/snyk/error-catalog-golang-public v0.0.0-20250218074309-307ad7b38a60/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4=
-github.com/snyk/go-application-framework v0.0.0-20250307155453-ce7aaf72fe7d h1:eC0V150YUGvO3mEwXBMELQXzWGHFngxD1Y5fAtzWQC0=
-github.com/snyk/go-application-framework v0.0.0-20250307155453-ce7aaf72fe7d/go.mod h1:oWN7a1ud3u5y8HxW+Qdroy+ofEEA8s0MwVAtb8qq/v4=
+github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f h1:1EPrRhLQ5Bo0SmIqoAU38Et1Bv2klCbyfgLmVJfUyvM=
+github.com/snyk/go-application-framework v0.0.0-20250325133828-3ffd1aa4f76f/go.mod h1:A7oFVjMjNukzsMeiIWXEXjCrAf2ARvoK4aQOm9e3E/Y=
github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530 h1:s9PHNkL6ueYRiAKNfd8OVxlUOqU3qY0VDbgCD1f6WQY=
github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530/go.mod h1:88KbbvGYlmLgee4OcQ19yr0bNpXpOr2kciOthaSzCAg=
github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd h1:Dq5WSzWsP1TbVi10zPWBI5LKEBDg4Y1OhWEph1wr5WQ=
@@ -395,6 +395,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
+github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
+github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
diff --git a/infrastructure/code/ai_fix_handler.go b/infrastructure/code/ai_fix_handler.go
index dfa2b0b15..ca297e79d 100644
--- a/infrastructure/code/ai_fix_handler.go
+++ b/infrastructure/code/ai_fix_handler.go
@@ -94,12 +94,11 @@ func (fixHandler *AiFixHandler) EnrichWithExplain(ctx context.Context, c *config
deepCodeLLMBinding := llm.NewDeepcodeLLMBinding(
llm.WithLogger(c.Logger()),
llm.WithOutputFormat(llm.HTML),
- llm.WithEndpoint(getExplainEndpoint(c)),
llm.WithHTTPClient(func() codeClientHTTP.HTTPClient {
return c.Engine().GetNetworkAccess().GetHttpClient()
}),
)
- explanations, err := deepCodeLLMBinding.ExplainWithOptions(contextWithCancel, llm.ExplainOptions{RuleKey: issue.GetID(), Diffs: diffs})
+ explanations, err := deepCodeLLMBinding.ExplainWithOptions(contextWithCancel, llm.ExplainOptions{RuleKey: issue.GetID(), Diffs: diffs, Endpoint: getExplainEndpoint(c)})
if err != nil {
logger.Error().Err(err).Msgf("Failed to explain with explain for issue %s", issue.GetID())
return
diff --git a/internal/mcp/llm_binding.go b/internal/mcp/llm_binding.go
index 193f4589e..7f9c979fe 100644
--- a/internal/mcp/llm_binding.go
+++ b/internal/mcp/llm_binding.go
@@ -21,35 +21,38 @@ import (
"fmt"
"net/http"
"net/url"
+ "os"
+ "strings"
"sync"
"time"
"github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"github.com/rs/zerolog"
+ "github.com/snyk/go-application-framework/pkg/configuration"
+ "github.com/snyk/go-application-framework/pkg/workflow"
+)
- "github.com/snyk/snyk-ls/application/config"
- "github.com/snyk/snyk-ls/internal/types"
+const (
+ SseTransportType string = "sse"
+ StdioTransportType string = "stdio"
)
// McpLLMBinding is an implementation of a mcp server that allows interaction between
// a given SnykLLMBinding and a CommandService.
type McpLLMBinding struct {
- c *config.Config
- scanner types.Scanner
- logger *zerolog.Logger
- mcpServer *server.MCPServer
- sseServer *server.SSEServer
- baseURL *url.URL
- forwardingResultProcessor types.ScanResultProcessor
- mutex sync.Mutex
- started bool
+ logger *zerolog.Logger
+ mcpServer *server.MCPServer
+ sseServer *server.SSEServer
+ baseURL *url.URL
+ mutex sync.RWMutex
+ started bool
+ cliPath string
}
-func NewMcpLLMBinding(c *config.Config, opts ...McpOption) *McpLLMBinding {
+func NewMcpLLMBinding(opts ...Option) *McpLLMBinding {
logger := zerolog.Nop()
mcpServerImpl := &McpLLMBinding{
- c: c,
logger: &logger,
}
@@ -70,30 +73,59 @@ func defaultURL() *url.URL {
}
// Start starts the MCP server. It blocks until the server is stopped via Shutdown.
-func (m *McpLLMBinding) Start() error {
- // protect critical assignments with mutex
- m.mutex.Lock()
+func (m *McpLLMBinding) Start(invocationContext workflow.InvocationContext) error {
+ runTimeInfo := invocationContext.GetRuntimeInfo()
+ version := ""
+ if runTimeInfo != nil {
+ version = runTimeInfo.GetVersion()
+ }
m.mcpServer = server.NewMCPServer(
"Snyk MCP Server",
- config.Version,
+ version,
server.WithLogging(),
server.WithResourceCapabilities(true, true),
server.WithPromptCapabilities(true),
)
- err := m.addSnykScanTool()
+ err := m.addSnykTools(invocationContext)
if err != nil {
- m.mutex.Unlock()
return err
}
+ transportType := invocationContext.GetConfiguration().GetString("transport")
+ if transportType == StdioTransportType {
+ return m.HandleStdioServer()
+ } else if transportType == SseTransportType {
+ return m.HandleSseServer()
+ } else {
+ return fmt.Errorf("invalid transport type: %s", transportType)
+ }
+}
+
+func (m *McpLLMBinding) HandleStdioServer() error {
+ m.mutex.Lock()
+ m.started = true
+ m.mutex.Unlock()
+
+ err := server.ServeStdio(m.mcpServer)
+
+ if err != nil {
+ m.logger.Error().Err(err).Msg("Error starting MCP Stdio server")
+ return err
+ }
+
+ return nil
+}
+
+func (m *McpLLMBinding) HandleSseServer() error {
// listen on default url/port if none was configured
if m.baseURL == nil {
m.baseURL = defaultURL()
}
- m.sseServer = server.NewSSEServer(m.mcpServer, m.baseURL.String())
- m.mutex.Unlock()
+ m.sseServer = server.NewSSEServer(m.mcpServer, server.WithBaseURL(m.baseURL.String()))
+
+ _, _ = fmt.Fprintf(os.Stderr, "Starting with base URL %s\n", m.baseURL.String())
m.logger.Info().Str("baseURL", m.baseURL.String()).Msg("starting")
go func() {
@@ -106,10 +138,16 @@ func (m *McpLLMBinding) Start() error {
m.mutex.Lock()
m.logger.Info().Str("baseURL", m.baseURL.String()).Msg("started")
m.started = true
- m.c.SetMCPServerURL(m.baseURL)
m.mutex.Unlock()
}()
- err = m.sseServer.Start(m.baseURL.Host)
+
+ srv := &http.Server{
+ Addr: m.baseURL.Host,
+ Handler: middleware(m.sseServer),
+ }
+
+ err := srv.ListenAndServe()
+
if err != nil {
// expect http.ErrServerClosed when shutting down
if !errors.Is(err, http.ErrServerClosed) {
@@ -120,18 +158,67 @@ func (m *McpLLMBinding) Start() error {
return nil
}
+var allowedHostnames = map[string]bool{
+ "localhost": true,
+ "127.0.0.1": true,
+ "::1": true,
+}
+
+func middleware(sseServer *server.SSEServer) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ originHeader := r.Header.Get("Origin")
+ isValidOrigin := originHeader == ""
+
+ if originHeader != "" {
+ parsedOrigin, err := url.Parse(originHeader)
+ if err == nil {
+ requestHost := parsedOrigin.Hostname()
+ if _, allowed := allowedHostnames[requestHost]; allowed {
+ isValidOrigin = true
+ }
+ }
+ }
+
+ if isValidOrigin {
+ sseServer.ServeHTTP(w, r)
+ } else {
+ http.Error(w, "Forbidden: Access restricted to localhost origins", http.StatusForbidden)
+ }
+ })
+}
+
func (m *McpLLMBinding) Shutdown(ctx context.Context) {
m.mutex.Lock()
defer m.mutex.Unlock()
- err := m.sseServer.Shutdown(ctx)
- if err != nil {
- m.logger.Error().Err(err).Msg("Error shutting down MCP SSE server")
+ if m.sseServer != nil {
+ err := m.sseServer.Shutdown(ctx)
+ if err != nil {
+ m.logger.Error().Err(err).Msg("Error shutting down MCP SSE server")
+ }
}
}
func (m *McpLLMBinding) Started() bool {
- m.mutex.Lock()
- defer m.mutex.Unlock()
+ m.mutex.RLock()
+ defer m.mutex.RUnlock()
+
return m.started
}
+
+func (m *McpLLMBinding) expandedEnv(version string) []string {
+ environ := os.Environ()
+ var expandedEnv = []string{}
+ for _, v := range environ {
+ if strings.HasPrefix(strings.ToLower(v), strings.ToLower(configuration.INTEGRATION_NAME)) {
+ continue
+ }
+ if strings.HasPrefix(strings.ToLower(v), strings.ToLower(configuration.INTEGRATION_VERSION)) {
+ continue
+ }
+ expandedEnv = append(expandedEnv, v)
+ }
+ expandedEnv = append(expandedEnv, configuration.INTEGRATION_NAME+"=MCP")
+ expandedEnv = append(expandedEnv, fmt.Sprintf("%s=%s", configuration.INTEGRATION_VERSION, version))
+ return expandedEnv
+}
diff --git a/internal/mcp/llm_binding_smoke_test.go b/internal/mcp/llm_binding_smoke_test.go
deleted file mode 100644
index bea283253..000000000
--- a/internal/mcp/llm_binding_smoke_test.go
+++ /dev/null
@@ -1,114 +0,0 @@
-//go:build !race
-// +build !race
-
-/*
- * © 2025 Snyk Limited
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package mcp
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/mark3labs/mcp-go/client"
- "github.com/mark3labs/mcp-go/mcp"
- "github.com/stretchr/testify/assert"
-
- "github.com/snyk/snyk-ls/domain/ide/hover"
- "github.com/snyk/snyk-ls/domain/ide/workspace"
- "github.com/snyk/snyk-ls/domain/scanstates"
- "github.com/snyk/snyk-ls/domain/snyk/persistence"
- "github.com/snyk/snyk-ls/domain/snyk/scanner"
- noti "github.com/snyk/snyk-ls/internal/notification"
- "github.com/snyk/snyk-ls/internal/observability/performance"
- "github.com/snyk/snyk-ls/internal/testutil"
-)
-
-// Test_WorkspaceScan does not run in race mode, due to races in the underlying framework
-func Test_WorkspaceScan(t *testing.T) {
- c := testutil.SmokeTest(t, false)
- sc := scanner.NewTestScanner()
- scanNotifier := scanner.NewMockScanNotifier()
- hoverService := hover.NewFakeHoverService()
- notifier := noti.NewMockNotifier()
- emitter := scanstates.NewSummaryEmitter(c, notifier)
- scanStateAggregator := scanstates.NewScanStateAggregator(c, emitter)
- scanPersister := persistence.NewNopScanPersister()
- w := workspace.New(c, performance.NewInstrumentor(), sc, hoverService, scanNotifier, notifier, scanPersister, scanStateAggregator)
-
- f := workspace.NewFolder(
- c,
- "testPath",
- "test",
- sc,
- hoverService,
- scanNotifier,
- notifier,
- scanPersister,
- scanStateAggregator,
- )
-
- w.AddFolder(f)
-
- c.SetWorkspace(w)
- server := NewMcpLLMBinding(c, WithScanner(sc), WithLogger(c.Logger()))
-
- go func() {
- err := server.Start()
- assert.NoError(t, err)
- }()
-
- assert.Eventually(t, func() bool {
- server.mutex.Lock()
- defer server.mutex.Unlock()
- portInUse := isPortInUse(server.baseURL)
- return portInUse && server.baseURL == c.GetMCPServerURL()
- }, time.Minute, time.Second)
-
- clientEndpoint := server.baseURL.String() + "/sse"
-
- mcpClient, err := client.NewSSEMCPClient(clientEndpoint)
- assert.NoError(t, err)
- defer mcpClient.Close()
-
- // start
- err = mcpClient.Start(context.Background())
- assert.NoError(t, err)
-
- // initialize
- initRequest := mcp.InitializeRequest{}
- initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
- initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "example-client",
- Version: "1.0.0",
- }
-
- _, err = mcpClient.Initialize(context.Background(), initRequest)
- assert.NoError(t, err)
-
- toolsRequest := mcp.ListToolsRequest{}
- tools, err := mcpClient.ListTools(context.Background(), toolsRequest)
- assert.NoError(t, err)
- assert.Len(t, tools.Tools, 1)
-
- scanRequest := mcp.CallToolRequest{}
- scanRequest.Params.Name = SnykScanWorkspaceScan
-
- result, err := mcpClient.CallTool(context.Background(), scanRequest)
- assert.NoError(t, err)
- assert.NotNil(t, result)
-}
diff --git a/internal/mcp/llm_binding_test.go b/internal/mcp/llm_binding_test.go
index 2d0dce314..2d4e415dc 100644
--- a/internal/mcp/llm_binding_test.go
+++ b/internal/mcp/llm_binding_test.go
@@ -17,29 +17,25 @@ package mcp
import (
"net/url"
+ "os"
+ "strings"
"testing"
+ "github.com/snyk/go-application-framework/pkg/configuration"
"github.com/stretchr/testify/assert"
-
- "github.com/snyk/snyk-ls/domain/snyk/scanner"
- "github.com/snyk/snyk-ls/internal/testutil"
)
func TestNewMcpServer(t *testing.T) {
- c := testutil.UnitTest(t)
- mcpServer := NewMcpLLMBinding(c)
+ mcpServer := NewMcpLLMBinding()
assert.NotNil(t, mcpServer)
assert.NotNil(t, mcpServer.logger)
}
func TestNewMcpServerWithOptions(t *testing.T) {
- c := testutil.UnitTest(t)
baseURL, _ := url.Parse("http://test:8080")
- s := scanner.NewTestScanner()
- mcpServer := NewMcpLLMBinding(c, WithScanner(s), WithBaseURL(baseURL))
+ mcpServer := NewMcpLLMBinding(WithBaseURL(baseURL))
- assert.Equal(t, s, mcpServer.scanner)
assert.Equal(t, baseURL, mcpServer.baseURL)
}
@@ -49,3 +45,24 @@ func TestDefaultURL(t *testing.T) {
assert.Equal(t, "http", u.Scheme)
assert.Contains(t, u.Host, DefaultHost)
}
+
+func TestExpandedEnv(t *testing.T) {
+ t.Setenv(configuration.INTEGRATION_NAME, "abc")
+ t.Setenv(configuration.INTEGRATION_VERSION, "abc")
+ binding := NewMcpLLMBinding()
+
+ env := binding.expandedEnv("1.x.1")
+
+ for _, s := range os.Environ() {
+ if strings.HasPrefix(s, configuration.INTEGRATION_NAME) {
+ continue
+ }
+ if strings.HasPrefix(s, configuration.INTEGRATION_VERSION) {
+ continue
+ }
+ assert.Contains(t, env, s)
+ }
+
+ assert.Contains(t, env, configuration.INTEGRATION_NAME+"=MCP")
+ assert.Contains(t, env, configuration.INTEGRATION_VERSION+"=1.x.1")
+}
diff --git a/internal/mcp/options.go b/internal/mcp/options.go
index c4313c70b..b0e0e232b 100644
--- a/internal/mcp/options.go
+++ b/internal/mcp/options.go
@@ -20,33 +20,25 @@ import (
"net/url"
"github.com/rs/zerolog"
-
- "github.com/snyk/snyk-ls/internal/types"
)
-type McpOption func(server *McpLLMBinding)
-
-func WithScanner(scanner types.Scanner) McpOption {
- return func(server *McpLLMBinding) {
- server.scanner = scanner
- }
-}
+type Option func(server *McpLLMBinding)
-func WithLogger(logger *zerolog.Logger) McpOption {
+func WithLogger(logger *zerolog.Logger) Option {
return func(server *McpLLMBinding) {
l := logger.With().Str("component", "mcp").Logger()
server.logger = &l
}
}
-func WithBaseURL(baseURL *url.URL) func(server *McpLLMBinding) {
+func WithCliPath(cliPath string) Option {
return func(server *McpLLMBinding) {
- server.baseURL = baseURL
+ server.cliPath = cliPath
}
}
-func WithScanResultProcessor(proc types.ScanResultProcessor) func(server *McpLLMBinding) {
+func WithBaseURL(baseURL *url.URL) func(server *McpLLMBinding) {
return func(server *McpLLMBinding) {
- server.forwardingResultProcessor = proc
+ server.baseURL = baseURL
}
}
diff --git a/internal/mcp/scan_tool.go b/internal/mcp/scan_tool.go
index d0efd7dc0..f7a8939ce 100644
--- a/internal/mcp/scan_tool.go
+++ b/internal/mcp/scan_tool.go
@@ -1,5 +1,5 @@
/*
- * © 2025 Snyk Limited
+ * 2025 Snyk Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,68 +18,141 @@ package mcp
import (
"context"
+ _ "embed"
+ "encoding/json"
+ "fmt"
+ "os/exec"
"github.com/mark3labs/mcp-go/mcp"
+ "github.com/snyk/go-application-framework/pkg/workflow"
+)
- ctx2 "github.com/snyk/snyk-ls/internal/context"
- "github.com/snyk/snyk-ls/internal/types"
+// Tool name constants to maintain backward compatibility
+const (
+ SnykScaTest = "snyk_sca_test"
+ SnykCodeTest = "snyk_code_test"
+ SnykVersion = "snyk_version"
+ SnykAuth = "snyk_auth"
+ SnykAuthStatus = "snyk_auth_status"
+ SnykLogout = "snyk_logout"
)
-const SnykScanWorkspaceScan = types.WorkspaceScanCommand
+type SnykMcpToolsDefinition struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Command []string `json:"command"`
+ StandardParams []string `json:"standardParams"`
+ Params []SnykMcpToolParameter `json:"params"`
+}
+
+type SnykMcpToolParameter struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ IsRequired bool `json:"isRequired"`
+ Description string `json:"description"`
+}
-func (m *McpLLMBinding) addSnykScanTool() error {
- tool := mcp.NewTool(SnykScanWorkspaceScan,
- mcp.WithDescription("Perform Snyk scans on current workspace"),
- )
+//go:embed snyk_tools.json
+var snykToolsJson string
- m.mcpServer.AddTool(tool, m.snykWorkSpaceScanHandler())
+type SnykMcpTools struct {
+ Tools []SnykMcpToolsDefinition `json:"tools"`
+}
+
+func loadMcpToolsFromJson() (*SnykMcpTools, error) {
+ var config SnykMcpTools
+ if err := json.Unmarshal([]byte(snykToolsJson), &config); err != nil {
+ return nil, fmt.Errorf("failed to parse config file: %w", err)
+ }
+
+ return &config, nil
+}
+
+func (m *McpLLMBinding) addSnykTools(invocationCtx workflow.InvocationContext) error {
+ config, err := loadMcpToolsFromJson()
+ if err != nil || config == nil {
+ m.logger.Err(err).Msg("Failed to load Snyk tools configuration")
+ return err
+ }
+
+ for _, toolDef := range config.Tools {
+ tool := createToolFromDefinition(&toolDef)
+ switch toolDef.Name {
+ case SnykLogout:
+ m.mcpServer.AddTool(tool, m.snykLogoutHandler(invocationCtx, toolDef))
+ default:
+ m.mcpServer.AddTool(tool, m.defaultHandler(invocationCtx, toolDef))
+ }
+ }
return nil
}
-func (m *McpLLMBinding) snykWorkSpaceScanHandler() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+// runSnyk runs a Snyk command and returns the result
+func (m *McpLLMBinding) runSnyk(ctx context.Context, invocationCtx workflow.InvocationContext, workingDir string, cmd []string) (string, error) {
+ command := exec.CommandContext(ctx, cmd[0], cmd[1:]...)
+
+ if workingDir != "" {
+ command.Dir = workingDir
+ }
+ runtimeInfo := invocationCtx.GetRuntimeInfo()
+ if runtimeInfo != nil {
+ command.Env = m.expandedEnv(runtimeInfo.GetVersion())
+ } else {
+ command.Env = m.expandedEnv("unknown")
+ }
+ command.Stderr = invocationCtx.GetEnhancedLogger()
+ res, err := command.Output()
+
+ resAsString := string(res)
+ if err != nil {
+ m.logger.Err(err).Msg("Failed to execute command")
+ }
+ return resAsString, nil
+}
+
+// defaultHandler creates a generic handler for Snyk commands that applies standard parameters
+func (m *McpLLMBinding) defaultHandler(invocationCtx workflow.InvocationContext, toolDef SnykMcpToolsDefinition) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- w := m.c.Workspace()
- trusted, _ := w.GetFolderTrust()
+ params, workingDir := extractParamsFromRequestArgs(toolDef, request.Params.Arguments)
- callToolResult := &mcp.CallToolResult{
- Content: make([]interface{}, 0),
+ // Apply standard parameters from tool definition
+ // e.g. all_projects and json
+ for _, paramName := range toolDef.StandardParams {
+ cliParamName := convertToCliParam(paramName)
+ params[cliParamName] = true
}
- resultProcessor := func(ctx context.Context, data types.ScanData) {
- // add the scan results to the call tool response
- // in the future, this could be a rendered markdown/html template
- callToolResult.Content = append(callToolResult.Content, data)
- if data.Err != nil {
- callToolResult.IsError = true
- }
-
- // standard processing for the folder
- scanResultProcessor := folderScanResultProcessor(w, data.Path)
- if scanResultProcessor != nil {
- scanResultProcessor(ctx, data)
- }
-
- // forward to forwarding processor
- if m.forwardingResultProcessor != nil {
- m.forwardingResultProcessor(ctx, data)
- }
+ // Handle regular commands
+ if len(toolDef.Command) == 0 {
+ return nil, fmt.Errorf("empty command in tool definition for %s", toolDef.Name)
}
- enrichedContext := ctx2.NewContextWithScanSource(ctx, ctx2.LLM)
- for _, folder := range trusted {
- m.scanner.Scan(enrichedContext, folder.Path(), resultProcessor, folder.Path())
+ args := buildArgs(m.cliPath, toolDef.Command, params)
+
+ // Add working directory if specified
+ if workingDir != "" {
+ args = append(args, workingDir)
}
- return callToolResult, nil
+ // Run the command
+ output, err := m.runSnyk(ctx, invocationCtx, workingDir, args)
+ if err != nil {
+ return nil, err
+ }
+ return mcp.NewToolResultText(output), nil
}
}
-func folderScanResultProcessor(w types.Workspace, path types.FilePath) types.ScanResultProcessor {
- folder := w.GetFolderContaining(path)
- if folder == nil {
- return nil
+func (m *McpLLMBinding) snykLogoutHandler(invocationCtx workflow.InvocationContext, _ SnykMcpToolsDefinition) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ // Special handling for logout which needs multiple commands
+ params := []string{m.cliPath, "config", "unset", "INTERNAL_OAUTH_TOKEN_STORAGE"}
+ _, _ = m.runSnyk(ctx, invocationCtx, "", params)
+
+ params = []string{m.cliPath, "config", "unset", "token"}
+ _, _ = m.runSnyk(ctx, invocationCtx, "", params)
+
+ return mcp.NewToolResultText("Successfully logged out"), nil
}
- scanResultProcessor := folder.ScanResultProcessor()
- return scanResultProcessor
}
diff --git a/internal/mcp/scan_tool_test.go b/internal/mcp/scan_tool_test.go
new file mode 100644
index 000000000..255bdd0e3
--- /dev/null
+++ b/internal/mcp/scan_tool_test.go
@@ -0,0 +1,832 @@
+/*
+* 2025 Snyk Limited
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+ */
+
+package mcp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/rs/zerolog"
+ "github.com/snyk/go-application-framework/pkg/configuration"
+ "github.com/snyk/go-application-framework/pkg/mocks"
+ "github.com/snyk/go-application-framework/pkg/runtimeinfo"
+ "github.com/snyk/go-application-framework/pkg/workflow"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type testFixture struct {
+ t *testing.T
+ mockEngine *mocks.MockEngine
+ binding *McpLLMBinding
+ snykCliPath string
+ invocationContext *mocks.MockInvocationContext
+ tools *SnykMcpTools
+}
+
+func SetupEngineMock(t *testing.T) (*mocks.MockEngine, configuration.Configuration) {
+ t.Helper()
+ ctrl := gomock.NewController(t)
+ mockEngine := mocks.NewMockEngine(ctrl)
+ engineConfig := configuration.NewWithOpts(configuration.WithAutomaticEnv())
+ mockEngine.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes()
+ return mockEngine, engineConfig
+}
+
+func setupTestFixture(t *testing.T) *testFixture {
+ t.Helper()
+ engine, engineConfig := SetupEngineMock(t)
+ logger := zerolog.New(io.Discard)
+
+ mockctl := gomock.NewController(t)
+ storage := mocks.NewMockStorage(mockctl)
+ engineConfig.SetStorage(storage)
+
+ invocationCtx := mocks.NewMockInvocationContext(mockctl)
+ invocationCtx.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes()
+ invocationCtx.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
+ invocationCtx.EXPECT().GetRuntimeInfo().Return(runtimeinfo.New(runtimeinfo.WithName("hurz"), runtimeinfo.WithVersion("1000.8.3"))).AnyTimes()
+
+ // Snyk CLI mock
+ tempDir := t.TempDir()
+ snykCliPath := filepath.Join(tempDir, "snyk")
+ if runtime.GOOS == "windows" {
+ snykCliPath += ".bat"
+ }
+
+ // Create a default mock CLI that just echoes the command
+ defaultMockResponse := "{\"ok\": true}"
+ createMockSnykCli(t, snykCliPath, defaultMockResponse)
+
+ // Create the binding
+ binding := NewMcpLLMBinding(WithCliPath(snykCliPath), WithLogger(invocationCtx.GetEnhancedLogger()))
+ binding.mcpServer = server.NewMCPServer("Snyk", "1.1.1")
+ tools, err := loadMcpToolsFromJson()
+ assert.NoError(t, err)
+ return &testFixture{
+ t: t,
+ mockEngine: engine,
+ binding: binding,
+ snykCliPath: snykCliPath,
+ invocationContext: invocationCtx,
+ tools: tools,
+ }
+}
+
+func (f *testFixture) mockCliOutput(output string) {
+ createMockSnykCli(f.t, f.snykCliPath, output)
+}
+
+func getToolWithName(t *testing.T, tools *SnykMcpTools, toolName string) *SnykMcpToolsDefinition {
+ t.Helper()
+ for _, tool := range tools.Tools {
+ if tool.Name == toolName {
+ return &tool
+ }
+ }
+ return nil
+}
+
+func TestMcpSnykToolRegistration(t *testing.T) {
+ fixture := setupTestFixture(t)
+ err := fixture.binding.addSnykTools(fixture.invocationContext)
+ assert.NoError(t, err)
+}
+
+func TestSnykTestHandler(t *testing.T) {
+ // Setup
+ fixture := setupTestFixture(t)
+
+ // Configure mock CLI to return a specific JSON response
+ mockOutput := `{ok": false,"vulnerabilities": [{"id": "SNYK-JS-ACORN-559469","title": "Regular Expression Denial of Service (ReDoS)","severity":"high","packageName": "acorn"},{"id": "SNYK-JS-TUNNELAGENT-1572284","title": "Uninitialized Memory Exposure","severity": "medium","packageName": "tunnel-agent"}],"dependencyCount": 42,"packageManager": "npm"}`
+ fixture.mockCliOutput(mockOutput)
+ tool := getToolWithName(t, fixture.tools, SnykScaTest)
+ assert.NotNil(t, tool)
+ // Create the handler
+ handler := fixture.binding.defaultHandler(fixture.invocationContext, *tool)
+
+ tmpDir := t.TempDir()
+ // Define test cases
+ testCases := []struct {
+ name string
+ args map[string]interface{}
+ expectedParams []string
+ }{
+ {
+ name: "Basic SCA Test",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "all_projects": true,
+ "json": true,
+ },
+ expectedParams: []string{"--all-projects", "--json"},
+ },
+ {
+ name: "Test with Organization",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "all_projects": true,
+ "json": true,
+ "org": "my-snyk-org",
+ },
+ expectedParams: []string{"--all-projects", "--json", "--org=my-snyk-org"},
+ },
+ {
+ name: "Test with Severity Threshold",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "all_projects": false,
+ "json": true,
+ "severity_threshold": "high",
+ },
+ expectedParams: []string{"--json", "--severity-threshold=high"},
+ },
+ {
+ name: "Test with Multiple Options",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "all_projects": true,
+ "json": true,
+ "severity_threshold": "medium",
+ "dev": true,
+ "skip_unresolved": true,
+ "prune_repeated_subdependencies": true,
+ "fail_on": "upgradable",
+ },
+ expectedParams: []string{
+ "--all-projects", "--json", "--severity-threshold=medium",
+ "--dev", "--skip-unresolved", "--prune-repeated-subdependencies",
+ "--fail-on=upgradable",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ requestObj := map[string]interface{}{
+ "params": map[string]interface{}{
+ "arguments": tc.args,
+ },
+ }
+ requestJSON, err := json.Marshal(requestObj)
+ assert.NoError(t, err, "Failed to marshal request to JSON")
+
+ // Parse the JSON string to CallToolRequest
+ var request mcp.CallToolRequest
+ err = json.Unmarshal(requestJSON, &request)
+ assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest")
+
+ result, err := handler(context.Background(), request)
+
+ assert.NoError(t, err)
+ assert.NotNil(t, result)
+
+ textContent, ok := result.Content[0].(mcp.TextContent)
+ assert.True(t, ok)
+ content := strings.TrimSpace(textContent.Text)
+ assert.Contains(t, content, "ok")
+ assert.Contains(t, content, "vulnerabilities")
+ assert.Contains(t, content, "dependencyCount")
+ assert.Contains(t, content, "packageManager")
+ })
+ }
+}
+
+func TestSnykCodeTestHandler(t *testing.T) {
+ // Setup
+ fixture := setupTestFixture(t)
+
+ // Configure mock CLI
+ mockJsonResponse := `{"ok":false,"issues":[],"filesAnalyzed":10}`
+ fixture.mockCliOutput(mockJsonResponse)
+
+ // Get the tool definition
+ toolDef := getToolWithName(t, fixture.tools, SnykCodeTest)
+
+ // Create the handler
+ handler := fixture.binding.defaultHandler(fixture.invocationContext, *toolDef)
+ tmpDir := t.TempDir()
+ // Test cases with various combinations of arguments
+ testCases := []struct {
+ name string
+ args map[string]interface{}
+ }{
+ {
+ name: "Basic Test",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ },
+ },
+ {
+ name: "Test with Custom File",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "file": "specific_file.js",
+ },
+ },
+ {
+ name: "Test with Severity Threshold",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "severity_threshold": "high",
+ },
+ },
+ {
+ name: "Test with Organization",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "org": "my-snyk-org",
+ },
+ },
+ {
+ name: "Test with All Options",
+ args: map[string]interface{}{
+ "path": tmpDir,
+ "file": "specific_file.js",
+ "severity_threshold": "high",
+ "org": "my-snyk-org",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ requestObj := map[string]interface{}{
+ "params": map[string]interface{}{
+ "arguments": tc.args,
+ },
+ }
+ requestJSON, err := json.Marshal(requestObj)
+ assert.NoError(t, err, "Failed to marshal request to JSON")
+
+ var request mcp.CallToolRequest
+ err = json.Unmarshal(requestJSON, &request)
+ assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest")
+
+ result, err := handler(context.Background(), request)
+
+ assert.NoError(t, err)
+ assert.NotNil(t, result)
+ textContent, ok := result.Content[0].(mcp.TextContent)
+ assert.True(t, ok)
+ content := strings.TrimSpace(textContent.Text)
+ assert.Contains(t, content, "ok")
+ assert.Contains(t, content, "issues")
+ assert.Contains(t, content, "filesAnalyzed")
+ })
+ }
+}
+
+func TestBasicSnykCommands(t *testing.T) {
+ // Setup
+ fixture := setupTestFixture(t)
+
+ testCases := []struct {
+ name string
+ handlerFunc func(invocationCtx workflow.InvocationContext, toolDefinition SnykMcpToolsDefinition) func(ctx context.Context, arguments mcp.CallToolRequest) (*mcp.CallToolResult, error)
+ mockResponse string
+ expectedCmd string
+ command []string
+ }{
+ {
+ name: "Version Command",
+ handlerFunc: fixture.binding.defaultHandler,
+ command: []string{"--version"},
+ mockResponse: `{"client":{"version":"1.1192.0"}}`,
+ expectedCmd: "version",
+ },
+ {
+ name: "Auth Status Command",
+ handlerFunc: fixture.binding.defaultHandler,
+ command: []string{"whoami", "--experimental"},
+ mockResponse: `{"authenticated":true,"username":"user@example.com"}`,
+ expectedCmd: "auth",
+ },
+ {
+ name: "Logout Command",
+ handlerFunc: fixture.binding.snykLogoutHandler,
+ command: []string{"--version"},
+ mockResponse: `Successfully logged out`,
+ expectedCmd: "logout",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Configure mock CLI
+ fixture.mockCliOutput(tc.mockResponse)
+
+ // Create the handler
+ handler := tc.handlerFunc(fixture.invocationContext, SnykMcpToolsDefinition{Command: tc.command})
+
+ // Create an empty request object as JSON string
+ requestObj := map[string]interface{}{
+ "params": map[string]interface{}{
+ "arguments": map[string]interface{}{},
+ },
+ }
+ requestJSON, err := json.Marshal(requestObj)
+ assert.NoError(t, err, "Failed to marshal request to JSON")
+
+ // Parse the JSON string to CallToolRequest
+ var request mcp.CallToolRequest
+ err = json.Unmarshal(requestJSON, &request)
+ assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest")
+
+ // Call the handler
+ result, err := handler(context.Background(), request)
+
+ // Assertions
+ assert.NoError(t, err)
+ assert.NotNil(t, result)
+ textContent, ok := result.Content[0].(mcp.TextContent)
+ assert.True(t, ok)
+ assert.Equal(t, tc.mockResponse, strings.TrimSpace(textContent.Text))
+ })
+ }
+}
+
+func TestAuthHandler(t *testing.T) {
+ // Setup
+ fixture := setupTestFixture(t)
+
+ // Configure mock CLI
+ mockAuthResponse := "Authenticated Successfully"
+ fixture.mockCliOutput(mockAuthResponse)
+
+ // Create the handler
+ handler := fixture.binding.defaultHandler(fixture.invocationContext, SnykMcpToolsDefinition{Command: []string{"auth"}})
+
+ requestObj := map[string]interface{}{
+ "params": map[string]interface{}{
+ "arguments": map[string]interface{}{},
+ },
+ }
+ requestJSON, err := json.Marshal(requestObj)
+ assert.NoError(t, err, "Failed to marshal request to JSON")
+
+ var request mcp.CallToolRequest
+ err = json.Unmarshal(requestJSON, &request)
+ assert.NoError(t, err, "Failed to unmarshal JSON to CallToolRequest")
+
+ result, err := handler(context.Background(), request)
+
+ // Assertions
+ assert.NoError(t, err)
+ assert.NotNil(t, result)
+ textContent, ok := result.Content[0].(mcp.TextContent)
+ assert.True(t, ok)
+ assert.Equal(t, mockAuthResponse, strings.TrimSpace(textContent.Text))
+}
+
+func TestGetSnykToolsConfig(t *testing.T) {
+ config, err := loadMcpToolsFromJson()
+
+ assert.NoError(t, err)
+ assert.NotNil(t, config)
+ assert.NotEmpty(t, config.Tools)
+
+ toolNames := map[string]bool{
+ SnykScaTest: false,
+ SnykCodeTest: false,
+ SnykVersion: false,
+ SnykAuth: false,
+ SnykAuthStatus: false,
+ SnykLogout: false,
+ }
+
+ for _, tool := range config.Tools {
+ toolNames[tool.Name] = true
+ }
+
+ for name, found := range toolNames {
+ assert.True(t, found, "Tool %s not found in configuration", name)
+ }
+}
+
+func TestCreateToolFromDefinition(t *testing.T) {
+ testCases := []struct {
+ name string
+ toolDefinition SnykMcpToolsDefinition
+ expectedName string
+ }{
+ {
+ name: "Simple Tool",
+ toolDefinition: SnykMcpToolsDefinition{
+ Name: "test_tool",
+ Description: "Test tool description",
+ Command: []string{"test"},
+ Params: []SnykMcpToolParameter{},
+ },
+ expectedName: "test_tool",
+ },
+ {
+ name: "Tool with String Params",
+ toolDefinition: SnykMcpToolsDefinition{
+ Name: "string_param_tool",
+ Description: "Tool with string params",
+ Command: []string{"test"},
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "param1",
+ Type: "string",
+ IsRequired: true,
+ Description: "Required string param",
+ },
+ {
+ Name: "param2",
+ Type: "string",
+ IsRequired: false,
+ Description: "Optional string param",
+ },
+ },
+ },
+ expectedName: "string_param_tool",
+ },
+ {
+ name: "Tool with Boolean Params",
+ toolDefinition: SnykMcpToolsDefinition{
+ Name: "bool_param_tool",
+ Description: "Tool with boolean params",
+ Command: []string{"test"},
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "flag1",
+ Type: "boolean",
+ IsRequired: true,
+ Description: "Required boolean param",
+ },
+ {
+ Name: "flag2",
+ Type: "boolean",
+ IsRequired: false,
+ Description: "Optional boolean param",
+ },
+ },
+ },
+ expectedName: "bool_param_tool",
+ },
+ {
+ name: "Tool with Mixed Params",
+ toolDefinition: SnykMcpToolsDefinition{
+ Name: "mixed_param_tool",
+ Description: "Tool with mixed params",
+ Command: []string{"test"},
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "str_param",
+ Type: "string",
+ IsRequired: true,
+ Description: "Required string param",
+ },
+ {
+ Name: "bool_flag",
+ Type: "boolean",
+ IsRequired: false,
+ Description: "Optional boolean param",
+ },
+ },
+ },
+ expectedName: "mixed_param_tool",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ tool := createToolFromDefinition(&tc.toolDefinition)
+
+ assert.NotNil(t, tool)
+ assert.Equal(t, tc.expectedName, tool.Name)
+ })
+ }
+}
+
+func TestExtractParamsFromRequest(t *testing.T) {
+ testCases := []struct {
+ name string
+ toolDef SnykMcpToolsDefinition
+ arguments map[string]interface{}
+ expectedParamCount int
+ expectedWorkingDir string
+ expectedParams map[string]interface{}
+ }{
+ {
+ name: "Empty Request",
+ toolDef: SnykMcpToolsDefinition{
+ Name: "test_tool",
+ Params: []SnykMcpToolParameter{},
+ },
+ arguments: map[string]interface{}{},
+ expectedParamCount: 0,
+ expectedWorkingDir: "",
+ expectedParams: map[string]interface{}{},
+ },
+ {
+ name: "String Parameters",
+ toolDef: SnykMcpToolsDefinition{
+ Name: "string_tool",
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "org",
+ Type: "string",
+ },
+ {
+ Name: "path",
+ Type: "string",
+ },
+ },
+ },
+ arguments: map[string]interface{}{
+ "org": "my-org",
+ "path": "/test/path",
+ },
+ expectedParamCount: 2,
+ expectedWorkingDir: "/test/path",
+ expectedParams: map[string]interface{}{
+ "org": "my-org",
+ "path": "/test/path",
+ },
+ },
+ {
+ name: "Boolean Parameters",
+ toolDef: SnykMcpToolsDefinition{
+ Name: "bool_tool",
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "json",
+ Type: "boolean",
+ },
+ {
+ Name: "all_projects",
+ Type: "boolean",
+ },
+ },
+ },
+ arguments: map[string]interface{}{
+ "json": true,
+ "all_projects": true,
+ },
+ expectedParamCount: 2,
+ expectedWorkingDir: "",
+ expectedParams: map[string]interface{}{
+ "json": true,
+ "all-projects": true,
+ },
+ },
+ {
+ name: "Mixed Parameters",
+ toolDef: SnykMcpToolsDefinition{
+ Name: "mixed_tool",
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "path",
+ Type: "string",
+ },
+ {
+ Name: "json",
+ Type: "boolean",
+ },
+ {
+ Name: "severity_threshold",
+ Type: "string",
+ },
+ },
+ },
+ arguments: map[string]interface{}{
+ "path": "/test/path",
+ "json": true,
+ "severity_threshold": "high",
+ },
+ expectedParamCount: 3,
+ expectedWorkingDir: "/test/path",
+ expectedParams: map[string]interface{}{
+ "path": "/test/path",
+ "json": true,
+ "severity-threshold": "high",
+ },
+ },
+ {
+ name: "Empty String Parameters",
+ toolDef: SnykMcpToolsDefinition{
+ Name: "empty_string_tool",
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "org",
+ Type: "string",
+ },
+ },
+ },
+ arguments: map[string]interface{}{
+ "org": "",
+ },
+ expectedParamCount: 0,
+ expectedWorkingDir: "",
+ expectedParams: map[string]interface{}{},
+ },
+ {
+ name: "False Boolean Parameters",
+ toolDef: SnykMcpToolsDefinition{
+ Name: "false_bool_tool",
+ Params: []SnykMcpToolParameter{
+ {
+ Name: "json",
+ Type: "boolean",
+ },
+ },
+ },
+ arguments: map[string]interface{}{
+ "json": false,
+ },
+ expectedParamCount: 0,
+ expectedWorkingDir: "",
+ expectedParams: map[string]interface{}{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ params, workingDir := extractParamsFromRequestArgs(tc.toolDef, tc.arguments)
+
+ assert.Equal(t, tc.expectedWorkingDir, workingDir)
+
+ assert.Equal(t, len(tc.expectedParams), len(params))
+
+ for key, expectedValue := range tc.expectedParams {
+ switch key {
+ case "path":
+ continue
+ default:
+ expectedKey := strings.ReplaceAll(key, "_", "-")
+ actualValue, ok := params[expectedKey]
+ assert.True(t, ok, "Parameter %s not found", expectedKey)
+ assert.Equal(t, expectedValue, actualValue)
+ }
+ }
+ })
+ }
+}
+
+func TestBuildArgs(t *testing.T) {
+ testCases := []struct {
+ name string
+ cliPath string
+ command []string
+ params map[string]interface{}
+ expected []string
+ }{
+ {
+ name: "No Parameters",
+ cliPath: "snyk",
+ command: []string{"test"},
+ params: map[string]interface{}{},
+ expected: []string{"snyk", "test"},
+ },
+ {
+ name: "String Parameters",
+ cliPath: "snyk",
+ command: []string{"test"},
+ params: map[string]interface{}{
+ "org": "my-org",
+ "file": "package.json",
+ },
+ expected: []string{"snyk", "test", "--org=my-org", "--file=package.json"},
+ },
+ {
+ name: "Boolean Parameters",
+ cliPath: "snyk",
+ command: []string{"test"},
+ params: map[string]interface{}{
+ "json": true,
+ "all-projects": true,
+ },
+ expected: []string{"snyk", "test", "--json", "--all-projects"},
+ },
+ {
+ name: "Mixed Parameters",
+ cliPath: "snyk",
+ command: []string{"test"},
+ params: map[string]interface{}{
+ "org": "my-org",
+ "json": true,
+ "all-projects": true,
+ },
+ expected: []string{"snyk", "test", "--org=my-org", "--all-projects", "--json"},
+ },
+ {
+ name: "Empty String Parameters",
+ cliPath: "snyk",
+ command: []string{"test"},
+ params: map[string]interface{}{
+ "org": "",
+ },
+ expected: []string{"snyk", "test"},
+ },
+ {
+ name: "False Boolean Parameters",
+ cliPath: "snyk",
+ command: []string{"test"},
+ params: map[string]interface{}{
+ "json": false,
+ },
+ expected: []string{"snyk", "test"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ args := buildArgs(tc.cliPath, tc.command, tc.params)
+ for _, arg := range args {
+ assert.Contains(t, tc.expected, arg)
+ }
+ })
+ }
+}
+
+func TestRunSnyk(t *testing.T) {
+ fixture := setupTestFixture(t)
+
+ ctx := context.Background()
+
+ testCases := []struct {
+ name string
+ mockOutput string
+ command []string
+ workingDir string
+ expectError bool
+ }{
+ {
+ name: "Successful Command",
+ mockOutput: "Command executed successfully",
+ command: []string{fixture.snykCliPath, "test"},
+ workingDir: "",
+ expectError: false,
+ },
+ {
+ name: "Command with Working Directory",
+ mockOutput: "Command executed successfully",
+ command: []string{fixture.snykCliPath, "test"},
+ workingDir: t.TempDir(),
+ expectError: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ fixture.mockCliOutput(tc.mockOutput)
+
+ output, err := fixture.binding.runSnyk(ctx, fixture.invocationContext, tc.workingDir, tc.command)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.mockOutput, strings.TrimSpace(output))
+ }
+ })
+ }
+}
+
+func createMockSnykCli(t *testing.T, path, output string) {
+ t.Helper()
+
+ var script string
+
+ if runtime.GOOS == "windows" {
+ script = fmt.Sprintf(`@echo off
+echo %s
+exit /b 0
+`, output)
+ } else {
+ script = fmt.Sprintf(`#!/bin/sh
+echo '%s'
+exit 0
+`, output)
+ }
+
+ err := os.WriteFile(path, []byte(script), 0755)
+ require.NoError(t, err)
+}
diff --git a/internal/mcp/snyk_tools.json b/internal/mcp/snyk_tools.json
new file mode 100644
index 000000000..c1b2c0f5d
--- /dev/null
+++ b/internal/mcp/snyk_tools.json
@@ -0,0 +1,198 @@
+{
+ "tools": [
+ {
+ "name": "snyk_sca_test",
+ "description": "Run a SCA test on project dependencies to detect known vulnerabilities. Use this to scan open-source packages in supported ecosystems like npm, Maven, etc. Supports monorepo scanning via `--all-projects`. Outputs vulnerability data in JSON if enabled.",
+ "command": ["test"],
+ "standardParams": ["all_projects"],
+ "params": [
+ {
+ "name": "path",
+ "type": "string",
+ "isRequired": true,
+ "description": "Path to the project to test (default is the absolute path of the current directory, formatted according to the operating system's conventions)."
+ },
+ {
+ "name": "all_projects",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Scan all projects in the specified directory. (Default is true)."
+ },
+ {
+ "name": "severity_threshold",
+ "type": "string",
+ "isRequired": false,
+ "description": "Only report vulnerabilities of the specified level or higher (low, medium, high, critical). (Default is empty)"
+ },
+ {
+ "name": "org",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify the organization under which to run the test. (Default is empty)."
+ },
+ {
+ "name": "dev",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Include development dependencies. (Default is false)"
+ },
+ {
+ "name": "skip_unresolved",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Skip testing of unresolved packages. (Default is false)"
+ },
+ {
+ "name": "prune_repeated_subdependencies",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Prune repeated sub-dependencies. (Default is false)."
+ },
+ {
+ "name": "fail_on",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify the failure criteria (all, upgradable, patchable). (Default is all)."
+ },
+ {
+ "name": "file",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify a package file to test. (Default is empty)"
+ },
+ {
+ "name": "fail_fast",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Use with --all-projects to interrupt scans when errors occur. (Default is false)"
+ },
+ {
+ "name": "detection_depth",
+ "type": "string",
+ "isRequired": false,
+ "description": "Use with --all-projects to indicate how many subdirectories to search. (Default is empty)"
+ },
+ {
+ "name": "exclude",
+ "type": "string",
+ "isRequired": false,
+ "description": "Use with --all-projects to exclude directory names and file names. (Default is empty)"
+ },
+ {
+ "name": "remote_repo_url",
+ "type": "string",
+ "isRequired": false,
+ "description": "Set or override the remote URL for the repository to monitor. (Default is empty)"
+ },
+ {
+ "name": "package_manager",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify the name of the package manager when the filename is not standard. (Default is empty)"
+ },
+ {
+ "name": "unmanaged",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "For C++ only, scan all files for known open source dependencies. (Default is false)"
+ },
+ {
+ "name": "ignore_policy",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Ignore all set policies, the current policy in the .snyk file, Org level ignores, and the project policy. (Default is false)"
+ },
+ {
+ "name": "trust_policies",
+ "type": "boolean",
+ "isRequired": false,
+ "description": "Apply and use ignore rules from the Snyk policies in your dependencies. (Default is false)"
+ },
+ {
+ "name": "show_vulnerable_paths",
+ "type": "string",
+ "isRequired": false,
+ "description": "Display the dependency paths (none|some|all). (Default: none)."
+ },
+ {
+ "name": "project_name",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify a custom Snyk project name. (Default is empty)"
+ },
+ {
+ "name": "target_reference",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify a reference that differentiates this project, for example, a branch name. (Default is empty)"
+ },
+ {
+ "name": "policy_path",
+ "type": "string",
+ "isRequired": false,
+ "description": "Manually pass a path to a .snyk policy file. (Default is empty)"
+ }
+ ]
+ },
+ {
+ "name": "snyk_code_test",
+ "description": "Run a static application security test (SAST) on your source code to detect security issues like SQL injection, XSS, and hardcoded secrets. Designed to catch issues early in the development cycle.",
+ "command": ["code", "test"],
+ "standardParams": [],
+ "params": [
+ {
+ "name": "path",
+ "type": "string",
+ "isRequired": true,
+ "description": "Path to the project to test (default is the absolute path of the current directory, formatted according to the operating system's conventions)."
+ },
+ {
+ "name": "file",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specific file to scan (default: empty)."
+ },
+ {
+ "name": "severity_threshold",
+ "type": "string",
+ "isRequired": false,
+ "description": "Only report vulnerabilities of the specified level or higher (low, medium, high). (default: empty)"
+ },
+ {
+ "name": "org",
+ "type": "string",
+ "isRequired": false,
+ "description": "Specify the organization under which to run the test. (default: empty)"
+ }
+ ]
+ },
+ {
+ "name": "snyk_version",
+ "description": "Get Snyk CLI version",
+ "command": ["--version"],
+ "standardParams": [],
+ "params": []
+ },
+ {
+ "name": "snyk_auth",
+ "description": "Authenticate with Snyk",
+ "command": ["auth"],
+ "standardParams": [],
+ "params": []
+ },
+ {
+ "name": "snyk_auth_status",
+ "description": "Check Snyk authentication status",
+ "command": ["whoami", "--experimental"],
+ "standardParams": [],
+ "params": []
+ },
+ {
+ "name": "snyk_logout",
+ "description": "Log out from Snyk",
+ "command": ["logout"],
+ "standardParams": [],
+ "params": []
+ }
+ ]
+}
diff --git a/internal/mcp/utils.go b/internal/mcp/utils.go
new file mode 100644
index 000000000..c1447a3d8
--- /dev/null
+++ b/internal/mcp/utils.go
@@ -0,0 +1,108 @@
+/*
+ * © 2025 Snyk Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package mcp
+
+import (
+ "strings"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// buildArgs builds command-line arguments for Snyk CLI based on parameters
+func buildArgs(cliPath string, command []string, params map[string]interface{}) []string {
+ args := []string{cliPath}
+ args = append(args, command...)
+
+ // Add params as command-line flags
+ for key, value := range params {
+ switch v := value.(type) {
+ case bool:
+ if v {
+ args = append(args, "--"+key)
+ }
+ case string:
+ if v != "" {
+ args = append(args, "--"+key+"="+v)
+ }
+ }
+ }
+
+ return args
+}
+
+// createToolFromDefinition creates an MCP tool from a Snyk tool definition
+func createToolFromDefinition(toolDef *SnykMcpToolsDefinition) mcp.Tool {
+ opts := []mcp.ToolOption{mcp.WithDescription(toolDef.Description)}
+ for _, param := range toolDef.Params {
+ if param.Type == "string" {
+ if param.IsRequired {
+ opts = append(opts, mcp.WithString(param.Name, mcp.Required(), mcp.Description(param.Description)))
+ } else {
+ opts = append(opts, mcp.WithString(param.Name, mcp.Description(param.Description)))
+ }
+ } else if param.Type == "boolean" {
+ if param.IsRequired {
+ opts = append(opts, mcp.WithBoolean(param.Name, mcp.Required(), mcp.Description(param.Description)))
+ } else {
+ opts = append(opts, mcp.WithBoolean(param.Name, mcp.Description(param.Description)))
+ }
+ }
+ }
+
+ return mcp.NewTool(toolDef.Name, opts...)
+}
+
+// extractParamsFromRequestArgs extracts parameters from the arguments based on the tool definition
+func extractParamsFromRequestArgs(toolDef SnykMcpToolsDefinition, arguments map[string]interface{}) (map[string]interface{}, string) {
+ params := make(map[string]interface{})
+ var workingDir string
+
+ for _, paramDef := range toolDef.Params {
+ val, ok := arguments[paramDef.Name]
+ if !ok {
+ continue
+ }
+
+ // Store path separately to use as working directory
+ if paramDef.Name == "path" {
+ if pathStr, ok := val.(string); ok {
+ workingDir = pathStr
+ }
+ }
+
+ // Convert parameter name from snake_case to kebab-case for CLI arguments
+ cliParamName := strings.ReplaceAll(paramDef.Name, "_", "-")
+
+ // Cast the value based on parameter type
+ if paramDef.Type == "string" {
+ if strVal, ok := val.(string); ok && strVal != "" {
+ params[cliParamName] = strVal
+ }
+ } else if paramDef.Type == "boolean" {
+ if boolVal, ok := val.(bool); ok && boolVal {
+ params[cliParamName] = true
+ }
+ }
+ }
+
+ return params, workingDir
+}
+
+// convertToCliParam Convert parameter name from snake_case to kebab-case for CLI arguments
+func convertToCliParam(cliParam string) string {
+ return strings.ReplaceAll(cliParam, "_", "-")
+}
diff --git a/licenses/github.com/yosida95/uritemplate/v3/LICENSE b/licenses/github.com/yosida95/uritemplate/v3/LICENSE
new file mode 100644
index 000000000..79e8f8757
--- /dev/null
+++ b/licenses/github.com/yosida95/uritemplate/v3/LICENSE
@@ -0,0 +1,25 @@
+Copyright (C) 2016, Kohei YOSHIDA . All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/mcp_extension/main.go b/mcp_extension/main.go
new file mode 100644
index 000000000..dea81d27e
--- /dev/null
+++ b/mcp_extension/main.go
@@ -0,0 +1,117 @@
+/*
+ * © 2023 Snyk Limited All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package mcp_extension
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/snyk/go-application-framework/pkg/configuration"
+
+ "github.com/snyk/snyk-ls/application/entrypoint"
+ "github.com/snyk/snyk-ls/internal/mcp"
+
+ "github.com/spf13/pflag"
+
+ "github.com/snyk/go-application-framework/pkg/workflow"
+)
+
+var WORKFLOWID_MCP = workflow.NewWorkflowIdentifier("mcp")
+
+func Init(engine workflow.Engine) error {
+ flags := pflag.NewFlagSet("mcp", pflag.ContinueOnError)
+
+ flags.StringP("transport", "t", "sse", "sets transport to ")
+ flags.Bool(configuration.FLAG_EXPERIMENTAL, false, "enable experimental mcp command")
+
+ cfg := workflow.ConfigurationOptionsFromFlagset(flags)
+ entry, _ := engine.Register(WORKFLOWID_MCP, cfg, mcpWorkflow)
+ entry.SetVisibility(false)
+
+ return nil
+}
+
+func mcpWorkflow(
+ invocation workflow.InvocationContext,
+ _ []workflow.Data,
+) (output []workflow.Data, err error) {
+ defer entrypoint.OnPanicRecover()
+
+ config := invocation.GetConfiguration()
+ config.Set(configuration.INTEGRATION_NAME, "MCP")
+
+ runtimeInfo := invocation.GetRuntimeInfo()
+ if runtimeInfo != nil {
+ config.Set(configuration.INTEGRATION_VERSION, runtimeInfo.GetVersion())
+ } else {
+ config.Set(configuration.INTEGRATION_VERSION, "unknown")
+ }
+
+ // only run if experimental flag is set
+ if !config.GetBool(configuration.FLAG_EXPERIMENTAL) {
+ return nil, fmt.Errorf("set `--experimental` flag to enable mcp command")
+ }
+
+ output = []workflow.Data{}
+ logger := invocation.GetEnhancedLogger()
+
+ cliPath, err := getCliPath(invocation)
+ if err != nil {
+ logger.Err(err).Msg("Failed to set cli path")
+ return output, err
+ }
+ logger.Trace().Interface("environment", os.Environ()).Msg("start environment")
+ mcpStart(invocation, cliPath)
+
+ return output, nil
+}
+
+func mcpStart(invocationContext workflow.InvocationContext, cliPath string) {
+ mcpServer := mcp.NewMcpLLMBinding(mcp.WithLogger(invocationContext.GetEnhancedLogger()), mcp.WithCliPath(cliPath))
+ logger := invocationContext.GetEnhancedLogger()
+
+ // start mcp server
+ err := mcpServer.Start(invocationContext)
+
+ if err != nil {
+ logger.Err(err).Msg("failed to start mcp server")
+ }
+ defer func() {
+ logger.Info().Msg("Shutting down MCP Server...")
+ mcpServer.Shutdown(context.Background())
+ }()
+}
+
+func getCliPath(ctx workflow.InvocationContext) (string, error) {
+ logger := ctx.GetEnhancedLogger()
+ exePath, err := os.Executable()
+ if err != nil {
+ logger.Err(err).Msg("Failed to get executable path")
+ return "", err
+ }
+ resolvedPath, err := filepath.EvalSymlinks(exePath)
+
+ if err != nil {
+ logger.Err(err).Msg("Failed to eval symlink from path")
+ return "", err
+ } else {
+ // Set Cli path to current process path
+ return resolvedPath, nil
+ }
+}
diff --git a/mcp_extension/main_test.go b/mcp_extension/main_test.go
new file mode 100644
index 000000000..773f2bb54
--- /dev/null
+++ b/mcp_extension/main_test.go
@@ -0,0 +1,39 @@
+package mcp_extension
+
+import (
+ "testing"
+ "time"
+
+ "github.com/snyk/go-application-framework/pkg/configuration"
+ "github.com/stretchr/testify/assert"
+
+ "github.com/snyk/go-application-framework/pkg/app"
+)
+
+func Test_ExtensionEntryPoint(t *testing.T) {
+ expectedTransportType := "stdio"
+ engine := app.CreateAppEngineWithOptions()
+
+ engineConfig := configuration.NewWithOpts(
+ configuration.WithAutomaticEnv(),
+ )
+ engineConfig.Set("transport", expectedTransportType)
+ engineConfig.Set(configuration.FLAG_EXPERIMENTAL, true)
+
+ //register extension under test
+ err := Init(engine)
+ assert.Nil(t, err)
+
+ go func() {
+ err = engine.Init()
+ assert.Nil(t, err)
+
+ data, err := engine.InvokeWithConfig(WORKFLOWID_MCP, engineConfig)
+ assert.Nil(t, err)
+ assert.Empty(t, data)
+ }()
+
+ assert.Eventuallyf(t, func() bool {
+ return expectedTransportType == engineConfig.GetString("transport") && engineConfig.GetBool(configuration.FLAG_EXPERIMENTAL)
+ }, time.Minute, time.Millisecond, "open browser was not called")
+}