Skip to content

Commit 6cdfdfd

Browse files
committed
routing: remove context.TODOs
1 parent 46ce108 commit 6cdfdfd

25 files changed

+242
-189
lines changed

graph/session/graph_session.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ func NewGraphSessionFactory(graph ReadOnlyGraph) routing.GraphSessionFactory {
3030
// was created at Graph construction time.
3131
//
3232
// NOTE: This is part of the routing.GraphSessionFactory interface.
33-
func (g *Factory) NewGraphSession() (routing.Graph, func() error, error) {
34-
tx, err := g.graph.NewPathFindTx(context.TODO())
33+
func (g *Factory) NewGraphSession(ctx context.Context) (routing.Graph,
34+
func() error, error) {
35+
36+
tx, err := g.graph.NewPathFindTx(ctx)
3537
if err != nil {
3638
return nil, nil, err
3739
}
@@ -83,22 +85,20 @@ func (g *session) close() error {
8385
// ForEachNodeChannel calls the callback for every channel of the given node.
8486
//
8587
// NOTE: Part of the routing.Graph interface.
86-
func (g *session) ForEachNodeChannel(nodePub route.Vertex,
88+
func (g *session) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
8789
cb func(channel *graphdb.DirectedChannel) error) error {
8890

89-
return g.graph.ForEachNodeDirectedChannel(
90-
context.TODO(), g.tx, nodePub, cb,
91-
)
91+
return g.graph.ForEachNodeDirectedChannel(ctx, g.tx, nodePub, cb)
9292
}
9393

9494
// FetchNodeFeatures returns the features of the given node. If the node is
9595
// unknown, assume no additional features are supported.
9696
//
9797
// NOTE: Part of the routing.Graph interface.
98-
func (g *session) FetchNodeFeatures(nodePub route.Vertex) (
99-
*lnwire.FeatureVector, error) {
98+
func (g *session) FetchNodeFeatures(ctx context.Context,
99+
nodePub route.Vertex) (*lnwire.FeatureVector, error) {
100100

101-
return g.graph.FetchNodeFeatures(context.TODO(), g.tx, nodePub)
101+
return g.graph.FetchNodeFeatures(ctx, g.tx, nodePub)
102102
}
103103

104104
// A compile-time check to ensure that *session implements the

lnrpc/invoicesrpc/addinvoice.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ type AddInvoiceConfig struct {
9696

9797
// QueryBlindedRoutes can be used to generate a few routes to this node
9898
// that can then be used in the construction of a blinded payment path.
99-
QueryBlindedRoutes func(lnwire.MilliSatoshi) ([]*route.Route, error)
99+
QueryBlindedRoutes func(context.Context, lnwire.MilliSatoshi) (
100+
[]*route.Route, error)
100101
}
101102

102103
// AddInvoiceData contains the required data to create a new invoice.

lnrpc/routerrpc/router_backend.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ type RouterBackend struct {
5050

5151
// FetchAmountPairCapacity determines the maximal channel capacity
5252
// between two nodes given a certain amount.
53-
FetchAmountPairCapacity func(nodeFrom, nodeTo route.Vertex,
53+
FetchAmountPairCapacity func(ctx context.Context, nodeFrom,
54+
nodeTo route.Vertex,
5455
amount lnwire.MilliSatoshi) (btcutil.Amount, error)
5556

5657
// FetchChannelEndpoints returns the pubkeys of both endpoints of the
@@ -60,7 +61,8 @@ type RouterBackend struct {
6061

6162
// FindRoute is a closure that abstracts away how we locate/query for
6263
// routes.
63-
FindRoute func(*routing.RouteRequest) (*route.Route, float64, error)
64+
FindRoute func(context.Context, *routing.RouteRequest) (*route.Route,
65+
float64, error)
6466

6567
MissionControl MissionControl
6668

@@ -165,7 +167,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
165167
// Query the channel router for a possible path to the destination that
166168
// can carry `in.Amt` satoshis _including_ the total fee required on
167169
// the route
168-
route, successProb, err := r.FindRoute(routeReq)
170+
route, successProb, err := r.FindRoute(ctx, routeReq)
169171
if err != nil {
170172
return nil, err
171173
}

lnrpc/routerrpc/router_backend_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool,
120120
}
121121
}
122122

123-
findRoute := func(req *routing.RouteRequest) (*route.Route, float64,
124-
error) {
123+
findRoute := func(_ context.Context, req *routing.RouteRequest) (
124+
*route.Route, float64, error) {
125125

126126
if int64(req.Amount) != amtSat*1000 {
127127
t.Fatal("unexpected amount")
@@ -200,7 +200,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool,
200200

201201
return 1, nil
202202
},
203-
FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex,
203+
FetchAmountPairCapacity: func(_ context.Context, nodeFrom,
204+
nodeTo route.Vertex,
204205
amount lnwire.MilliSatoshi) (btcutil.Amount, error) {
205206

206207
return 1, nil

lnrpc/routerrpc/router_server.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ func (s *Server) EstimateRouteFee(ctx context.Context,
426426
return nil, errors.New("amount must be greater than 0")
427427

428428
default:
429-
return s.probeDestination(req.Dest, req.AmtSat)
429+
return s.probeDestination(ctx, req.Dest, req.AmtSat)
430430
}
431431

432432
case isProbeInvoice:
@@ -440,8 +440,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context,
440440

441441
// probeDestination estimates fees along a route to a destination based on the
442442
// contents of the local graph.
443-
func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse,
444-
error) {
443+
func (s *Server) probeDestination(ctx context.Context, dest []byte,
444+
amtSat int64) (*RouteFeeResponse, error) {
445445

446446
destNode, err := route.NewVertexFromBytes(dest)
447447
if err != nil {
@@ -469,7 +469,7 @@ func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse,
469469
return nil, err
470470
}
471471

472-
route, _, err := s.cfg.Router.FindRoute(routeReq)
472+
route, _, err := s.cfg.Router.FindRoute(ctx, routeReq)
473473
if err != nil {
474474
return nil, err
475475
}
@@ -1429,7 +1429,7 @@ func (s *Server) trackPaymentStream(context context.Context,
14291429
}
14301430

14311431
// BuildRoute builds a route from a list of hop addresses.
1432-
func (s *Server) BuildRoute(_ context.Context,
1432+
func (s *Server) BuildRoute(ctx context.Context,
14331433
req *BuildRouteRequest) (*BuildRouteResponse, error) {
14341434

14351435
if len(req.HopPubkeys) == 0 {
@@ -1490,7 +1490,7 @@ func (s *Server) BuildRoute(_ context.Context,
14901490

14911491
// Build the route and return it to the caller.
14921492
route, err := s.cfg.Router.BuildRoute(
1493-
amt, hops, outgoingChan, req.FinalCltvDelta, payAddr,
1493+
ctx, amt, hops, outgoingChan, req.FinalCltvDelta, payAddr,
14941494
firstHopBlob,
14951495
)
14961496
if err != nil {

lnrpc/routerrpc/router_server_deprecated.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func (s *Server) SendToRoute(ctx context.Context,
123123

124124
// QueryProbability returns the current success probability estimate for a
125125
// given node pair and amount.
126-
func (s *Server) QueryProbability(_ context.Context,
126+
func (s *Server) QueryProbability(ctx context.Context,
127127
req *QueryProbabilityRequest) (*QueryProbabilityResponse, error) {
128128

129129
fromNode, err := route.NewVertexFromBytes(req.FromNode)
@@ -142,7 +142,7 @@ func (s *Server) QueryProbability(_ context.Context,
142142
var prob float64
143143
mc := s.cfg.RouterBackend.MissionControl
144144
capacity, err := s.cfg.RouterBackend.FetchAmountPairCapacity(
145-
fromNode, toNode, amt,
145+
ctx, fromNode, toNode, amt,
146146
)
147147

148148
// If we cannot query the capacity this means that either we don't have

routing/bandwidth.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package routing
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/lightningnetwork/lnd/fn"
@@ -82,8 +83,9 @@ type bandwidthManager struct {
8283
// hints for the edges we directly have open ourselves. Obtaining these hints
8384
// allows us to reduce the number of extraneous attempts as we can skip channels
8485
// that are inactive, or just don't have enough bandwidth to carry the payment.
85-
func newBandwidthManager(graph Graph, sourceNode route.Vertex,
86-
linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob],
86+
func newBandwidthManager(ctx context.Context, graph Graph,
87+
sourceNode route.Vertex, linkQuery getLinkQuery,
88+
firstHopBlob fn.Option[tlv.Blob],
8789
trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) {
8890

8991
manager := &bandwidthManager{
@@ -95,7 +97,7 @@ func newBandwidthManager(graph Graph, sourceNode route.Vertex,
9597

9698
// First, we'll collect the set of outbound edges from the target
9799
// source node and add them to our bandwidth manager's map of channels.
98-
err := graph.ForEachNodeChannel(sourceNode,
100+
err := graph.ForEachNodeChannel(ctx, sourceNode,
99101
func(channel *graphdb.DirectedChannel) error {
100102
shortID := lnwire.NewShortChanIDFromInt(
101103
channel.ChannelID,

routing/bandwidth_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package routing
22

33
import (
4+
"context"
45
"testing"
56

67
"github.com/btcsuite/btcd/btcutil"
@@ -116,7 +117,8 @@ func TestBandwidthManager(t *testing.T) {
116117
)
117118

118119
m, err := newBandwidthManager(
119-
g, sourceNode.pubkey, testCase.linkQuery,
120+
context.Background(), g, sourceNode.pubkey,
121+
testCase.linkQuery,
120122
fn.None[[]byte](),
121123
fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
122124
)

routing/blindedpath/blinded_path.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ type BuildBlindedPathCfg struct {
3939
// various lengths and may even contain only a single hop. Any route
4040
// shorter than MinNumHops will be padded with dummy hops during route
4141
// construction.
42-
FindRoutes func(value lnwire.MilliSatoshi) ([]*route.Route, error)
42+
FindRoutes func(ctx context.Context, value lnwire.MilliSatoshi) (
43+
[]*route.Route, error)
4344

4445
// FetchChannelEdgesByID attempts to look up the two directed edges for
4546
// the channel identified by the channel ID.
@@ -118,7 +119,7 @@ func BuildBlindedPaymentPaths(ctx context.Context, cfg *BuildBlindedPathCfg) (
118119

119120
// Find some appropriate routes for the value to be routed. This will
120121
// return a set of routes made up of real nodes.
121-
routes, err := cfg.FindRoutes(cfg.ValueMsat)
122+
routes, err := cfg.FindRoutes(ctx, cfg.ValueMsat)
122123
if err != nil {
123124
return nil, err
124125
}

routing/blindedpath/blinded_path_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ func TestBuildBlindedPath(t *testing.T) {
595595
}
596596

597597
paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
598-
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
599-
error) {
598+
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
599+
[]*route.Route, error) {
600600

601601
return []*route.Route{realRoute}, nil
602602
},
@@ -765,8 +765,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
765765
}
766766

767767
paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
768-
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
769-
error) {
768+
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
769+
[]*route.Route, error) {
770770

771771
return []*route.Route{realRoute}, nil
772772
},
@@ -935,8 +935,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
935935
// still get 1 valid path.
936936
var errCount int
937937
paths, err = BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
938-
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
939-
error) {
938+
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
939+
[]*route.Route, error) {
940940

941941
return []*route.Route{realRoute, realRoute, realRoute},
942942
nil
@@ -1016,8 +1016,8 @@ func TestSingleHopBlindedPath(t *testing.T) {
10161016
}
10171017

10181018
paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
1019-
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
1020-
error) {
1019+
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
1020+
[]*route.Route, error) {
10211021

10221022
return []*route.Route{realRoute}, nil
10231023
},

0 commit comments

Comments
 (0)