Skip to content

Commit 8765800

Browse files
add proper tests for the session manager
1 parent d4de9bf commit 8765800

File tree

9 files changed

+228
-101
lines changed

9 files changed

+228
-101
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ jobs:
1111
go: [ "1.24.x", "1.25.x" ]
1212
runs-on: ubuntu-latest
1313
name: Lint
14+
env:
15+
GOEXPERIMENT: ${{ matrix.go == '1.24.x' && 'synctest' || '' }}
1416
steps:
1517
- uses: actions/checkout@v6
1618
- uses: actions/setup-go@v6

.github/workflows/unit.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ jobs:
1212
go: [ "1.24.x", "1.25.x" ]
1313
runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }}
1414
name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }})
15+
env:
16+
GOEXPERIMENT: ${{ matrix.go == '1.24.x' && 'synctest' || '' }}
1517
steps:
1618
- uses: actions/checkout@v6
1719
- uses: actions/setup-go@v6
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//go:build go1.24 && !go1.25
2+
3+
package synctest
4+
5+
import (
6+
"testing"
7+
"testing/synctest"
8+
)
9+
10+
func Test(t *testing.T, f func(t *testing.T)) {
11+
synctest.Run(func() {
12+
f(t)
13+
})
14+
}
15+
16+
func Wait() {
17+
//nolint:govet // the CI configuration sets the GOEXPERIMENT=synctest flag
18+
synctest.Wait()
19+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//go:build go1.25
2+
3+
package synctest
4+
5+
import (
6+
"testing"
7+
"testing/synctest"
8+
)
9+
10+
func Test(t *testing.T, f func(t *testing.T)) {
11+
synctest.Test(t, f)
12+
}
13+
14+
func Wait() {
15+
synctest.Wait()
16+
}

server_test.go

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -188,88 +188,6 @@ func TestServerReorderedUpgradeRequestTimeout(t *testing.T) {
188188
require.Equal(t, []byte("raboof"), data)
189189
}
190190

191-
func TestServerReorderedMultipleStreams(t *testing.T) {
192-
timeout := scaleDuration(150 * time.Millisecond)
193-
s := webtransport.Server{
194-
H3: &http3.Server{TLSConfig: webtransport.TLSConf, EnableDatagrams: true},
195-
ReorderingTimeout: timeout,
196-
}
197-
defer s.Close()
198-
connChan := make(chan *webtransport.Session)
199-
addHandler(t, &s, func(c *webtransport.Session) {
200-
connChan <- c
201-
})
202-
203-
udpConn, err := net.ListenUDP("udp", nil)
204-
require.NoError(t, err)
205-
port := udpConn.LocalAddr().(*net.UDPAddr).Port
206-
webtransport.ConfigureHTTP3Server(s.H3)
207-
go s.Serve(udpConn)
208-
209-
cconn, err := quic.DialAddr(
210-
context.Background(),
211-
fmt.Sprintf("localhost:%d", port),
212-
&tls.Config{RootCAs: webtransport.CertPool, NextProtos: []string{http3.NextProtoH3}},
213-
&quic.Config{EnableDatagrams: true},
214-
)
215-
require.NoError(t, err)
216-
start := time.Now()
217-
// Open a new stream for a WebTransport session we'll establish later.
218-
str1 := createStreamAndWrite(t, cconn, 8, []byte("foobar"))
219-
220-
// After a while, open another stream for the same session.
221-
// This resets the timer, so the timeout will be timeout after this point.
222-
time.Sleep(timeout / 2)
223-
str2 := createStreamAndWrite(t, cconn, 8, []byte("raboof"))
224-
225-
// Wait for timeout after the second stream was added.
226-
// The timer was reset when str2 was added, so both streams should be reset
227-
// timeout after str2 was added, which is timeout/2 + timeout = 1.5*timeout from start.
228-
time.Sleep(timeout + timeout/4)
229-
230-
// Both streams should now have been reset by the server.
231-
_, err = str1.Read([]byte{0})
232-
var streamErr *quic.StreamError
233-
require.ErrorAs(t, err, &streamErr)
234-
require.Equal(t, webtransport.WTBufferedStreamRejectedErrorCode, streamErr.ErrorCode)
235-
236-
_, err = str2.Read([]byte{0})
237-
require.ErrorAs(t, err, &streamErr)
238-
require.Equal(t, webtransport.WTBufferedStreamRejectedErrorCode, streamErr.ErrorCode)
239-
240-
took := time.Since(start)
241-
require.GreaterOrEqual(t, took, timeout*3/2)
242-
require.Less(t, took, timeout*2)
243-
244-
tr := &http3.Transport{EnableDatagrams: true}
245-
conn := tr.NewClientConn(cconn)
246-
// Now establish the session. Make sure we don't accept the streams (they were reset).
247-
requestStr, err := conn.OpenRequestStream(context.Background())
248-
require.NoError(t, err)
249-
require.NoError(t, requestStr.SendRequestHeader(
250-
webtransport.NewWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)),
251-
))
252-
rsp, err := requestStr.ReadResponse()
253-
require.NoError(t, err)
254-
require.Equal(t, http.StatusOK, rsp.StatusCode)
255-
sconn := <-connChan
256-
defer sconn.CloseWithError(0, "")
257-
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
258-
defer cancel()
259-
_, err = sconn.AcceptStream(ctx)
260-
require.ErrorIs(t, err, context.DeadlineExceeded)
261-
262-
// Establish another stream and make sure it's accepted now.
263-
createStreamAndWrite(t, cconn, 8, []byte("baz"))
264-
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
265-
defer cancel()
266-
sstr, err := sconn.AcceptStream(ctx)
267-
require.NoError(t, err)
268-
data, err := io.ReadAll(sstr)
269-
require.NoError(t, err)
270-
require.Equal(t, []byte("baz"), data)
271-
}
272-
273191
func TestServerSettingsCheck(t *testing.T) {
274192
timeout := scaleDuration(150 * time.Millisecond)
275193
s := webtransport.Server{

session_manager_test.go

Lines changed: 170 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,46 @@ package webtransport
33
import (
44
"context"
55
"io"
6+
"net"
67
"net/http"
78
"net/http/httptest"
89
"testing"
910
"time"
1011

12+
"github.com/quic-go/webtransport-go/internal/synctest"
13+
14+
"github.com/quic-go/quic-go"
1115
"github.com/quic-go/quic-go/http3"
16+
"github.com/quic-go/quic-go/testutils/simnet"
1217

1318
"github.com/stretchr/testify/require"
1419
)
1520

16-
func TestSessionManager(t *testing.T) {
17-
clientConn, serverConn := newConnPair(t)
21+
func (m *sessionManager) NumSessions() int {
22+
m.mx.Lock()
23+
defer m.mx.Unlock()
24+
return len(m.sessions)
25+
}
26+
27+
func newSimnetLink(t *testing.T, rtt time.Duration) (client, server *simnet.SimConn, close func(t *testing.T)) {
28+
t.Helper()
29+
30+
n := &simnet.Simnet{Router: &simnet.PerfectRouter{}}
31+
settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2}
32+
clientPacketConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}, settings)
33+
serverPacketConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}, settings)
34+
35+
require.NoError(t, n.Start())
36+
37+
return clientPacketConn, serverPacketConn, func(t *testing.T) {
38+
require.NoError(t, clientPacketConn.Close())
39+
require.NoError(t, serverPacketConn.Close())
40+
require.NoError(t, n.Close())
41+
}
42+
}
43+
44+
func TestSessionManagerAddingStreams(t *testing.T) {
45+
clientConn, serverConn := newConnPair(t, newUDPConnLocalhost(t), newUDPConnLocalhost(t))
1846
t.Cleanup(func() {
1947
clientConn.CloseWithError(0, "")
2048
serverConn.CloseWithError(0, "")
@@ -44,18 +72,34 @@ func TestSessionManager(t *testing.T) {
4472
reqStr, err := cc.OpenRequestStream(ctx)
4573
require.NoError(t, err)
4674
require.NoError(t, reqStr.SendRequestHeader(httptest.NewRequest(http.MethodGet, "/", nil)))
47-
// TODO: somehow send an HTTP response so we can call ReadResponse()
48-
reqStr.ReadResponse()
75+
server := http3.Server{
76+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77+
w.WriteHeader(http.StatusOK)
78+
w.(http.Flusher).Flush()
79+
<-r.Context().Done()
80+
}),
81+
}
82+
sc, err := server.NewRawServerConn(serverConn)
83+
require.NoError(t, err)
84+
serverReqStr, err := serverConn.AcceptStream(ctx)
85+
require.NoError(t, err)
86+
go sc.HandleRequestStream(serverReqStr)
87+
88+
_, err = reqStr.ReadResponse()
89+
require.NoError(t, err)
4990

5091
const sessionID = 42
5192

5293
sessMgr := newSessionManager(time.Hour)
94+
require.Zero(t, sessMgr.NumSessions())
5395
// first add the streams...
96+
sess := newSession(context.Background(), sessionID, clientConn, reqStr, "")
97+
// ...then add the session
5498
sessMgr.AddStream(clientStr, sessionID)
99+
require.Equal(t, 1, sessMgr.NumSessions())
55100
sessMgr.AddUniStream(clientUniStr, sessionID)
56-
// ...then add the session
57-
sess := newSession(context.Background(), sessionID, clientConn, reqStr, "")
58101
sessMgr.AddSession(sessionID, sess)
102+
require.Equal(t, 1, sessMgr.NumSessions())
59103

60104
// the streams should now be returned from the session
61105
sessStr, err := sess.AcceptStream(ctx)
@@ -70,3 +114,123 @@ func TestSessionManager(t *testing.T) {
70114
require.NoError(t, err)
71115
require.Equal(t, []byte("world"), data)
72116
}
117+
118+
func TestSessionManagerStreamReordering(t *testing.T) {
119+
synctest.Test(t, func(t *testing.T) {
120+
clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, 10*time.Millisecond)
121+
defer closeFn(t)
122+
clientConn, serverConn := newConnPair(t, clientPacketConn, serverPacketConn)
123+
t.Cleanup(func() {
124+
clientConn.CloseWithError(0, "")
125+
serverConn.CloseWithError(0, "")
126+
})
127+
128+
serverStr1, err := serverConn.OpenStream()
129+
require.NoError(t, err)
130+
_, err = serverStr1.Write([]byte("lorem"))
131+
require.NoError(t, err)
132+
require.NoError(t, serverStr1.Close())
133+
134+
serverUniStr1, err := serverConn.OpenUniStream()
135+
require.NoError(t, err)
136+
_, err = serverUniStr1.Write([]byte("ipsum"))
137+
require.NoError(t, err)
138+
require.NoError(t, serverUniStr1.Close())
139+
140+
serverStr2, err := serverConn.OpenStream()
141+
require.NoError(t, err)
142+
_, err = serverStr2.Write([]byte("dolor"))
143+
require.NoError(t, err)
144+
require.NoError(t, serverStr2.Close())
145+
146+
serverUniStr2, err := serverConn.OpenUniStream()
147+
require.NoError(t, err)
148+
_, err = serverUniStr2.Write([]byte("sit"))
149+
require.NoError(t, err)
150+
require.NoError(t, serverUniStr2.Close())
151+
152+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
153+
defer cancel()
154+
clientStr1, err := clientConn.AcceptStream(ctx)
155+
require.NoError(t, err)
156+
clientUniStr1, err := clientConn.AcceptUniStream(ctx)
157+
require.NoError(t, err)
158+
clientStr2, err := clientConn.AcceptStream(ctx)
159+
require.NoError(t, err)
160+
clientUniStr2, err := clientConn.AcceptUniStream(ctx)
161+
require.NoError(t, err)
162+
163+
tr := http3.Transport{}
164+
cc := tr.NewRawClientConn(clientConn)
165+
reqStr, err := cc.OpenRequestStream(ctx)
166+
require.NoError(t, err)
167+
require.NoError(t, reqStr.SendRequestHeader(httptest.NewRequest(http.MethodGet, "/", nil)))
168+
server := http3.Server{
169+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
170+
w.WriteHeader(http.StatusOK)
171+
w.(http.Flusher).Flush()
172+
<-r.Context().Done()
173+
}),
174+
}
175+
sc, err := server.NewRawServerConn(serverConn)
176+
require.NoError(t, err)
177+
serverReqStr, err := serverConn.AcceptStream(ctx)
178+
require.NoError(t, err)
179+
go sc.HandleRequestStream(serverReqStr)
180+
181+
_, err = reqStr.ReadResponse()
182+
require.NoError(t, err)
183+
184+
const sessionID = 42
185+
const timeout = 3 * time.Second
186+
187+
sessMgr := newSessionManager(timeout)
188+
require.Zero(t, sessMgr.NumSessions())
189+
// add the first stream
190+
sessMgr.AddStream(clientStr1, sessionID)
191+
require.Equal(t, 1, sessMgr.NumSessions())
192+
time.Sleep(timeout + time.Second)
193+
// the stream should have been reset and the session manager should have no sessions
194+
require.Zero(t, sessMgr.NumSessions())
195+
_, err = serverStr1.Read([]byte{0})
196+
require.ErrorIs(t, err, &quic.StreamError{Remote: true, StreamID: serverStr1.StreamID(), ErrorCode: WTBufferedStreamRejectedErrorCode})
197+
198+
sessMgr.AddUniStream(clientUniStr1, sessionID)
199+
require.Equal(t, 1, sessMgr.NumSessions())
200+
time.Sleep(timeout - time.Second)
201+
// adding another stream resets the timer
202+
sessMgr.AddStream(clientStr2, sessionID)
203+
time.Sleep(timeout - time.Second)
204+
require.Equal(t, 1, sessMgr.NumSessions())
205+
206+
// now add the session
207+
sess := newSession(context.Background(), sessionID, clientConn, reqStr, "")
208+
sessMgr.AddSession(sessionID, sess)
209+
require.Equal(t, 1, sessMgr.NumSessions())
210+
211+
// wait for a long time, then add another stream
212+
time.Sleep(timeout + time.Second)
213+
require.Equal(t, 1, sessMgr.NumSessions())
214+
sessMgr.AddUniStream(clientUniStr2, sessionID)
215+
time.Sleep(timeout + time.Second)
216+
217+
// the "lorem" stream should have been reset and the "dolor" stream should have been returned
218+
sessStr, err := sess.AcceptStream(ctx)
219+
require.NoError(t, err)
220+
data, err := io.ReadAll(sessStr)
221+
require.NoError(t, err)
222+
require.Equal(t, []byte("dolor"), data)
223+
224+
sessUniStr1, err := sess.AcceptUniStream(ctx)
225+
require.NoError(t, err)
226+
data, err = io.ReadAll(sessUniStr1)
227+
require.NoError(t, err)
228+
require.Equal(t, []byte("ipsum"), data)
229+
230+
sessUniStr2, err := sess.AcceptUniStream(ctx)
231+
require.NoError(t, err)
232+
data, err = io.ReadAll(sessUniStr2)
233+
require.NoError(t, err)
234+
require.Equal(t, []byte("sit"), data)
235+
})
236+
}

0 commit comments

Comments
 (0)