Skip to content

Commit 9d448b1

Browse files
enable and use the QUIC Stream Resets with Partial Delivery extension (#239)
* enable and use the QUIC Stream Resets with Partial Delivery extension * fix stream header sending for concurrent Close / CancelWrite calls
1 parent 128538a commit 9d448b1

File tree

6 files changed

+195
-71
lines changed

6 files changed

+195
-71
lines changed

client.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,17 @@ func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (*
5858

5959
quicConf := d.QUICConfig
6060
if quicConf == nil {
61-
quicConf = &quic.Config{EnableDatagrams: true}
62-
} else if !d.QUICConfig.EnableDatagrams {
63-
return nil, nil, errors.New("webtransport: DATAGRAM support required, enable it via QUICConfig.EnableDatagrams")
61+
quicConf = &quic.Config{
62+
EnableDatagrams: true,
63+
EnableStreamResetPartialDelivery: true,
64+
}
65+
} else {
66+
if !d.QUICConfig.EnableDatagrams {
67+
return nil, nil, errors.New("webtransport: DATAGRAM support required, enable it via QUICConfig.EnableDatagrams")
68+
}
69+
if !d.QUICConfig.EnableStreamResetPartialDelivery {
70+
return nil, nil, errors.New("webtransport: stream reset partial delivery required, enable it via QUICConfig.EnableStreamResetPartialDelivery")
71+
}
6472
}
6573

6674
tlsConf := d.TLSClientConfig

server.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ func (s *Server) Serve(conn net.PacketConn) error {
112112
if err := s.initialize(); err != nil {
113113
return err
114114
}
115-
quicConf := s.H3.QUICConfig
116-
if quicConf == nil {
115+
var quicConf *quic.Config
116+
if s.H3.QUICConfig != nil {
117+
quicConf = s.H3.QUICConfig.Clone()
118+
} else {
117119
quicConf = &quic.Config{}
118120
}
119-
quicConf = quicConf.Clone()
120121
quicConf.EnableDatagrams = true
122+
quicConf.EnableStreamResetPartialDelivery = true
121123
ln, err := quic.ListenEarly(conn, s.H3.TLSConfig, quicConf)
122124
if err != nil {
123125
return err
@@ -142,6 +144,13 @@ func (s *Server) Serve(conn net.PacketConn) error {
142144

143145
// ServeQUICConn serves a single QUIC connection.
144146
func (s *Server) ServeQUICConn(conn *quic.Conn) error {
147+
connState := conn.ConnectionState()
148+
if !connState.SupportsDatagrams.Local {
149+
return errors.New("webtransport: QUIC DATAGRAM support required, enable it via QUICConfig.EnableDatagrams")
150+
}
151+
if !connState.SupportsStreamResetPartialDelivery.Local {
152+
return errors.New("webtransport: QUIC Stream Resets with Partial Delivery required, enable it via QUICConfig.EnableStreamResetPartialDelivery")
153+
}
145154
if err := s.initialize(); err != nil {
146155
return err
147156
}

server_test.go

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestServerReorderedUpgradeRequest(t *testing.T) {
8282
context.Background(),
8383
fmt.Sprintf("localhost:%d", port),
8484
&tls.Config{RootCAs: webtransport.CertPool, NextProtos: []string{http3.NextProtoH3}},
85-
&quic.Config{EnableDatagrams: true},
85+
&quic.Config{EnableDatagrams: true, EnableStreamResetPartialDelivery: true},
8686
)
8787
require.NoError(t, err)
8888
// Open a new stream for a WebTransport session we'll establish later. Stream ID: 0.
@@ -235,3 +235,65 @@ func TestImmediateClose(t *testing.T) {
235235
s := webtransport.Server{H3: &http3.Server{}}
236236
require.NoError(t, s.Close())
237237
}
238+
239+
func TestServerConnectionStateChecks(t *testing.T) {
240+
tests := []struct {
241+
name string
242+
enableDatagrams bool
243+
enableStreamResetPartial bool
244+
wantErr string
245+
}{
246+
{
247+
name: "missing datagram support",
248+
enableDatagrams: false,
249+
enableStreamResetPartial: true,
250+
wantErr: "webtransport: QUIC DATAGRAM support required",
251+
},
252+
{
253+
name: "missing stream reset partial delivery support",
254+
enableDatagrams: true,
255+
enableStreamResetPartial: false,
256+
wantErr: "webtransport: QUIC Stream Resets with Partial Delivery required",
257+
},
258+
}
259+
260+
for _, tt := range tests {
261+
t.Run(tt.name, func(t *testing.T) {
262+
s := webtransport.Server{H3: &http3.Server{TLSConfig: webtransport.TLSConf}}
263+
defer s.Close()
264+
265+
serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
266+
require.NoError(t, err)
267+
defer serverConn.Close()
268+
269+
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
270+
require.NoError(t, err)
271+
defer clientConn.Close()
272+
273+
ln, err := quic.ListenEarly(serverConn, webtransport.TLSConf, &quic.Config{
274+
EnableDatagrams: tt.enableDatagrams,
275+
EnableStreamResetPartialDelivery: tt.enableStreamResetPartial,
276+
})
277+
require.NoError(t, err)
278+
defer ln.Close()
279+
280+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
281+
defer cancel()
282+
_, err = quic.DialEarly(ctx, clientConn, ln.Addr(), &tls.Config{
283+
ServerName: "localhost",
284+
NextProtos: []string{http3.NextProtoH3},
285+
RootCAs: webtransport.CertPool,
286+
}, &quic.Config{
287+
EnableDatagrams: tt.enableDatagrams,
288+
EnableStreamResetPartialDelivery: tt.enableStreamResetPartial,
289+
})
290+
require.NoError(t, err)
291+
292+
qconn, err := ln.Accept(ctx)
293+
require.NoError(t, err)
294+
defer qconn.CloseWithError(0, "")
295+
296+
require.ErrorContains(t, s.ServeQUICConn(qconn), tt.wantErr)
297+
})
298+
}
299+
}

session_test.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/quic-go/quic-go/http3"
2020
"github.com/quic-go/quic-go/quicvarint"
2121

22+
"github.com/stretchr/testify/assert"
2223
"github.com/stretchr/testify/require"
2324
)
2425

@@ -54,9 +55,10 @@ func newConnPair(t *testing.T, clientConn, serverConn net.PacketConn) (client, s
5455
serverConn,
5556
TLSConf,
5657
&quic.Config{
57-
InitialStreamReceiveWindow: 1 << 60,
58-
InitialConnectionReceiveWindow: 1 << 60,
59-
EnableDatagrams: true,
58+
InitialStreamReceiveWindow: 1 << 60,
59+
InitialConnectionReceiveWindow: 1 << 60,
60+
EnableDatagrams: true,
61+
EnableStreamResetPartialDelivery: true,
6062
},
6163
)
6264
require.NoError(t, err)
@@ -72,18 +74,27 @@ func newConnPair(t *testing.T, clientConn, serverConn net.PacketConn) (client, s
7274
NextProtos: []string{http3.NextProtoH3},
7375
RootCAs: CertPool,
7476
},
75-
&quic.Config{EnableDatagrams: true},
77+
&quic.Config{
78+
EnableDatagrams: true,
79+
EnableStreamResetPartialDelivery: true,
80+
},
7681
)
7782
require.NoError(t, err)
78-
require.True(t, cl.ConnectionState().SupportsDatagrams.Remote)
83+
assert.True(t, cl.ConnectionState().SupportsDatagrams.Local)
84+
assert.True(t, cl.ConnectionState().SupportsDatagrams.Remote)
85+
assert.True(t, cl.ConnectionState().SupportsStreamResetPartialDelivery.Local)
86+
assert.True(t, cl.ConnectionState().SupportsStreamResetPartialDelivery.Remote)
7987
t.Cleanup(func() { cl.CloseWithError(0, "") })
8088

8189
conn, err := ln.Accept(ctx)
8290
require.NoError(t, err)
8391
t.Cleanup(func() { conn.CloseWithError(0, "") })
8492
select {
8593
case <-conn.HandshakeComplete():
86-
require.True(t, conn.ConnectionState().SupportsDatagrams.Remote)
94+
assert.True(t, conn.ConnectionState().SupportsDatagrams.Local)
95+
assert.True(t, conn.ConnectionState().SupportsDatagrams.Remote)
96+
assert.True(t, conn.ConnectionState().SupportsStreamResetPartialDelivery.Local)
97+
assert.True(t, conn.ConnectionState().SupportsStreamResetPartialDelivery.Remote)
8798
case <-ctx.Done():
8899
t.Fatal("timeout")
89100
}

stream.go

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type quicSendStream interface {
1818
CancelWrite(quic.StreamErrorCode)
1919
Context() context.Context
2020
SetWriteDeadline(time.Time) error
21+
SetReliableBoundary()
2122
}
2223

2324
var (
@@ -44,6 +45,9 @@ type SendStream struct {
4445
// Might be initialized to nil if this sendStream is part of an incoming bidirectional stream.
4546
streamHdr []byte
4647
streamHdrMu sync.Mutex
48+
// Set to true when a goroutine is spawned to send the header asynchronously.
49+
// This only happens if the stream is closed / reset immediately after creation.
50+
sendingHdrAsync bool
4751

4852
onClose func() // to remove the stream from the streamsMap
4953

@@ -116,23 +120,24 @@ func (s *SendStream) handleSessionGoneError() error {
116120
}
117121

118122
func (s *SendStream) write(b []byte) (int, error) {
119-
if err := s.maybeSendStreamHeader(); err != nil {
123+
s.streamHdrMu.Lock()
124+
err := s.maybeSendStreamHeader()
125+
s.streamHdrMu.Unlock()
126+
if err != nil {
120127
return 0, err
121128
}
122129
return s.str.Write(b)
123130
}
124131

125132
func (s *SendStream) maybeSendStreamHeader() error {
126-
s.streamHdrMu.Lock()
127-
defer s.streamHdrMu.Unlock()
128-
129133
if len(s.streamHdr) == 0 {
130134
return nil
131135
}
132136
n, err := s.str.Write(s.streamHdr)
133137
if n > 0 {
134138
s.streamHdr = s.streamHdr[n:]
135139
}
140+
s.str.SetReliableBoundary()
136141
if err != nil {
137142
return err
138143
}
@@ -145,6 +150,32 @@ func (s *SendStream) maybeSendStreamHeader() error {
145150
// Write will unblock immediately, and future calls to Write will fail.
146151
// When called multiple times it is a no-op.
147152
func (s *SendStream) CancelWrite(e StreamErrorCode) {
153+
// if a Goroutine is already sending the header, return immediately
154+
s.streamHdrMu.Lock()
155+
if s.sendingHdrAsync {
156+
s.streamHdrMu.Unlock()
157+
return
158+
}
159+
160+
if len(s.streamHdr) > 0 {
161+
// Sending the stream header might block if we are blocked by flow control.
162+
// Send a stream header async so that CancelWrite can return immediately.
163+
s.sendingHdrAsync = true
164+
streamHdr := s.streamHdr
165+
s.streamHdr = nil
166+
s.streamHdrMu.Unlock()
167+
168+
go func() {
169+
s.SetWriteDeadline(time.Time{})
170+
_, _ = s.str.Write(streamHdr)
171+
s.str.SetReliableBoundary()
172+
s.str.CancelWrite(webtransportCodeToHTTPCode(e))
173+
s.onClose()
174+
}()
175+
return
176+
}
177+
s.streamHdrMu.Unlock()
178+
148179
s.str.CancelWrite(webtransportCodeToHTTPCode(e))
149180
s.onClose()
150181
}
@@ -160,9 +191,32 @@ func (s *SendStream) closeWithSession(err error) {
160191
// Close closes the write-direction of the stream.
161192
// Future calls to Write are not permitted after calling Close.
162193
func (s *SendStream) Close() error {
163-
if err := s.maybeSendStreamHeader(); err != nil {
164-
return err
194+
// if a Goroutine is already sending the header, return immediately
195+
s.streamHdrMu.Lock()
196+
if s.sendingHdrAsync {
197+
s.streamHdrMu.Unlock()
198+
return nil
165199
}
200+
201+
if len(s.streamHdr) > 0 {
202+
// Sending the stream header might block if we are blocked by flow control.
203+
// Send a stream header async so that CancelWrite can return immediately.
204+
s.sendingHdrAsync = true
205+
streamHdr := s.streamHdr
206+
s.streamHdr = nil
207+
s.streamHdrMu.Unlock()
208+
209+
go func() {
210+
s.SetWriteDeadline(time.Time{})
211+
_, _ = s.str.Write(streamHdr)
212+
s.str.SetReliableBoundary()
213+
_ = s.str.Close()
214+
s.onClose()
215+
}()
216+
return nil
217+
}
218+
s.streamHdrMu.Unlock()
219+
166220
s.onClose()
167221
return maybeConvertStreamError(s.str.Close())
168222
}

0 commit comments

Comments
 (0)