@@ -3,18 +3,46 @@ package webtransport
33import (
44 "context"
55 "io"
6+ "net"
67 "net/http"
78 "net/http/httptest"
89 "testing"
910 "time"
1011
12+ "github.com/quic-go/webtransport-go/internal/synctest"
13+
14+ "github.com/quic-go/quic-go"
1115 "github.com/quic-go/quic-go/http3"
16+ "github.com/quic-go/quic-go/testutils/simnet"
1217
1318 "github.com/stretchr/testify/require"
1419)
1520
16- func TestSessionManager (t * testing.T ) {
17- clientConn , serverConn := newConnPair (t )
21+ func (m * sessionManager ) NumSessions () int {
22+ m .mx .Lock ()
23+ defer m .mx .Unlock ()
24+ return len (m .sessions )
25+ }
26+
27+ func newSimnetLink (t * testing.T , rtt time.Duration ) (client , server * simnet.SimConn , close func (t * testing.T )) {
28+ t .Helper ()
29+
30+ n := & simnet.Simnet {Router : & simnet.PerfectRouter {}}
31+ settings := simnet.NodeBiDiLinkSettings {Latency : rtt / 2 }
32+ clientPacketConn := n .NewEndpoint (& net.UDPAddr {IP : net .ParseIP ("1.0.0.1" ), Port : 9001 }, settings )
33+ serverPacketConn := n .NewEndpoint (& net.UDPAddr {IP : net .ParseIP ("1.0.0.2" ), Port : 9002 }, settings )
34+
35+ require .NoError (t , n .Start ())
36+
37+ return clientPacketConn , serverPacketConn , func (t * testing.T ) {
38+ require .NoError (t , clientPacketConn .Close ())
39+ require .NoError (t , serverPacketConn .Close ())
40+ require .NoError (t , n .Close ())
41+ }
42+ }
43+
44+ func TestSessionManagerAddingStreams (t * testing.T ) {
45+ clientConn , serverConn := newConnPair (t , newUDPConnLocalhost (t ), newUDPConnLocalhost (t ))
1846 t .Cleanup (func () {
1947 clientConn .CloseWithError (0 , "" )
2048 serverConn .CloseWithError (0 , "" )
@@ -44,18 +72,34 @@ func TestSessionManager(t *testing.T) {
4472 reqStr , err := cc .OpenRequestStream (ctx )
4573 require .NoError (t , err )
4674 require .NoError (t , reqStr .SendRequestHeader (httptest .NewRequest (http .MethodGet , "/" , nil )))
47- // TODO: somehow send an HTTP response so we can call ReadResponse()
48- reqStr .ReadResponse ()
75+ server := http3.Server {
76+ Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
77+ w .WriteHeader (http .StatusOK )
78+ w .(http.Flusher ).Flush ()
79+ <- r .Context ().Done ()
80+ }),
81+ }
82+ sc , err := server .NewRawServerConn (serverConn )
83+ require .NoError (t , err )
84+ serverReqStr , err := serverConn .AcceptStream (ctx )
85+ require .NoError (t , err )
86+ go sc .HandleRequestStream (serverReqStr )
87+
88+ _ , err = reqStr .ReadResponse ()
89+ require .NoError (t , err )
4990
5091 const sessionID = 42
5192
5293 sessMgr := newSessionManager (time .Hour )
94+ require .Zero (t , sessMgr .NumSessions ())
5395 // first add the streams...
96+ sess := newSession (context .Background (), sessionID , clientConn , reqStr , "" )
97+ // ...then add the session
5498 sessMgr .AddStream (clientStr , sessionID )
99+ require .Equal (t , 1 , sessMgr .NumSessions ())
55100 sessMgr .AddUniStream (clientUniStr , sessionID )
56- // ...then add the session
57- sess := newSession (context .Background (), sessionID , clientConn , reqStr , "" )
58101 sessMgr .AddSession (sessionID , sess )
102+ require .Equal (t , 1 , sessMgr .NumSessions ())
59103
60104 // the streams should now be returned from the session
61105 sessStr , err := sess .AcceptStream (ctx )
@@ -70,3 +114,123 @@ func TestSessionManager(t *testing.T) {
70114 require .NoError (t , err )
71115 require .Equal (t , []byte ("world" ), data )
72116}
117+
118+ func TestSessionManagerStreamReordering (t * testing.T ) {
119+ synctest .Test (t , func (t * testing.T ) {
120+ clientPacketConn , serverPacketConn , closeFn := newSimnetLink (t , 10 * time .Millisecond )
121+ defer closeFn (t )
122+ clientConn , serverConn := newConnPair (t , clientPacketConn , serverPacketConn )
123+ t .Cleanup (func () {
124+ clientConn .CloseWithError (0 , "" )
125+ serverConn .CloseWithError (0 , "" )
126+ })
127+
128+ serverStr1 , err := serverConn .OpenStream ()
129+ require .NoError (t , err )
130+ _ , err = serverStr1 .Write ([]byte ("lorem" ))
131+ require .NoError (t , err )
132+ require .NoError (t , serverStr1 .Close ())
133+
134+ serverUniStr1 , err := serverConn .OpenUniStream ()
135+ require .NoError (t , err )
136+ _ , err = serverUniStr1 .Write ([]byte ("ipsum" ))
137+ require .NoError (t , err )
138+ require .NoError (t , serverUniStr1 .Close ())
139+
140+ serverStr2 , err := serverConn .OpenStream ()
141+ require .NoError (t , err )
142+ _ , err = serverStr2 .Write ([]byte ("dolor" ))
143+ require .NoError (t , err )
144+ require .NoError (t , serverStr2 .Close ())
145+
146+ serverUniStr2 , err := serverConn .OpenUniStream ()
147+ require .NoError (t , err )
148+ _ , err = serverUniStr2 .Write ([]byte ("sit" ))
149+ require .NoError (t , err )
150+ require .NoError (t , serverUniStr2 .Close ())
151+
152+ ctx , cancel := context .WithTimeout (context .Background (), time .Second )
153+ defer cancel ()
154+ clientStr1 , err := clientConn .AcceptStream (ctx )
155+ require .NoError (t , err )
156+ clientUniStr1 , err := clientConn .AcceptUniStream (ctx )
157+ require .NoError (t , err )
158+ clientStr2 , err := clientConn .AcceptStream (ctx )
159+ require .NoError (t , err )
160+ clientUniStr2 , err := clientConn .AcceptUniStream (ctx )
161+ require .NoError (t , err )
162+
163+ tr := http3.Transport {}
164+ cc := tr .NewRawClientConn (clientConn )
165+ reqStr , err := cc .OpenRequestStream (ctx )
166+ require .NoError (t , err )
167+ require .NoError (t , reqStr .SendRequestHeader (httptest .NewRequest (http .MethodGet , "/" , nil )))
168+ server := http3.Server {
169+ Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
170+ w .WriteHeader (http .StatusOK )
171+ w .(http.Flusher ).Flush ()
172+ <- r .Context ().Done ()
173+ }),
174+ }
175+ sc , err := server .NewRawServerConn (serverConn )
176+ require .NoError (t , err )
177+ serverReqStr , err := serverConn .AcceptStream (ctx )
178+ require .NoError (t , err )
179+ go sc .HandleRequestStream (serverReqStr )
180+
181+ _ , err = reqStr .ReadResponse ()
182+ require .NoError (t , err )
183+
184+ const sessionID = 42
185+ const timeout = 3 * time .Second
186+
187+ sessMgr := newSessionManager (timeout )
188+ require .Zero (t , sessMgr .NumSessions ())
189+ // add the first stream
190+ sessMgr .AddStream (clientStr1 , sessionID )
191+ require .Equal (t , 1 , sessMgr .NumSessions ())
192+ time .Sleep (timeout + time .Second )
193+ // the stream should have been reset and the session manager should have no sessions
194+ require .Zero (t , sessMgr .NumSessions ())
195+ _ , err = serverStr1 .Read ([]byte {0 })
196+ require .ErrorIs (t , err , & quic.StreamError {Remote : true , StreamID : serverStr1 .StreamID (), ErrorCode : WTBufferedStreamRejectedErrorCode })
197+
198+ sessMgr .AddUniStream (clientUniStr1 , sessionID )
199+ require .Equal (t , 1 , sessMgr .NumSessions ())
200+ time .Sleep (timeout - time .Second )
201+ // adding another stream resets the timer
202+ sessMgr .AddStream (clientStr2 , sessionID )
203+ time .Sleep (timeout - time .Second )
204+ require .Equal (t , 1 , sessMgr .NumSessions ())
205+
206+ // now add the session
207+ sess := newSession (context .Background (), sessionID , clientConn , reqStr , "" )
208+ sessMgr .AddSession (sessionID , sess )
209+ require .Equal (t , 1 , sessMgr .NumSessions ())
210+
211+ // wait for a long time, then add another stream
212+ time .Sleep (timeout + time .Second )
213+ require .Equal (t , 1 , sessMgr .NumSessions ())
214+ sessMgr .AddUniStream (clientUniStr2 , sessionID )
215+ time .Sleep (timeout + time .Second )
216+
217+ // the "lorem" stream should have been reset and the "dolor" stream should have been returned
218+ sessStr , err := sess .AcceptStream (ctx )
219+ require .NoError (t , err )
220+ data , err := io .ReadAll (sessStr )
221+ require .NoError (t , err )
222+ require .Equal (t , []byte ("dolor" ), data )
223+
224+ sessUniStr1 , err := sess .AcceptUniStream (ctx )
225+ require .NoError (t , err )
226+ data , err = io .ReadAll (sessUniStr1 )
227+ require .NoError (t , err )
228+ require .Equal (t , []byte ("ipsum" ), data )
229+
230+ sessUniStr2 , err := sess .AcceptUniStream (ctx )
231+ require .NoError (t , err )
232+ data , err = io .ReadAll (sessUniStr2 )
233+ require .NoError (t , err )
234+ require .Equal (t , []byte ("sit" ), data )
235+ })
236+ }
0 commit comments