Skip to content

Commit 4de31e2

Browse files
committed
htlcswitch: continue threading contexts through
In preparation for `failAliasUpdate` and `createFailureWithUpdate` calling `FetchLastChannelUpdate` which is a graph DB call that will eventually take a context. For `ForwardPackets`, we also make sure to listen on the newly added context param at any location where we listen on the `linkQuit` parameter. This commit adds a few `context.TODO()` calls that will be addressed in future commits.
1 parent 5058588 commit 4de31e2

10 files changed

+129
-94
lines changed

htlcswitch/interceptable_switch.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ func (s *InterceptableSwitch) run(ctx context.Context) error {
316316
}
317317
}
318318
err := s.htlcSwitch.ForwardPackets(
319-
packets.linkQuit, notIntercepted...,
319+
ctx, packets.linkQuit, notIntercepted...,
320320
)
321321
if err != nil {
322322
log.Errorf("Cannot forward packets: %v", err)
@@ -671,9 +671,10 @@ func (f *interceptedForward) Packet() InterceptedPacket {
671671

672672
// Resume resumes the default behavior as if the packet was not intercepted.
673673
func (f *interceptedForward) Resume() error {
674+
ctx := context.TODO()
674675
// Forward to the switch. A link quit channel isn't needed, because we
675676
// are on a different thread now.
676-
return f.htlcSwitch.ForwardPackets(nil, f.packet)
677+
return f.htlcSwitch.ForwardPackets(ctx, nil, f.packet)
677678
}
678679

679680
// ResumeModified resumes the default behavior with field modifications. The
@@ -686,6 +687,8 @@ func (f *interceptedForward) ResumeModified(
686687
outAmountMsat fn.Option[lnwire.MilliSatoshi],
687688
outWireCustomRecords fn.Option[lnwire.CustomRecords]) error {
688689

690+
ctx := context.TODO()
691+
689692
// Convert the optional custom records to the correct type and validate
690693
// them.
691694
validatedRecords, err := fn.MapOptionZ(
@@ -742,7 +745,7 @@ func (f *interceptedForward) ResumeModified(
742745

743746
// Forward to the switch. A link quit channel isn't needed, because we
744747
// are on a different thread now.
745-
return f.htlcSwitch.ForwardPackets(nil, f.packet)
748+
return f.htlcSwitch.ForwardPackets(ctx, nil, f.packet)
746749
}
747750

748751
// Fail notifies the intention to Fail an existing hold forward with an
@@ -757,7 +760,7 @@ func (f *interceptedForward) Fail(reason []byte) error {
757760

758761
// FailWithCode notifies the intention to fail an existing hold forward with the
759762
// specified failure code.
760-
func (f *interceptedForward) FailWithCode(_ context.Context,
763+
func (f *interceptedForward) FailWithCode(ctx context.Context,
761764
code lnwire.FailCode) error {
762765

763766
shaOnionBlob := func() [32]byte {
@@ -785,7 +788,7 @@ func (f *interceptedForward) FailWithCode(_ context.Context,
785788

786789
case lnwire.CodeTemporaryChannelFailure:
787790
update := f.htlcSwitch.failAliasUpdate(
788-
f.packet.incomingChanID, true,
791+
ctx, f.packet.incomingChanID, true,
789792
)
790793
if update == nil {
791794
// Fallback to the original, non-alias behavior.

htlcswitch/interfaces.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ type dustHandler interface {
8585
type scidAliasHandler interface {
8686
// attachFailAliasUpdate allows the link to properly fail incoming
8787
// HTLCs on option_scid_alias channels.
88-
attachFailAliasUpdate(failClosure func(
88+
attachFailAliasUpdate(failClosure func(ctx context.Context,
8989
sid lnwire.ShortChannelID,
9090
incoming bool) *lnwire.ChannelUpdate1)
9191

htlcswitch/link.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ type ChannelLinkConfig struct {
263263

264264
// FailAliasUpdate is a function used to fail an HTLC for an
265265
// option_scid_alias channel.
266-
FailAliasUpdate func(sid lnwire.ShortChannelID,
266+
FailAliasUpdate func(ctx context.Context, sid lnwire.ShortChannelID,
267267
incoming bool) *lnwire.ChannelUpdate1
268268

269269
// GetAliases is used by the link and switch to fetch the set of
@@ -859,8 +859,9 @@ type failCb func(update *lnwire.ChannelUpdate1) lnwire.FailureMessage
859859
// outgoing HTLC. It may return a FailureMessage that references a channel's
860860
// alias. If the channel does not have an alias, then the regular channel
861861
// update from disk will be returned.
862-
func (l *channelLink) createFailureWithUpdate(_ context.Context, incoming bool,
863-
outgoingScid lnwire.ShortChannelID, cb failCb) lnwire.FailureMessage {
862+
func (l *channelLink) createFailureWithUpdate(ctx context.Context,
863+
incoming bool, outgoingScid lnwire.ShortChannelID,
864+
cb failCb) lnwire.FailureMessage {
864865

865866
// Determine which SCID to use in case we need to use aliases in the
866867
// ChannelUpdate.
@@ -871,7 +872,7 @@ func (l *channelLink) createFailureWithUpdate(_ context.Context, incoming bool,
871872

872873
// Try using the FailAliasUpdate function. If it returns nil, fallback
873874
// to the non-alias behavior.
874-
update := l.cfg.FailAliasUpdate(scid, incoming)
875+
update := l.cfg.FailAliasUpdate(ctx, scid, incoming)
875876
if update == nil {
876877
// Fallback to the non-alias behavior.
877878
var err error
@@ -3209,7 +3210,7 @@ func (l *channelLink) getAliases() []lnwire.ShortChannelID {
32093210
// attachFailAliasUpdate sets the link's FailAliasUpdate function.
32103211
//
32113212
// Part of the scidAliasHandler interface.
3212-
func (l *channelLink) attachFailAliasUpdate(closure func(
3213+
func (l *channelLink) attachFailAliasUpdate(closure func(ctx context.Context,
32133214
sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) {
32143215

32153216
l.Lock()

htlcswitch/link_test.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -2201,7 +2201,9 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt,
22012201
forwardPackets := func(linkQuit <-chan struct{}, _ bool,
22022202
packets ...*htlcPacket) error {
22032203

2204-
return aliceSwitch.ForwardPackets(linkQuit, packets...)
2204+
return aliceSwitch.ForwardPackets(
2205+
context.Background(), linkQuit, packets...,
2206+
)
22052207
}
22062208

22072209
// Instantiate with a long interval, so that we can precisely control
@@ -4884,7 +4886,9 @@ func (h *persistentLinkHarness) restartLink(
48844886
forwardPackets := func(linkQuit <-chan struct{}, _ bool,
48854887
packets ...*htlcPacket) error {
48864888

4887-
return h.hSwitch.ForwardPackets(linkQuit, packets...)
4889+
return h.hSwitch.ForwardPackets(
4890+
context.Background(), linkQuit, packets...,
4891+
)
48884892
}
48894893

48904894
// Instantiate with a long interval, so that we can precisely control
@@ -6217,7 +6221,7 @@ func TestCheckHtlcForward(t *testing.T) {
62176221
return &lnwire.ChannelUpdate1{}, nil
62186222
}
62196223

6220-
failAliasUpdate := func(sid lnwire.ShortChannelID,
6224+
failAliasUpdate := func(_ context.Context, sid lnwire.ShortChannelID,
62216225
incoming bool) *lnwire.ChannelUpdate1 {
62226226

62236227
return nil

htlcswitch/mailbox.go

+11-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package htlcswitch
33
import (
44
"bytes"
55
"container/list"
6+
"context"
67
"errors"
78
"fmt"
89
"sync"
@@ -95,7 +96,8 @@ type mailBoxConfig struct {
9596
// forwardPackets send a varidic number of htlcPackets to the switch to
9697
// be routed. A quit channel should be provided so that the call can
9798
// properly exit during shutdown.
98-
forwardPackets func(<-chan struct{}, ...*htlcPacket) error
99+
forwardPackets func(context.Context, <-chan struct{},
100+
...*htlcPacket) error
99101

100102
// clock is a time source for the mailbox.
101103
clock clock.Clock
@@ -107,7 +109,7 @@ type mailBoxConfig struct {
107109

108110
// failMailboxUpdate is used to fail an expired HTLC and use the
109111
// correct SCID if the underlying channel uses aliases.
110-
failMailboxUpdate func(outScid,
112+
failMailboxUpdate func(ctx context.Context, outScid,
111113
mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage
112114
}
113115

@@ -687,6 +689,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
687689
// generated LinkError will show an OutgoingFailureDownstreamHtlcAdd
688690
// FailureDetail.
689691
func (m *memoryMailBox) FailAdd(pkt *htlcPacket) {
692+
ctx := context.TODO()
693+
690694
// First, remove the packet from mailbox. If we didn't find the packet
691695
// because it has already been acked, we'll exit early to avoid sending
692696
// a duplicate fail message through the switch.
@@ -703,7 +707,7 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) {
703707
// peer if this is a forward, or report to the user if the failed
704708
// payment was locally initiated.
705709
failure := m.cfg.failMailboxUpdate(
706-
pkt.originalOutgoingChanID, m.cfg.shortChanID,
710+
ctx, pkt.originalOutgoingChanID, m.cfg.shortChanID,
707711
)
708712

709713
// If the payment was locally initiated (which is indicated by a nil
@@ -748,7 +752,7 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) {
748752
},
749753
}
750754

751-
if err := m.cfg.forwardPackets(m.quit, failPkt); err != nil {
755+
if err := m.cfg.forwardPackets(ctx, m.quit, failPkt); err != nil {
752756
log.Errorf("Unhandled error while reforwarding packets "+
753757
"settle/fail over htlcswitch: %v", err)
754758
}
@@ -804,7 +808,8 @@ type mailOrchConfig struct {
804808
// forwardPackets send a varidic number of htlcPackets to the switch to
805809
// be routed. A quit channel should be provided so that the call can
806810
// properly exit during shutdown.
807-
forwardPackets func(<-chan struct{}, ...*htlcPacket) error
811+
forwardPackets func(context.Context, <-chan struct{},
812+
...*htlcPacket) error
808813

809814
// clock is a time source for the generated mailboxes.
810815
clock clock.Clock
@@ -816,7 +821,7 @@ type mailOrchConfig struct {
816821

817822
// failMailboxUpdate is used to fail an expired HTLC and use the
818823
// correct SCID if the underlying channel uses aliases.
819-
failMailboxUpdate func(outScid,
824+
failMailboxUpdate func(ctx context.Context, outScid,
820825
mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage
821826
}
822827

htlcswitch/mailbox_test.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package htlcswitch
22

33
import (
4+
"context"
45
prand "math/rand"
56
"reflect"
67
"testing"
@@ -206,7 +207,7 @@ func newMailboxContextWithClock(t *testing.T,
206207
forwards: make(chan *htlcPacket, 1),
207208
}
208209

209-
failMailboxUpdate := func(outScid,
210+
failMailboxUpdate := func(_ context.Context, outScid,
210211
mboxScid lnwire.ShortChannelID) lnwire.FailureMessage {
211212

212213
return &lnwire.FailTemporaryNodeFailure{}
@@ -232,7 +233,7 @@ func newMailboxContext(t *testing.T, startTime time.Time,
232233
forwards: make(chan *htlcPacket, 1),
233234
}
234235

235-
failMailboxUpdate := func(outScid,
236+
failMailboxUpdate := func(_ context.Context, outScid,
236237
mboxScid lnwire.ShortChannelID) lnwire.FailureMessage {
237238

238239
return &lnwire.FailTemporaryNodeFailure{}
@@ -250,7 +251,7 @@ func newMailboxContext(t *testing.T, startTime time.Time,
250251
return ctx
251252
}
252253

253-
func (c *mailboxContext) forward(_ <-chan struct{},
254+
func (c *mailboxContext) forward(_ context.Context, _ <-chan struct{},
254255
pkts ...*htlcPacket) error {
255256

256257
for _, pkt := range pkts {
@@ -697,7 +698,7 @@ func testMailBoxDust(t *testing.T, chantype channeldb.ChannelType) {
697698
func TestMailOrchestrator(t *testing.T) {
698699
t.Parallel()
699700

700-
failMailboxUpdate := func(outScid,
701+
failMailboxUpdate := func(_ context.Context, outScid,
701702
mboxScid lnwire.ShortChannelID) lnwire.FailureMessage {
702703

703704
return &lnwire.FailTemporaryNodeFailure{}
@@ -706,7 +707,7 @@ func TestMailOrchestrator(t *testing.T) {
706707
// First, we'll create a new instance of our orchestrator.
707708
mo := newMailOrchestrator(&mailOrchConfig{
708709
failMailboxUpdate: failMailboxUpdate,
709-
forwardPackets: func(_ <-chan struct{},
710+
forwardPackets: func(_ context.Context, _ <-chan struct{},
710711
pkts ...*htlcPacket) error {
711712

712713
return nil

htlcswitch/mock.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ type mockChannelLink struct {
735735

736736
checkHtlcForwardResult *LinkError
737737

738-
failAliasUpdate func(sid lnwire.ShortChannelID,
738+
failAliasUpdate func(_ context.Context, sid lnwire.ShortChannelID,
739739
incoming bool) *lnwire.ChannelUpdate1
740740

741741
confirmedZC bool
@@ -872,7 +872,8 @@ func (f *mockChannelLink) AttachMailBox(mailBox MailBox) {
872872
}
873873

874874
func (f *mockChannelLink) attachFailAliasUpdate(closure func(
875-
sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) {
875+
ctx context.Context, sid lnwire.ShortChannelID,
876+
incoming bool) *lnwire.ChannelUpdate1) {
876877

877878
f.failAliasUpdate = closure
878879
}

0 commit comments

Comments
 (0)