Skip to content

htlcswitch+peer [1/2]: thread context through in preparation for passing to graph DB calls #9691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
75 changes: 47 additions & 28 deletions htlcswitch/interceptable_switch.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package htlcswitch

import (
"context"
"crypto/sha256"
"fmt"
"sync"
Expand Down Expand Up @@ -89,8 +90,9 @@ type InterceptableSwitch struct {
// currentHeight is the currently best known height.
currentHeight int32

wg sync.WaitGroup
quit chan struct{}
wg sync.WaitGroup
quit chan struct{}
cancel fn.Option[context.CancelFunc]
}

type interceptedPackets struct {
Expand Down Expand Up @@ -229,6 +231,9 @@ func (s *InterceptableSwitch) Start() error {
return fmt.Errorf("InterceptableSwitch started more than once")
}

ctx, cancel := context.WithCancel(context.Background())
s.cancel = fn.Some(cancel)

blockEpochStream, err := s.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return err
Expand All @@ -239,7 +244,7 @@ func (s *InterceptableSwitch) Start() error {
go func() {
defer s.wg.Done()

err := s.run()
err := s.run(ctx)
if err != nil {
log.Errorf("InterceptableSwitch stopped: %v", err)
}
Expand All @@ -257,6 +262,7 @@ func (s *InterceptableSwitch) Stop() error {
return fmt.Errorf("InterceptableSwitch stopped more than once")
}

s.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
close(s.quit)
s.wg.Wait()

Expand All @@ -271,7 +277,7 @@ func (s *InterceptableSwitch) Stop() error {
return nil
}

func (s *InterceptableSwitch) run() error {
func (s *InterceptableSwitch) run(ctx context.Context) error {
// The block epoch stream will immediately stream the current height.
// Read it out here.
select {
Expand All @@ -283,6 +289,9 @@ func (s *InterceptableSwitch) run() error {

case <-s.quit:
return nil

case <-ctx.Done():
return nil
}

log.Debugf("InterceptableSwitch running: height=%v, "+
Expand All @@ -298,7 +307,7 @@ func (s *InterceptableSwitch) run() error {
var notIntercepted []*htlcPacket
for _, p := range packets.packets {
intercepted, err := s.interceptForward(
p, packets.isReplay,
ctx, p, packets.isReplay,
)
if err != nil {
return err
Expand All @@ -311,7 +320,7 @@ func (s *InterceptableSwitch) run() error {
}
}
err := s.htlcSwitch.ForwardPackets(
packets.linkQuit, notIntercepted...,
ctx, packets.linkQuit, notIntercepted...,
)
if err != nil {
log.Errorf("Cannot forward packets: %v", err)
Expand All @@ -325,12 +334,12 @@ func (s *InterceptableSwitch) run() error {
// already intercepted in the off-chain flow. And even
// if not, it is safe to signal replay so that we won't
// unexpectedly skip over this htlc.
if _, err := s.forward(fwd, true); err != nil {
if _, err := s.forward(ctx, fwd, true); err != nil {
return err
}

case res := <-s.resolutionChan:
res.errChan <- s.resolve(res.resolution)
res.errChan <- s.resolve(ctx, res.resolution)

case currentBlock, ok := <-s.blockEpochStream.Epochs:
if !ok {
Expand All @@ -341,20 +350,23 @@ func (s *InterceptableSwitch) run() error {

// A new block is appended. Fail any held htlcs that
// expire at this height to prevent channel force-close.
s.failExpiredHtlcs()
s.failExpiredHtlcs(ctx)

case <-s.quit:
return nil

case <-ctx.Done():
return nil
}
}
}

func (s *InterceptableSwitch) failExpiredHtlcs() {
func (s *InterceptableSwitch) failExpiredHtlcs(ctx context.Context) {
s.heldHtlcSet.popAutoFails(
uint32(s.currentHeight),
func(fwd InterceptedForward) {
err := fwd.FailWithCode(
lnwire.CodeTemporaryChannelFailure,
ctx, lnwire.CodeTemporaryChannelFailure,
)
if err != nil {
log.Errorf("Cannot fail packet: %v", err)
Expand Down Expand Up @@ -407,7 +419,9 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {

// resolve processes a HTLC given the resolution type specified by the
// intercepting client.
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
func (s *InterceptableSwitch) resolve(ctx context.Context,
res *FwdResolution) error {

intercepted, err := s.heldHtlcSet.pop(res.Key)
if err != nil {
return err
Expand All @@ -431,7 +445,7 @@ func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
return intercepted.Fail(res.FailureMessage)
}

return intercepted.FailWithCode(res.FailureCode)
return intercepted.FailWithCode(ctx, res.FailureCode)

default:
return fmt.Errorf("unrecognized action %v", res.Action)
Expand Down Expand Up @@ -503,8 +517,8 @@ func (s *InterceptableSwitch) ForwardPacket(

// interceptForward forwards the packet to the external interceptor after
// checking the interception criteria.
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
isReplay bool) (bool, error) {
func (s *InterceptableSwitch) interceptForward(ctx context.Context,
packet *htlcPacket, isReplay bool) (bool, error) {

switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
Expand All @@ -522,7 +536,7 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
}

// Handle forwards that are too close to expiry.
handled, err := s.handleExpired(intercepted)
handled, err := s.handleExpired(ctx, intercepted)
if err != nil {
log.Errorf("Error handling intercepted htlc "+
"that expires too soon: circuit=%v, "+
Expand All @@ -542,15 +556,15 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
return true, nil
}

return s.forward(intercepted, isReplay)
return s.forward(ctx, intercepted, isReplay)

default:
return false, nil
}
}

// forward records the intercepted htlc and forwards it to the interceptor.
func (s *InterceptableSwitch) forward(
func (s *InterceptableSwitch) forward(ctx context.Context,
fwd InterceptedForward, isReplay bool) (bool, error) {

inKey := fwd.Packet().IncomingCircuit
Expand All @@ -573,7 +587,7 @@ func (s *InterceptableSwitch) forward(
// yet. This limits the backlog of htlcs when the interceptor is down.
if !isReplay {
err := fwd.FailWithCode(
lnwire.CodeTemporaryChannelFailure,
ctx, lnwire.CodeTemporaryChannelFailure,
)
if err != nil {
log.Errorf("Cannot fail packet: %v", err)
Expand Down Expand Up @@ -605,8 +619,8 @@ func (s *InterceptableSwitch) forward(

// handleExpired checks that the htlc isn't too close to the channel
// force-close broadcast height. If it is, it is cancelled back.
func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
bool, error) {
func (s *InterceptableSwitch) handleExpired(ctx context.Context,
fwd *interceptedForward) (bool, error) {

height := uint32(s.currentHeight)
if fwd.packet.incomingTimeout >= height+s.cltvInterceptDelta {
Expand All @@ -620,7 +634,7 @@ func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
fwd.packet.incomingTimeout)

err := fwd.FailWithCode(
lnwire.CodeExpiryTooSoon,
ctx, lnwire.CodeExpiryTooSoon,
)
if err != nil {
return false, err
Expand Down Expand Up @@ -661,9 +675,10 @@ func (f *interceptedForward) Packet() InterceptedPacket {

// Resume resumes the default behavior as if the packet was not intercepted.
func (f *interceptedForward) Resume() error {
ctx := context.TODO()
// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
return f.htlcSwitch.ForwardPackets(ctx, nil, f.packet)
}

// ResumeModified resumes the default behavior with field modifications. The
Expand All @@ -676,6 +691,8 @@ func (f *interceptedForward) ResumeModified(
outAmountMsat fn.Option[lnwire.MilliSatoshi],
outWireCustomRecords fn.Option[lnwire.CustomRecords]) error {

ctx := context.TODO()

// Convert the optional custom records to the correct type and validate
// them.
validatedRecords, err := fn.MapOptionZ(
Expand Down Expand Up @@ -732,7 +749,7 @@ func (f *interceptedForward) ResumeModified(

// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
return f.htlcSwitch.ForwardPackets(ctx, nil, f.packet)
}

// Fail notifies the intention to Fail an existing hold forward with an
Expand All @@ -747,7 +764,9 @@ func (f *interceptedForward) Fail(reason []byte) error {

// FailWithCode notifies the intention to fail an existing hold forward with the
// specified failure code.
func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {
func (f *interceptedForward) FailWithCode(ctx context.Context,
code lnwire.FailCode) error {

shaOnionBlob := func() [32]byte {
return sha256.Sum256(f.htlc.OnionBlob[:])
}
Expand All @@ -773,13 +792,13 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {

case lnwire.CodeTemporaryChannelFailure:
update := f.htlcSwitch.failAliasUpdate(
f.packet.incomingChanID, true,
ctx, f.packet.incomingChanID, true,
)
if update == nil {
// Fallback to the original, non-alias behavior.
var err error
update, err = f.htlcSwitch.cfg.FetchLastChannelUpdate(
f.packet.incomingChanID,
ctx, f.packet.incomingChanID,
)
if err != nil {
return err
Expand All @@ -790,7 +809,7 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {

case lnwire.CodeExpiryTooSoon:
update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate(
f.packet.incomingChanID,
ctx, f.packet.incomingChanID,
)
if err != nil {
return err
Expand Down
11 changes: 6 additions & 5 deletions htlcswitch/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type dustHandler interface {
type scidAliasHandler interface {
// attachFailAliasUpdate allows the link to properly fail incoming
// HTLCs on option_scid_alias channels.
attachFailAliasUpdate(failClosure func(
attachFailAliasUpdate(failClosure func(ctx context.Context,
sid lnwire.ShortChannelID,
incoming bool) *lnwire.ChannelUpdate1)

Expand Down Expand Up @@ -281,7 +281,8 @@ type ChannelLink interface {
// a LinkError with a valid protocol failure message should be returned
// in order to signal to the source of the HTLC, the policy consistency
// issue.
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
CheckHtlcForward(ctx context.Context, payHash [32]byte,
incomingAmt lnwire.MilliSatoshi,
amtToForward lnwire.MilliSatoshi, incomingTimeout,
outgoingTimeout uint32, inboundFee models.InboundFee,
heightNow uint32, scid lnwire.ShortChannelID,
Expand All @@ -292,8 +293,8 @@ type ChannelLink interface {
// valid protocol failure message should be returned in order to signal
// the violation. This call is intended to be used for locally initiated
// payments for which there is no corresponding incoming htlc.
CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi,
timeout uint32, heightNow uint32,
CheckHtlcTransit(ctx context.Context, payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32,
customRecords lnwire.CustomRecords) *LinkError

// Stats return the statistics of channel link. Number of updates,
Expand Down Expand Up @@ -452,7 +453,7 @@ type InterceptedForward interface {

// FailWithCode notifies the intention to fail an existing hold forward
// with the specified failure code.
FailWithCode(code lnwire.FailCode) error
FailWithCode(ctx context.Context, code lnwire.FailCode) error
}

// htlcNotifier is an interface which represents the input side of the
Expand Down
Loading
Loading