Skip to content

Commit

Permalink
time out a client connection if a successful handshake has not happen…
Browse files Browse the repository at this point in the history
…ed within the duration HandshakeTimeout
  • Loading branch information
aderouineau-amz committed Apr 19, 2024
1 parent adec695 commit ccf6f02
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
34 changes: 21 additions & 13 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ import (
type serverConn struct {
net.Conn

idleTimeout time.Duration
maxDeadline time.Time
closeCanceler context.CancelFunc
idleTimeout time.Duration
handshakeDeadline time.Time
maxDeadline time.Time
closeCanceler context.CancelFunc
}

func (c *serverConn) Write(p []byte) (n int, err error) {
c.updateDeadline()
if c.idleTimeout > 0 {
c.updateDeadline()
}
n, err = c.Conn.Write(p)
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
c.closeCanceler()
Expand All @@ -24,7 +27,9 @@ func (c *serverConn) Write(p []byte) (n int, err error) {
}

func (c *serverConn) Read(b []byte) (n int, err error) {
c.updateDeadline()
if c.idleTimeout > 0 {
c.updateDeadline()
}
n, err = c.Conn.Read(b)
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
c.closeCanceler()
Expand All @@ -41,15 +46,18 @@ func (c *serverConn) Close() (err error) {
}

func (c *serverConn) updateDeadline() {
switch {
case c.idleTimeout > 0:
deadline := c.maxDeadline

if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) {
deadline = c.handshakeDeadline
}

if c.idleTimeout > 0 {
idleDeadline := time.Now().Add(c.idleTimeout)
if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
c.Conn.SetDeadline(idleDeadline)
return
if deadline.IsZero() || idleDeadline.Before(deadline) {
deadline = idleDeadline
}
fallthrough
default:
c.Conn.SetDeadline(c.maxDeadline)
}

c.Conn.SetDeadline(deadline)
}
12 changes: 9 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ type Server struct {

ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures

IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty
HandshakeTimeout time.Duration // connection timeout until successful handshake, none if empty
IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty

// ChannelHandlers allow overriding the built-in session handlers or provide
// extensions to the protocol, such as tcpip forwarding. By default only the
Expand Down Expand Up @@ -290,6 +291,10 @@ func (srv *Server) HandleConn(newConn net.Conn) {
if srv.MaxTimeout > 0 {
conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
}
if srv.HandshakeTimeout > 0 {
conn.handshakeDeadline = time.Now().Add(srv.HandshakeTimeout)
}
conn.updateDeadline()
defer conn.Close()
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
if err != nil {
Expand All @@ -298,7 +303,8 @@ func (srv *Server) HandleConn(newConn net.Conn) {
}
return
}

conn.handshakeDeadline = time.Time{}
conn.updateDeadline()
srv.trackConn(sshConn, true)
defer srv.trackConn(sshConn, false)

Expand Down
34 changes: 34 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"io"
"net"
"testing"
"time"
)
Expand Down Expand Up @@ -124,3 +125,36 @@ func TestServerClose(t *testing.T) {
return
}
}

func TestServerHandshakeTimeout(t *testing.T) {
l := newLocalListener()

s := &Server{
HandshakeTimeout: time.Millisecond,
}
go func() {
if err := s.Serve(l); err != nil {
t.Error(err)
}
}()

conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()

ch := make(chan struct{})
go func() {
defer close(ch)
io.Copy(io.Discard, conn)
}()

select {
case <-ch:
return
case <-time.After(time.Second):
t.Fatal("client connection was not force-closed")
return
}
}

0 comments on commit ccf6f02

Please sign in to comment.