Skip to content

Commit 604d1c3

Browse files
sd2kclaude
andauthored
fix: register ephemeral sessions to fix horizontal scaling of proxied tools (#754)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent caf4095 commit 604d1c3

3 files changed

Lines changed: 256 additions & 3 deletions

File tree

cmd/mcp-grafana/main.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,10 @@ func newServer(transport string, dt disabledTools, obs *observability.Observabil
136136
mcpgrafana.WithSessionTTL(time.Duration(sessionIdleTimeoutMinutes) * time.Minute),
137137
)
138138

139-
// Declare variable for ToolManager that will be initialized after server creation
139+
// Declare variables that will be initialized after server creation.
140+
// The hooks below capture these by pointer, so they must be declared first.
140141
var stm *mcpgrafana.ToolManager
142+
var s *server.MCPServer
141143

142144
// Create hooks
143145
hooks := &server.Hooks{
@@ -149,9 +151,24 @@ func newServer(transport string, dt disabledTools, obs *observability.Observabil
149151
// (stdio mode is handled by InitializeAndRegisterServerTools; per-session tools
150152
// are not supported).
151153
if transport != "stdio" && !dt.proxied {
154+
// ensureSessionRegistered registers an ephemeral session in MCPServer.sessions
155+
// if it's not already there. This is needed for horizontal scaling: when a
156+
// request lands on a pod that didn't handle the initialize call, the SDK
157+
// creates an ephemeral session that isn't registered, causing AddSessionTools
158+
// to fail with ErrSessionNotFound. RegisterSession uses LoadOrStore
159+
// internally, so this is a no-op for already-registered sessions.
160+
ensureSessionRegistered := func(ctx context.Context) {
161+
if s != nil {
162+
if session := server.ClientSessionFromContext(ctx); session != nil {
163+
_ = s.RegisterSession(ctx, session)
164+
}
165+
}
166+
}
167+
152168
// OnBeforeListTools: Discover, connect, and register tools
153169
hooks.OnBeforeListTools = []server.OnBeforeListToolsFunc{
154170
func(ctx context.Context, id any, request *mcp.ListToolsRequest) {
171+
ensureSessionRegistered(ctx)
155172
if stm != nil {
156173
if session := server.ClientSessionFromContext(ctx); session != nil {
157174
stm.InitializeAndRegisterProxiedTools(ctx, session)
@@ -163,6 +180,7 @@ func newServer(transport string, dt disabledTools, obs *observability.Observabil
163180
// OnBeforeCallTool: Fallback in case client calls tool without listing first
164181
hooks.OnBeforeCallTool = []server.OnBeforeCallToolFunc{
165182
func(ctx context.Context, id any, request *mcp.CallToolRequest) {
183+
ensureSessionRegistered(ctx)
166184
if stm != nil {
167185
if session := server.ClientSessionFromContext(ctx); session != nil {
168186
stm.InitializeAndRegisterProxiedTools(ctx, session)
@@ -175,7 +193,7 @@ func newServer(transport string, dt disabledTools, obs *observability.Observabil
175193
// Merge observability hooks with existing hooks
176194
hooks = observability.MergeHooks(hooks, obs.MCPHooks())
177195

178-
s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(),
196+
s = server.NewMCPServer("mcp-grafana", mcpgrafana.Version(),
179197
server.WithInstructions(`
180198
This server provides access to your Grafana instance and the surrounding ecosystem.
181199
@@ -203,6 +221,10 @@ Note that some of these capabilities may be disabled. Do not try to use features
203221
// Initialize ToolManager now that server is created
204222
stm = mcpgrafana.NewToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied))
205223

224+
// Give the SessionManager a reference to the MCPServer so the reaper can
225+
// unregister sessions from the SDK's internal session map.
226+
sm.SetMCPServer(s)
227+
206228
dt.addTools(s)
207229
return s, stm, sm
208230
}

session.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,22 @@ type SessionManager struct {
8989
reaperDone chan struct{}
9090
closeOnce sync.Once
9191
metrics sessionMetrics
92+
93+
// mcpServer is an optional reference to the MCP server, used to unregister
94+
// sessions from the SDK's internal session map when they are reaped. This
95+
// prevents a memory leak when sessions are registered via RegisterSession
96+
// in horizontal scaling scenarios (where ephemeral sessions are registered
97+
// so that AddSessionTools can find them).
98+
mcpServer *server.MCPServer
99+
}
100+
101+
// SetMCPServer sets the MCP server reference for session cleanup. When set,
102+
// the reaper will call MCPServer.UnregisterSession for reaped sessions to
103+
// prevent a memory leak in the SDK's internal session map.
104+
func (sm *SessionManager) SetMCPServer(s *server.MCPServer) {
105+
sm.mutex.Lock()
106+
defer sm.mutex.Unlock()
107+
sm.mcpServer = s
92108
}
93109

94110
func NewSessionManager(opts ...SessionManagerOption) *SessionManager {
@@ -215,6 +231,7 @@ func (sm *SessionManager) reapStaleSessions() {
215231
sm.mutex.Lock()
216232
var stale []*SessionState
217233
var staleIDs []string
234+
mcpSrv := sm.mcpServer
218235
for id, state := range sm.sessions {
219236
if now.Sub(state.lastActivity) > sm.sessionTTL {
220237
stale = append(stale, state)
@@ -232,8 +249,15 @@ func (sm *SessionManager) reapStaleSessions() {
232249
slog.Info("Reaping stale sessions", "count", len(stale), "session_ids", staleIDs)
233250
}
234251

235-
for _, state := range stale {
252+
ctx := context.Background()
253+
for i, state := range stale {
236254
cleanupSessionState(state)
255+
// Also unregister from MCPServer.sessions to prevent a memory leak.
256+
// Sessions may have been registered there via RegisterSession in the
257+
// OnBeforeListTools/OnBeforeCallTool hooks for horizontal scaling support.
258+
if mcpSrv != nil {
259+
mcpSrv.UnregisterSession(ctx, staleIDs[i])
260+
}
237261
}
238262
}
239263

session_horizontal_scaling_test.go

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
package mcpgrafana
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
"time"
8+
9+
"github.com/mark3labs/mcp-go/mcp"
10+
"github.com/mark3labs/mcp-go/server"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// mockSessionWithTools implements both server.ClientSession and server.SessionWithTools
16+
// for testing the horizontal scaling fix where AddSessionTools needs to find
17+
// the session in MCPServer.sessions.
18+
type mockSessionWithTools struct {
19+
id string
20+
notifChannel chan mcp.JSONRPCNotification
21+
isInitialized bool
22+
tools map[string]server.ServerTool
23+
mu sync.RWMutex
24+
}
25+
26+
func newMockSessionWithTools(id string) *mockSessionWithTools {
27+
return &mockSessionWithTools{
28+
id: id,
29+
notifChannel: make(chan mcp.JSONRPCNotification, 10),
30+
tools: make(map[string]server.ServerTool),
31+
}
32+
}
33+
34+
func (m *mockSessionWithTools) SessionID() string {
35+
return m.id
36+
}
37+
38+
func (m *mockSessionWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification {
39+
return m.notifChannel
40+
}
41+
42+
func (m *mockSessionWithTools) Initialize() {
43+
m.isInitialized = true
44+
}
45+
46+
func (m *mockSessionWithTools) Initialized() bool {
47+
return m.isInitialized
48+
}
49+
50+
func (m *mockSessionWithTools) GetSessionTools() map[string]server.ServerTool {
51+
m.mu.RLock()
52+
defer m.mu.RUnlock()
53+
cp := make(map[string]server.ServerTool, len(m.tools))
54+
for k, v := range m.tools {
55+
cp[k] = v
56+
}
57+
return cp
58+
}
59+
60+
func (m *mockSessionWithTools) SetSessionTools(tools map[string]server.ServerTool) {
61+
m.mu.Lock()
62+
defer m.mu.Unlock()
63+
m.tools = tools
64+
}
65+
66+
// TestRegisterSessionFixesAddSessionTools verifies the core fix for issue #749:
67+
// when an ephemeral session is registered in MCPServer.sessions via RegisterSession,
68+
// AddSessionTools succeeds instead of returning ErrSessionNotFound.
69+
func TestRegisterSessionFixesAddSessionTools(t *testing.T) {
70+
t.Run("AddSessionTools fails without RegisterSession", func(t *testing.T) {
71+
s := server.NewMCPServer("test", "1.0.0")
72+
session := newMockSessionWithTools("unregistered-session")
73+
74+
err := s.AddSessionTools(session.SessionID(), server.ServerTool{
75+
Tool: mcp.NewTool("test-tool"),
76+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
77+
})
78+
assert.Error(t, err, "AddSessionTools should fail for unregistered session")
79+
})
80+
81+
t.Run("AddSessionTools succeeds after RegisterSession", func(t *testing.T) {
82+
s := server.NewMCPServer("test", "1.0.0")
83+
session := newMockSessionWithTools("cross-pod-session")
84+
85+
// Simulate what the fix does: register the ephemeral session
86+
err := s.RegisterSession(context.Background(), session)
87+
require.NoError(t, err)
88+
89+
// Now AddSessionTools should succeed
90+
err = s.AddSessionTools(session.SessionID(), server.ServerTool{
91+
Tool: mcp.NewTool("tempo_traceql-search"),
92+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
93+
})
94+
assert.NoError(t, err, "AddSessionTools should succeed after RegisterSession")
95+
96+
// Verify the tool was actually registered
97+
tools := session.GetSessionTools()
98+
assert.Contains(t, tools, "tempo_traceql-search")
99+
})
100+
101+
t.Run("RegisterSession is idempotent for already-registered sessions", func(t *testing.T) {
102+
s := server.NewMCPServer("test", "1.0.0")
103+
session := newMockSessionWithTools("existing-session")
104+
105+
// First registration
106+
err := s.RegisterSession(context.Background(), session)
107+
require.NoError(t, err)
108+
109+
// Second registration should return ErrSessionExists but not panic
110+
err = s.RegisterSession(context.Background(), session)
111+
assert.Error(t, err, "Second RegisterSession should return error")
112+
113+
// AddSessionTools should still work
114+
err = s.AddSessionTools(session.SessionID(), server.ServerTool{
115+
Tool: mcp.NewTool("test-tool"),
116+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
117+
})
118+
assert.NoError(t, err)
119+
})
120+
}
121+
122+
// TestReaperUnregistersFromMCPServer verifies that the session reaper also
123+
// cleans up sessions from MCPServer.sessions when SetMCPServer has been called.
124+
func TestReaperUnregistersFromMCPServer(t *testing.T) {
125+
t.Run("reaper cleans up MCPServer sessions", func(t *testing.T) {
126+
s := server.NewMCPServer("test", "1.0.0")
127+
128+
sm := NewSessionManager(
129+
WithSessionTTL(50 * time.Millisecond),
130+
)
131+
sm.SetMCPServer(s)
132+
defer sm.Close()
133+
134+
session := newMockSessionWithTools("reap-me")
135+
136+
// Register in both the application SessionManager and MCPServer
137+
sm.CreateSession(context.Background(), session)
138+
err := s.RegisterSession(context.Background(), session)
139+
require.NoError(t, err)
140+
141+
// Add a tool to prove the session is registered
142+
err = s.AddSessionTools(session.SessionID(), server.ServerTool{
143+
Tool: mcp.NewTool("test-tool"),
144+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
145+
})
146+
require.NoError(t, err)
147+
148+
// Wait for the session to become stale and be reaped
149+
time.Sleep(200 * time.Millisecond)
150+
151+
// Session should be removed from both managers
152+
_, exists := sm.GetSession("reap-me")
153+
assert.False(t, exists, "Session should be removed from SessionManager")
154+
155+
// Verify the session is gone from MCPServer too by trying to add tools
156+
err = s.AddSessionTools("reap-me", server.ServerTool{
157+
Tool: mcp.NewTool("another-tool"),
158+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil },
159+
})
160+
assert.Error(t, err, "Session should be unregistered from MCPServer after reaping")
161+
})
162+
163+
t.Run("reaper works without MCPServer reference", func(t *testing.T) {
164+
sm := NewSessionManager(
165+
WithSessionTTL(50 * time.Millisecond),
166+
)
167+
// Deliberately NOT calling sm.SetMCPServer()
168+
defer sm.Close()
169+
170+
session := &mockClientSession{id: "no-mcp-server"}
171+
sm.CreateSession(context.Background(), session)
172+
173+
// Wait for reaping
174+
time.Sleep(200 * time.Millisecond)
175+
176+
_, exists := sm.GetSession("no-mcp-server")
177+
assert.False(t, exists, "Session should still be reaped without MCPServer reference")
178+
})
179+
}
180+
181+
// TestSetMCPServer verifies the SetMCPServer method.
182+
func TestSetMCPServer(t *testing.T) {
183+
t.Run("SetMCPServer sets the reference", func(t *testing.T) {
184+
sm := NewSessionManager()
185+
defer sm.Close()
186+
187+
s := server.NewMCPServer("test", "1.0.0")
188+
sm.SetMCPServer(s)
189+
190+
sm.mutex.RLock()
191+
assert.Equal(t, s, sm.mcpServer)
192+
sm.mutex.RUnlock()
193+
})
194+
195+
t.Run("SetMCPServer can be called with nil", func(t *testing.T) {
196+
sm := NewSessionManager()
197+
defer sm.Close()
198+
199+
s := server.NewMCPServer("test", "1.0.0")
200+
sm.SetMCPServer(s)
201+
sm.SetMCPServer(nil)
202+
203+
sm.mutex.RLock()
204+
assert.Nil(t, sm.mcpServer)
205+
sm.mutex.RUnlock()
206+
})
207+
}

0 commit comments

Comments
 (0)