1
1
package htlcswitch
2
2
3
3
import (
4
+ "context"
4
5
"crypto/sha256"
5
6
"fmt"
6
7
"sync"
@@ -89,8 +90,9 @@ type InterceptableSwitch struct {
89
90
// currentHeight is the currently best known height.
90
91
currentHeight int32
91
92
92
- wg sync.WaitGroup
93
- quit chan struct {}
93
+ wg sync.WaitGroup
94
+ quit chan struct {}
95
+ cancel fn.Option [context.CancelFunc ]
94
96
}
95
97
96
98
type interceptedPackets struct {
@@ -222,12 +224,14 @@ func (s *InterceptableSwitch) SetInterceptor(
222
224
}
223
225
}
224
226
225
- func (s * InterceptableSwitch ) Start () error {
227
+ func (s * InterceptableSwitch ) Start (ctx context. Context ) error {
226
228
log .Info ("InterceptableSwitch starting..." )
227
229
228
230
if s .started .Swap (true ) {
229
231
return fmt .Errorf ("InterceptableSwitch started more than once" )
230
232
}
233
+ ctx , cancel := context .WithCancel (ctx )
234
+ s .cancel = fn .Some (cancel )
231
235
232
236
blockEpochStream , err := s .notifier .RegisterBlockEpochNtfn (nil )
233
237
if err != nil {
@@ -239,7 +243,7 @@ func (s *InterceptableSwitch) Start() error {
239
243
go func () {
240
244
defer s .wg .Done ()
241
245
242
- err := s .run ()
246
+ err := s .run (ctx )
243
247
if err != nil {
244
248
log .Errorf ("InterceptableSwitch stopped: %v" , err )
245
249
}
@@ -257,6 +261,7 @@ func (s *InterceptableSwitch) Stop() error {
257
261
return fmt .Errorf ("InterceptableSwitch stopped more than once" )
258
262
}
259
263
264
+ s .cancel .WhenSome (func (fn context.CancelFunc ) { fn () })
260
265
close (s .quit )
261
266
s .wg .Wait ()
262
267
@@ -271,7 +276,7 @@ func (s *InterceptableSwitch) Stop() error {
271
276
return nil
272
277
}
273
278
274
- func (s * InterceptableSwitch ) run () error {
279
+ func (s * InterceptableSwitch ) run (ctx context. Context ) error {
275
280
// The block epoch stream will immediately stream the current height.
276
281
// Read it out here.
277
282
select {
@@ -298,7 +303,7 @@ func (s *InterceptableSwitch) run() error {
298
303
var notIntercepted []* htlcPacket
299
304
for _ , p := range packets .packets {
300
305
intercepted , err := s .interceptForward (
301
- p , packets .isReplay ,
306
+ ctx , p , packets .isReplay ,
302
307
)
303
308
if err != nil {
304
309
return err
@@ -325,12 +330,12 @@ func (s *InterceptableSwitch) run() error {
325
330
// already intercepted in the off-chain flow. And even
326
331
// if not, it is safe to signal replay so that we won't
327
332
// unexpectedly skip over this htlc.
328
- if _ , err := s .forward (fwd , true ); err != nil {
333
+ if _ , err := s .forward (ctx , fwd , true ); err != nil {
329
334
return err
330
335
}
331
336
332
337
case res := <- s .resolutionChan :
333
- res .errChan <- s .resolve (res .resolution )
338
+ res .errChan <- s .resolve (ctx , res .resolution )
334
339
335
340
case currentBlock , ok := <- s .blockEpochStream .Epochs :
336
341
if ! ok {
@@ -341,20 +346,23 @@ func (s *InterceptableSwitch) run() error {
341
346
342
347
// A new block is appended. Fail any held htlcs that
343
348
// expire at this height to prevent channel force-close.
344
- s .failExpiredHtlcs ()
349
+ s .failExpiredHtlcs (ctx )
345
350
346
351
case <- s .quit :
347
352
return nil
353
+
354
+ case <- ctx .Done ():
355
+ return ctx .Err ()
348
356
}
349
357
}
350
358
}
351
359
352
- func (s * InterceptableSwitch ) failExpiredHtlcs () {
360
+ func (s * InterceptableSwitch ) failExpiredHtlcs (ctx context. Context ) {
353
361
s .heldHtlcSet .popAutoFails (
354
- uint32 (s .currentHeight ),
355
- func (fwd InterceptedForward ) {
362
+ ctx , uint32 (s .currentHeight ),
363
+ func (ctx context. Context , fwd InterceptedForward ) {
356
364
err := fwd .FailWithCode (
357
- lnwire .CodeTemporaryChannelFailure ,
365
+ ctx , lnwire .CodeTemporaryChannelFailure ,
358
366
)
359
367
if err != nil {
360
368
log .Errorf ("Cannot fail packet: %v" , err )
@@ -407,7 +415,9 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
407
415
408
416
// resolve processes a HTLC given the resolution type specified by the
409
417
// intercepting client.
410
- func (s * InterceptableSwitch ) resolve (res * FwdResolution ) error {
418
+ func (s * InterceptableSwitch ) resolve (ctx context.Context ,
419
+ res * FwdResolution ) error {
420
+
411
421
intercepted , err := s .heldHtlcSet .pop (res .Key )
412
422
if err != nil {
413
423
return err
@@ -431,7 +441,7 @@ func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
431
441
return intercepted .Fail (res .FailureMessage )
432
442
}
433
443
434
- return intercepted .FailWithCode (res .FailureCode )
444
+ return intercepted .FailWithCode (ctx , res .FailureCode )
435
445
436
446
default :
437
447
return fmt .Errorf ("unrecognized action %v" , res .Action )
@@ -503,8 +513,8 @@ func (s *InterceptableSwitch) ForwardPacket(
503
513
504
514
// interceptForward forwards the packet to the external interceptor after
505
515
// checking the interception criteria.
506
- func (s * InterceptableSwitch ) interceptForward (packet * htlcPacket ,
507
- isReplay bool ) (bool , error ) {
516
+ func (s * InterceptableSwitch ) interceptForward (ctx context. Context ,
517
+ packet * htlcPacket , isReplay bool ) (bool , error ) {
508
518
509
519
switch htlc := packet .htlc .(type ) {
510
520
case * lnwire.UpdateAddHTLC :
@@ -522,7 +532,7 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
522
532
}
523
533
524
534
// Handle forwards that are too close to expiry.
525
- handled , err := s .handleExpired (intercepted )
535
+ handled , err := s .handleExpired (ctx , intercepted )
526
536
if err != nil {
527
537
log .Errorf ("Error handling intercepted htlc " +
528
538
"that expires too soon: circuit=%v, " +
@@ -542,15 +552,15 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
542
552
return true , nil
543
553
}
544
554
545
- return s .forward (intercepted , isReplay )
555
+ return s .forward (ctx , intercepted , isReplay )
546
556
547
557
default :
548
558
return false , nil
549
559
}
550
560
}
551
561
552
562
// forward records the intercepted htlc and forwards it to the interceptor.
553
- func (s * InterceptableSwitch ) forward (
563
+ func (s * InterceptableSwitch ) forward (ctx context. Context ,
554
564
fwd InterceptedForward , isReplay bool ) (bool , error ) {
555
565
556
566
inKey := fwd .Packet ().IncomingCircuit
@@ -573,7 +583,7 @@ func (s *InterceptableSwitch) forward(
573
583
// yet. This limits the backlog of htlcs when the interceptor is down.
574
584
if ! isReplay {
575
585
err := fwd .FailWithCode (
576
- lnwire .CodeTemporaryChannelFailure ,
586
+ ctx , lnwire .CodeTemporaryChannelFailure ,
577
587
)
578
588
if err != nil {
579
589
log .Errorf ("Cannot fail packet: %v" , err )
@@ -605,8 +615,8 @@ func (s *InterceptableSwitch) forward(
605
615
606
616
// handleExpired checks that the htlc isn't too close to the channel
607
617
// force-close broadcast height. If it is, it is cancelled back.
608
- func (s * InterceptableSwitch ) handleExpired (fwd * interceptedForward ) (
609
- bool , error ) {
618
+ func (s * InterceptableSwitch ) handleExpired (ctx context. Context ,
619
+ fwd * interceptedForward ) ( bool , error ) {
610
620
611
621
height := uint32 (s .currentHeight )
612
622
if fwd .packet .incomingTimeout >= height + s .cltvInterceptDelta {
@@ -620,7 +630,7 @@ func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
620
630
fwd .packet .incomingTimeout )
621
631
622
632
err := fwd .FailWithCode (
623
- lnwire .CodeExpiryTooSoon ,
633
+ ctx , lnwire .CodeExpiryTooSoon ,
624
634
)
625
635
if err != nil {
626
636
return false , err
@@ -747,7 +757,9 @@ func (f *interceptedForward) Fail(reason []byte) error {
747
757
748
758
// FailWithCode notifies the intention to fail an existing hold forward with the
749
759
// specified failure code.
750
- func (f * interceptedForward ) FailWithCode (code lnwire.FailCode ) error {
760
+ func (f * interceptedForward ) FailWithCode (_ context.Context ,
761
+ code lnwire.FailCode ) error {
762
+
751
763
shaOnionBlob := func () [32 ]byte {
752
764
return sha256 .Sum256 (f .htlc .OnionBlob [:])
753
765
}
0 commit comments