diff --git a/session.go b/session.go index bc56066..11a3189 100644 --- a/session.go +++ b/session.go @@ -338,9 +338,24 @@ func (s *Session) recvLoop() { } s.streamLock.Unlock() case cmdFIN: + var reason string + if hdr.Length() > 0 { + newbuf := defaultAllocator.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, newbuf); err == nil { + reason = string(newbuf) + } else { + s.notifyReadError(err) + return + } + } + s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { - stream.fin() + if reason != "" { + stream.rst(reason) + } else { + stream.fin() + } stream.notifyReadEvent() } s.streamLock.Unlock() diff --git a/session_test.go b/session_test.go index 3479570..ca4b322 100644 --- a/session_test.go +++ b/session_test.go @@ -4,6 +4,7 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" + "errors" "fmt" "io" "log" @@ -451,6 +452,24 @@ func TestStreamDoubleClose(t *testing.T) { session.Close() } +func TestStreamCloseWithError(t *testing.T) { + cs, ss, err := getSmuxStreamPair() + if err != nil { + t.Fatal(err) + } + ss.CloseWithError(errors.New("test error")) + + tinybuf := make([]byte, 6) + _, err = cs.Read(tinybuf) + if err == nil { + t.Fatal("stream cancel must return error") + } + + if err.Error() != "test error" { + t.Fatal("client stream must handle the same error as server stream send after cancel") + } +} + func TestConcurrentClose(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { diff --git a/stream.go b/stream.go index b21bdef..aef31d1 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,7 @@ package smux import ( "encoding/binary" + "errors" "io" "net" "sync" @@ -31,6 +32,8 @@ type Stream struct { chFinEvent chan struct{} finEventOnce sync.Once + finReason string + // deadlines readDeadline atomic.Value writeDeadline atomic.Value @@ -110,7 +113,11 @@ func (s *Stream) tryRead(b []byte) (n int, err error) { select { case <-s.die: - return 0, io.EOF + if s.finReason != "" { + return 0, errors.New(s.finReason) + } else { + return 0, io.EOF + } default: return 0, ErrWouldBlock } @@ -161,7 +168,11 @@ func (s *Stream) tryReadv2(b []byte) (n int, err error) { select { case <-s.die: - return 0, io.EOF + if s.finReason != "" { + return 0, errors.New(s.finReason) + } else { + return 0, io.EOF + } default: return 0, ErrWouldBlock } @@ -272,7 +283,11 @@ func (s *Stream) waitRead() error { case <-s.chReadEvent: return nil case <-s.chFinEvent: - return io.EOF + if s.finReason != "" { + return errors.New(s.finReason) + } else { + return io.EOF + } case <-s.sess.chSocketReadError: return s.sess.socketReadError.Load().(error) case <-s.sess.chProtoError: @@ -402,7 +417,11 @@ func (s *Stream) writeV2(b []byte) (n int, err error) { if len(b) > 0 { select { case <-s.chFinEvent: // if fin arrived, future window update is impossible - return 0, io.EOF + if s.finReason != "" { + return 0, errors.New(s.finReason) + } else { + return 0, io.EOF + } case <-s.die: return sent, io.ErrClosedPipe case <-deadline: @@ -420,15 +439,31 @@ func (s *Stream) writeV2(b []byte) (n int, err error) { // Close implements net.Conn func (s *Stream) Close() error { + return s.close(nil) +} + +func (s *Stream) CloseWithError(reason error) error { + return s.close(reason) +} + +func (s *Stream) close(reason error) error { var once bool var err error s.dieOnce.Do(func() { close(s.die) once = true + if reason != nil { + s.finReason = reason.Error() + } }) if once { - _, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id)) + fr := newFrame(byte(s.sess.config.Version), cmdFIN, s.id) + if reason != nil && s.sess.config.Version > 1 { + fr.data = []byte(reason.Error()) + } + + _, err = s.sess.writeFrame(fr) s.sess.streamClosed(s.id) return err } else { @@ -541,3 +576,11 @@ func (s *Stream) fin() { close(s.chFinEvent) }) } + +// mark this stream has been canceled in protocol +func (s *Stream) rst(reason string) { + s.finEventOnce.Do(func() { + s.finReason = reason + close(s.chFinEvent) + }) +} \ No newline at end of file