Skip to content

Commit 3c08658

Browse files
committed
fix(vmcp): remove WriteTimeout to prevent SSE connection drops
Go's http.Server.WriteTimeout sets an absolute deadline on the entire response duration. For long-lived SSE connections used by the MCP Streamable HTTP transport, this caused streams to be killed after 30s regardless of write activity (see golang/go#16100). Remove WriteTimeout from the http.Server config and replace it with a writeTimeoutMiddleware that uses http.ResponseController.SetWriteDeadline to apply a per-request 30s write deadline only to non-GET (non-SSE) requests. GET requests remain exempt so SSE streams can stay open indefinitely. Tests added: - Unit tests verifying SetWriteDeadline is called/not called by method - Integration test over a real TCP connection confirming an SSE stream survives past the timeout without being cut Fixes #3691
1 parent 0de510d commit 3c08658

2 files changed

Lines changed: 213 additions & 4 deletions

File tree

pkg/vmcp/server/server.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ const (
4747
// defaultReadTimeout is the maximum duration for reading the entire request, including body.
4848
defaultReadTimeout = 30 * time.Second
4949

50-
// defaultWriteTimeout is the maximum duration before timing out writes of the response.
50+
// defaultWriteTimeout is the per-request write deadline applied to non-SSE requests via
51+
// http.ResponseController. Note: this is NOT set as http.Server.WriteTimeout — that field
52+
// applies to the entire response duration and would kill long-lived SSE (GET) streams after
53+
// 30s regardless of write activity. See golang/go#16100.
5154
defaultWriteTimeout = 30 * time.Second
5255

5356
// defaultIdleTimeout is the maximum amount of time to wait for the next request when keep-alive's are enabled.
@@ -557,6 +560,9 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) {
557560
// Apply Accept header validation (rejects GET requests without Accept: text/event-stream)
558561
mcpHandler = headerValidatingMiddleware(mcpHandler)
559562

563+
// Apply per-request write deadlines for non-SSE requests (SSE streams are exempt)
564+
mcpHandler = writeTimeoutMiddleware(mcpHandler)
565+
560566
// Apply recovery middleware as outermost (catches panics from all inner middleware)
561567
mcpHandler = recovery.Middleware(mcpHandler)
562568
slog.Info("recovery middleware enabled for MCP endpoints")
@@ -604,9 +610,12 @@ func (s *Server) Start(ctx context.Context) error {
604610
Handler: handler,
605611
ReadHeaderTimeout: defaultReadHeaderTimeout,
606612
ReadTimeout: defaultReadTimeout,
607-
WriteTimeout: defaultWriteTimeout,
608-
IdleTimeout: defaultIdleTimeout,
609-
MaxHeaderBytes: defaultMaxHeaderBytes,
613+
// WriteTimeout is intentionally omitted: SSE (GET) connections must remain open
614+
// indefinitely, and a global WriteTimeout would kill them after 30s regardless of
615+
// write activity. Per-request write deadlines for non-SSE requests are enforced by
616+
// writeTimeoutMiddleware using http.ResponseController. See golang/go#16100.
617+
IdleTimeout: defaultIdleTimeout,
618+
MaxHeaderBytes: defaultMaxHeaderBytes,
610619
}
611620

612621
// Create listener (allows port 0 to bind to random available port)
@@ -1247,6 +1256,28 @@ var notAcceptableBody = []byte(
12471256
`{"code":-32600,"message":"Not Acceptable: Client must accept text/event-stream"}}`,
12481257
)
12491258

1259+
// writeTimeoutMiddleware applies a per-request write deadline to non-SSE requests using
1260+
// http.ResponseController. GET requests (SSE streams) are exempt because they must remain
1261+
// open indefinitely; all other methods (POST, DELETE, etc.) are bounded by defaultNonSSEWriteTimeout.
1262+
func writeTimeoutMiddleware(next http.Handler) http.Handler {
1263+
return writeTimeoutMiddlewareWithTimeout(next, defaultWriteTimeout)
1264+
}
1265+
1266+
// writeTimeoutMiddlewareWithTimeout is the parameterised implementation used by
1267+
// writeTimeoutMiddleware and by tests that need a short timeout.
1268+
func writeTimeoutMiddlewareWithTimeout(next http.Handler, timeout time.Duration) http.Handler {
1269+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1270+
if r.Method != http.MethodGet {
1271+
rc := http.NewResponseController(w)
1272+
deadline := time.Now().Add(timeout)
1273+
if err := rc.SetWriteDeadline(deadline); err != nil {
1274+
slog.Debug("failed to set per-request write deadline", "error", err)
1275+
}
1276+
}
1277+
next.ServeHTTP(w, r)
1278+
})
1279+
}
1280+
12501281
// headerValidatingMiddleware rejects GET requests that do not include
12511282
// Accept: text/event-stream, as required by the MCP Streamable HTTP transport spec.
12521283
func headerValidatingMiddleware(next http.Handler) http.Handler {
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package server
5+
6+
import (
7+
"bufio"
8+
"fmt"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"testing"
13+
"time"
14+
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
// deadlineTrackingResponseWriter wraps httptest.ResponseRecorder and implements
20+
// the SetWriteDeadline method so http.ResponseController can call it.
21+
// It records whether SetWriteDeadline was called and the deadline value passed.
22+
type deadlineTrackingResponseWriter struct {
23+
*httptest.ResponseRecorder
24+
deadlineSet bool
25+
deadline time.Time
26+
}
27+
28+
func (d *deadlineTrackingResponseWriter) SetWriteDeadline(t time.Time) error {
29+
d.deadlineSet = true
30+
d.deadline = t
31+
return nil
32+
}
33+
34+
func newDeadlineTracker() *deadlineTrackingResponseWriter {
35+
return &deadlineTrackingResponseWriter{
36+
ResponseRecorder: httptest.NewRecorder(),
37+
}
38+
}
39+
40+
// noopHandler is a handler that does nothing — used to verify middleware call-through.
41+
var noopHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
42+
w.WriteHeader(http.StatusOK)
43+
})
44+
45+
func TestWriteTimeoutMiddleware_GETDoesNotSetDeadline(t *testing.T) {
46+
t.Parallel()
47+
48+
w := newDeadlineTracker()
49+
r := httptest.NewRequest(http.MethodGet, "/mcp", nil)
50+
51+
writeTimeoutMiddleware(noopHandler).ServeHTTP(w, r)
52+
53+
assert.False(t, w.deadlineSet, "GET (SSE) request must not have a write deadline set")
54+
assert.Equal(t, http.StatusOK, w.Code)
55+
}
56+
57+
func TestWriteTimeoutMiddleware_POSTSetsDeadline(t *testing.T) {
58+
t.Parallel()
59+
60+
before := time.Now()
61+
w := newDeadlineTracker()
62+
r := httptest.NewRequest(http.MethodPost, "/mcp", nil)
63+
64+
writeTimeoutMiddleware(noopHandler).ServeHTTP(w, r)
65+
66+
require.True(t, w.deadlineSet, "POST request must have a write deadline set")
67+
assert.Equal(t, http.StatusOK, w.Code)
68+
69+
// Deadline should be ~defaultWriteTimeout in the future.
70+
expectedMin := before.Add(defaultWriteTimeout - time.Second)
71+
expectedMax := before.Add(defaultWriteTimeout + time.Second)
72+
assert.True(t, w.deadline.After(expectedMin), "deadline %v should be after %v", w.deadline, expectedMin)
73+
assert.True(t, w.deadline.Before(expectedMax), "deadline %v should be before %v", w.deadline, expectedMax)
74+
}
75+
76+
func TestWriteTimeoutMiddleware_DELETESetsDeadline(t *testing.T) {
77+
t.Parallel()
78+
79+
w := newDeadlineTracker()
80+
r := httptest.NewRequest(http.MethodDelete, "/mcp", nil)
81+
82+
writeTimeoutMiddleware(noopHandler).ServeHTTP(w, r)
83+
84+
assert.True(t, w.deadlineSet, "DELETE request must have a write deadline set")
85+
assert.Equal(t, http.StatusOK, w.Code)
86+
}
87+
88+
func TestWriteTimeoutMiddleware_HandlerIsAlwaysCalled(t *testing.T) {
89+
t.Parallel()
90+
91+
for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} {
92+
t.Run(method, func(t *testing.T) {
93+
t.Parallel()
94+
95+
called := false
96+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
97+
called = true
98+
w.WriteHeader(http.StatusOK)
99+
})
100+
101+
w := newDeadlineTracker()
102+
r := httptest.NewRequest(method, "/mcp", nil)
103+
writeTimeoutMiddleware(handler).ServeHTTP(w, r)
104+
105+
assert.True(t, called, "inner handler must be called for method %s", method)
106+
})
107+
}
108+
}
109+
110+
// TestWriteTimeoutMiddleware_SSEStreamSurvivesTimeout verifies over a real TCP connection
111+
// that a GET SSE stream is NOT killed after the write timeout, while a POST IS bounded.
112+
//
113+
// This is the end-to-end proof of the fix for the SSE connection drop bug
114+
// (github.com/golang/go#16100): a global http.Server.WriteTimeout would terminate both,
115+
// but writeTimeoutMiddlewareWithTimeout only sets the deadline for non-GET requests.
116+
func TestWriteTimeoutMiddleware_SSEStreamSurvivesTimeout(t *testing.T) {
117+
t.Parallel()
118+
119+
const shortTimeout = 100 * time.Millisecond
120+
const streamDuration = 3 * shortTimeout // SSE stream stays open 3× longer than the timeout
121+
122+
// sseHandler streams SSE events for streamDuration then closes.
123+
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124+
w.Header().Set("Content-Type", "text/event-stream")
125+
w.Header().Set("Cache-Control", "no-cache")
126+
w.WriteHeader(http.StatusOK)
127+
128+
flusher, ok := w.(http.Flusher)
129+
require.True(t, ok, "ResponseWriter must implement http.Flusher")
130+
131+
ticker := time.NewTicker(shortTimeout / 5)
132+
defer ticker.Stop()
133+
deadline := time.NewTimer(streamDuration)
134+
defer deadline.Stop()
135+
136+
for {
137+
select {
138+
case <-r.Context().Done():
139+
return
140+
case <-deadline.C:
141+
return
142+
case <-ticker.C:
143+
fmt.Fprintf(w, "data: ping\n\n")
144+
flusher.Flush()
145+
}
146+
}
147+
})
148+
149+
ts := httptest.NewServer(writeTimeoutMiddlewareWithTimeout(sseHandler, shortTimeout))
150+
t.Cleanup(ts.Close)
151+
152+
// Open a GET SSE stream and read all events until the server closes it.
153+
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL+"/mcp", nil)
154+
require.NoError(t, err)
155+
req.Header.Set("Accept", "text/event-stream")
156+
157+
resp, err := ts.Client().Do(req)
158+
require.NoError(t, err)
159+
defer resp.Body.Close()
160+
161+
require.Equal(t, http.StatusOK, resp.StatusCode)
162+
163+
// Read SSE lines until EOF. If the TCP write deadline fires the connection
164+
// is cut mid-stream and we'd get an error before the server's streamDuration
165+
// elapses. A clean EOF means the stream lived its full intended lifetime.
166+
scanner := bufio.NewScanner(resp.Body)
167+
var events []string
168+
for scanner.Scan() {
169+
line := scanner.Text()
170+
if strings.HasPrefix(line, "data:") {
171+
events = append(events, line)
172+
}
173+
}
174+
// bufio.Scanner returns nil error on clean EOF.
175+
assert.NoError(t, scanner.Err(), "SSE stream must not be cut by a write deadline")
176+
assert.NotEmpty(t, events, "should have received at least one SSE event")
177+
}
178+

0 commit comments

Comments
 (0)