diff --git a/fn/mvar.go b/fn/mvar.go new file mode 100644 index 00000000000..6f5312d00d6 --- /dev/null +++ b/fn/mvar.go @@ -0,0 +1,180 @@ +package fn + +import ( + "sync/atomic" +) + +// MVar[A any] is a structure that is designed to store a single value in an API +// that dispenses with data races. Think of it as a box for a value. +// +// It has two states: full and empty. +// +// It supports two operations: take and put. +// +// The state transition rules are as follows: +// 1. put while full blocks. +// 2. put while empty sets. +// 3. take while full resets. +// 4. take while empty blocks. +// 5. read while full nops. +// 6. read while empty blocks. +type MVar[A any] struct { + // current is an immediately available copy of whatever is inside the + // value channel that is served to readers. It is updated whenever a + // change to the value channel is successful. + current *atomic.Pointer[A] + // readers is used to wake all blocked readers when a new value is + // written. + readers chan chan A + + // takers is used to wake a single taker when a new value is written. + takers chan chan A + + // value is a bounded channel of size 1 that represents the core state + // oof the channel. + value chan A +} + +// Zero initializes an MVar that has no values in it. In this state, TakeMVar +// will block and PutMVar will immediately succeed. +// +// Zero : () -> MVar[A]. +func Zero[A any]() MVar[A] { + ptr := atomic.Pointer[A]{} + + return MVar[A]{ + current: &ptr, + readers: make(chan chan A), + takers: make(chan chan A), + value: make(chan A, 1), + } +} + +// NewMVar initializes an MVar that has a value in it from the getgo. In this +// state, TakeMVar will succeed immediately and PutMVar will block. +// +// NewMVar : A -> MVar[A]. +func NewMVar[A any](a A) MVar[A] { + z := Zero[A]() + z.value <- a + z.current.Store(&a) + + return z +} + +// Take will wait for a value to be put into the MVar and then immediately +// take it out. +// +// Take : MVar[A] -> A. +func (m *MVar[A]) Take() A { + select { + case v := <-m.value: + m.current.Store(nil) + return v + default: + t := make(chan A) + m.takers <- t + return <-t + } +} + +// TryTake is the non-blocking version of TakeMVar, it will return an +// None() Option if it would have blocked. +// +// TryTake : MVar[A] -> Option[A]. +func (m *MVar[A]) TryTake() Option[A] { + select { + case v := <-m.value: + m.current.Store(nil) + return Some(v) + default: + return None[A]() + } +} + +// Put will wait for a value to be made empty and will immediately replace it +// with the argument. +// +// Put : (MVar[A], A) -> (). +func (m *MVar[A]) Put(a A) { +readLoop: + // Give the newly put value to all of the waiting readers. + for { + select { + case r := <-m.readers: + r <- a + default: + break readLoop + } + } + + // Give the newly put value to a single taker if one exists. If there + // are no available takers, then store it in the MVar. Since the value + // channel is bounded with capacity 1, subsequent put operations will + // block. + select { + case t := <-m.takers: + t <- a + default: + m.value <- a + m.current.Store(&a) + } +} + +// TryPut is the non-blocking version of Put and will return true if the MVar is +// successfully set. +// +// TryPut : (MVar[A], A) -> bool. +func (m *MVar[A]) TryPut(a A) bool { + select { + case m.value <- a: + m.current.Store(&a) + return true + default: + return false + } +} + +// Read will atomically read the contents of the MVar. If the MVar is empty, +// Read will block until a value is put in. Callers of Read are guaranteed to +// be woken up before callers of Take. +// +// Read : MVar[A] -> A. +func (m *MVar[A]) Read() A { + // Check to see if MVar has something in it. + if ptr := m.current.Load(); ptr != nil { + return *ptr + } + + // It's empty so we need to wait. + r := make(chan A) + m.readers <- r + + return <-r +} + +// TryRead will atomically read the contents of the MVar if it is full. +// Otherwise, it will return None. +// +// TryRead : MVar[A] -> Option[A]. +func (m *MVar[A]) TryRead() Option[A] { + if ptr := m.current.Load(); ptr != nil { + return Some(*ptr) + } + + return None[A]() +} + +// IsFull will return true if the MVar currently has a value in it. +// +// IsFull : MVar[A] -> bool. +func (m *MVar[A]) IsFull() bool { + return m.current.Load() != nil +} + +// IsEmpty will return true if the MVar currently does not have a value in it. +// +// IsEmpty : MVar[A] -> bool. +func (m *MVar[A]) IsEmpty() bool { + return m.current.Load() == nil +} diff --git a/fn/mvar_test.go b/fn/mvar_test.go new file mode 100644 index 00000000000..0e06cf41505 --- /dev/null +++ b/fn/mvar_test.go @@ -0,0 +1,327 @@ +package fn + +import ( + "sync" + "sync/atomic" + "testing" + "testing/quick" + "time" + + "github.com/stretchr/testify/require" +) + +// blockTimeout is a parameter that defines all of the waiting periods for the +// tests in this file. Generally there is a tradeoff here where the higher this +// value is, the less flaky the tests will be, at the expense of the tests +// taking longer to execute. +const blockTimeout = time.Millisecond + +// TestTakeZeroBlocks ensures that if we initialize an empty MVar and +// immediately try to Take from it, it will block. +func TestTakeZeroBlocks(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + require.True(t, blocks(func() { m.Take() })) +} + +// TestTakeNewMVarProceeds ensures that if we initialize an MVar with a value +// in it and immediately try to Take from it, it will succeed. +func TestTakeNewMVarProceeds(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + require.False(t, blocks(func() { m.Take() })) +} + +// TestPutNewMVarBlocks ensures that if we initialize an MVar with a value in +// it and immediately try to Put a new value into it, it will block. +func TestPutNewMVarBlocks(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + require.True(t, blocks(func() { m.Put(1) })) +} + +// TestPutZeroProceeds ensures that if we initialize an empty Mvar and then try +// to Put a new value into it, it will succeed. +func TestPutZeroProceeds(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + require.False(t, blocks(func() { m.Put(1) })) +} + +// TestPutWhenEmptyLeavesFull ensures that we successfully leave the Mvar in a +// full state after executing a Put in an empty state. +func TestPutWhenEmptyLeavesFull(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + m.Put(0) + if m.IsEmpty() { + t.Fatal("Put left empty") + } +} + +// TestTakeWhenFullLeavesEmpty ensures that we successfully leave the Mvar in an +// empty state after executing a Take in a full state. +func TestTakeWhenFullLeavesEmpty(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + m.Take() + if m.IsFull() { + t.Fatal("Take left full") + } +} + +// TestReadWhenFullLeavesFull ensures that a Read when in a full state does not +// change the state of the MVar. +func TestReadWhenFullLeavesFull(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + m.Read() + if m.IsEmpty() { + t.Fatal("Read left empty") + } +} + +// TestTakeAfterTryTakeBlocks ensures that regardless of what state the Mvar +// begins in, if we try to Take immediately after a TryTake, it will always +// block. +func TestTakeAfterTryTakeBlocks(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryTake() + return blocks(func() { m.Take() }) + }, nil) + require.NoError(t, err, "Take after TryTake did not block") +} + +// TestPutAfterTryPutBlocks ensures that regardless of what state the MVar +// begins in, if we try to Put immediately after a TryPut, it will always block. +func TestPutAfterTryPutBlocks(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryPut(0) + return blocks(func() { m.Put(1) }) + }, nil) + require.NoError(t, err, "Put after TryPut did not block") +} + +// TestTryTakeLeavesEmpty ensures that regardless of what state the MVar begins +// in, if we execute a TryTake, the resulting MVar state is empty. +func TestTryTakeLeavesEmpty(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryTake() + return m.IsEmpty() + }, nil) + require.NoError(t, err, "TryTake did not leave empty") +} + +// TestTryPutLeavesFull ensures that regardless of what state the MVar begins +// in, if we execute a TryPut, the resulting MVar state is full. +func TestTryPutLeavesFull(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryPut(n) + return m.IsFull() + }, nil) + require.NoError(t, err, "TryPut did not leave full") +} + +// TestReadWhenEmptyBlocks ensures that if an MVar is in an empty state, a +// Read operation will block. +func TestReadWhenEmptyBlocks(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + return implies(m.IsEmpty(), blocks(func() { m.Read() })) + }, nil) + require.NoError(t, err, "Read did not block when empty") +} + +// TestTryReadNops ensures that a TryRead will not change the state of the MVar. +// It implicitly tests a second property which is that TryRead will never block. +func TestTryReadNops(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + before := m.IsEmpty() + tryReadBlocked := blocks(func() { m.TryRead() }) + after := m.IsEmpty() + + return before == after && !tryReadBlocked + }, nil) + require.NoError(t, err, "TryRead did not leave state unchanged") +} + +// TestPutWakesAllReaders ensures the property that if we have many blocked +// Read operations that are waiting for the MVar to be filled, all of them are +// woken up when we execute a Put operation. +func TestPutWakesAllReaders(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + v := uint8(21) + n := uint32(10) + + counter := atomic.Uint32{} + wg := sync.WaitGroup{} + for i := uint32(0); i < n; i++ { + wg.Add(1) + go func() { + x := m.Read() + if x == v { + counter.Add(1) + } + wg.Done() + }() + } + m.Put(v) + wg.Wait() + require.Equal(t, counter.Load(), n, "not all readers given same value") +} + +// TestPutWakesReadersBeforeTaker ensures the property that a waiting taker +// does not preempt any waiting readers. This test construction is a bit +// delicate, using a Sleep to ensure that all goroutines that are set up to +// wait on the Put have gotten to the point where they actually block. +func TestPutWakesReadersBeforeTaker(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + v := uint8(21) + n := uint32(10) + + counter := atomic.Uint32{} + wg := sync.WaitGroup{} + + // Set up taker first. + wg.Add(1) + go func() { + x := m.Take() + if x == v { + counter.Add(1) + } + wg.Done() + }() + + // Set up readers. + for i := uint32(0); i < n; i++ { + wg.Add(1) + go func() { + x := m.Read() + if x == v { + counter.Add(1) + } + wg.Done() + }() + } + + time.Sleep(blockTimeout) // Forgive me + + m.Put(v) + wg.Wait() + require.Equal( + t, counter.Load(), n+1, "readers did not wake before taker", + ) +} + +// TestPutWakesOneTaker ensures the property that only a single blocked Take +// operation wakes when a Put comes in. This test construction is a bit delicate +// using a Sleep to wait for the counter to be incremented by the Take goroutine +// after waking. +func TestPutWakesOneTaker(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + v := uint8(21) + n := uint32(10) + + counter := atomic.Uint32{} + wg := sync.WaitGroup{} + for i := uint32(0); i < n; i++ { + wg.Add(1) + go func() { + x := m.Take() + if x == v { + counter.Add(1) + } + wg.Done() + }() + } + m.Put(v) + + time.Sleep(blockTimeout) // Forgive me + + require.Equal( + t, + counter.Load(), + uint32(1), + "put wakes zero or more than one taker ", + ) +} + +// TestTakeWakesPutter ensures that if there is a blocked Put operation due to +// the MVar being full, that it unblocks when a Take operation is executed. This +// is because the Take operation would set the MVar to an empty state, allowing +// the blocked Put to proceed. +func TestTakeWakesPutter(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { m.Put(1); wg.Done() }() + m.Take() + wg.Wait() + require.Equal(t, m.Read(), uint8(1)) +} + +// blocks is a helper function to decide if the supplied function blocks or not. +// This is not fool-proof since it does make a judgement call based off of the +// file-global timeout parameter. +func blocks(f func()) bool { + unblocked := make(chan struct{}) + go func() { f(); unblocked <- struct{}{} }() + + select { + case <-unblocked: + return false + case <-time.NewTimer(blockTimeout).C: + return true + } +} + +// implies is a helper function that computes the `=>` operation from boolean +// algebra. It is true when the first argument is false, or if both arguments +// are true. +func implies(b bool, b2 bool) bool { + return !b || b && b2 +} + +// gen is a helper function whose first argument decides if the returned MVar +// should be full or empty, and the second argument decides what value should +// be in the MVar if it is full. We use this because it is substantially easier +// than teaching testing.Quick how to generate random MVar[uint8] values. +func gen(set bool, n uint8) MVar[uint8] { + if set { + return NewMVar[uint8](n) + } + + return Zero[uint8]() +} diff --git a/fn/option.go b/fn/option.go new file mode 100644 index 00000000000..a2c3afdc252 --- /dev/null +++ b/fn/option.go @@ -0,0 +1,149 @@ +package fn + +// Option[A] represents a value which may or may not be there. This is very +// often preferable to nil-able pointers. +type Option[A any] struct { + isSome bool + some A +} + +// Some trivially injects a value into an optional context. +// +// Some : A -> Option[A]. +func Some[A any](a A) Option[A] { + return Option[A]{ + isSome: true, + some: a, + } +} + +// None trivially constructs an empty option +// +// None : Option[A]. +func None[A any]() Option[A] { + return Option[A]{} +} + +// ElimOption is the universal Option eliminator. It can be used to safely +// handle all possible values inside the Option by supplying two continuations. +// +// ElimOption : (Option[A], () -> B, A -> B) -> B. +func ElimOption[A, B any](o Option[A], b func() B, f func(A) B) B { + if o.isSome { + return f(o.some) + } + + return b() +} + +// UnwrapOr is used to extract a value from an option, and we supply the default +// value in the case when the Option is empty. +// +// UnwrapOr : (Option[A], A) -> A. +func (o Option[A]) UnwrapOr(a A) A { + if o.isSome { + return o.some + } + + return a +} + +// WhenSome is used to conditionally perform a side-effecting function that +// accepts a value of the type that parameterizes the option. If this function +// performs no side effects, WhenSome is useless. +// +// WhenSome : (Option[A], A -> ()) -> (). +func (o Option[A]) WhenSome(f func(A)) { + if o.isSome { + f(o.some) + } +} + +// IsSome returns true if the Option contains a value +// +// IsSome : Option[A] -> bool. +func (o Option[A]) IsSome() bool { + return o.isSome +} + +// IsNone returns true if the Option is empty +// +// IsNone : Option[A] -> bool. +func (o Option[A]) IsNone() bool { + return !o.isSome +} + +// FlattenOption joins multiple layers of Options together such that if any of +// the layers is None, then the joined value is None. Otherwise the innermost +// Some value is returned. +// +// FlattenOption : Option[Option[A]] -> Option[A]. +func FlattenOption[A any](oo Option[Option[A]]) Option[A] { + if oo.IsNone() { + return None[A]() + } + if oo.some.IsNone() { + return None[A]() + } + + return oo.some +} + +// ChainOption transforms a function A -> Option[B] into one that accepts an +// Option[A] as an argument. +// +// ChainOption : (A -> Option[B]) -> Option[A] -> Option[B]. +func ChainOption[A, B any](f func(A) Option[B]) func(Option[A]) Option[B] { + return func(o Option[A]) Option[B] { + if o.isSome { + return f(o.some) + } + + return None[B]() + } +} + +// MapOption transforms a pure function A -> B into one that will operate +// inside the Option context. +// +// MapOption : (A -> B) -> Option[A] -> Option[B]. +func MapOption[A, B any](f func(A) B) func(Option[A]) Option[B] { + return func(o Option[A]) Option[B] { + if o.isSome { + return Some(f(o.some)) + } + + return None[B]() + } +} + +// LiftA2Option transforms a pure function (A, B) -> C into one that will +// operate in an Option context. For the returned function, if either of its +// arguments are None, then the result will be None. +// +// LiftA2Option : ((A, B) -> C) -> (Option[A], Option[B]) -> Option[C]. +func LiftA2Option[A, B, C any]( + f func(A, B) C, +) func(Option[A], Option[B]) Option[C] { + + return func(o1 Option[A], o2 Option[B]) Option[C] { + if o1.isSome && o2.isSome { + return Some(f(o1.some, o2.some)) + } + + return None[C]() + } +} + +// Alt chooses the left Option if it is full, otherwise it chooses the right +// option. This can be useful in a long chain if you want to choose between +// many different ways of producing the needed value. +// +// Alt : Option[A] -> Option[A] -> Option[A]. +func (o Option[A]) Alt(o2 Option[A]) Option[A] { + if o.isSome { + return o + } + + return o2 +} diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 304f4a5957a..3d068dd8136 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -135,7 +135,21 @@ type ChannelUpdateHandler interface { // ShutdownIfChannelClean shuts the link down if the channel state is // clean. This can be used with dynamic commitment negotiation or coop // close negotiation which require a clean channel state. - ShutdownIfChannelClean() error + ShutdownHtlcManager() + + // Flush is a method that disables htlc adds to the channel until it has + // reached an empty state. When we reach zero HTLCs, the supplied + // function will be called. + Flush(func()) error + + // CancelFlush will abort an in-progress flush. If there is no + // current flush operation taking place, then this function will return + // an error. + CancelFlush() error + + // IsFlushing returns true if there is a currently in-progress flush + // operation. + IsFlushing() bool } // ChannelLink is an interface which represents the subsystem for managing the diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 46aee939d12..eb7efcdc331 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" @@ -273,13 +274,6 @@ type ChannelLinkConfig struct { GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID } -// shutdownReq contains an error channel that will be used by the channelLink -// to send an error if shutdown failed. If shutdown succeeded, the channel will -// be closed. -type shutdownReq struct { - err chan error -} - // channelLink is the service which drives a channel's commitment update // state-machine. In the event that an HTLC needs to be propagated to another // link, the forward handler from config is used which sends HTLC to the @@ -342,7 +336,7 @@ type channelLink struct { // shutdownRequest is a channel that the channelLink will listen on to // service shutdown requests from ShutdownIfChannelClean calls. - shutdownRequest chan *shutdownReq + shutdownRequest chan struct{} // updateFeeTimer is the timer responsible for updating the link's // commitment fee every time it fires. @@ -364,6 +358,10 @@ type channelLink struct { // resolving those htlcs when we receive a message on hodlQueue. hodlMap map[models.CircuitKey]hodlHtlc + // flushCont is a function that is called when the channel finishes + // flushing. + flushCont fn.MVar[func()] + // log is a link-specific logging instance. log btclog.Logger @@ -388,11 +386,12 @@ func NewChannelLink(cfg ChannelLinkConfig, cfg: cfg, channel: channel, shortChanID: channel.ShortChanID(), - shutdownRequest: make(chan *shutdownReq), + shutdownRequest: make(chan struct{}), hodlMap: make(map[models.CircuitKey]hodlHtlc), hodlQueue: queue.NewConcurrentQueue(10), log: build.NewPrefixLog(logPrefix, log), quit: make(chan struct{}), + flushCont: fn.Zero[func()](), } } @@ -528,6 +527,28 @@ func (l *channelLink) Stop() { } } +func (l *channelLink) Flush(onFlushed func()) error { + if !l.flushCont.TryPut(onFlushed) { + return errors.New( + "can't flush because flush already in progress", + ) + } + + return nil +} + +func (l *channelLink) CancelFlush() error { + if l.flushCont.TryTake().IsNone() { + return errors.New("no flush in progress to cancel") + } + + return nil +} + +func (l *channelLink) IsFlushing() bool { + return l.flushCont.IsFull() +} + // WaitForShutdown blocks until the link finishes shutting down, which includes // termination of all dependent goroutines. func (l *channelLink) WaitForShutdown() { @@ -542,7 +563,8 @@ func (l *channelLink) WaitForShutdown() { func (l *channelLink) EligibleToForward() bool { return l.channel.RemoteNextRevocation() != nil && l.ShortChanID() != hop.Source && - l.isReestablished() + l.isReestablished() && + !l.IsFlushing() } // isReestablished returns true if the link has successfully completed the @@ -1257,23 +1279,26 @@ func (l *channelLink) htlcManager() { ) } - case req := <-l.shutdownRequest: - // If the channel is clean, we send nil on the err chan - // and return to prevent the htlcManager goroutine from - // processing any more updates. The full link shutdown - // will be triggered by RemoveLink in the peer. - if l.channel.IsChannelClean() { - req.err <- nil - return - } - - // Otherwise, the channel has lingering updates, send - // an error and continue. - req.err <- ErrLinkFailedShutdown + case <-l.shutdownRequest: + return case <-l.quit: return } + + // After we are finished processing the event, if the link is + // flushing, we check if the channel is clean and invoke the + // post-flush hook if it is. + if l.IsFlushing() && l.channel.IsChannelClean() { + // This will not block since flushCont must be full. + // We Read instead of Take to ensure a new flush + // operation can't be initiated until the continuation + // for the current flush has completed. + l.flushCont.Read()() + + // Reset the flushCont MVar. This will also not block. + l.flushCont.Take() + } } } @@ -2138,7 +2163,6 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { default: l.log.Warnf("received unknown message of type %T", msg) } - } // ackDownStreamPackets is responsible for removing htlcs from a link's mailbox @@ -2750,26 +2774,14 @@ func (l *channelLink) HandleChannelUpdate(message lnwire.Message) { l.mailBox.AddMessage(message) } -// ShutdownIfChannelClean triggers a link shutdown if the channel is in a clean +// ShutdownHtlcManager triggers a link shutdown if the channel is in a clean // state and errors if the channel has lingering updates. // // NOTE: Part of the ChannelUpdateHandler interface. -func (l *channelLink) ShutdownIfChannelClean() error { - errChan := make(chan error, 1) - +func (l *channelLink) ShutdownHtlcManager() { select { - case l.shutdownRequest <- &shutdownReq{ - err: errChan, - }: + case l.shutdownRequest <- struct{}{}: case <-l.quit: - return ErrLinkShuttingDown - } - - select { - case err := <-errChan: - return err - case <-l.quit: - return ErrLinkShuttingDown } } @@ -3053,6 +3065,25 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, fwdInfo := pld.ForwardingInfo() + // If we are in a flush state we need to cancel back all of the + // net new HTLCs rather than forwarding them. This is the first + // opportunity we have to bounce invalid HTLC adds without + // doing a force-close. + if l.IsFlushing() { + var isReceive bool + switch fwdInfo.NextHop { + case hop.Exit: + isReceive = true + default: + isReceive = false + } + failure := lnwire.NewTemporaryChannelFailure(nil) + l.sendHTLCError( + pd, NewLinkError(failure), obfuscator, + isReceive, + ) + } + switch fwdInfo.NextHop { case hop.Exit: err := l.processExitHop( diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 37e306559dc..ea103e0d4eb 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6490,92 +6490,99 @@ func TestPendingCommitTicker(t *testing.T) { } } -// TestShutdownIfChannelClean tests that a link will exit the htlcManager loop -// if and only if the underlying channel state is clean. -func TestShutdownIfChannelClean(t *testing.T) { +func TestFlushInvokesCallbackWhenDrained(t *testing.T) { t.Parallel() const chanAmt = btcutil.SatoshiPerBitcoin * 5 - const chanReserve = btcutil.SatoshiPerBitcoin * 1 - aliceLink, bobChannel, batchTicker, start, _, err := - newSingleLinkTestHarness(t, chanAmt, chanReserve) + const reserve = btcutil.SatoshiPerBitcoin * 1 + aliceLink, bobChannel, _, start, _, err := + newSingleLinkTestHarness(t, chanAmt, reserve) require.NoError(t, err) - var ( - coreLink = aliceLink.(*channelLink) - aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs - ) + coreLink := aliceLink.(*channelLink) + aliceMsgs := coreLink.cfg.Peer.(*mockPeer).sentMsgs - shutdownAssert := func(expectedErr error) { - err = aliceLink.ShutdownIfChannelClean() - if expectedErr != nil { - require.Error(t, err, expectedErr) - } else { - require.NoError(t, err) - } + if err := start(); err != nil { + t.Fatalf("unable to start test harness: %v", err) } - err = start() - require.NoError(t, err) - ctx := linkTestContext{ - t: t, - aliceLink: aliceLink, + t: t, + aliceLink: aliceLink, bobChannel: bobChannel, - aliceMsgs: aliceMsgs, + aliceMsgs: aliceMsgs, } - // First send an HTLC from Bob to Alice and assert that the link can't - // be shutdown while the update is outstanding. + flushFinished := make(chan struct{}) + assertFlushFinished := func(exp bool) { + select { + case <-flushFinished: + if !exp { + t.Fatal("flush callback invoked") + } + default: + if exp { + t.Fatal("flush callback not invoked") + } + } + } + + htlc := generateHtlc(t, coreLink, 0) - // <---add----- + // <-- add --- ctx.sendHtlcBobToAlice(htlc) - // <---sig----- + // <-- sig --- ctx.sendCommitSigBobToAlice(1) - // ----rev----> + // --- rev --> ctx.receiveRevAndAckAliceToBob() - shutdownAssert(ErrLinkFailedShutdown) - // ----sig----> + // put the link into a flush state + aliceLink.Flush(func() { + flushFinished <- struct{}{} + }) + assertFlushFinished(false) + + // --- sig --> ctx.receiveCommitSigAliceToBob(1) - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // <---rev----- + // <-- rev --- ctx.sendRevAndAckBobToAlice() - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // ---settle--> + // --- set --> ctx.receiveSettleAliceToBob() - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // ----sig----> + // --- sig --> ctx.receiveCommitSigAliceToBob(0) - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // <---rev----- + // <-- rev --- ctx.sendRevAndAckBobToAlice() - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) // There is currently no controllable breakpoint between Alice // receiving the CommitSig and her sending out the RevokeAndAck. As // soon as the RevokeAndAck is generated, the channel becomes clean. // This can happen right after the CommitSig is received, so there is // no shutdown assertion here. - // <---sig----- + // <-- sig --- ctx.sendCommitSigBobToAlice(0) - // ----rev----> + // --- rev --> ctx.receiveRevAndAckAliceToBob() - shutdownAssert(nil) + <-flushFinished +} - // Now that the link has exited the htlcManager loop, attempt to - // trigger the batch ticker. It should not be possible. - select { - case batchTicker <- time.Now(): - t.Fatalf("expected batch ticker to be inactive") - case <-time.After(5 * time.Second): - } +func TestFlushBlocksAdds(t *testing.T) { + // TODO +} + + +func TestFlushNotEligibleToFwd(t *testing.T) { + // TODO } // TestPipelineSettle tests that a link should only pipeline a settle if the diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index fe593c9c52f..e0ad2843210 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -898,13 +898,23 @@ func (f *mockChannelLink) ChannelPoint() *wire.OutPoint { return func (f *mockChannelLink) Stop() {} func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } -func (f *mockChannelLink) ShutdownIfChannelClean() error { return nil } +func (f *mockChannelLink) ShutdownHtlcManager() {} func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } func (f *mockChannelLink) IsUnadvertised() bool { return f.unadvertised } func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { f.eligible = true return f.shortChanID, nil } +func (f *mockChannelLink) Flush(onFlushed func()) error { + onFlushed() + return nil +} +func (f *mockChannelLink) CancelFlush() error { + return errors.New("no flush in progress to cancel") +} +func (f *mockChannelLink) IsFlushing() bool { + return false +} var _ ChannelLink = (*mockChannelLink)(nil) diff --git a/peer/brontide.go b/peer/brontide.go index 4c43d4a4740..b766a323040 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -3136,19 +3136,18 @@ func (p *Brontide) tryLinkShutdown(cid lnwire.ChannelID) error { return ErrChannelNotFound } - // Else, the link exists, so attempt to trigger shutdown. If this - // fails, we'll send an error message to the remote peer. - if err := chanLink.ShutdownIfChannelClean(); err != nil { - return err - } - - // Next, we remove the link from the switch to shut down all of the - // link's goroutines and remove it from the switch's internal maps. We - // don't call WipeChannel as the channel must still be in the - // activeChannels map to process coop close messages. - p.cfg.Switch.RemoveLink(cid) - - return nil + return chanLink.Flush(func() { + // Else, the link exists, so attempt to trigger shutdown. If + // this fails, we'll send an error message to the remote peer. + chanLink.ShutdownHtlcManager() + + // Next, we remove the link from the switch to shut down all of + // the link's goroutines and remove it from the switch's + // internal maps. We don't call WipeChannel as the channel must + // still be in the activeChannels map to process coop close + // messages. + p.cfg.Switch.RemoveLink(cid) + }) } // fetchLinkFromKeyAndCid fetches a link from the switch via the remote's diff --git a/peer/test_utils.go b/peer/test_utils.go index add15cf19dc..72a976ebc30 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -4,6 +4,7 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" + "errors" "io" "math/rand" "net" @@ -482,7 +483,7 @@ func (m *mockUpdateHandler) EligibleToForward() bool { return false } func (m *mockUpdateHandler) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } // ShutdownIfChannelClean currently returns nil. -func (m *mockUpdateHandler) ShutdownIfChannelClean() error { return nil } +func (m *mockUpdateHandler) ShutdownHtlcManager() {} type mockMessageConn struct { t *testing.T @@ -499,6 +500,19 @@ type mockMessageConn struct { curReadMessage []byte } +func (m *mockUpdateHandler) Flush(onFlushed func()) error { + onFlushed() + return nil +} + +func (m *mockUpdateHandler) CancelFlush() error { + return errors.New("no flush in progress to cancel") +} + +func (m *mockUpdateHandler) IsFlushing() bool { + return false +} + func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { return &mockMessageConn{ t: t,