Skip to content

Commit b9ebb54

Browse files
committed
htlcswitch: start threading contexts through
The switch makes a call to a `FetchLastChannelUpdate` call-back. This is a graph DB call behind the scenes and so will eventually take a context. In preparation for this, we start threading contexts through the htlcswitch sub-system to the various call sites of `s.cfg.FetchLastChannelUpdate`.
1 parent e214b57 commit b9ebb54

File tree

7 files changed

+149
-115
lines changed

7 files changed

+149
-115
lines changed

htlcswitch/interfaces.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ type ChannelLink interface {
281281
// a LinkError with a valid protocol failure message should be returned
282282
// in order to signal to the source of the HTLC, the policy consistency
283283
// issue.
284-
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
284+
CheckHtlcForward(ctx context.Context, payHash [32]byte,
285+
incomingAmt lnwire.MilliSatoshi,
285286
amtToForward lnwire.MilliSatoshi, incomingTimeout,
286287
outgoingTimeout uint32, inboundFee models.InboundFee,
287288
heightNow uint32, scid lnwire.ShortChannelID,
@@ -292,7 +293,8 @@ type ChannelLink interface {
292293
// valid protocol failure message should be returned in order to signal
293294
// the violation. This call is intended to be used for locally initiated
294295
// payments for which there is no corresponding incoming htlc.
295-
CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi,
296+
CheckHtlcTransit(ctx context.Context, payHash [32]byte,
297+
amt lnwire.MilliSatoshi,
296298
timeout uint32, heightNow uint32,
297299
customRecords lnwire.CustomRecords) *LinkError
298300

htlcswitch/link.go

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ type failCb func(update *lnwire.ChannelUpdate1) lnwire.FailureMessage
859859
// outgoing HTLC. It may return a FailureMessage that references a channel's
860860
// alias. If the channel does not have an alias, then the regular channel
861861
// update from disk will be returned.
862-
func (l *channelLink) createFailureWithUpdate(incoming bool,
862+
func (l *channelLink) createFailureWithUpdate(_ context.Context, incoming bool,
863863
outgoingScid lnwire.ShortChannelID, cb failCb) lnwire.FailureMessage {
864864

865865
// Determine which SCID to use in case we need to use aliases in the
@@ -1043,7 +1043,7 @@ func (l *channelLink) resolveFwdPkgs(ctx context.Context) error {
10431043
l.log.Debugf("loaded %d fwd pks", len(fwdPkgs))
10441044

10451045
for _, fwdPkg := range fwdPkgs {
1046-
if err := l.resolveFwdPkg(fwdPkg); err != nil {
1046+
if err := l.resolveFwdPkg(ctx, fwdPkg); err != nil {
10471047
return err
10481048
}
10491049
}
@@ -1060,7 +1060,9 @@ func (l *channelLink) resolveFwdPkgs(ctx context.Context) error {
10601060
// resolveFwdPkg interprets the FwdState of the provided package, either
10611061
// reprocesses any outstanding htlcs in the package, or performs garbage
10621062
// collection on the package.
1063-
func (l *channelLink) resolveFwdPkg(fwdPkg *channeldb.FwdPkg) error {
1063+
func (l *channelLink) resolveFwdPkg(ctx context.Context,
1064+
fwdPkg *channeldb.FwdPkg) error {
1065+
10641066
// Remove any completed packages to clear up space.
10651067
if fwdPkg.State == channeldb.FwdStateCompleted {
10661068
l.log.Debugf("removing completed fwd pkg for height=%d",
@@ -1091,7 +1093,7 @@ func (l *channelLink) resolveFwdPkg(fwdPkg *channeldb.FwdPkg) error {
10911093
// shove the entire, original set of adds down the pipeline so that the
10921094
// batch of adds presented to the sphinx router does not ever change.
10931095
if !fwdPkg.AckFilter.IsFull() {
1094-
l.processRemoteAdds(fwdPkg)
1096+
l.processRemoteAdds(ctx, fwdPkg)
10951097

10961098
// If the link failed during processing the adds, we must
10971099
// return to ensure we won't attempted to update the state
@@ -2539,10 +2541,10 @@ func (l *channelLink) handleUpstreamMsg(ctx context.Context,
25392541
// check since processing those can't result in further updates
25402542
// to this channel link.
25412543
if l.quiescer.CanSendUpdates() {
2542-
l.processRemoteAdds(fwdPkg)
2544+
l.processRemoteAdds(ctx, fwdPkg)
25432545
} else {
25442546
l.quiescer.OnResume(func() {
2545-
l.processRemoteAdds(fwdPkg)
2547+
l.processRemoteAdds(ctx, fwdPkg)
25462548
})
25472549
}
25482550
l.processRemoteSettleFails(fwdPkg)
@@ -3257,8 +3259,8 @@ func (l *channelLink) UpdateForwardingPolicy(
32573259
// issue.
32583260
//
32593261
// NOTE: Part of the ChannelLink interface.
3260-
func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
3261-
amtToForward lnwire.MilliSatoshi, incomingTimeout,
3262+
func (l *channelLink) CheckHtlcForward(ctx context.Context, payHash [32]byte,
3263+
incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, incomingTimeout,
32623264
outgoingTimeout uint32, inboundFee models.InboundFee,
32633265
heightNow uint32, originalScid lnwire.ShortChannelID,
32643266
customRecords lnwire.CustomRecords) *LinkError {
@@ -3303,13 +3305,15 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
33033305
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
33043306
return lnwire.NewFeeInsufficient(amtToForward, *upd)
33053307
}
3306-
failure := l.createFailureWithUpdate(false, originalScid, cb)
3308+
failure := l.createFailureWithUpdate(
3309+
ctx, false, originalScid, cb,
3310+
)
33073311
return NewLinkError(failure)
33083312
}
33093313

33103314
// Check whether the outgoing htlc satisfies the channel policy.
33113315
err := l.canSendHtlc(
3312-
policy, payHash, amtToForward, outgoingTimeout, heightNow,
3316+
ctx, policy, payHash, amtToForward, outgoingTimeout, heightNow,
33133317
originalScid, customRecords,
33143318
)
33153319
if err != nil {
@@ -3333,7 +3337,10 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
33333337
incomingTimeout, *upd,
33343338
)
33353339
}
3336-
failure := l.createFailureWithUpdate(false, originalScid, cb)
3340+
failure := l.createFailureWithUpdate(
3341+
ctx, false, originalScid, cb,
3342+
)
3343+
33373344
return NewLinkError(failure)
33383345
}
33393346

@@ -3345,7 +3352,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
33453352
// valid protocol failure message should be returned in order to signal
33463353
// the violation. This call is intended to be used for locally initiated
33473354
// payments for which there is no corresponding incoming htlc.
3348-
func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
3355+
func (l *channelLink) CheckHtlcTransit(ctx context.Context, payHash [32]byte,
33493356
amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32,
33503357
customRecords lnwire.CustomRecords) *LinkError {
33513358

@@ -3357,15 +3364,16 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
33573364
// trying to send over a local link. This causes the fallback mechanism
33583365
// to occur.
33593366
return l.canSendHtlc(
3360-
policy, payHash, amt, timeout, heightNow, hop.Source,
3367+
ctx, policy, payHash, amt, timeout, heightNow, hop.Source,
33613368
customRecords,
33623369
)
33633370
}
33643371

33653372
// canSendHtlc checks whether the given htlc parameters satisfy
33663373
// the channel's amount and time lock constraints.
3367-
func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
3368-
payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32,
3374+
func (l *channelLink) canSendHtlc(ctx context.Context,
3375+
policy models.ForwardingPolicy, payHash [32]byte,
3376+
amt lnwire.MilliSatoshi, timeout uint32,
33693377
heightNow uint32, originalScid lnwire.ShortChannelID,
33703378
customRecords lnwire.CustomRecords) *LinkError {
33713379

@@ -3382,7 +3390,10 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
33823390
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
33833391
return lnwire.NewAmountBelowMinimum(amt, *upd)
33843392
}
3385-
failure := l.createFailureWithUpdate(false, originalScid, cb)
3393+
failure := l.createFailureWithUpdate(
3394+
ctx, false, originalScid, cb,
3395+
)
3396+
33863397
return NewLinkError(failure)
33873398
}
33883399

@@ -3397,8 +3408,13 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
33973408
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
33983409
return lnwire.NewTemporaryChannelFailure(upd)
33993410
}
3400-
failure := l.createFailureWithUpdate(false, originalScid, cb)
3401-
return NewDetailedLinkError(failure, OutgoingFailureHTLCExceedsMax)
3411+
failure := l.createFailureWithUpdate(
3412+
ctx, false, originalScid, cb,
3413+
)
3414+
3415+
return NewDetailedLinkError(
3416+
failure, OutgoingFailureHTLCExceedsMax,
3417+
)
34023418
}
34033419

34043420
// We want to avoid offering an HTLC which will expire in the near
@@ -3412,7 +3428,10 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
34123428
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
34133429
return lnwire.NewExpiryTooSoon(*upd)
34143430
}
3415-
failure := l.createFailureWithUpdate(false, originalScid, cb)
3431+
failure := l.createFailureWithUpdate(
3432+
ctx, false, originalScid, cb,
3433+
)
3434+
34163435
return NewLinkError(failure)
34173436
}
34183437

@@ -3466,7 +3485,10 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
34663485
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
34673486
return lnwire.NewTemporaryChannelFailure(upd)
34683487
}
3469-
failure := l.createFailureWithUpdate(false, originalScid, cb)
3488+
failure := l.createFailureWithUpdate(
3489+
ctx, false, originalScid, cb,
3490+
)
3491+
34703492
return NewDetailedLinkError(
34713493
failure, OutgoingFailureInsufficientBalance,
34723494
)
@@ -3726,7 +3748,9 @@ func (l *channelLink) processRemoteSettleFails(fwdPkg *channeldb.FwdPkg) {
37263748
// have already been acknowledged in the forwarding package will be ignored.
37273749
//
37283750
//nolint:funlen
3729-
func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) {
3751+
func (l *channelLink) processRemoteAdds(ctx context.Context,
3752+
fwdPkg *channeldb.FwdPkg) {
3753+
37303754
l.log.Tracef("processing %d remote adds for height %d",
37313755
len(fwdPkg.Adds), fwdPkg.Height)
37323756

@@ -4049,7 +4073,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) {
40494073
}
40504074

40514075
failure := l.createFailureWithUpdate(
4052-
true, hop.Source, cb,
4076+
ctx, true, hop.Source, cb,
40534077
)
40544078

40554079
l.sendHTLCError(

htlcswitch/link_test.go

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6206,6 +6206,9 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) {
62066206
// TestCheckHtlcForward tests that a link is properly enforcing the HTLC
62076207
// forwarding policy.
62086208
func TestCheckHtlcForward(t *testing.T) {
6209+
t.Parallel()
6210+
ctx := context.Background()
6211+
62096212
fetchLastChannelUpdate := func(lnwire.ShortChannelID) (
62106213
*lnwire.ChannelUpdate1, error) {
62116214

@@ -6248,7 +6251,7 @@ func TestCheckHtlcForward(t *testing.T) {
62486251

62496252
t.Run("satisfied", func(t *testing.T) {
62506253
result := link.CheckHtlcForward(
6251-
hash, 1500, 1000, 200, 150, models.InboundFee{}, 0,
6254+
ctx, hash, 1500, 1000, 200, 150, models.InboundFee{}, 0,
62526255
lnwire.ShortChannelID{}, nil,
62536256
)
62546257
if result != nil {
@@ -6258,32 +6261,29 @@ func TestCheckHtlcForward(t *testing.T) {
62586261

62596262
t.Run("below minhtlc", func(t *testing.T) {
62606263
result := link.CheckHtlcForward(
6261-
hash, 100, 50, 200, 150, models.InboundFee{}, 0,
6264+
ctx, hash, 100, 50, 200, 150, models.InboundFee{}, 0,
62626265
lnwire.ShortChannelID{}, nil,
62636266
)
6264-
if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok {
6265-
t.Fatalf("expected FailAmountBelowMinimum failure code")
6266-
}
6267+
_, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum)
6268+
require.True(t, ok)
62676269
})
62686270

62696271
t.Run("above maxhtlc", func(t *testing.T) {
62706272
result := link.CheckHtlcForward(
6271-
hash, 1500, 1200, 200, 150, models.InboundFee{}, 0,
6273+
ctx, hash, 1500, 1200, 200, 150, models.InboundFee{}, 0,
62726274
lnwire.ShortChannelID{}, nil,
62736275
)
6274-
if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok {
6275-
t.Fatalf("expected FailTemporaryChannelFailure failure code")
6276-
}
6276+
_, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure) //nolint:ll
6277+
require.True(t, ok)
62776278
})
62786279

62796280
t.Run("insufficient fee", func(t *testing.T) {
62806281
result := link.CheckHtlcForward(
6281-
hash, 1005, 1000, 200, 150, models.InboundFee{}, 0,
6282+
ctx, hash, 1005, 1000, 200, 150, models.InboundFee{}, 0,
62826283
lnwire.ShortChannelID{}, nil,
62836284
)
6284-
if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok {
6285-
t.Fatalf("expected FailFeeInsufficient failure code")
6286-
}
6285+
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
6286+
require.True(t, ok)
62876287
})
62886288

62896289
// Test that insufficient fee error takes preference over insufficient
@@ -6292,50 +6292,46 @@ func TestCheckHtlcForward(t *testing.T) {
62926292
t.Parallel()
62936293

62946294
result := link.CheckHtlcForward(
6295-
hash, 100005, 100000, 200, 150, models.InboundFee{}, 0,
6296-
lnwire.ShortChannelID{}, nil,
6295+
ctx, hash, 100005, 100000, 200, 150,
6296+
models.InboundFee{}, 0, lnwire.ShortChannelID{}, nil,
62976297
)
62986298
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
62996299
require.True(t, ok, "expected FailFeeInsufficient failure code")
63006300
})
63016301

63026302
t.Run("expiry too soon", func(t *testing.T) {
63036303
result := link.CheckHtlcForward(
6304-
hash, 1500, 1000, 200, 150, models.InboundFee{}, 190,
6305-
lnwire.ShortChannelID{}, nil,
6304+
ctx, hash, 1500, 1000, 200, 150, models.InboundFee{},
6305+
190, lnwire.ShortChannelID{}, nil,
63066306
)
6307-
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok {
6308-
t.Fatalf("expected FailExpiryTooSoon failure code")
6309-
}
6307+
_, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon)
6308+
require.True(t, ok)
63106309
})
63116310

63126311
t.Run("incorrect cltv expiry", func(t *testing.T) {
63136312
result := link.CheckHtlcForward(
6314-
hash, 1500, 1000, 200, 190, models.InboundFee{}, 0,
6313+
ctx, hash, 1500, 1000, 200, 190, models.InboundFee{}, 0,
63156314
lnwire.ShortChannelID{}, nil,
63166315
)
6317-
if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok {
6318-
t.Fatalf("expected FailIncorrectCltvExpiry failure code")
6319-
}
6320-
6316+
_, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry)
6317+
require.True(t, ok)
63216318
})
63226319

63236320
t.Run("cltv expiry too far in the future", func(t *testing.T) {
63246321
// Check that expiry isn't too far in the future.
63256322
result := link.CheckHtlcForward(
6326-
hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0,
6327-
lnwire.ShortChannelID{}, nil,
6323+
ctx, hash, 1500, 1000, 10200, 10100,
6324+
models.InboundFee{}, 0, lnwire.ShortChannelID{}, nil,
63286325
)
6329-
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok {
6330-
t.Fatalf("expected FailExpiryTooFar failure code")
6331-
}
6326+
_, ok := result.WireMessage().(*lnwire.FailExpiryTooFar)
6327+
require.True(t, ok)
63326328
})
63336329

63346330
t.Run("inbound fee satisfied", func(t *testing.T) {
63356331
t.Parallel()
63366332

63376333
result := link.CheckHtlcForward(
6338-
hash, 1000+10-2-1, 1000, 200, 150,
6334+
ctx, hash, 1000+10-2-1, 1000, 200, 150,
63396335
models.InboundFee{Base: -2, Rate: -1_000},
63406336
0, lnwire.ShortChannelID{}, nil,
63416337
)
@@ -6348,7 +6344,7 @@ func TestCheckHtlcForward(t *testing.T) {
63486344
t.Parallel()
63496345

63506346
result := link.CheckHtlcForward(
6351-
hash, 1000+10-10-101-1, 1000,
6347+
ctx, hash, 1000+10-10-101-1, 1000,
63526348
200, 150, models.InboundFee{Base: -10, Rate: -100_000},
63536349
0, lnwire.ShortChannelID{}, nil,
63546350
)

htlcswitch/mock.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ func (s *mockServer) Start() error {
269269
return errors.New("mock server already started")
270270
}
271271

272-
if err := s.htlcSwitch.Start(); err != nil {
272+
if err := s.htlcSwitch.Start(context.Background()); err != nil {
273273
return err
274274
}
275275

@@ -844,14 +844,15 @@ func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
844844

845845
func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
846846
}
847-
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
848-
lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
849-
lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError {
847+
func (f *mockChannelLink) CheckHtlcForward(context.Context, [32]byte,
848+
lnwire.MilliSatoshi, lnwire.MilliSatoshi, uint32, uint32,
849+
models.InboundFee, uint32, lnwire.ShortChannelID,
850+
lnwire.CustomRecords) *LinkError {
850851

851852
return f.checkHtlcForwardResult
852853
}
853854

854-
func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
855+
func (f *mockChannelLink) CheckHtlcTransit(_ context.Context, payHash [32]byte,
855856
amt lnwire.MilliSatoshi, timeout uint32,
856857
heightNow uint32, _ lnwire.CustomRecords) *LinkError {
857858

0 commit comments

Comments
 (0)