Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 60 additions & 36 deletions cmd/docker-mcp/internal/gateway/capabilitites.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,36 @@ import (
"strings"
"sync"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
"golang.org/x/sync/errgroup"
)

type Capabilities struct {
Tools []server.ServerTool
Prompts []server.ServerPrompt
Resources []server.ServerResource
ResourceTemplates []server.ServerResourceTemplate
Tools []ToolRegistration
Prompts []PromptRegistration
Resources []ResourceRegistration
ResourceTemplates []ResourceTemplateRegistration
}

type ToolRegistration struct {
Tool *mcp.Tool
Handler mcp.ToolHandler
}

type PromptRegistration struct {
Prompt *mcp.Prompt
Handler mcp.PromptHandler
}

type ResourceRegistration struct {
Resource *mcp.Resource
Handler mcp.ResourceHandler
}

type ResourceTemplateRegistration struct {
ResourceTemplate mcp.ResourceTemplate
Handler mcp.ResourceHandler
}

func (g *Gateway) listCapabilities(ctx context.Context, configuration Configuration, serverNames []string) (*Capabilities, error) {
Expand All @@ -38,7 +58,7 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
// It's an MCP Server
case serverConfig != nil:
errs.Go(func() error {
client, err := g.clientPool.AcquireClient(ctx, *serverConfig, nil)
client, err := g.clientPool.AcquireClient(context.Background(), *serverConfig, nil)
if err != nil {
logf(" > Can't start %s: %s", serverConfig.Name, err)
return nil
Expand All @@ -47,47 +67,47 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat

var capabilities Capabilities

tools, err := client.ListTools(ctx, mcp.ListToolsRequest{})
tools, err := client.Session().ListTools(ctx, &mcp.ListToolsParams{})
if err != nil {
logf(" > Can't list tools %s: %s", serverConfig.Name, err)
} else {
for _, tool := range tools.Tools {
if !isToolEnabled(configuration, serverConfig.Name, serverConfig.Spec.Image, tool.Name, g.ToolNames) {
continue
}
capabilities.Tools = append(capabilities.Tools, server.ServerTool{
capabilities.Tools = append(capabilities.Tools, ToolRegistration{
Tool: tool,
Handler: g.mcpServerToolHandler(*serverConfig, tool.Annotations),
Handler: g.mcpServerToolHandler(*serverConfig, g.mcpServer, tool.Annotations),
})
}
}

prompts, err := client.ListPrompts(ctx, mcp.ListPromptsRequest{})
prompts, err := client.Session().ListPrompts(ctx, &mcp.ListPromptsParams{})
if err == nil {
for _, prompt := range prompts.Prompts {
capabilities.Prompts = append(capabilities.Prompts, server.ServerPrompt{
capabilities.Prompts = append(capabilities.Prompts, PromptRegistration{
Prompt: prompt,
Handler: g.mcpServerPromptHandler(*serverConfig),
Handler: g.mcpServerPromptHandler(*serverConfig, g.mcpServer),
})
}
}

resources, err := client.ListResources(ctx, mcp.ListResourcesRequest{})
resources, err := client.Session().ListResources(ctx, &mcp.ListResourcesParams{})
if err == nil {
for _, resource := range resources.Resources {
capabilities.Resources = append(capabilities.Resources, server.ServerResource{
capabilities.Resources = append(capabilities.Resources, ResourceRegistration{
Resource: resource,
Handler: g.mcpServerResourceHandler(*serverConfig),
Handler: g.mcpServerResourceHandler(*serverConfig, g.mcpServer),
})
}
}

resourceTemplates, err := client.ListResourceTemplates(ctx, mcp.ListResourceTemplatesRequest{})
resourceTemplates, err := client.Session().ListResourceTemplates(ctx, &mcp.ListResourceTemplatesParams{})
if err == nil {
for _, resourceTemplate := range resourceTemplates.ResourceTemplates {
capabilities.ResourceTemplates = append(capabilities.ResourceTemplates, server.ServerResourceTemplate{
Template: resourceTemplate,
Handler: g.mcpServerResourceTemplateHandler(*serverConfig),
capabilities.ResourceTemplates = append(capabilities.ResourceTemplates, ResourceTemplateRegistration{
ResourceTemplate: *resourceTemplate,
Handler: g.mcpServerResourceHandler(*serverConfig, g.mcpServer),
})
}
}
Expand Down Expand Up @@ -128,17 +148,21 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
mcpTool := mcp.Tool{
Name: tool.Name,
Description: tool.Description,
InputSchema: &jsonschema.Schema{},
}
// TODO: Properly convert tool.Parameters to jsonschema.Schema
// For now, we'll create a simple schema structure
if len(tool.Parameters.Properties) == 0 {
mcpTool.InputSchema.Type = "object"
} else {
mcpTool.InputSchema.Type = tool.Parameters.Type
mcpTool.InputSchema.Properties = tool.Parameters.Properties.ToMap()
mcpTool.InputSchema.Required = tool.Parameters.Required
// Note: tool.Parameters.Properties.ToMap() returns map[string]any
// but we need map[string]*jsonschema.Schema
// This is a complex conversion that needs proper implementation
}

capabilities.Tools = append(capabilities.Tools, server.ServerTool{
Tool: mcpTool,
capabilities.Tools = append(capabilities.Tools, ToolRegistration{
Tool: &mcpTool,
Handler: g.mcpToolHandler(tool),
})
}
Expand All @@ -154,22 +178,22 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
}

// Merge all capabilities
var serverTools []server.ServerTool
var serverPrompts []server.ServerPrompt
var serverResources []server.ServerResource
var serverResourceTemplates []server.ServerResourceTemplate
var allTools []ToolRegistration
var allPrompts []PromptRegistration
var allResources []ResourceRegistration
var allResourceTemplates []ResourceTemplateRegistration
for _, capabilities := range allCapabilities {
serverTools = append(serverTools, capabilities.Tools...)
serverPrompts = append(serverPrompts, capabilities.Prompts...)
serverResources = append(serverResources, capabilities.Resources...)
serverResourceTemplates = append(serverResourceTemplates, capabilities.ResourceTemplates...)
allTools = append(allTools, capabilities.Tools...)
allPrompts = append(allPrompts, capabilities.Prompts...)
allResources = append(allResources, capabilities.Resources...)
allResourceTemplates = append(allResourceTemplates, capabilities.ResourceTemplates...)
}

return &Capabilities{
Tools: serverTools,
Prompts: serverPrompts,
Resources: serverResources,
ResourceTemplates: serverResourceTemplates,
Tools: allTools,
Prompts: allPrompts,
Resources: allResources,
ResourceTemplates: allResourceTemplates,
}, nil
}

Expand Down
79 changes: 56 additions & 23 deletions cmd/docker-mcp/internal/gateway/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
"os/exec"
"strings"
"sync"
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/docker/mcp-gateway/cmd/docker-mcp/internal/catalog"
"github.com/docker/mcp-gateway/cmd/docker-mcp/internal/docker"
Expand All @@ -32,6 +31,12 @@ type clientPool struct {
docker docker.Client
}

type clientConfig struct {
readOnly *bool
serverSession *mcp.ServerSession
server *mcp.Server
}

func newClientPool(options Options, docker docker.Client) *clientPool {
return &clientPool{
Options: options,
Expand All @@ -40,7 +45,7 @@ func newClientPool(options Options, docker docker.Client) *clientPool {
}
}

func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.ServerConfig, readOnly *bool) (mcpclient.Client, error) {
func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.ServerConfig, config *clientConfig) (mcpclient.Client, error) {
var getter *clientGetter

// Check if client is kept, can be returned immediately
Expand All @@ -55,7 +60,7 @@ func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.Se

// No client found, create a new one
if getter == nil {
getter = newClientGetter(serverConfig, cp, readOnly)
getter = newClientGetter(serverConfig, cp, config)

// If the client is long running, save it for later
if serverConfig.Spec.LongLived || cp.LongLived {
Expand Down Expand Up @@ -103,7 +108,7 @@ func (cp *clientPool) ReleaseClient(client mcpclient.Client) {

// Client was not kept, close it
if !foundKept {
client.Close()
client.Session().Close()
return
}

Expand All @@ -120,7 +125,7 @@ func (cp *clientPool) Close() {
for _, keptClient := range existingMap {
client, err := keptClient.Getter.GetClient(context.TODO()) // should be cached
if err == nil {
client.Close()
client.Session().Close()
}
}
}
Expand All @@ -129,16 +134,22 @@ func (cp *clientPool) SetNetworks(networks []string) {
cp.networks = networks
}

func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, params *mcp.CallToolParams) (*mcp.CallToolResult, error) {
args := cp.baseArgs(tool.Name)

// Attach the MCP servers to the same network as the gateway.
for _, network := range cp.networks {
args = append(args, "--network", network)
}

// Convert params.Arguments to map[string]any
arguments, ok := params.Arguments.(map[string]any)
if !ok {
arguments = make(map[string]any)
}

// Volumes
for _, mount := range eval.EvaluateList(tool.Container.Volumes, request.GetArguments()) {
for _, mount := range eval.EvaluateList(tool.Container.Volumes, arguments) {
if mount == "" {
continue
}
Expand All @@ -150,7 +161,7 @@ func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, r
args = append(args, tool.Container.Image)

// Command
command := eval.EvaluateList(tool.Container.Command, request.GetArguments())
command := eval.EvaluateList(tool.Container.Command, arguments)
args = append(args, command...)

log(" - Running container", tool.Container.Image, "with args", args)
Expand All @@ -161,10 +172,20 @@ func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, r
}
out, err := cmd.Output()
if err != nil {
return mcp.NewToolResultError(string(out)), nil
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{
Text: string(out),
}},
IsError: true,
}, nil
}

return mcp.NewToolResultText(string(out)), nil
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{
Text: string(out),
}},
IsError: false,
}, nil
}

func (cp *clientPool) baseArgs(name string) []string {
Expand Down Expand Up @@ -289,14 +310,15 @@ type clientGetter struct {

serverConfig catalog.ServerConfig
cp *clientPool
readOnly *bool

clientConfig *clientConfig
}

func newClientGetter(serverConfig catalog.ServerConfig, cp *clientPool, readOnly *bool) *clientGetter {
func newClientGetter(serverConfig catalog.ServerConfig, cp *clientPool, config *clientConfig) *clientGetter {
return &clientGetter{
serverConfig: serverConfig,
cp: cp,
readOnly: readOnly,
clientConfig: config,
}
}

Expand Down Expand Up @@ -328,7 +350,11 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
}

image := cg.serverConfig.Spec.Image
args, env := cg.cp.argsAndEnv(cg.serverConfig, cg.readOnly, targetConfig)
var readOnly *bool
if cg.clientConfig != nil {
readOnly = cg.clientConfig.readOnly
}
args, env := cg.cp.argsAndEnv(cg.serverConfig, readOnly, targetConfig)

command := expandEnvList(eval.EvaluateList(cg.serverConfig.Spec.Command, cg.serverConfig.Config), env)
if len(command) == 0 {
Expand All @@ -345,17 +371,24 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
client = mcpclient.NewStdioCmdClient(cg.serverConfig.Name, "docker", env, runArgs...)
}

initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "docker",
Version: "1.0.0",
initParams := &mcp.InitializeParams{
ProtocolVersion: "2024-11-05",
ClientInfo: &mcp.Implementation{
Name: "docker",
Version: "1.0.0",
},
}

ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
var ss *mcp.ServerSession
var server *mcp.Server
if cg.clientConfig != nil {
ss = cg.clientConfig.serverSession
server = cg.clientConfig.server
}
// ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
// defer cancel()

if _, err := client.Initialize(ctx, initRequest, cg.cp.Verbose); err != nil {
if err := client.Initialize(ctx, initParams, cg.cp.Verbose, ss, server); err != nil {
return nil, err
}

Expand Down
Loading