Skip to content

Commit 101fa28

Browse files
Fix CE mode OAuth and preserve raw tool arguments through gateway
Three fixes for running the MCP Gateway without Docker Desktop (CE mode): 1. **Forward tool arguments as raw JSON** (`handlers.go`): The gateway previously unmarshaled `CallToolParamsRaw.Arguments` into `any` before forwarding to tool handlers. This loses type fidelity for tools that rely on structured/typed inputs. Arguments are now forwarded as `json.RawMessage` unchanged, keeping the gateway schema-agnostic. 2. **Normalize argument types in clientpool** (`clientpool.go`): `runToolContainer` used a single type assertion (`map[string]any`) which silently dropped arguments arriving as `json.RawMessage` or `[]byte`. A `normalizeArguments` function now handles all expected argument representations safely via type switch. 3. **CE mode OAuth redirect URI and state validation** (`manager.go`): - When `DOCKER_MCP_USE_CE=true`, the redirect URI now defaults to the local callback (`http://localhost:5000/callback`) instead of the SaaS endpoint (`mcp.docker.com`), with override via `DOCKER_MCP_OAUTH_REDIRECT_URI`. - `ExchangeCode` now strips the `mcp-gateway:PORT:` prefix from the state parameter before validation, fixing a mismatch where `BuildAuthorizationURL` adds the prefix for proxy routing but `StateManager` only stores the base UUID. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0c22732 commit 101fa28

File tree

5 files changed

+243
-31
lines changed

5 files changed

+243
-31
lines changed

pkg/gateway/clientpool.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gateway
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"os"
78
"os/exec"
@@ -211,6 +212,36 @@ func (cp *clientPool) InvalidateOAuthClients(provider string) {
211212
}
212213
}
213214

215+
// normalizeArguments converts tool arguments from various representations
216+
// (json.RawMessage, []byte, map[string]any, nil) into a consistent
217+
// map[string]any for template evaluation. MCP transports may deliver
218+
// arguments in any of these forms depending on the caller.
219+
func normalizeArguments(args any) map[string]any {
220+
switch v := args.(type) {
221+
case map[string]any:
222+
return v
223+
case json.RawMessage:
224+
var m map[string]any
225+
if err := json.Unmarshal(v, &m); err != nil {
226+
log.Logf("Warning: failed to decode tool arguments RawMessage: %v", err)
227+
return make(map[string]any)
228+
}
229+
return m
230+
case []byte:
231+
var m map[string]any
232+
if err := json.Unmarshal(v, &m); err != nil {
233+
log.Logf("Warning: failed to decode tool arguments JSON: %v", err)
234+
return make(map[string]any)
235+
}
236+
return m
237+
case nil:
238+
return make(map[string]any)
239+
default:
240+
log.Logf("Warning: unsupported tool arguments type: %T", args)
241+
return make(map[string]any)
242+
}
243+
}
244+
214245
func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, params *mcp.CallToolParams) (*mcp.CallToolResult, error) {
215246
args := cp.baseArgs(tool.Name)
216247

@@ -219,11 +250,7 @@ func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, p
219250
args = append(args, "--network", network)
220251
}
221252

222-
// Convert params.Arguments to map[string]any
223-
arguments, ok := params.Arguments.(map[string]any)
224-
if !ok {
225-
arguments = make(map[string]any)
226-
}
253+
arguments := normalizeArguments(params.Arguments)
227254

228255
// Volumes
229256
for _, mount := range eval.EvaluateList(tool.Container.Volumes, arguments) {

pkg/gateway/clientpool_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gateway
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"os"
78
"testing"
@@ -278,6 +279,59 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any {
278279
return config
279280
}
280281

282+
func TestNormalizeArguments_MapStringAny(t *testing.T) {
283+
input := map[string]any{"key": "value", "num": float64(42)}
284+
result := normalizeArguments(input)
285+
assert.Equal(t, input, result)
286+
}
287+
288+
func TestNormalizeArguments_JSONRawMessage(t *testing.T) {
289+
raw := json.RawMessage(`{"url":"https://example.com","count":3}`)
290+
result := normalizeArguments(raw)
291+
assert.Equal(t, "https://example.com", result["url"])
292+
assert.Equal(t, float64(3), result["count"])
293+
}
294+
295+
func TestNormalizeArguments_ByteSlice(t *testing.T) {
296+
raw := []byte(`{"path":"/tmp/data","verbose":true}`)
297+
result := normalizeArguments(raw)
298+
assert.Equal(t, "/tmp/data", result["path"])
299+
assert.Equal(t, true, result["verbose"])
300+
}
301+
302+
func TestNormalizeArguments_Nil(t *testing.T) {
303+
result := normalizeArguments(nil)
304+
assert.NotNil(t, result)
305+
assert.Empty(t, result)
306+
}
307+
308+
func TestNormalizeArguments_UnexpectedType(t *testing.T) {
309+
result := normalizeArguments("unexpected string")
310+
assert.NotNil(t, result)
311+
assert.Empty(t, result)
312+
}
313+
314+
func TestNormalizeArguments_InvalidJSON(t *testing.T) {
315+
raw := json.RawMessage(`{not valid json}`)
316+
result := normalizeArguments(raw)
317+
assert.NotNil(t, result)
318+
assert.Empty(t, result)
319+
}
320+
321+
func TestNormalizeArguments_InvalidByteSlice(t *testing.T) {
322+
raw := []byte(`{not valid json}`)
323+
result := normalizeArguments(raw)
324+
assert.NotNil(t, result)
325+
assert.Empty(t, result)
326+
}
327+
328+
func TestNormalizeArguments_EmptyJSONObject(t *testing.T) {
329+
raw := json.RawMessage(`{}`)
330+
result := normalizeArguments(raw)
331+
assert.NotNil(t, result)
332+
assert.Empty(t, result)
333+
}
334+
281335
func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) {
282336
// Community server: remote URL set, but no Spec.OAuth metadata.
283337
// This verifies Gap 3: InvalidateOAuthClients matches community servers

pkg/gateway/handlers.go

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package gateway
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76
"os"
87
"time"
@@ -43,18 +42,22 @@ func inferServerTransportType(serverConfig *catalog.ServerConfig) string {
4342

4443
func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler {
4544
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
46-
// Convert CallToolParamsRaw to CallToolParams
47-
var args any
48-
if len(req.Params.Arguments) > 0 {
49-
if err := json.Unmarshal(req.Params.Arguments, &args); err != nil {
50-
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
51-
}
52-
}
45+
// Convert CallToolParamsRaw to CallToolParams.
46+
//
47+
// Arguments are forwarded as raw JSON (json.RawMessage) and intentionally
48+
// not unmarshaled here. The gateway must remain schema-agnostic and avoid
49+
// coercing tool inputs, preserving full argument fidelity for tools that
50+
// rely on structured or typed inputs.
5351
params := &mcp.CallToolParams{
54-
Meta: req.Params.Meta,
55-
Name: req.Params.Name,
56-
Arguments: args,
52+
Meta: req.Params.Meta,
53+
Name: req.Params.Name,
5754
}
55+
56+
// Forward raw arguments unchanged, if present.
57+
if len(req.Params.Arguments) > 0 {
58+
params.Arguments = req.Params.Arguments
59+
}
60+
5861
return g.clientPool.runToolContainer(ctx, tool, params)
5962
}
6063
}
@@ -132,19 +135,21 @@ func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _
132135
}
133136
defer g.clientPool.ReleaseClient(client)
134137

135-
// Convert CallToolParamsRaw to CallToolParams
136-
var args any
137-
if len(req.Params.Arguments) > 0 {
138-
if jsonErr := json.Unmarshal(req.Params.Arguments, &args); jsonErr != nil {
139-
telemetry.RecordToolError(ctx, span, serverConfig.Name, serverTransportType, req.Params.Name)
140-
span.SetStatus(codes.Error, "Failed to unmarshal arguments")
141-
return nil, fmt.Errorf("failed to unmarshal arguments: %w", jsonErr)
142-
}
143-
}
138+
// Convert CallToolParamsRaw to CallToolParams.
139+
//
140+
// NOTE: Arguments are forwarded as raw JSON (json.RawMessage) instead of being
141+
// unmarshaled here. The gateway must not interpret or coerce tool arguments,
142+
// as it does not own the tool schema. Preserving the raw payload ensures full
143+
// fidelity for schema-based and typed tools and matches the MCP Go SDK
144+
// expectations.
144145
params := &mcp.CallToolParams{
145-
Meta: req.Params.Meta,
146-
Name: originalToolName,
147-
Arguments: args,
146+
Meta: req.Params.Meta,
147+
Name: originalToolName,
148+
}
149+
150+
// Forward raw arguments unchanged, if present.
151+
if len(req.Params.Arguments) > 0 {
152+
params.Arguments = req.Params.Arguments
148153
}
149154

150155
// Execute the tool call

pkg/oauth/manager.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"fmt"
66
"net/url"
7+
"os"
8+
"strings"
79

810
"github.com/docker/docker-credential-helpers/credentials"
911
"golang.org/x/oauth2"
@@ -23,13 +25,31 @@ type Manager struct {
2325
redirectURI string
2426
}
2527

26-
// NewManager creates a new OAuth manager for CE mode
28+
// NewManager creates a new OAuth manager.
29+
// In CE mode, the redirect URI must point to the local gateway callback
30+
// instead of the SaaS endpoint (mcp.docker.com).
2731
func NewManager(credHelper credentials.Helper) *Manager {
32+
redirectURI := DefaultRedirectURI
33+
34+
// CE mode requires a local redirect URI because the OAuth callback
35+
// is handled by the local mcp-gateway process, not by Docker SaaS.
36+
//
37+
// Example:
38+
// http://localhost:5000/callback
39+
if os.Getenv("DOCKER_MCP_USE_CE") == "true" {
40+
if v := os.Getenv("DOCKER_MCP_OAUTH_REDIRECT_URI"); v != "" {
41+
redirectURI = v
42+
} else {
43+
// Default CE callback used by the local OAuth proxy
44+
redirectURI = "http://localhost:5000/callback"
45+
}
46+
}
47+
2848
return &Manager{
29-
dcrManager: dcr.NewManager(credHelper, DefaultRedirectURI),
49+
dcrManager: dcr.NewManager(credHelper, redirectURI),
3050
tokenStore: NewTokenStore(credHelper),
3151
stateManager: NewStateManager(),
32-
redirectURI: DefaultRedirectURI,
52+
redirectURI: redirectURI,
3353
}
3454
}
3555

@@ -125,6 +145,16 @@ func (m *Manager) BuildAuthorizationURL(_ context.Context, serverName string, sc
125145

126146
// ExchangeCode exchanges an authorization code for an access token
127147
func (m *Manager) ExchangeCode(ctx context.Context, code string, state string) error {
148+
// Strip the mcp-gateway:PORT: prefix if present.
149+
// BuildAuthorizationURL formats state as "mcp-gateway:PORT:UUID" for proxy routing,
150+
// but the StateManager only stores the base UUID.
151+
if strings.HasPrefix(state, "mcp-gateway:") {
152+
parts := strings.SplitN(state, ":", 3)
153+
if len(parts) == 3 {
154+
state = parts[2]
155+
}
156+
}
157+
128158
// Validate state and retrieve verifier
129159
serverName, verifier, err := m.stateManager.Validate(state)
130160
if err != nil {

pkg/oauth/manager_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package oauth
22

33
import (
44
"context"
5+
"os"
56
"strings"
67
"testing"
78
"time"
@@ -344,6 +345,101 @@ func TestManager_CallbackURLParsing(t *testing.T) {
344345
}
345346
}
346347

348+
func TestManager_ExchangeCode_StripsStatePrefix(t *testing.T) {
349+
manager := setupTestManager(t)
350+
serverName := "test-server"
351+
352+
setupTestDCRClient(t, manager, serverName)
353+
354+
// Generate a state via BuildAuthorizationURL with a callback URL,
355+
// which produces "mcp-gateway:PORT:UUID" format.
356+
_, baseState, _, err := manager.BuildAuthorizationURL(
357+
context.Background(),
358+
serverName,
359+
[]string{"read"},
360+
"http://localhost:8080/callback",
361+
)
362+
require.NoError(t, err)
363+
364+
// Simulate the prefixed state that would come back from the OAuth callback
365+
prefixedState := "mcp-gateway:8080:" + baseState
366+
367+
// ExchangeCode will fail at token exchange (no real server), but it should
368+
// get past state validation — meaning the prefix was correctly stripped.
369+
err = manager.ExchangeCode(context.Background(), "test-code", prefixedState)
370+
require.Error(t, err)
371+
// If prefix stripping failed, we'd get "invalid state parameter".
372+
// If it succeeded, we get a token exchange error instead.
373+
assert.NotContains(t, err.Error(), "invalid state parameter")
374+
}
375+
376+
func TestManager_ExchangeCode_PlainStateStillWorks(t *testing.T) {
377+
manager := setupTestManager(t)
378+
serverName := "test-server"
379+
380+
setupTestDCRClient(t, manager, serverName)
381+
382+
// Generate a state without callback URL (no prefix)
383+
_, baseState, _, err := manager.BuildAuthorizationURL(
384+
context.Background(),
385+
serverName,
386+
[]string{"read"},
387+
"",
388+
)
389+
require.NoError(t, err)
390+
391+
// ExchangeCode should still validate plain UUIDs (no prefix to strip)
392+
err = manager.ExchangeCode(context.Background(), "test-code", baseState)
393+
require.Error(t, err)
394+
assert.NotContains(t, err.Error(), "invalid state parameter")
395+
}
396+
397+
func TestManager_NewManager_CEModeRedirectURI(t *testing.T) {
398+
// Save and restore env vars
399+
origCE := os.Getenv("DOCKER_MCP_USE_CE")
400+
origURI := os.Getenv("DOCKER_MCP_OAUTH_REDIRECT_URI")
401+
defer func() {
402+
os.Setenv("DOCKER_MCP_USE_CE", origCE)
403+
os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", origURI)
404+
}()
405+
406+
t.Run("default mode uses SaaS redirect", func(t *testing.T) {
407+
os.Setenv("DOCKER_MCP_USE_CE", "")
408+
os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "")
409+
410+
helper := newFakeCredentialHelper()
411+
manager := NewManager(helper)
412+
assert.Equal(t, DefaultRedirectURI, manager.redirectURI)
413+
})
414+
415+
t.Run("CE mode uses localhost redirect", func(t *testing.T) {
416+
os.Setenv("DOCKER_MCP_USE_CE", "true")
417+
os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "")
418+
419+
helper := newFakeCredentialHelper()
420+
manager := NewManager(helper)
421+
assert.Equal(t, "http://localhost:5000/callback", manager.redirectURI)
422+
})
423+
424+
t.Run("CE mode with custom redirect URI", func(t *testing.T) {
425+
os.Setenv("DOCKER_MCP_USE_CE", "true")
426+
os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "http://localhost:9999/custom")
427+
428+
helper := newFakeCredentialHelper()
429+
manager := NewManager(helper)
430+
assert.Equal(t, "http://localhost:9999/custom", manager.redirectURI)
431+
})
432+
433+
t.Run("non-CE mode ignores custom redirect URI", func(t *testing.T) {
434+
os.Setenv("DOCKER_MCP_USE_CE", "false")
435+
os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "http://localhost:9999/custom")
436+
437+
helper := newFakeCredentialHelper()
438+
manager := NewManager(helper)
439+
assert.Equal(t, DefaultRedirectURI, manager.redirectURI)
440+
})
441+
}
442+
347443
func TestManager_StateFormatWithPort(t *testing.T) {
348444
manager := setupTestManager(t)
349445
serverName := "test-server"

0 commit comments

Comments
 (0)