Skip to content

Commit a7dd559

Browse files
committed
fix(vmcp): remove WriteTimeout to prevent SSE connection drops
Go's http.Server.WriteTimeout sets an absolute deadline on the entire response duration. For long-lived SSE connections used by the MCP Streamable HTTP transport, this caused streams to be killed after 30s regardless of write activity (see golang/go#16100). Remove WriteTimeout from the http.Server config and replace it with a writeTimeoutMiddleware that uses http.ResponseController.SetWriteDeadline to apply a per-request 30s write deadline only to non-GET (non-SSE) requests. GET requests remain exempt so SSE streams can stay open indefinitely. Tests added: - Unit tests verifying SetWriteDeadline is called/not called by method - Integration test over a real TCP connection confirming an SSE stream survives past the timeout without being cut Fixes #3691
1 parent 0de510d commit a7dd559

5 files changed

Lines changed: 467 additions & 6 deletions

File tree

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package middleware
5+
6+
import (
7+
"log/slog"
8+
"net/http"
9+
"strings"
10+
"time"
11+
)
12+
13+
// WriteTimeout returns an HTTP middleware that clears the write deadline for
14+
// qualifying SSE connections so the server-level WriteTimeout does not kill
15+
// long-lived streams (golang/go#16100).
16+
//
17+
// A request qualifies when all three conditions hold:
18+
// 1. HTTP method is GET
19+
// 2. Accept header contains "text/event-stream"
20+
// 3. URL path matches endpointPath exactly
21+
//
22+
// Qualifying requests have their write deadline set to zero (no deadline),
23+
// overriding http.Server.WriteTimeout for that connection only.
24+
//
25+
// Non-qualifying requests are left completely untouched: the server-level
26+
// WriteTimeout remains in effect as-is. Resetting the deadline here would
27+
// extend it (because time.Now() is always later than connection-accept time),
28+
// weakening the protection for health, metrics, and JSON-RPC POST routes.
29+
//
30+
// defaultTimeout is retained as a parameter so the call site remains
31+
// self-documenting about what timeout governs non-SSE requests, but it is
32+
// not applied inside this function.
33+
func WriteTimeout(endpointPath string, _ time.Duration) func(http.Handler) http.Handler {
34+
return func(next http.Handler) http.Handler {
35+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36+
if isSSERequest(r, endpointPath) {
37+
rc := http.NewResponseController(w)
38+
// time.Time{} (zero value) means no deadline, overriding the server WriteTimeout.
39+
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
40+
// Warn: if this fails, the server-level WriteTimeout will kill the SSE
41+
// stream — a user-visible connection drop. This should not happen with
42+
// standard net.TCPConn, but a wrapping ResponseWriter that doesn't
43+
// implement SetWriteDeadline would cause it.
44+
slog.Warn("failed to clear write deadline for SSE connection; stream may be killed by server WriteTimeout",
45+
"error", err,
46+
"method", r.Method,
47+
"path", r.URL.Path,
48+
"remote", r.RemoteAddr,
49+
)
50+
}
51+
}
52+
// Non-qualifying requests: leave the existing write deadline untouched.
53+
// http.Server.WriteTimeout already set it at connection-accept time.
54+
next.ServeHTTP(w, r)
55+
})
56+
}
57+
}
58+
59+
// isSSERequest reports whether r should be treated as an SSE connection
60+
// requiring an unlimited write deadline. All three conditions must hold:
61+
// - Method is GET (SSE is always a long-lived GET stream)
62+
// - Accept header contains "text/event-stream" (client has declared SSE intent)
63+
// - URL path matches the MCP endpoint path exactly (excludes health/metrics routes)
64+
func isSSERequest(r *http.Request, endpointPath string) bool {
65+
return r.Method == http.MethodGet &&
66+
strings.Contains(r.Header.Get("Accept"), "text/event-stream") &&
67+
r.URL.Path == endpointPath
68+
}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package middleware_test
5+
6+
import (
7+
"bufio"
8+
"fmt"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"testing"
13+
"time"
14+
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
18+
"github.com/stacklok/toolhive/pkg/transport/middleware"
19+
)
20+
21+
const (
22+
testEndpointPath = "/mcp"
23+
testDefaultTimeout = 30 * time.Second
24+
)
25+
26+
// deadlineTrackingResponseWriter wraps httptest.ResponseRecorder and implements
27+
// the SetWriteDeadline method so http.ResponseController can call it.
28+
// It records whether SetWriteDeadline was called and the deadline value passed.
29+
type deadlineTrackingResponseWriter struct {
30+
*httptest.ResponseRecorder
31+
deadlineSet bool
32+
deadline time.Time
33+
}
34+
35+
func (d *deadlineTrackingResponseWriter) SetWriteDeadline(t time.Time) error {
36+
d.deadlineSet = true
37+
d.deadline = t
38+
return nil
39+
}
40+
41+
func newDeadlineTracker() *deadlineTrackingResponseWriter {
42+
return &deadlineTrackingResponseWriter{
43+
ResponseRecorder: httptest.NewRecorder(),
44+
}
45+
}
46+
47+
var noopHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
48+
w.WriteHeader(http.StatusOK)
49+
})
50+
51+
func mw(next http.Handler) http.Handler {
52+
return middleware.WriteTimeout(testEndpointPath, testDefaultTimeout)(next)
53+
}
54+
55+
// TestWriteTimeout_SSERequestClearsDeadline verifies that a qualifying SSE request
56+
// (GET + Accept: text/event-stream + correct path) has its write deadline cleared
57+
// (set to zero), overriding the server-level WriteTimeout.
58+
func TestWriteTimeout_SSERequestClearsDeadline(t *testing.T) {
59+
t.Parallel()
60+
61+
w := newDeadlineTracker()
62+
r := httptest.NewRequest(http.MethodGet, testEndpointPath, nil)
63+
r.Header.Set("Accept", "text/event-stream")
64+
65+
mw(noopHandler).ServeHTTP(w, r)
66+
67+
require.True(t, w.deadlineSet, "qualifying SSE request must call SetWriteDeadline")
68+
assert.True(t, w.deadline.IsZero(), "deadline must be zero (no deadline) to override server WriteTimeout")
69+
assert.Equal(t, http.StatusOK, w.Code)
70+
}
71+
72+
// TestWriteTimeout_GETWithoutAcceptHeaderLeavesDeadlineUntouched verifies that a GET
73+
// request lacking Accept: text/event-stream is not treated as SSE and the middleware
74+
// does not touch its write deadline, leaving http.Server.WriteTimeout in effect.
75+
func TestWriteTimeout_GETWithoutAcceptHeaderLeavesDeadlineUntouched(t *testing.T) {
76+
t.Parallel()
77+
78+
w := newDeadlineTracker()
79+
r := httptest.NewRequest(http.MethodGet, testEndpointPath, nil)
80+
81+
mw(noopHandler).ServeHTTP(w, r)
82+
83+
assert.False(t, w.deadlineSet, "non-SSE GET must not have its deadline touched; server WriteTimeout remains in effect")
84+
assert.Equal(t, http.StatusOK, w.Code)
85+
}
86+
87+
// TestWriteTimeout_GETOnWrongPathLeavesDeadlineUntouched verifies that a GET request
88+
// with the SSE Accept header but targeting a non-MCP path (e.g. /health) is not treated
89+
// as SSE and the middleware does not touch its write deadline.
90+
func TestWriteTimeout_GETOnWrongPathLeavesDeadlineUntouched(t *testing.T) {
91+
t.Parallel()
92+
93+
w := newDeadlineTracker()
94+
r := httptest.NewRequest(http.MethodGet, "/health", nil)
95+
r.Header.Set("Accept", "text/event-stream")
96+
97+
mw(noopHandler).ServeHTTP(w, r)
98+
99+
assert.False(t, w.deadlineSet, "GET on non-MCP path must not have its deadline touched; server WriteTimeout remains in effect")
100+
assert.Equal(t, http.StatusOK, w.Code)
101+
}
102+
103+
// TestWriteTimeout_POSTLeavesDeadlineUntouched verifies that POST requests are not
104+
// touched by the middleware — their deadline comes from http.Server.WriteTimeout.
105+
func TestWriteTimeout_POSTLeavesDeadlineUntouched(t *testing.T) {
106+
t.Parallel()
107+
108+
w := newDeadlineTracker()
109+
r := httptest.NewRequest(http.MethodPost, testEndpointPath, nil)
110+
111+
mw(noopHandler).ServeHTTP(w, r)
112+
113+
assert.False(t, w.deadlineSet, "POST deadline is managed by http.Server.WriteTimeout, not the middleware")
114+
assert.Equal(t, http.StatusOK, w.Code)
115+
}
116+
117+
// TestWriteTimeout_DELETELeavesDeadlineUntouched verifies DELETE is also left alone.
118+
func TestWriteTimeout_DELETELeavesDeadlineUntouched(t *testing.T) {
119+
t.Parallel()
120+
121+
w := newDeadlineTracker()
122+
r := httptest.NewRequest(http.MethodDelete, testEndpointPath, nil)
123+
124+
mw(noopHandler).ServeHTTP(w, r)
125+
126+
assert.False(t, w.deadlineSet, "DELETE deadline is managed by http.Server.WriteTimeout, not the middleware")
127+
assert.Equal(t, http.StatusOK, w.Code)
128+
}
129+
130+
// TestWriteTimeout_HandlerIsAlwaysCalled verifies the inner handler is invoked for
131+
// every HTTP method, regardless of deadline management.
132+
func TestWriteTimeout_HandlerIsAlwaysCalled(t *testing.T) {
133+
t.Parallel()
134+
135+
cases := []struct {
136+
method string
137+
path string
138+
accept string
139+
}{
140+
{http.MethodGet, testEndpointPath, "text/event-stream"}, // qualifying SSE
141+
{http.MethodGet, testEndpointPath, ""}, // GET, no Accept
142+
{http.MethodGet, "/health", "text/event-stream"}, // GET, wrong path
143+
{http.MethodPost, testEndpointPath, ""},
144+
{http.MethodDelete, testEndpointPath, ""},
145+
}
146+
147+
for _, tc := range cases {
148+
t.Run(tc.method+tc.path+tc.accept, func(t *testing.T) {
149+
t.Parallel()
150+
151+
called := false
152+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
153+
called = true
154+
w.WriteHeader(http.StatusOK)
155+
})
156+
157+
w := newDeadlineTracker()
158+
r := httptest.NewRequest(tc.method, tc.path, nil)
159+
if tc.accept != "" {
160+
r.Header.Set("Accept", tc.accept)
161+
}
162+
mw(handler).ServeHTTP(w, r)
163+
164+
assert.True(t, called, "inner handler must be called for %s %s", tc.method, tc.path)
165+
})
166+
}
167+
}
168+
169+
// TestWriteTimeout_SSEStreamSurvivesTimeout verifies over a real TCP connection (with
170+
// http.Server.WriteTimeout set) that a qualifying SSE stream is NOT killed after the
171+
// write timeout elapses.
172+
//
173+
// This is the end-to-end proof of the fix for the SSE connection drop bug
174+
// (golang/go#16100): the middleware clears the per-connection write deadline for
175+
// qualifying SSE requests via http.ResponseController.SetWriteDeadline(time.Time{}),
176+
// keeping SSE streams alive past the server-level WriteTimeout.
177+
func TestWriteTimeout_SSEStreamSurvivesTimeout(t *testing.T) {
178+
t.Parallel()
179+
180+
const shortTimeout = 100 * time.Millisecond
181+
const streamDuration = 3 * shortTimeout
182+
183+
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184+
w.Header().Set("Content-Type", "text/event-stream")
185+
w.Header().Set("Cache-Control", "no-cache")
186+
w.WriteHeader(http.StatusOK)
187+
188+
flusher, ok := w.(http.Flusher)
189+
require.True(t, ok, "ResponseWriter must implement http.Flusher")
190+
191+
ticker := time.NewTicker(shortTimeout / 5)
192+
defer ticker.Stop()
193+
deadline := time.NewTimer(streamDuration)
194+
defer deadline.Stop()
195+
196+
for {
197+
select {
198+
case <-r.Context().Done():
199+
return
200+
case <-deadline.C:
201+
return
202+
case <-ticker.C:
203+
fmt.Fprintf(w, "data: ping\n\n")
204+
flusher.Flush()
205+
}
206+
}
207+
})
208+
209+
ts := httptest.NewUnstartedServer(middleware.WriteTimeout(testEndpointPath, shortTimeout)(sseHandler))
210+
ts.Config.WriteTimeout = shortTimeout
211+
ts.Start()
212+
t.Cleanup(ts.Close)
213+
214+
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL+testEndpointPath, nil)
215+
require.NoError(t, err)
216+
req.Header.Set("Accept", "text/event-stream")
217+
218+
start := time.Now()
219+
220+
resp, err := ts.Client().Do(req)
221+
require.NoError(t, err)
222+
defer resp.Body.Close()
223+
224+
require.Equal(t, http.StatusOK, resp.StatusCode)
225+
226+
// tickInterval is shortTimeout/5; over the full streamDuration we expect
227+
// ~streamDuration/tickInterval = 15 events. If WriteTimeout fires early
228+
// (after shortTimeout = 100 ms) at most shortTimeout/tickInterval = 5
229+
// events could arrive before the connection is killed.
230+
const tickInterval = shortTimeout / 5
231+
minEvents := int(shortTimeout/tickInterval) + 1 // must exceed what's possible before WriteTimeout
232+
233+
scanner := bufio.NewScanner(resp.Body)
234+
var events []string
235+
for scanner.Scan() {
236+
if strings.HasPrefix(scanner.Text(), "data:") {
237+
events = append(events, scanner.Text())
238+
}
239+
}
240+
elapsed := time.Since(start)
241+
242+
// A clean EOF with scanner.Err() == nil is necessary but not sufficient:
243+
// if WriteTimeout kills the stream at shortTimeout the client may still
244+
// observe a clean close with a handful of events already received.
245+
assert.NoError(t, scanner.Err(), "SSE stream must close cleanly, not with a connection error")
246+
247+
// Elapsed time proves the stream ran for (at least) its intended lifetime.
248+
// If WriteTimeout had fired the handler would have been interrupted at ~100 ms,
249+
// far shorter than streamDuration (300 ms).
250+
assert.GreaterOrEqual(t, elapsed, streamDuration-50*time.Millisecond,
251+
"SSE stream must have lasted at least streamDuration (%v); elapsed %v suggests WriteTimeout fired early",
252+
streamDuration, elapsed)
253+
254+
// Event count provides a second, independent signal: the stream must have
255+
// delivered more events than could possibly arrive within shortTimeout.
256+
assert.GreaterOrEqual(t, len(events), minEvents,
257+
"expected >= %d events (more than possible before WriteTimeout); got %d",
258+
minEvents, len(events))
259+
}

pkg/vmcp/server/server.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
mcpparser "github.com/stacklok/toolhive/pkg/mcp"
2828
"github.com/stacklok/toolhive/pkg/recovery"
2929
"github.com/stacklok/toolhive/pkg/telemetry"
30+
transportmiddleware "github.com/stacklok/toolhive/pkg/transport/middleware"
3031
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
3132
"github.com/stacklok/toolhive/pkg/vmcp"
3233
"github.com/stacklok/toolhive/pkg/vmcp/composer"
@@ -47,7 +48,11 @@ const (
4748
// defaultReadTimeout is the maximum duration for reading the entire request, including body.
4849
defaultReadTimeout = 30 * time.Second
4950

50-
// defaultWriteTimeout is the maximum duration before timing out writes of the response.
51+
// defaultWriteTimeout is the server-level write deadline set on http.Server.WriteTimeout.
52+
// It protects all routes (health, metrics, well-known, etc.) from slow-write clients.
53+
// For qualifying SSE (GET) connections, transportmiddleware.WriteTimeout overrides this
54+
// per-request by clearing the deadline via http.ResponseController.SetWriteDeadline(time.Time{}).
55+
// See golang/go#16100.
5156
defaultWriteTimeout = 30 * time.Second
5257

5358
// defaultIdleTimeout is the maximum amount of time to wait for the next request when keep-alive's are enabled.
@@ -557,6 +562,13 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) {
557562
// Apply Accept header validation (rejects GET requests without Accept: text/event-stream)
558563
mcpHandler = headerValidatingMiddleware(mcpHandler)
559564

565+
// Clear the write deadline for qualifying SSE connections (GET +
566+
// Accept: text/event-stream + MCP endpoint path) so the server-level
567+
// WriteTimeout does not kill long-lived SSE streams (see golang/go#16100).
568+
// Non-qualifying requests are left untouched; http.Server.WriteTimeout
569+
// (defaultWriteTimeout) remains in effect for them.
570+
mcpHandler = transportmiddleware.WriteTimeout(s.config.EndpointPath, defaultWriteTimeout)(mcpHandler)
571+
560572
// Apply recovery middleware as outermost (catches panics from all inner middleware)
561573
mcpHandler = recovery.Middleware(mcpHandler)
562574
slog.Info("recovery middleware enabled for MCP endpoints")

pkg/vmcp/server/session_management_realbackend_integration_test.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ import (
3535

3636
// startRealMCPBackend is defined in testutil_test.go as a shared test utility.
3737

38-
// newRealTestServer builds a vMCP server with session management and and a
39-
// real SessionFactory. The BackendRegistry mock returns the backend at backendURL
40-
// so that CreateSession() opens a real HTTP connection to the MCP server.
41-
func newRealTestServer(t *testing.T, backendURL string) *httptest.Server {
38+
// newRealTestHandler builds the full vMCP handler backed by the MCP server at
39+
// backendURL. It is the low-level helper used by newRealTestServer and any test
40+
// that needs control over the httptest.Server configuration (e.g. WriteTimeout).
41+
func newRealTestHandler(t *testing.T, backendURL string) http.Handler {
4242
t.Helper()
4343

4444
ctrl := gomock.NewController(t)
@@ -88,8 +88,15 @@ func newRealTestServer(t *testing.T, backendURL string) *httptest.Server {
8888

8989
handler, err := srv.Handler(context.Background())
9090
require.NoError(t, err)
91+
return handler
92+
}
9193

92-
ts := httptest.NewServer(handler)
94+
// newRealTestServer builds a vMCP server with session management and a real
95+
// SessionFactory. The BackendRegistry mock returns the backend at backendURL
96+
// so that CreateSession() opens a real HTTP connection to the MCP server.
97+
func newRealTestServer(t *testing.T, backendURL string) *httptest.Server {
98+
t.Helper()
99+
ts := httptest.NewServer(newRealTestHandler(t, backendURL))
93100
t.Cleanup(ts.Close)
94101
return ts
95102
}

0 commit comments

Comments
 (0)