Skip to content

Commit 5058588

Browse files
committed
htlcswitch: continue threading context through
Here, we focus on the `interceptForward.FailWithCode` method which also makes a call to the switch's `FetchLastChannelUpdate` method which will later take a context since it is a call to the graph DB.
1 parent 5282f0e commit 5058588

7 files changed

+65
-41
lines changed

htlcswitch/held_htlc_set.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package htlcswitch
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67

@@ -39,13 +40,15 @@ func (h *heldHtlcSet) popAll(cb func(InterceptedForward)) {
3940

4041
// popAutoFails calls the callback for each forward that has an auto-fail height
4142
// equal or less then the specified pop height and removes them from the set.
42-
func (h *heldHtlcSet) popAutoFails(height uint32, cb func(InterceptedForward)) {
43+
func (h *heldHtlcSet) popAutoFails(ctx context.Context, height uint32,
44+
cb func(context.Context, InterceptedForward)) {
45+
4346
for key, fwd := range h.set {
4447
if uint32(fwd.Packet().AutoFailHeight) > height {
4548
continue
4649
}
4750

48-
cb(fwd)
51+
cb(ctx, fwd)
4952

5053
delete(h.set, key)
5154
}

htlcswitch/held_htlc_set_test.go

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package htlcswitch
22

33
import (
4+
"context"
45
"testing"
56

67
"github.com/lightningnetwork/lnd/graph/db/models"
@@ -80,6 +81,8 @@ func TestHeldHtlcSet(t *testing.T) {
8081
}
8182

8283
func TestHeldHtlcSetAutoFails(t *testing.T) {
84+
t.Parallel()
85+
ctx := context.Background()
8386
set := newHeldHtlcSet()
8487

8588
key := models.CircuitKey{
@@ -98,17 +101,17 @@ func TestHeldHtlcSetAutoFails(t *testing.T) {
98101
// Test popping auto fails up to one block before the auto-fail height
99102
// of our forward.
100103
set.popAutoFails(
101-
autoFailHeight-1,
102-
func(_ InterceptedForward) {
104+
ctx, autoFailHeight-1,
105+
func(_ context.Context, _ InterceptedForward) {
103106
require.Fail(t, "unexpected fwd")
104107
},
105108
)
106109

107110
// Popping succeeds at the auto-fail height.
108111
cbCalled := false
109112
set.popAutoFails(
110-
autoFailHeight,
111-
func(poppedFwd InterceptedForward) {
113+
ctx, autoFailHeight,
114+
func(_ context.Context, poppedFwd InterceptedForward) {
112115
cbCalled = true
113116

114117
require.Equal(t, fwd, poppedFwd)
@@ -118,8 +121,8 @@ func TestHeldHtlcSetAutoFails(t *testing.T) {
118121

119122
// After this, there should be nothing more to pop.
120123
set.popAutoFails(
121-
autoFailHeight,
122-
func(_ InterceptedForward) {
124+
ctx, autoFailHeight,
125+
func(_ context.Context, _ InterceptedForward) {
123126
require.Fail(t, "unexpected fwd")
124127
},
125128
)

htlcswitch/interceptable_switch.go

+37-25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package htlcswitch
22

33
import (
4+
"context"
45
"crypto/sha256"
56
"fmt"
67
"sync"
@@ -89,8 +90,9 @@ type InterceptableSwitch struct {
8990
// currentHeight is the currently best known height.
9091
currentHeight int32
9192

92-
wg sync.WaitGroup
93-
quit chan struct{}
93+
wg sync.WaitGroup
94+
quit chan struct{}
95+
cancel fn.Option[context.CancelFunc]
9496
}
9597

9698
type interceptedPackets struct {
@@ -222,12 +224,14 @@ func (s *InterceptableSwitch) SetInterceptor(
222224
}
223225
}
224226

225-
func (s *InterceptableSwitch) Start() error {
227+
func (s *InterceptableSwitch) Start(ctx context.Context) error {
226228
log.Info("InterceptableSwitch starting...")
227229

228230
if s.started.Swap(true) {
229231
return fmt.Errorf("InterceptableSwitch started more than once")
230232
}
233+
ctx, cancel := context.WithCancel(ctx)
234+
s.cancel = fn.Some(cancel)
231235

232236
blockEpochStream, err := s.notifier.RegisterBlockEpochNtfn(nil)
233237
if err != nil {
@@ -239,7 +243,7 @@ func (s *InterceptableSwitch) Start() error {
239243
go func() {
240244
defer s.wg.Done()
241245

242-
err := s.run()
246+
err := s.run(ctx)
243247
if err != nil {
244248
log.Errorf("InterceptableSwitch stopped: %v", err)
245249
}
@@ -257,6 +261,7 @@ func (s *InterceptableSwitch) Stop() error {
257261
return fmt.Errorf("InterceptableSwitch stopped more than once")
258262
}
259263

264+
s.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
260265
close(s.quit)
261266
s.wg.Wait()
262267

@@ -271,7 +276,7 @@ func (s *InterceptableSwitch) Stop() error {
271276
return nil
272277
}
273278

274-
func (s *InterceptableSwitch) run() error {
279+
func (s *InterceptableSwitch) run(ctx context.Context) error {
275280
// The block epoch stream will immediately stream the current height.
276281
// Read it out here.
277282
select {
@@ -298,7 +303,7 @@ func (s *InterceptableSwitch) run() error {
298303
var notIntercepted []*htlcPacket
299304
for _, p := range packets.packets {
300305
intercepted, err := s.interceptForward(
301-
p, packets.isReplay,
306+
ctx, p, packets.isReplay,
302307
)
303308
if err != nil {
304309
return err
@@ -325,12 +330,12 @@ func (s *InterceptableSwitch) run() error {
325330
// already intercepted in the off-chain flow. And even
326331
// if not, it is safe to signal replay so that we won't
327332
// unexpectedly skip over this htlc.
328-
if _, err := s.forward(fwd, true); err != nil {
333+
if _, err := s.forward(ctx, fwd, true); err != nil {
329334
return err
330335
}
331336

332337
case res := <-s.resolutionChan:
333-
res.errChan <- s.resolve(res.resolution)
338+
res.errChan <- s.resolve(ctx, res.resolution)
334339

335340
case currentBlock, ok := <-s.blockEpochStream.Epochs:
336341
if !ok {
@@ -341,20 +346,23 @@ func (s *InterceptableSwitch) run() error {
341346

342347
// A new block is appended. Fail any held htlcs that
343348
// expire at this height to prevent channel force-close.
344-
s.failExpiredHtlcs()
349+
s.failExpiredHtlcs(ctx)
345350

346351
case <-s.quit:
347352
return nil
353+
354+
case <-ctx.Done():
355+
return ctx.Err()
348356
}
349357
}
350358
}
351359

352-
func (s *InterceptableSwitch) failExpiredHtlcs() {
360+
func (s *InterceptableSwitch) failExpiredHtlcs(ctx context.Context) {
353361
s.heldHtlcSet.popAutoFails(
354-
uint32(s.currentHeight),
355-
func(fwd InterceptedForward) {
362+
ctx, uint32(s.currentHeight),
363+
func(ctx context.Context, fwd InterceptedForward) {
356364
err := fwd.FailWithCode(
357-
lnwire.CodeTemporaryChannelFailure,
365+
ctx, lnwire.CodeTemporaryChannelFailure,
358366
)
359367
if err != nil {
360368
log.Errorf("Cannot fail packet: %v", err)
@@ -407,7 +415,9 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
407415

408416
// resolve processes a HTLC given the resolution type specified by the
409417
// intercepting client.
410-
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
418+
func (s *InterceptableSwitch) resolve(ctx context.Context,
419+
res *FwdResolution) error {
420+
411421
intercepted, err := s.heldHtlcSet.pop(res.Key)
412422
if err != nil {
413423
return err
@@ -431,7 +441,7 @@ func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
431441
return intercepted.Fail(res.FailureMessage)
432442
}
433443

434-
return intercepted.FailWithCode(res.FailureCode)
444+
return intercepted.FailWithCode(ctx, res.FailureCode)
435445

436446
default:
437447
return fmt.Errorf("unrecognized action %v", res.Action)
@@ -503,8 +513,8 @@ func (s *InterceptableSwitch) ForwardPacket(
503513

504514
// interceptForward forwards the packet to the external interceptor after
505515
// checking the interception criteria.
506-
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
507-
isReplay bool) (bool, error) {
516+
func (s *InterceptableSwitch) interceptForward(ctx context.Context,
517+
packet *htlcPacket, isReplay bool) (bool, error) {
508518

509519
switch htlc := packet.htlc.(type) {
510520
case *lnwire.UpdateAddHTLC:
@@ -522,7 +532,7 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
522532
}
523533

524534
// Handle forwards that are too close to expiry.
525-
handled, err := s.handleExpired(intercepted)
535+
handled, err := s.handleExpired(ctx, intercepted)
526536
if err != nil {
527537
log.Errorf("Error handling intercepted htlc "+
528538
"that expires too soon: circuit=%v, "+
@@ -542,15 +552,15 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
542552
return true, nil
543553
}
544554

545-
return s.forward(intercepted, isReplay)
555+
return s.forward(ctx, intercepted, isReplay)
546556

547557
default:
548558
return false, nil
549559
}
550560
}
551561

552562
// forward records the intercepted htlc and forwards it to the interceptor.
553-
func (s *InterceptableSwitch) forward(
563+
func (s *InterceptableSwitch) forward(ctx context.Context,
554564
fwd InterceptedForward, isReplay bool) (bool, error) {
555565

556566
inKey := fwd.Packet().IncomingCircuit
@@ -573,7 +583,7 @@ func (s *InterceptableSwitch) forward(
573583
// yet. This limits the backlog of htlcs when the interceptor is down.
574584
if !isReplay {
575585
err := fwd.FailWithCode(
576-
lnwire.CodeTemporaryChannelFailure,
586+
ctx, lnwire.CodeTemporaryChannelFailure,
577587
)
578588
if err != nil {
579589
log.Errorf("Cannot fail packet: %v", err)
@@ -605,8 +615,8 @@ func (s *InterceptableSwitch) forward(
605615

606616
// handleExpired checks that the htlc isn't too close to the channel
607617
// force-close broadcast height. If it is, it is cancelled back.
608-
func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
609-
bool, error) {
618+
func (s *InterceptableSwitch) handleExpired(ctx context.Context,
619+
fwd *interceptedForward) (bool, error) {
610620

611621
height := uint32(s.currentHeight)
612622
if fwd.packet.incomingTimeout >= height+s.cltvInterceptDelta {
@@ -620,7 +630,7 @@ func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
620630
fwd.packet.incomingTimeout)
621631

622632
err := fwd.FailWithCode(
623-
lnwire.CodeExpiryTooSoon,
633+
ctx, lnwire.CodeExpiryTooSoon,
624634
)
625635
if err != nil {
626636
return false, err
@@ -747,7 +757,9 @@ func (f *interceptedForward) Fail(reason []byte) error {
747757

748758
// FailWithCode notifies the intention to fail an existing hold forward with the
749759
// specified failure code.
750-
func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {
760+
func (f *interceptedForward) FailWithCode(_ context.Context,
761+
code lnwire.FailCode) error {
762+
751763
shaOnionBlob := func() [32]byte {
752764
return sha256.Sum256(f.htlc.OnionBlob[:])
753765
}

htlcswitch/interfaces.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ type InterceptedForward interface {
454454

455455
// FailWithCode notifies the intention to fail an existing hold forward
456456
// with the specified failure code.
457-
FailWithCode(code lnwire.FailCode) error
457+
FailWithCode(ctx context.Context, code lnwire.FailCode) error
458458
}
459459

460460
// htlcNotifier is an interface which represents the input side of the

htlcswitch/switch_test.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -3893,6 +3893,7 @@ func (c *interceptableSwitchTestContext) createSettlePacket(
38933893

38943894
func TestSwitchHoldForward(t *testing.T) {
38953895
t.Parallel()
3896+
ctx := context.Background()
38963897

38973898
c := newInterceptableSwitchTestContext(t)
38983899
defer c.finish()
@@ -3911,7 +3912,7 @@ func TestSwitchHoldForward(t *testing.T) {
39113912
},
39123913
)
39133914
require.NoError(t, err)
3914-
require.NoError(t, switchForwardInterceptor.Start())
3915+
require.NoError(t, switchForwardInterceptor.Start(ctx))
39153916

39163917
switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc)
39173918
linkQuit := make(chan struct{})
@@ -4115,7 +4116,7 @@ func TestSwitchHoldForward(t *testing.T) {
41154116
},
41164117
)
41174118
require.NoError(t, err)
4118-
require.NoError(t, switchForwardInterceptor.Start())
4119+
require.NoError(t, switchForwardInterceptor.Start(ctx))
41194120

41204121
// Forward a fresh packet. It is expected to be failed immediately,
41214122
// because there is no interceptor registered.
@@ -4183,6 +4184,7 @@ func TestSwitchHoldForward(t *testing.T) {
41834184

41844185
func TestInterceptableSwitchWatchDog(t *testing.T) {
41854186
t.Parallel()
4187+
ctx := context.Background()
41864188

41874189
c := newInterceptableSwitchTestContext(t)
41884190
defer c.finish()
@@ -4202,7 +4204,7 @@ func TestInterceptableSwitchWatchDog(t *testing.T) {
42024204
},
42034205
)
42044206
require.NoError(t, err)
4205-
require.NoError(t, switchForwardInterceptor.Start())
4207+
require.NoError(t, switchForwardInterceptor.Start(ctx))
42064208

42074209
// Set interceptor.
42084210
switchForwardInterceptor.SetInterceptor(
@@ -5439,6 +5441,7 @@ func TestSwitchAliasInterceptFail(t *testing.T) {
54395441

54405442
func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
54415443
t.Parallel()
5444+
ctx := context.Background()
54425445

54435446
chanID, aliceScid := genID()
54445447

@@ -5524,14 +5527,14 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
55245527
},
55255528
)
55265529
require.NoError(t, err)
5527-
require.NoError(t, interceptSwitch.Start())
5530+
require.NoError(t, interceptSwitch.Start(ctx))
55285531
interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
55295532

55305533
err = interceptSwitch.ForwardPackets(nil, false, ogPacket)
55315534
require.NoError(t, err)
55325535

55335536
inCircuit := forwardInterceptor.getIntercepted().IncomingCircuit
5534-
require.NoError(t, interceptSwitch.resolve(&FwdResolution{
5537+
require.NoError(t, interceptSwitch.resolve(ctx, &FwdResolution{
55355538
Action: FwdActionFail,
55365539
Key: inCircuit,
55375540
FailureCode: lnwire.CodeTemporaryChannelFailure,

intercepted_forward.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package lnd
22

33
import (
4+
"context"
45
"errors"
56

67
"github.com/lightningnetwork/lnd/fn/v2"
@@ -72,7 +73,9 @@ func (f *interceptedForward) Fail(_ []byte) error {
7273

7374
// FailWithCode notifies the intention to fail an existing hold forward with the
7475
// specified failure code.
75-
func (f *interceptedForward) FailWithCode(_ lnwire.FailCode) error {
76+
func (f *interceptedForward) FailWithCode(_ context.Context,
77+
_ lnwire.FailCode) error {
78+
7679
return ErrCannotFail
7780
}
7881

server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2353,7 +2353,7 @@ func (s *server) Start(ctx context.Context) error {
23532353
}
23542354

23552355
cleanup = cleanup.add(s.interceptableSwitch.Stop)
2356-
if err := s.interceptableSwitch.Start(); err != nil {
2356+
if err := s.interceptableSwitch.Start(ctx); err != nil {
23572357
startErr = err
23582358
return
23592359
}

0 commit comments

Comments
 (0)