Skip to content

Commit 1198cd8

Browse files
simplify error handling in Session.Open{Uni}StreamSync when session is closed
1 parent b15cfe1 commit 1198cd8

File tree

2 files changed

+39
-69
lines changed

2 files changed

+39
-69
lines changed

session.go

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package webtransport
33
import (
44
"context"
55
"encoding/binary"
6-
"errors"
76
"io"
87
"math/rand"
98
"net"
@@ -112,22 +111,8 @@ func newSession(sessionID sessionID, qconn http3.Connection, requestStr http3.St
112111
}
113112

114113
func (s *Session) handleConn() {
115-
var closeErr *SessionError
116114
err := s.parseNextCapsule()
117-
if !errors.As(err, &closeErr) {
118-
closeErr = &SessionError{Remote: true}
119-
}
120-
121-
s.closeMx.Lock()
122-
defer s.closeMx.Unlock()
123-
// If we closed the connection, the closeErr will be set in Close.
124-
if s.closeErr == nil {
125-
s.closeErr = closeErr
126-
}
127-
for _, cancel := range s.streamCtxs {
128-
cancel()
129-
}
130-
s.streams.CloseSession()
115+
s.closeWithError(err)
131116
}
132117

133118
// parseNextCapsule parses the next Capsule sent on the request stream.
@@ -176,9 +161,7 @@ func (s *Session) addStream(qstr quic.Stream, addStreamHeader bool) Stream {
176161

177162
func (s *Session) addReceiveStream(qstr quic.ReceiveStream) ReceiveStream {
178163
str := newReceiveStream(qstr, func() { s.streams.RemoveStream(qstr.StreamID()) })
179-
s.streams.AddStream(qstr.StreamID(), func() {
180-
str.closeWithSession()
181-
})
164+
s.streams.AddStream(qstr.StreamID(), str.closeWithSession)
182165
return str
183166
}
184167

@@ -307,23 +290,19 @@ func (s *Session) OpenStreamSync(ctx context.Context) (Stream, error) {
307290
id := s.addStreamCtxCancel(cancel)
308291
s.closeMx.Unlock()
309292

293+
// open a new bidirectional stream without holding the mutex: this call might block
310294
qstr, err := s.qconn.OpenStreamSync(ctx)
295+
296+
s.closeMx.Lock()
297+
defer s.closeMx.Unlock()
298+
delete(s.streamCtxs, id)
299+
311300
if err != nil {
312301
if s.closeErr != nil {
313302
return nil, s.closeErr
314303
}
315304
return nil, err
316305
}
317-
318-
s.closeMx.Lock()
319-
defer s.closeMx.Unlock()
320-
delete(s.streamCtxs, id)
321-
// Some time might have passed. Check if the session is still alive
322-
if s.closeErr != nil {
323-
qstr.CancelWrite(sessionCloseErrorCode)
324-
qstr.CancelRead(sessionCloseErrorCode)
325-
return nil, s.closeErr
326-
}
327306
return s.addStream(qstr, true), nil
328307
}
329308

@@ -351,22 +330,19 @@ func (s *Session) OpenUniStreamSync(ctx context.Context) (str SendStream, err er
351330
id := s.addStreamCtxCancel(cancel)
352331
s.closeMx.Unlock()
353332

333+
// open a new unidirectional stream without holding the mutex: this call might block
354334
qstr, err := s.qconn.OpenUniStreamSync(ctx)
335+
336+
s.closeMx.Lock()
337+
defer s.closeMx.Unlock()
338+
delete(s.streamCtxs, id)
339+
355340
if err != nil {
356341
if s.closeErr != nil {
357342
return nil, s.closeErr
358343
}
359344
return nil, err
360345
}
361-
362-
s.closeMx.Lock()
363-
defer s.closeMx.Unlock()
364-
delete(s.streamCtxs, id)
365-
// Some time might have passed. Check if the session is still alive
366-
if s.closeErr != nil {
367-
qstr.CancelWrite(sessionCloseErrorCode)
368-
return nil, s.closeErr
369-
}
370346
return s.addSendStream(qstr), nil
371347
}
372348

@@ -379,11 +355,22 @@ func (s *Session) RemoteAddr() net.Addr {
379355
}
380356

381357
func (s *Session) CloseWithError(code SessionErrorCode, msg string) error {
382-
first, err := s.closeWithError(code, msg)
358+
first, err := s.closeWithError(&SessionError{ErrorCode: code, Message: msg})
383359
if err != nil || !first {
384360
return err
385361
}
386362

363+
b := make([]byte, 4, 4+len(msg))
364+
binary.BigEndian.PutUint32(b, uint32(code))
365+
b = append(b, []byte(msg)...)
366+
if err := http3.WriteCapsule(
367+
quicvarint.NewWriter(s.requestStr),
368+
closeWebtransportSessionCapsuleType,
369+
b,
370+
); err != nil {
371+
return err
372+
}
373+
387374
s.requestStr.CancelRead(1337)
388375
err = s.requestStr.Close()
389376
<-s.ctx.Done()
@@ -398,27 +385,21 @@ func (s *Session) ReceiveDatagram(ctx context.Context) ([]byte, error) {
398385
return s.requestStr.ReceiveDatagram(ctx)
399386
}
400387

401-
func (s *Session) closeWithError(code SessionErrorCode, msg string) (bool /* first call to close session */, error) {
388+
func (s *Session) closeWithError(closeErr error) (bool /* first call to close session */, error) {
402389
s.closeMx.Lock()
403390
defer s.closeMx.Unlock()
404391
// Duplicate call, or the remote already closed this session.
405392
if s.closeErr != nil {
406393
return false, nil
407394
}
408-
s.closeErr = &SessionError{
409-
ErrorCode: code,
410-
Message: msg,
411-
}
395+
s.closeErr = closeErr
412396

413-
b := make([]byte, 4, 4+len(msg))
414-
binary.BigEndian.PutUint32(b, uint32(code))
415-
b = append(b, []byte(msg)...)
397+
for _, cancel := range s.streamCtxs {
398+
cancel()
399+
}
400+
s.streams.CloseSession()
416401

417-
return true, http3.WriteCapsule(
418-
quicvarint.NewWriter(s.requestStr),
419-
closeWebtransportSessionCapsuleType,
420-
b,
421-
)
402+
return true, nil
422403
}
423404

424405
func (s *Session) ConnectionState() quic.ConnectionState {

session_test.go

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,11 @@ func TestOpenStreamAfterSessionClose(t *testing.T) {
124124

125125
mockSess := NewMockConnection(ctrl)
126126
mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337)))
127-
wait := make(chan struct{})
128127
streamOpen := make(chan struct{})
129-
mockSess.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
128+
mockSess.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
130129
streamOpen <- struct{}{}
131-
str := NewMockStream(ctrl)
132-
str.EXPECT().CancelRead(sessionCloseErrorCode)
133-
str.EXPECT().CancelWrite(sessionCloseErrorCode)
134-
<-wait
135-
return str, nil
130+
<-ctx.Done()
131+
return nil, ctx.Err()
136132
})
137133

138134
sess := newSession(42, mockSess, newMockRequestStream(ctrl))
@@ -145,8 +141,6 @@ func TestOpenStreamAfterSessionClose(t *testing.T) {
145141
<-streamOpen
146142

147143
require.NoError(t, sess.CloseWithError(0, "session closed"))
148-
149-
close(wait)
150144
require.EqualError(t, <-errChan, "session closed")
151145
}
152146

@@ -156,14 +150,11 @@ func TestOpenUniStreamAfterSessionClose(t *testing.T) {
156150

157151
mockSess := NewMockConnection(ctrl)
158152
mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337)))
159-
wait := make(chan struct{})
160153
streamOpen := make(chan struct{})
161-
mockSess.EXPECT().OpenUniStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.SendStream, error) {
154+
mockSess.EXPECT().OpenUniStreamSync(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.SendStream, error) {
162155
streamOpen <- struct{}{}
163-
str := NewMockStream(ctrl)
164-
str.EXPECT().CancelWrite(sessionCloseErrorCode)
165-
<-wait
166-
return str, nil
156+
<-ctx.Done()
157+
return nil, ctx.Err()
167158
})
168159

169160
sess := newSession(42, mockSess, newMockRequestStream(ctrl))
@@ -176,7 +167,5 @@ func TestOpenUniStreamAfterSessionClose(t *testing.T) {
176167
<-streamOpen
177168

178169
require.NoError(t, sess.CloseWithError(0, "session closed"))
179-
180-
close(wait)
181170
require.EqualError(t, <-errChan, "session closed")
182171
}

0 commit comments

Comments
 (0)