From 29643602415fafeae2dbf472db6a3956732571bd Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Wed, 13 Dec 2023 07:42:58 +1300 Subject: [PATCH 1/3] ssh: add func DialContext DialContext starts a client connection to the given SSH server using the supplied Context. The supplied Context affects the dial and handshake. If it expires after the connection is opened, it has no effect on the resulting Client. --- ssh/client.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/ssh/client.go b/ssh/client.go index fd8c49749e..20d5af4f71 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -6,6 +6,7 @@ package ssh import ( "bytes" + "context" "errors" "fmt" "net" @@ -168,6 +169,50 @@ func (c *Client) handleChannelOpens(in <-chan NewChannel) { c.mu.Unlock() } +// DialContext starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See [Dial] for additional information. +func DialContext(ctx context.Context, network, addr string, config *ClientConfig) (*Client, error) { + d := net.Dialer{ + Timeout: config.Timeout, + } + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + type result struct { + client *Client + err error + } + ch := make(chan result) + go func() { + var client *Client + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err == nil { + client = NewClient(c, chans, reqs) + } + select { + case ch <- result{client, err}: + case <-ctx.Done(): + if client != nil { + client.Close() + } + } + }() + select { + case res := <-ch: + return res.client, res.err + case <-ctx.Done(): + return nil, context.Cause(ctx) + } +} + // Dial starts a client connection to the given SSH server. It is a // convenience function that connects to the given network address, // initiates the SSH handshake, and then sets up a Client. For access From c5764670eba74778329ffe4b97031c352710df0d Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Mon, 18 Dec 2023 11:28:38 +1300 Subject: [PATCH 2/3] ssh: add tests for DialContext --- ssh/client_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ssh/client_test.go b/ssh/client_test.go index 2621f0ea52..d011fb4f3b 100644 --- a/ssh/client_test.go +++ b/ssh/client_test.go @@ -6,12 +6,14 @@ package ssh import ( "bytes" + "context" "crypto/rand" "errors" "fmt" "net" "strings" "testing" + "time" ) func TestClientVersion(t *testing.T) { @@ -365,3 +367,27 @@ func TestUnsupportedAlgorithm(t *testing.T) { }) } } + +func TestDialContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := DialContext(ctx, "tcp", ":22", &ClientConfig{}) + wantErr := context.Canceled + if !errors.Is(err, wantErr) { + t.Errorf("DialContext: err == %v, expected %v", err, wantErr) + } + + ctx, cancel = context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err = DialContext(ctx, "tcp", ":22", &ClientConfig{}) + wantErr = context.DeadlineExceeded + if !errors.Is(err, wantErr) { + t.Errorf("DialContext: err == %v, expected %v", err, wantErr) + } + + ctx = context.Background() + _, err = DialContext(ctx, "tcp", ":22", &ClientConfig{}) + if _, ok := err.(*net.OpError); !ok { + t.Errorf("DialContext: err == %#v, expected *net.OpError", err) + } +} From 69b3f59aa296fcbb7d8fb9dc936a02744e7741ea Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Mon, 18 Dec 2023 11:28:56 +1300 Subject: [PATCH 3/3] ssh: add example for DialContext --- ssh/example_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ssh/example_test.go b/ssh/example_test.go index 3920832c1a..c21c7a61dc 100644 --- a/ssh/example_test.go +++ b/ssh/example_test.go @@ -7,6 +7,7 @@ package ssh_test import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "fmt" @@ -17,6 +18,7 @@ import ( "path/filepath" "strings" "sync" + "time" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" @@ -262,6 +264,30 @@ func ExampleDial() { fmt.Println(b.String()) } +func ExampleDialContext() { + var hostKey ssh.PublicKey + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + // The Context supplied to DialContext allows the caller to control + // the timeout or cancel opening an SSH connection. + // + // Cancelling the context after DialContext returns will not effect + // the resulting Client. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + client, err := ssh.DialContext(ctx, "tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + defer client.Close() +} + func ExamplePublicKeys() { var hostKey ssh.PublicKey // A public key may be used to authenticate against the remote