diff --git a/ssh/client.go b/ssh/client.go index fd8c49749e..20d5af4f71 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -6,6 +6,7 @@ package ssh import ( "bytes" + "context" "errors" "fmt" "net" @@ -168,6 +169,50 @@ func (c *Client) handleChannelOpens(in <-chan NewChannel) { c.mu.Unlock() } +// DialContext starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See [Dial] for additional information. +func DialContext(ctx context.Context, network, addr string, config *ClientConfig) (*Client, error) { + d := net.Dialer{ + Timeout: config.Timeout, + } + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + type result struct { + client *Client + err error + } + ch := make(chan result) + go func() { + var client *Client + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err == nil { + client = NewClient(c, chans, reqs) + } + select { + case ch <- result{client, err}: + case <-ctx.Done(): + if client != nil { + client.Close() + } + } + }() + select { + case res := <-ch: + return res.client, res.err + case <-ctx.Done(): + return nil, context.Cause(ctx) + } +} + // Dial starts a client connection to the given SSH server. It is a // convenience function that connects to the given network address, // initiates the SSH handshake, and then sets up a Client. For access diff --git a/ssh/client_test.go b/ssh/client_test.go index 2621f0ea52..d011fb4f3b 100644 --- a/ssh/client_test.go +++ b/ssh/client_test.go @@ -6,12 +6,14 @@ package ssh import ( "bytes" + "context" "crypto/rand" "errors" "fmt" "net" "strings" "testing" + "time" ) func TestClientVersion(t *testing.T) { @@ -365,3 +367,27 @@ func TestUnsupportedAlgorithm(t *testing.T) { }) } } + +func TestDialContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := DialContext(ctx, "tcp", ":22", &ClientConfig{}) + wantErr := context.Canceled + if !errors.Is(err, wantErr) { + t.Errorf("DialContext: err == %v, expected %v", err, wantErr) + } + + ctx, cancel = context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err = DialContext(ctx, "tcp", ":22", &ClientConfig{}) + wantErr = context.DeadlineExceeded + if !errors.Is(err, wantErr) { + t.Errorf("DialContext: err == %v, expected %v", err, wantErr) + } + + ctx = context.Background() + _, err = DialContext(ctx, "tcp", ":22", &ClientConfig{}) + if _, ok := err.(*net.OpError); !ok { + t.Errorf("DialContext: err == %#v, expected *net.OpError", err) + } +} diff --git a/ssh/example_test.go b/ssh/example_test.go index 3920832c1a..c21c7a61dc 100644 --- a/ssh/example_test.go +++ b/ssh/example_test.go @@ -7,6 +7,7 @@ package ssh_test import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "fmt" @@ -17,6 +18,7 @@ import ( "path/filepath" "strings" "sync" + "time" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" @@ -262,6 +264,30 @@ func ExampleDial() { fmt.Println(b.String()) } +func ExampleDialContext() { + var hostKey ssh.PublicKey + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + // The Context supplied to DialContext allows the caller to control + // the timeout or cancel opening an SSH connection. + // + // Cancelling the context after DialContext returns will not effect + // the resulting Client. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + client, err := ssh.DialContext(ctx, "tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + defer client.Close() +} + func ExamplePublicKeys() { var hostKey ssh.PublicKey // A public key may be used to authenticate against the remote