Skip to content

Commit dd67fae

Browse files
authored
test: timeout/retry ssh connections and remove intermediate pipe (#7850)
1 parent 93fa489 commit dd67fae

File tree

1 file changed

+126
-68
lines changed

1 file changed

+126
-68
lines changed

e2e/bastionssh.go

Lines changed: 126 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package e2e
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/json"
76
"fmt"
8-
"io"
97
"net"
108
"net/http"
119
"net/url"
@@ -18,6 +16,8 @@ import (
1816
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
1917
"github.com/coder/websocket"
2018
"golang.org/x/crypto/ssh"
19+
20+
"github.com/Azure/agentbaker/e2e/toolkit"
2121
)
2222

2323
var AllowedSSHPrefixes = []string{ssh.KeyAlgoED25519, ssh.KeyAlgoRSA, ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512}
@@ -53,18 +53,26 @@ type tunnelSession struct {
5353
bastion *Bastion
5454
ws *websocket.Conn
5555
session *sessionToken
56+
ctx context.Context
57+
58+
readDeadline time.Time
59+
writeDeadline time.Time
60+
readBuf []byte
61+
62+
targetHost string
63+
targetPort uint16
5664
}
5765

58-
func (b *Bastion) NewTunnelSession(targetHost string, port uint16) (*tunnelSession, error) {
66+
func (b *Bastion) NewTunnelSession(ctx context.Context, targetHost string, port uint16) (*tunnelSession, error) {
5967
session, err := b.newSessionToken(targetHost, port)
6068
if err != nil {
6169
return nil, err
6270
}
6371

6472
wsUrl := fmt.Sprintf("wss://%v/webtunnelv2/%v?X-Node-Id=%v", b.dnsName, session.WebsocketToken, session.NodeID)
6573

66-
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
67-
ws, _, err := websocket.Dial(ctx, wsUrl, &websocket.DialOptions{
74+
dialCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
75+
ws, _, err := websocket.Dial(dialCtx, wsUrl, &websocket.DialOptions{
6876
CompressionMode: websocket.CompressionDisabled,
6977
})
7078
cancel()
@@ -75,9 +83,12 @@ func (b *Bastion) NewTunnelSession(targetHost string, port uint16) (*tunnelSessi
7583
ws.SetReadLimit(32 * 1024 * 1024)
7684

7785
return &tunnelSession{
78-
bastion: b,
79-
ws: ws,
80-
session: session,
86+
bastion: b,
87+
ws: ws,
88+
session: session,
89+
ctx: ctx,
90+
targetHost: targetHost,
91+
targetPort: port,
8192
}, nil
8293
}
8394

@@ -167,48 +178,81 @@ func (b *Bastion) newSessionToken(targetHost string, port uint16) (*sessionToken
167178
return &response, nil
168179
}
169180

170-
func (t *tunnelSession) Pipe(conn net.Conn) error {
181+
func (t *tunnelSession) Read(p []byte) (int, error) {
182+
if len(t.readBuf) == 0 {
183+
ctx := t.ctx
184+
if !t.readDeadline.IsZero() {
185+
var cancel context.CancelFunc
186+
ctx, cancel = context.WithDeadline(t.ctx, t.readDeadline)
187+
defer cancel()
188+
}
189+
typ, data, err := t.ws.Read(ctx)
190+
if err != nil {
191+
return 0, err
192+
}
193+
if typ != websocket.MessageBinary {
194+
return 0, fmt.Errorf("unexpected websocket message type: %v", typ)
195+
}
196+
t.readBuf = data
197+
}
171198

172-
defer t.Close()
173-
defer conn.Close()
199+
n := copy(p, t.readBuf)
200+
t.readBuf = t.readBuf[n:]
201+
return n, nil
202+
}
174203

175-
done := make(chan error, 2)
204+
func (t *tunnelSession) Write(p []byte) (int, error) {
205+
ctx := t.ctx
206+
if !t.writeDeadline.IsZero() {
207+
var cancel context.CancelFunc
208+
ctx, cancel = context.WithDeadline(t.ctx, t.writeDeadline)
209+
defer cancel()
210+
}
211+
if err := t.ws.Write(ctx, websocket.MessageBinary, p); err != nil {
212+
return 0, err
213+
}
176214

177-
go func() {
178-
for {
179-
_, data, err := t.ws.Read(context.Background())
180-
if err != nil {
181-
done <- err
182-
return
183-
}
215+
return len(p), nil
216+
}
184217

185-
if _, err := io.Copy(conn, bytes.NewReader(data)); err != nil {
186-
done <- err
187-
return
188-
}
189-
}
190-
}()
218+
func (t *tunnelSession) LocalAddr() net.Addr {
219+
return bastionAddr{
220+
network: "bastion",
221+
address: "local",
222+
}
223+
}
191224

192-
go func() {
193-
buf := make([]byte, 4096) // 4096 is copy from az cli bastion code
225+
func (t *tunnelSession) RemoteAddr() net.Addr {
226+
return bastionAddr{
227+
network: "bastion",
228+
address: fmt.Sprintf("%s:%d", t.targetHost, t.targetPort),
229+
}
230+
}
194231

195-
for {
196-
n, err := conn.Read(buf)
197-
if err != nil {
198-
done <- err
199-
return
200-
}
232+
func (t *tunnelSession) SetDeadline(deadline time.Time) error {
233+
t.readDeadline = deadline
234+
t.writeDeadline = deadline
235+
return nil
236+
}
201237

202-
if err := t.ws.Write(context.Background(), websocket.MessageBinary, buf[:n]); err != nil {
203-
done <- err
204-
return
205-
}
206-
}
207-
}()
238+
func (t *tunnelSession) SetReadDeadline(deadline time.Time) error {
239+
t.readDeadline = deadline
240+
return nil
241+
}
242+
243+
func (t *tunnelSession) SetWriteDeadline(deadline time.Time) error {
244+
t.writeDeadline = deadline
245+
return nil
246+
}
208247

209-
return <-done
248+
type bastionAddr struct {
249+
network string
250+
address string
210251
}
211252

253+
func (a bastionAddr) Network() string { return a.network }
254+
func (a bastionAddr) String() string { return a.address }
255+
212256
func sshClientConfig(user string, privateKey []byte) (*ssh.ClientConfig, error) {
213257
signer, err := ssh.ParsePrivateKey(privateKey)
214258
if err != nil {
@@ -237,41 +281,55 @@ func DialSSHOverBastion(
237281
vmPrivateIP string,
238282
sshPrivateKey []byte,
239283
) (*ssh.Client, error) {
240-
241-
// Create Bastion tunnel session (SSH = port 22)
242-
tunnel, err := bastion.NewTunnelSession(
243-
vmPrivateIP,
244-
22,
245-
)
284+
sshConfig, err := sshClientConfig("azureuser", sshPrivateKey)
246285
if err != nil {
247286
return nil, err
248287
}
249288

250-
// Create in-memory connection pair
251-
sshSide, tunnelSide := net.Pipe()
289+
const (
290+
sshDialAttempts = 5
291+
sshDialTimeout = 30 * time.Second
292+
sshDialBackoff = 10 * time.Second
293+
)
252294

253-
// Start Bastion tunnel piping
254-
go func() {
255-
_ = tunnel.Pipe(tunnelSide)
256-
fmt.Printf("Closed tunnel for VM IP %s\n", vmPrivateIP)
257-
}()
295+
var lastErr error
296+
for attempt := 1; attempt <= sshDialAttempts; attempt++ {
297+
if attempt > 1 {
298+
select {
299+
case <-time.After(sshDialBackoff):
300+
case <-ctx.Done():
301+
return nil, ctx.Err()
302+
}
303+
}
304+
toolkit.Logf(ctx, "Attempt %d/%d establishing SSH over bastion to %s", attempt, sshDialAttempts, vmPrivateIP)
305+
306+
// Intentionally use a background context to prevent cancelling the SSH connection before
307+
// we fetch logs during cleanup.
308+
tunnel, err := bastion.NewTunnelSession(context.Background(), vmPrivateIP, 22)
309+
if err != nil {
310+
lastErr = err
311+
toolkit.Logf(ctx, "Attempt %d/%d failed to create bastion tunnel: %v", attempt, sshDialAttempts, err)
312+
continue
313+
}
258314

259-
// SSH client configuration
260-
sshConfig, err := sshClientConfig("azureuser", sshPrivateKey)
261-
if err != nil {
262-
return nil, err
315+
_ = tunnel.SetDeadline(time.Now().Add(sshDialTimeout))
316+
sshConn, chans, reqs, err := ssh.NewClientConn(
317+
tunnel,
318+
vmPrivateIP,
319+
sshConfig,
320+
)
321+
if err != nil {
322+
lastErr = err
323+
toolkit.Logf(ctx, "Attempt %d/%d SSH handshake failed: %v", attempt, sshDialAttempts, err)
324+
_ = tunnel.Close()
325+
continue
326+
}
327+
_ = tunnel.SetDeadline(time.Time{})
328+
return ssh.NewClient(sshConn, chans, reqs), nil
263329
}
264330

265-
// Establish SSH over the Bastion tunnel
266-
sshConn, chans, reqs, err := ssh.NewClientConn(
267-
sshSide,
268-
vmPrivateIP,
269-
sshConfig,
270-
)
271-
if err != nil {
272-
sshSide.Close()
273-
return nil, err
331+
if lastErr == nil {
332+
lastErr = fmt.Errorf("failed to establish SSH connection over bastion")
274333
}
275-
276-
return ssh.NewClient(sshConn, chans, reqs), nil
334+
return nil, lastErr
277335
}

0 commit comments

Comments
 (0)