Skip to content

Commit 5d8d3c4

Browse files
immediately reject streams for closed sessions (#235)
1 parent af56b7f commit 5d8d3c4

File tree

2 files changed

+183
-5
lines changed

2 files changed

+183
-5
lines changed

session_manager.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package webtransport
22

33
import (
44
"context"
5+
"slices"
56
"sync"
67
"time"
78

@@ -21,11 +22,14 @@ type sessionEntry struct {
2122
Session *Session
2223
}
2324

25+
const maxRecentlyClosedSessions = 16
26+
2427
type sessionManager struct {
2528
timeout time.Duration
2629

27-
mx sync.Mutex
28-
sessions map[sessionID]sessionEntry
30+
mx sync.Mutex
31+
sessions map[sessionID]sessionEntry
32+
recentlyClosedSessions []sessionID
2933
}
3034

3135
func newSessionManager(timeout time.Duration) *sessionManager {
@@ -45,6 +49,13 @@ func (m *sessionManager) AddStream(str *quic.Stream, id sessionID) {
4549

4650
entry, ok := m.sessions[id]
4751
if !ok {
52+
// Receiving a stream for an unknown session is expected to be rare,
53+
// so the performance impact of searching through the slice is negligible.
54+
if slices.Contains(m.recentlyClosedSessions, id) {
55+
str.CancelRead(WTBufferedStreamRejectedErrorCode)
56+
str.CancelWrite(WTBufferedStreamRejectedErrorCode)
57+
return
58+
}
4859
entry = sessionEntry{Unestablished: &unestablishedSession{}}
4960
m.sessions[id] = entry
5061
}
@@ -67,6 +78,12 @@ func (m *sessionManager) AddUniStream(str *quic.ReceiveStream, id sessionID) {
6778

6879
entry, ok := m.sessions[id]
6980
if !ok {
81+
// Receiving a stream for an unknown session is expected to be rare,
82+
// so the performance impact of searching through the slice is negligible.
83+
if slices.Contains(m.recentlyClosedSessions, id) {
84+
str.CancelRead(WTBufferedStreamRejectedErrorCode)
85+
return
86+
}
7087
entry = sessionEntry{Unestablished: &unestablishedSession{}}
7188
m.sessions[id] = entry
7289
}
@@ -134,12 +151,21 @@ func (m *sessionManager) AddSession(id sessionID, s *Session) {
134151
m.sessions[id] = sessionEntry{Session: s}
135152

136153
context.AfterFunc(s.Context(), func() {
137-
m.mx.Lock()
138-
defer m.mx.Unlock()
139-
delete(m.sessions, id)
154+
m.deleteSession(id)
140155
})
141156
}
142157

158+
func (m *sessionManager) deleteSession(id sessionID) {
159+
m.mx.Lock()
160+
defer m.mx.Unlock()
161+
162+
delete(m.sessions, id)
163+
m.recentlyClosedSessions = append(m.recentlyClosedSessions, id)
164+
if len(m.recentlyClosedSessions) > maxRecentlyClosedSessions {
165+
m.recentlyClosedSessions = m.recentlyClosedSessions[1:]
166+
}
167+
}
168+
143169
func (m *sessionManager) Close() {
144170
m.mx.Lock()
145171
defer m.mx.Unlock()

session_manager_test.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ package webtransport
33
import (
44
"context"
55
"io"
6+
"maps"
7+
"math/rand/v2"
68
"net"
79
"net/http"
810
"net/http/httptest"
911
"runtime"
12+
"slices"
13+
"strconv"
1014
"strings"
1115
"testing"
1216
"time"
@@ -15,6 +19,7 @@ import (
1519

1620
"github.com/quic-go/quic-go"
1721
"github.com/quic-go/quic-go/http3"
22+
"github.com/quic-go/quic-go/quicvarint"
1823
"github.com/quic-go/quic-go/testutils/simnet"
1924

2025
"github.com/stretchr/testify/require"
@@ -257,3 +262,150 @@ func TestSessionManagerStreamReordering(t *testing.T) {
257262
require.Equal(t, []byte("amet"), data)
258263
})
259264
}
265+
266+
func TestSessionManagerSessionClose(t *testing.T) {
267+
// synctest works slightly differently on Go 1.24,
268+
// so we skip the test
269+
if strings.HasPrefix(runtime.Version(), "go1.24") {
270+
t.Skip("skipping on Go 1.24 due to synctest issues")
271+
}
272+
273+
synctest.Test(t, func(t *testing.T) {
274+
const rtt = 10 * time.Millisecond
275+
clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, rtt)
276+
defer closeFn(t)
277+
clientConn, serverConn := newConnPair(t, clientPacketConn, serverPacketConn)
278+
t.Cleanup(func() {
279+
clientConn.CloseWithError(0, "")
280+
serverConn.CloseWithError(0, "")
281+
})
282+
283+
tr := http3.Transport{}
284+
cc := tr.NewRawClientConn(clientConn)
285+
286+
server := http3.Server{
287+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
288+
w.WriteHeader(http.StatusOK)
289+
w.(http.Flusher).Flush()
290+
<-r.Context().Done()
291+
}),
292+
}
293+
sc, err := server.NewRawServerConn(serverConn)
294+
require.NoError(t, err)
295+
296+
openSession := func() (*http3.RequestStream, error) {
297+
reqStr, err := cc.OpenRequestStream(context.Background())
298+
if err != nil {
299+
return nil, err
300+
}
301+
if err := reqStr.SendRequestHeader(httptest.NewRequest(http.MethodGet, "/", nil)); err != nil {
302+
return nil, err
303+
}
304+
serverReqStr, err := serverConn.AcceptStream(context.Background())
305+
if err != nil {
306+
return nil, err
307+
}
308+
go sc.HandleRequestStream(serverReqStr)
309+
if _, err := reqStr.ReadResponse(); err != nil {
310+
return nil, err
311+
}
312+
return reqStr, nil
313+
}
314+
315+
sessMgr := newSessionManager(time.Hour)
316+
317+
reqStrs := make(map[sessionID]*http3.RequestStream)
318+
for range maxRecentlyClosedSessions {
319+
reqStr, err := openSession()
320+
require.NoError(t, err)
321+
sessID := sessionID(rand.Int64N(quicvarint.Max))
322+
reqStrs[sessID] = reqStr
323+
sess := newSession(context.Background(), sessID, clientConn, reqStr, "")
324+
sessMgr.AddSession(sessID, sess)
325+
synctest.Wait()
326+
}
327+
require.Equal(t, maxRecentlyClosedSessions, sessMgr.NumSessions())
328+
329+
// close a random session
330+
sessID := slices.Collect(maps.Keys(reqStrs))[rand.Int64N(int64(len(reqStrs)))]
331+
reqStrs[sessID].CancelWrite(0)
332+
reqStrs[sessID].CancelRead(0)
333+
synctest.Wait()
334+
require.Equal(t, maxRecentlyClosedSessions-1, sessMgr.NumSessions())
335+
delete(reqStrs, sessID)
336+
337+
// Consume the HTTP/3 control stream that the server opened during setup.
338+
// This is necessary because we're calling AcceptUniStream directly on the QUIC connection,
339+
// which returns ALL unidirectional streams, not just WebTransport streams.
340+
_, err = clientConn.AcceptUniStream(context.Background())
341+
require.NoError(t, err)
342+
343+
// enqueue streams for the remaining sessions
344+
for sessID := range reqStrs {
345+
// bidirectional stream
346+
serverStr, err := serverConn.OpenStream()
347+
require.NoError(t, err)
348+
_, err = serverStr.Write([]byte("stream " + strconv.Itoa(int(sessID))))
349+
require.NoError(t, err)
350+
clientStr, err := clientConn.AcceptStream(context.Background())
351+
require.NoError(t, err)
352+
sessMgr.AddStream(clientStr, sessID)
353+
synctest.Wait()
354+
// make sure the stream is not rejected
355+
select {
356+
case <-clientStr.Context().Done():
357+
require.Fail(t, "stream should not be rejected")
358+
case <-serverStr.Context().Done():
359+
require.Fail(t, "stream should not be rejected")
360+
default:
361+
}
362+
363+
// unidirectional stream
364+
serverUniStr, err := serverConn.OpenUniStream()
365+
require.NoError(t, err)
366+
_, err = serverUniStr.Write([]byte("unistream " + strconv.Itoa(int(sessID))))
367+
require.NoError(t, err)
368+
clientUniStr, err := clientConn.AcceptUniStream(context.Background())
369+
require.NoError(t, err)
370+
sessMgr.AddUniStream(clientUniStr, sessID)
371+
synctest.Wait()
372+
// make sure the stream is not rejected
373+
select {
374+
case <-serverUniStr.Context().Done():
375+
require.Fail(t, "unidirectional stream should not be rejected")
376+
default:
377+
}
378+
}
379+
380+
// test that streams for the closed session are immediately rejected
381+
start := time.Now()
382+
// bidirectional stream
383+
serverStr, err := serverConn.OpenStream()
384+
require.NoError(t, err)
385+
_, err = serverStr.Write([]byte("stream " + strconv.Itoa(int(sessID))))
386+
require.NoError(t, err)
387+
clientStrClosed, err := clientConn.AcceptStream(context.Background())
388+
require.NoError(t, err)
389+
sessMgr.AddStream(clientStrClosed, sessID)
390+
synctest.Wait()
391+
// make sure the stream is immediately rejected
392+
_, err = serverStr.Read([]byte{0})
393+
require.ErrorIs(t, err, &quic.StreamError{Remote: true, StreamID: serverStr.StreamID(), ErrorCode: WTBufferedStreamRejectedErrorCode})
394+
require.Equal(t, rtt, time.Since(start))
395+
396+
// unidirectional stream
397+
start = time.Now()
398+
serverUniStr, err := serverConn.OpenUniStream()
399+
require.NoError(t, err)
400+
_, err = serverUniStr.Write([]byte("unistream " + strconv.Itoa(int(sessID))))
401+
require.NoError(t, err)
402+
synctest.Wait()
403+
clientUniStr, err := clientConn.AcceptUniStream(context.Background())
404+
require.NoError(t, err)
405+
sessMgr.AddUniStream(clientUniStr, sessID)
406+
synctest.Wait()
407+
_, err = clientUniStr.Read([]byte{0})
408+
require.ErrorIs(t, err, &quic.StreamError{Remote: false, StreamID: clientUniStr.StreamID(), ErrorCode: WTBufferedStreamRejectedErrorCode})
409+
require.Equal(t, rtt/2, time.Since(start))
410+
})
411+
}

0 commit comments

Comments
 (0)