@@ -3,10 +3,14 @@ package webtransport
33import (
44 "context"
55 "io"
6+ "maps"
7+ "math/rand/v2"
68 "net"
79 "net/http"
810 "net/http/httptest"
911 "runtime"
12+ "slices"
13+ "strconv"
1014 "strings"
1115 "testing"
1216 "time"
@@ -15,6 +19,7 @@ import (
1519
1620 "github.com/quic-go/quic-go"
1721 "github.com/quic-go/quic-go/http3"
22+ "github.com/quic-go/quic-go/quicvarint"
1823 "github.com/quic-go/quic-go/testutils/simnet"
1924
2025 "github.com/stretchr/testify/require"
@@ -257,3 +262,150 @@ func TestSessionManagerStreamReordering(t *testing.T) {
257262 require .Equal (t , []byte ("amet" ), data )
258263 })
259264}
265+
266+ func TestSessionManagerSessionClose (t * testing.T ) {
267+ // synctest works slightly differently on Go 1.24,
268+ // so we skip the test
269+ if strings .HasPrefix (runtime .Version (), "go1.24" ) {
270+ t .Skip ("skipping on Go 1.24 due to synctest issues" )
271+ }
272+
273+ synctest .Test (t , func (t * testing.T ) {
274+ const rtt = 10 * time .Millisecond
275+ clientPacketConn , serverPacketConn , closeFn := newSimnetLink (t , rtt )
276+ defer closeFn (t )
277+ clientConn , serverConn := newConnPair (t , clientPacketConn , serverPacketConn )
278+ t .Cleanup (func () {
279+ clientConn .CloseWithError (0 , "" )
280+ serverConn .CloseWithError (0 , "" )
281+ })
282+
283+ tr := http3.Transport {}
284+ cc := tr .NewRawClientConn (clientConn )
285+
286+ server := http3.Server {
287+ Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
288+ w .WriteHeader (http .StatusOK )
289+ w .(http.Flusher ).Flush ()
290+ <- r .Context ().Done ()
291+ }),
292+ }
293+ sc , err := server .NewRawServerConn (serverConn )
294+ require .NoError (t , err )
295+
296+ openSession := func () (* http3.RequestStream , error ) {
297+ reqStr , err := cc .OpenRequestStream (context .Background ())
298+ if err != nil {
299+ return nil , err
300+ }
301+ if err := reqStr .SendRequestHeader (httptest .NewRequest (http .MethodGet , "/" , nil )); err != nil {
302+ return nil , err
303+ }
304+ serverReqStr , err := serverConn .AcceptStream (context .Background ())
305+ if err != nil {
306+ return nil , err
307+ }
308+ go sc .HandleRequestStream (serverReqStr )
309+ if _ , err := reqStr .ReadResponse (); err != nil {
310+ return nil , err
311+ }
312+ return reqStr , nil
313+ }
314+
315+ sessMgr := newSessionManager (time .Hour )
316+
317+ reqStrs := make (map [sessionID ]* http3.RequestStream )
318+ for range maxRecentlyClosedSessions {
319+ reqStr , err := openSession ()
320+ require .NoError (t , err )
321+ sessID := sessionID (rand .Int64N (quicvarint .Max ))
322+ reqStrs [sessID ] = reqStr
323+ sess := newSession (context .Background (), sessID , clientConn , reqStr , "" )
324+ sessMgr .AddSession (sessID , sess )
325+ synctest .Wait ()
326+ }
327+ require .Equal (t , maxRecentlyClosedSessions , sessMgr .NumSessions ())
328+
329+ // close a random session
330+ sessID := slices .Collect (maps .Keys (reqStrs ))[rand .Int64N (int64 (len (reqStrs )))]
331+ reqStrs [sessID ].CancelWrite (0 )
332+ reqStrs [sessID ].CancelRead (0 )
333+ synctest .Wait ()
334+ require .Equal (t , maxRecentlyClosedSessions - 1 , sessMgr .NumSessions ())
335+ delete (reqStrs , sessID )
336+
337+ // Consume the HTTP/3 control stream that the server opened during setup.
338+ // This is necessary because we're calling AcceptUniStream directly on the QUIC connection,
339+ // which returns ALL unidirectional streams, not just WebTransport streams.
340+ _ , err = clientConn .AcceptUniStream (context .Background ())
341+ require .NoError (t , err )
342+
343+ // enqueue streams for the remaining sessions
344+ for sessID := range reqStrs {
345+ // bidirectional stream
346+ serverStr , err := serverConn .OpenStream ()
347+ require .NoError (t , err )
348+ _ , err = serverStr .Write ([]byte ("stream " + strconv .Itoa (int (sessID ))))
349+ require .NoError (t , err )
350+ clientStr , err := clientConn .AcceptStream (context .Background ())
351+ require .NoError (t , err )
352+ sessMgr .AddStream (clientStr , sessID )
353+ synctest .Wait ()
354+ // make sure the stream is not rejected
355+ select {
356+ case <- clientStr .Context ().Done ():
357+ require .Fail (t , "stream should not be rejected" )
358+ case <- serverStr .Context ().Done ():
359+ require .Fail (t , "stream should not be rejected" )
360+ default :
361+ }
362+
363+ // unidirectional stream
364+ serverUniStr , err := serverConn .OpenUniStream ()
365+ require .NoError (t , err )
366+ _ , err = serverUniStr .Write ([]byte ("unistream " + strconv .Itoa (int (sessID ))))
367+ require .NoError (t , err )
368+ clientUniStr , err := clientConn .AcceptUniStream (context .Background ())
369+ require .NoError (t , err )
370+ sessMgr .AddUniStream (clientUniStr , sessID )
371+ synctest .Wait ()
372+ // make sure the stream is not rejected
373+ select {
374+ case <- serverUniStr .Context ().Done ():
375+ require .Fail (t , "unidirectional stream should not be rejected" )
376+ default :
377+ }
378+ }
379+
380+ // test that streams for the closed session are immediately rejected
381+ start := time .Now ()
382+ // bidirectional stream
383+ serverStr , err := serverConn .OpenStream ()
384+ require .NoError (t , err )
385+ _ , err = serverStr .Write ([]byte ("stream " + strconv .Itoa (int (sessID ))))
386+ require .NoError (t , err )
387+ clientStrClosed , err := clientConn .AcceptStream (context .Background ())
388+ require .NoError (t , err )
389+ sessMgr .AddStream (clientStrClosed , sessID )
390+ synctest .Wait ()
391+ // make sure the stream is immediately rejected
392+ _ , err = serverStr .Read ([]byte {0 })
393+ require .ErrorIs (t , err , & quic.StreamError {Remote : true , StreamID : serverStr .StreamID (), ErrorCode : WTBufferedStreamRejectedErrorCode })
394+ require .Equal (t , rtt , time .Since (start ))
395+
396+ // unidirectional stream
397+ start = time .Now ()
398+ serverUniStr , err := serverConn .OpenUniStream ()
399+ require .NoError (t , err )
400+ _ , err = serverUniStr .Write ([]byte ("unistream " + strconv .Itoa (int (sessID ))))
401+ require .NoError (t , err )
402+ synctest .Wait ()
403+ clientUniStr , err := clientConn .AcceptUniStream (context .Background ())
404+ require .NoError (t , err )
405+ sessMgr .AddUniStream (clientUniStr , sessID )
406+ synctest .Wait ()
407+ _ , err = clientUniStr .Read ([]byte {0 })
408+ require .ErrorIs (t , err , & quic.StreamError {Remote : false , StreamID : clientUniStr .StreamID (), ErrorCode : WTBufferedStreamRejectedErrorCode })
409+ require .Equal (t , rtt / 2 , time .Since (start ))
410+ })
411+ }
0 commit comments