Skip to content

Commit b71c83e

Browse files
close client’s QUIC connection on WebTransport session close or failure
1 parent 093f58b commit b71c83e

File tree

2 files changed

+81
-27
lines changed

2 files changed

+81
-27
lines changed

client.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (*
9191
}
9292
reqHdr.Set(wtAvailableProtocolsHeader, protocols)
9393
}
94+
9495
req := &http.Request{
9596
Method: http.MethodConnect,
9697
Header: reqHdr,
@@ -109,17 +110,29 @@ func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (*
109110
return nil, nil, err
110111
}
111112

113+
tr := &http3.Transport{EnableDatagrams: true}
114+
rsp, sess, err := d.handleConn(ctx, tr, qconn, req)
115+
if err != nil {
116+
// TODO: use a more specific error code
117+
// see https://github.com/ietf-wg-webtrans/draft-ietf-webtrans-http3/issues/245
118+
qconn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
119+
tr.Close()
120+
return rsp, nil, err
121+
}
122+
context.AfterFunc(sess.Context(), func() {
123+
qconn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
124+
tr.Close()
125+
})
126+
return rsp, sess, nil
127+
}
128+
129+
func (d *Dialer) handleConn(ctx context.Context, tr *http3.Transport, qconn *quic.Conn, req *http.Request) (*http.Response, *Session, error) {
112130
timeout := d.StreamReorderingTimeout
113131
if timeout == 0 {
114132
timeout = 5 * time.Second
115133
}
116134
sessMgr := newSessionManager(timeout)
117-
118-
context.AfterFunc(qconn.Context(), func() {
119-
sessMgr.Close()
120-
})
121-
122-
tr := &http3.Transport{EnableDatagrams: true}
135+
context.AfterFunc(qconn.Context(), sessMgr.Close)
123136

124137
conn := tr.NewRawClientConn(qconn)
125138

client_test.go

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,38 @@ func appendSettingsFrame(b []byte, values map[uint64]uint64) []byte {
4545
}
4646

4747
func TestClientInvalidResponseHandling(t *testing.T) {
48-
s, err := quic.ListenAddr("localhost:0", webtransport.TLSConf, &quic.Config{EnableDatagrams: true})
48+
ln, err := quic.ListenAddr("localhost:0", webtransport.TLSConf, &quic.Config{EnableDatagrams: true})
4949
require.NoError(t, err)
50-
errChan := make(chan error)
50+
defer ln.Close()
51+
52+
errChan := make(chan error, 1)
5153
go func() {
52-
conn, err := s.Accept(context.Background())
53-
require.NoError(t, err)
54+
conn, err := ln.Accept(context.Background())
55+
if err != nil {
56+
errChan <- err
57+
return
58+
}
5459
// send the SETTINGS frame
5560
settingsStr, err := conn.OpenUniStream()
56-
require.NoError(t, err)
61+
if err != nil {
62+
errChan <- err
63+
return
64+
}
5765
_, err = settingsStr.Write(appendSettingsFrame([]byte{0} /* stream type */, map[uint64]uint64{
5866
settingDatagram: 1,
5967
settingExtendedConnect: 1,
6068
settingsEnableWebtransport: 1,
6169
}))
62-
require.NoError(t, err)
70+
if err != nil {
71+
errChan <- err
72+
return
73+
}
6374

6475
str, err := conn.AcceptStream(context.Background())
65-
require.NoError(t, err)
76+
if err != nil {
77+
errChan <- err
78+
return
79+
}
6680
// write an HTTP/3 data frame. This will cause an error, since a HEADERS frame is expected
6781
var b []byte
6882
b = quicvarint.Append(b, 0x0)
@@ -78,7 +92,7 @@ func TestClientInvalidResponseHandling(t *testing.T) {
7892
}()
7993

8094
d := webtransport.Dialer{TLSClientConfig: &tls.Config{RootCAs: webtransport.CertPool}}
81-
_, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", s.Addr().(*net.UDPAddr).Port), nil)
95+
_, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", ln.Addr().(*net.UDPAddr).Port), nil)
8296
require.Error(t, err)
8397
var sErr error
8498
select {
@@ -114,9 +128,9 @@ func TestClientWaitForSettingsTimeout(t *testing.T) {
114128
errChan <- err
115129
}()
116130

131+
var serverConn *quic.Conn
117132
select {
118-
case conn := <-connChan:
119-
defer conn.CloseWithError(0, "")
133+
case serverConn = <-connChan:
120134
case <-time.After(time.Second):
121135
t.Fatal("timeout waiting for connection")
122136
}
@@ -131,6 +145,17 @@ func TestClientWaitForSettingsTimeout(t *testing.T) {
131145
case <-time.After(time.Second):
132146
t.Fatal("timeout waiting for dial to complete")
133147
}
148+
149+
// the client should close the underlying QUIC connection
150+
select {
151+
case <-serverConn.Context().Done():
152+
require.ErrorIs(t,
153+
context.Cause(serverConn.Context()),
154+
&quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(http3.ErrCodeNoError), Remote: true},
155+
)
156+
case <-time.After(time.Second):
157+
t.Fatal("timeout waiting for client to close connection")
158+
}
134159
}
135160

136161
func TestClientInvalidSettingsHandling(t *testing.T) {
@@ -172,32 +197,48 @@ func TestClientInvalidSettingsHandling(t *testing.T) {
172197
require.NoError(t, err)
173198
defer ln.Close()
174199

175-
done := make(chan struct{})
176-
ctx, cancel := context.WithCancel(context.Background())
200+
connChan := make(chan *quic.Conn, 1)
177201
go func() {
178-
defer close(done)
179202
conn, err := ln.Accept(context.Background())
180-
require.NoError(t, err)
203+
if err != nil {
204+
t.Errorf("failed to accept connection: %v", err)
205+
return
206+
}
181207
// send the SETTINGS frame
182208
settingsStr, err := conn.OpenUniStream()
183-
require.NoError(t, err)
184-
_, err = settingsStr.Write(appendSettingsFrame([]byte{0} /* stream type */, tc.settings))
185-
require.NoError(t, err)
186-
if _, err := conn.AcceptStream(ctx); err == nil || !errors.Is(err, context.Canceled) {
187-
require.Fail(t, "didn't expect any stream to be accepted")
209+
if err != nil {
210+
t.Errorf("failed to open uni stream: %v", err)
211+
return
188212
}
213+
if _, err := settingsStr.Write(appendSettingsFrame([]byte{0} /* stream type */, tc.settings)); err != nil {
214+
t.Errorf("failed to write settings frame: %v", err)
215+
return
216+
}
217+
connChan <- conn
189218
}()
190219

191220
d := webtransport.Dialer{TLSClientConfig: &tls.Config{RootCAs: webtransport.CertPool}}
192221
_, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", ln.Addr().(*net.UDPAddr).Port), nil)
193222
require.Error(t, err)
194223
require.ErrorContains(t, err, tc.errorStr)
195-
cancel()
224+
225+
var serverConn *quic.Conn
196226
select {
197-
case <-done:
227+
case serverConn = <-connChan:
198228
case <-time.After(5 * time.Second):
199229
t.Fatal("timeout")
200230
}
231+
232+
// the client should close the underlying QUIC connection
233+
select {
234+
case <-serverConn.Context().Done():
235+
require.ErrorIs(t,
236+
context.Cause(serverConn.Context()),
237+
&quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(http3.ErrCodeNoError), Remote: true},
238+
)
239+
case <-time.After(time.Second):
240+
t.Fatal("timeout waiting for client to close connection")
241+
}
201242
})
202243
}
203244
}

0 commit comments

Comments
 (0)