99 "sync"
1010 "time"
1111
12+ "github.com/avast/retry-go"
1213 "github.com/limanmys/render-engine/pkg/logger"
1314 "github.com/phayes/freeport"
1415 "golang.org/x/crypto/ssh"
@@ -44,26 +45,28 @@ var mut sync.Mutex = sync.Mutex{}
4445
4546// CreateTunnel starts a new tunnel instance and sets it into TunnelPool
4647func CreateTunnel (remoteHost , remotePort , username , password , sshPort string ) int {
48+ mut .Lock ()
49+ defer mut .Unlock ()
50+
4751 ch := make (chan int )
4852 time .AfterFunc (30 * time .Second , func () {
4953 ch <- 1
5054 })
51-
5255 t , err := Tunnels .Get (remoteHost , remotePort , username )
5356 if err == nil {
5457 if t .password != password {
5558 return 0
5659 }
5760
58- OL :
61+ startedLoop :
5962 for {
6063 if t .Started {
6164 break
6265 }
6366
6467 select {
6568 case <- ch :
66- break OL
69+ break startedLoop
6770 default :
6871 time .Sleep (5 * time .Millisecond )
6972 continue
@@ -74,9 +77,6 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in
7477 return t .Port
7578 }
7679
77- mut .Lock ()
78- defer mut .Unlock ()
79-
8080 port , err := freeport .GetFreePort ()
8181 if err != nil {
8282 logger .Sugar ().Errorw (err .Error ())
@@ -92,7 +92,7 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in
9292 }
9393
9494 sshTunnel := & Tunnel {
95- auth : []ssh.AuthMethod {ssh .Password (password )},
95+ auth : []ssh.AuthMethod {ssh .RetryableAuthMethod ( ssh . Password (password ), 3 )},
9696 hostKeys : ssh .InsecureIgnoreHostKey (),
9797 user : username ,
9898 mode : '>' ,
@@ -114,15 +114,15 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in
114114
115115 hasError := sshTunnel .Start ()
116116 if ! hasError {
117- L :
117+ loop :
118118 for {
119119 if sshTunnel .Started {
120120 break
121121 }
122122
123123 select {
124124 case <- ch :
125- break L
125+ break loop
126126 default :
127127 time .Sleep (5 * time .Millisecond )
128128 continue
@@ -177,20 +177,33 @@ func (t *Tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup, hasError *b
177177 for {
178178 var once sync.Once // Only print errors once per session
179179 func () {
180- // Connect to the server host via SSH.
181- cl , err := ssh .Dial ("tcp" , t .hostAddr , & ssh.ClientConfig {
182- User : t .user ,
183- Auth : t .auth ,
184- HostKeyCallback : t .hostKeys ,
185- Timeout : 5 * time .Second ,
186- })
180+ var cl * ssh.Client
181+ var err error
182+
183+ err = retry .Do (
184+ func () error {
185+ cl , err = ssh .Dial ("tcp" , t .hostAddr , & ssh.ClientConfig {
186+ User : t .user ,
187+ Auth : t .auth ,
188+ HostKeyCallback : t .hostKeys ,
189+ Timeout : 5 * time .Second ,
190+ })
191+ if err != nil {
192+ return err
193+ }
194+ return nil
195+ },
196+ retry .Attempts (5 ),
197+ retry .Delay (1 * time .Second ),
198+ )
199+
187200 if err != nil {
188201 once .Do (func () {
189202 t .log .Errorw ("ssh dial error" , "details" , fmt .Sprintf ("%v, %v" , t , err ))
190- t .errHandler ()
191203 t .Stop ()
192204 * hasError = true
193205 wg .Done ()
206+ t .errHandler ()
194207 })
195208 return
196209 }
@@ -210,10 +223,10 @@ func (t *Tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup, hasError *b
210223 if err != nil {
211224 once .Do (func () {
212225 t .log .Errorw ("bind error" , "details" , fmt .Sprintf ("%v, %v" , t , err ))
213- t .errHandler ()
214226 t .Stop ()
215227 * hasError = true
216228 wg .Done ()
229+ t .errHandler ()
217230 })
218231 return
219232 }
@@ -244,8 +257,8 @@ func (t *Tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup, hasError *b
244257 t .log .Errorw ("accept error" , "details" , fmt .Sprintf ("%v, %v" , t , err ))
245258 t .Stop ()
246259 * hasError = true
247- t .errHandler ()
248260 wg .Done ()
261+ t .errHandler ()
249262 })
250263 return
251264 }
@@ -273,21 +286,35 @@ func (t *Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh
273286 }()
274287
275288 // Establish the outbound connection.
289+ var once sync.Once
276290 var cn2 net.Conn
277291 var err error
278- switch t .mode {
279- case '>' :
280- cn2 , err = client .Dial (t .dialType , t .dialAddr )
281- case '<' :
282- cn2 , err = net .Dial (t .dialType , t .dialAddr )
283- }
284- if err != nil {
285- t .Stop ()
286- t .log .Errorw ("ssh dial error" , "details" , fmt .Sprintf ("%v, %v" , t , err ))
287- t .errHandler ()
288- * hasError = true
289292
290- wg .Done ()
293+ err = retry .Do (
294+ func () error {
295+ switch t .mode {
296+ case '>' :
297+ cn2 , err = client .Dial (t .dialType , t .dialAddr )
298+ case '<' :
299+ cn2 , err = net .Dial (t .dialType , t .dialAddr )
300+ }
301+
302+ if err != nil {
303+ return err
304+ }
305+ return nil
306+ },
307+ retry .Attempts (5 ),
308+ retry .Delay (1 * time .Second ),
309+ )
310+ if err != nil {
311+ once .Do (func () {
312+ t .Stop ()
313+ t .log .Errorw ("ssh dial error" , "details" , fmt .Sprintf ("%v, %v" , t , err ))
314+ * hasError = true
315+ wg .Done ()
316+ t .errHandler ()
317+ })
291318 return
292319 }
293320
@@ -300,7 +327,6 @@ func (t *Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh
300327 //defer t.log.Infow("connection closed", "details", t)
301328
302329 // Copy bytes from one connection to the other until one side closes.
303- var once sync.Once
304330 var wg2 sync.WaitGroup
305331 wg2 .Add (2 )
306332 go func () {
0 commit comments