11package e2e
22
33import (
4- "bytes"
54 "context"
65 "encoding/json"
76 "fmt"
8- "io"
97 "net"
108 "net/http"
119 "net/url"
@@ -18,6 +16,8 @@ import (
1816 "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
1917 "github.com/coder/websocket"
2018 "golang.org/x/crypto/ssh"
19+
20+ "github.com/Azure/agentbaker/e2e/toolkit"
2121)
2222
2323var AllowedSSHPrefixes = []string {ssh .KeyAlgoED25519 , ssh .KeyAlgoRSA , ssh .KeyAlgoRSASHA256 , ssh .KeyAlgoRSASHA512 }
@@ -53,18 +53,26 @@ type tunnelSession struct {
5353 bastion * Bastion
5454 ws * websocket.Conn
5555 session * sessionToken
56+ ctx context.Context
57+
58+ readDeadline time.Time
59+ writeDeadline time.Time
60+ readBuf []byte
61+
62+ targetHost string
63+ targetPort uint16
5664}
5765
58- func (b * Bastion ) NewTunnelSession (targetHost string , port uint16 ) (* tunnelSession , error ) {
66+ func (b * Bastion ) NewTunnelSession (ctx context. Context , targetHost string , port uint16 ) (* tunnelSession , error ) {
5967 session , err := b .newSessionToken (targetHost , port )
6068 if err != nil {
6169 return nil , err
6270 }
6371
6472 wsUrl := fmt .Sprintf ("wss://%v/webtunnelv2/%v?X-Node-Id=%v" , b .dnsName , session .WebsocketToken , session .NodeID )
6573
66- ctx , cancel := context .WithTimeout (context . Background () , 15 * time .Second )
67- ws , _ , err := websocket .Dial (ctx , wsUrl , & websocket.DialOptions {
74+ dialCtx , cancel := context .WithTimeout (ctx , 15 * time .Second )
75+ ws , _ , err := websocket .Dial (dialCtx , wsUrl , & websocket.DialOptions {
6876 CompressionMode : websocket .CompressionDisabled ,
6977 })
7078 cancel ()
@@ -75,9 +83,12 @@ func (b *Bastion) NewTunnelSession(targetHost string, port uint16) (*tunnelSessi
7583 ws .SetReadLimit (32 * 1024 * 1024 )
7684
7785 return & tunnelSession {
78- bastion : b ,
79- ws : ws ,
80- session : session ,
86+ bastion : b ,
87+ ws : ws ,
88+ session : session ,
89+ ctx : ctx ,
90+ targetHost : targetHost ,
91+ targetPort : port ,
8192 }, nil
8293}
8394
@@ -167,48 +178,81 @@ func (b *Bastion) newSessionToken(targetHost string, port uint16) (*sessionToken
167178 return & response , nil
168179}
169180
170- func (t * tunnelSession ) Pipe (conn net.Conn ) error {
181+ func (t * tunnelSession ) Read (p []byte ) (int , error ) {
182+ if len (t .readBuf ) == 0 {
183+ ctx := t .ctx
184+ if ! t .readDeadline .IsZero () {
185+ var cancel context.CancelFunc
186+ ctx , cancel = context .WithDeadline (t .ctx , t .readDeadline )
187+ defer cancel ()
188+ }
189+ typ , data , err := t .ws .Read (ctx )
190+ if err != nil {
191+ return 0 , err
192+ }
193+ if typ != websocket .MessageBinary {
194+ return 0 , fmt .Errorf ("unexpected websocket message type: %v" , typ )
195+ }
196+ t .readBuf = data
197+ }
171198
172- defer t .Close ()
173- defer conn .Close ()
199+ n := copy (p , t .readBuf )
200+ t .readBuf = t .readBuf [n :]
201+ return n , nil
202+ }
174203
175- done := make (chan error , 2 )
204+ func (t * tunnelSession ) Write (p []byte ) (int , error ) {
205+ ctx := t .ctx
206+ if ! t .writeDeadline .IsZero () {
207+ var cancel context.CancelFunc
208+ ctx , cancel = context .WithDeadline (t .ctx , t .writeDeadline )
209+ defer cancel ()
210+ }
211+ if err := t .ws .Write (ctx , websocket .MessageBinary , p ); err != nil {
212+ return 0 , err
213+ }
176214
177- go func () {
178- for {
179- _ , data , err := t .ws .Read (context .Background ())
180- if err != nil {
181- done <- err
182- return
183- }
215+ return len (p ), nil
216+ }
184217
185- if _ , err := io . Copy ( conn , bytes . NewReader ( data )); err != nil {
186- done <- err
187- return
188- }
189- }
190- }()
218+ func ( t * tunnelSession ) LocalAddr () net. Addr {
219+ return bastionAddr {
220+ network : "bastion" ,
221+ address : "local" ,
222+ }
223+ }
191224
192- go func () {
193- buf := make ([]byte , 4096 ) // 4096 is copy from az cli bastion code
225+ func (t * tunnelSession ) RemoteAddr () net.Addr {
226+ return bastionAddr {
227+ network : "bastion" ,
228+ address : fmt .Sprintf ("%s:%d" , t .targetHost , t .targetPort ),
229+ }
230+ }
194231
195- for {
196- n , err := conn .Read (buf )
197- if err != nil {
198- done <- err
199- return
200- }
232+ func (t * tunnelSession ) SetDeadline (deadline time.Time ) error {
233+ t .readDeadline = deadline
234+ t .writeDeadline = deadline
235+ return nil
236+ }
201237
202- if err := t .ws .Write (context .Background (), websocket .MessageBinary , buf [:n ]); err != nil {
203- done <- err
204- return
205- }
206- }
207- }()
238+ func (t * tunnelSession ) SetReadDeadline (deadline time.Time ) error {
239+ t .readDeadline = deadline
240+ return nil
241+ }
242+
243+ func (t * tunnelSession ) SetWriteDeadline (deadline time.Time ) error {
244+ t .writeDeadline = deadline
245+ return nil
246+ }
208247
209- return <- done
248+ type bastionAddr struct {
249+ network string
250+ address string
210251}
211252
253+ func (a bastionAddr ) Network () string { return a .network }
254+ func (a bastionAddr ) String () string { return a .address }
255+
212256func sshClientConfig (user string , privateKey []byte ) (* ssh.ClientConfig , error ) {
213257 signer , err := ssh .ParsePrivateKey (privateKey )
214258 if err != nil {
@@ -237,41 +281,55 @@ func DialSSHOverBastion(
237281 vmPrivateIP string ,
238282 sshPrivateKey []byte ,
239283) (* ssh.Client , error ) {
240-
241- // Create Bastion tunnel session (SSH = port 22)
242- tunnel , err := bastion .NewTunnelSession (
243- vmPrivateIP ,
244- 22 ,
245- )
284+ sshConfig , err := sshClientConfig ("azureuser" , sshPrivateKey )
246285 if err != nil {
247286 return nil , err
248287 }
249288
250- // Create in-memory connection pair
251- sshSide , tunnelSide := net .Pipe ()
289+ const (
290+ sshDialAttempts = 5
291+ sshDialTimeout = 30 * time .Second
292+ sshDialBackoff = 10 * time .Second
293+ )
252294
253- // Start Bastion tunnel piping
254- go func () {
255- _ = tunnel .Pipe (tunnelSide )
256- fmt .Printf ("Closed tunnel for VM IP %s\n " , vmPrivateIP )
257- }()
295+ var lastErr error
296+ for attempt := 1 ; attempt <= sshDialAttempts ; attempt ++ {
297+ if attempt > 1 {
298+ select {
299+ case <- time .After (sshDialBackoff ):
300+ case <- ctx .Done ():
301+ return nil , ctx .Err ()
302+ }
303+ }
304+ toolkit .Logf (ctx , "Attempt %d/%d establishing SSH over bastion to %s" , attempt , sshDialAttempts , vmPrivateIP )
305+
306+ // Intentionally use a background context to prevent cancelling the SSH connection before
307+ // we fetch logs during cleanup.
308+ tunnel , err := bastion .NewTunnelSession (context .Background (), vmPrivateIP , 22 )
309+ if err != nil {
310+ lastErr = err
311+ toolkit .Logf (ctx , "Attempt %d/%d failed to create bastion tunnel: %v" , attempt , sshDialAttempts , err )
312+ continue
313+ }
258314
259- // SSH client configuration
260- sshConfig , err := sshClientConfig ("azureuser" , sshPrivateKey )
261- if err != nil {
262- return nil , err
315+ _ = tunnel .SetDeadline (time .Now ().Add (sshDialTimeout ))
316+ sshConn , chans , reqs , err := ssh .NewClientConn (
317+ tunnel ,
318+ vmPrivateIP ,
319+ sshConfig ,
320+ )
321+ if err != nil {
322+ lastErr = err
323+ toolkit .Logf (ctx , "Attempt %d/%d SSH handshake failed: %v" , attempt , sshDialAttempts , err )
324+ _ = tunnel .Close ()
325+ continue
326+ }
327+ _ = tunnel .SetDeadline (time.Time {})
328+ return ssh .NewClient (sshConn , chans , reqs ), nil
263329 }
264330
265- // Establish SSH over the Bastion tunnel
266- sshConn , chans , reqs , err := ssh .NewClientConn (
267- sshSide ,
268- vmPrivateIP ,
269- sshConfig ,
270- )
271- if err != nil {
272- sshSide .Close ()
273- return nil , err
331+ if lastErr == nil {
332+ lastErr = fmt .Errorf ("failed to establish SSH connection over bastion" )
274333 }
275-
276- return ssh .NewClient (sshConn , chans , reqs ), nil
334+ return nil , lastErr
277335}
0 commit comments