@@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
269
269
270
270
// For each key found, we'll look up the actual
271
271
// invoice, then accumulate it into our return value.
272
- invoice , err := fetchInvoice (invoiceKey , invoices )
272
+ invoice , err := fetchInvoice (
273
+ invoiceKey , invoices , nil , false ,
274
+ )
273
275
if err != nil {
274
276
return err
275
277
}
@@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
341
343
342
344
// An invoice was found, retrieve the remainder of the invoice
343
345
// body.
344
- i , err := fetchInvoice (invoiceNum , invoices , setID )
346
+ i , err := fetchInvoice (
347
+ invoiceNum , invoices , []* invpkg.SetID {setID }, true ,
348
+ )
345
349
if err != nil {
346
350
return err
347
351
}
@@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
468
472
return nil
469
473
}
470
474
471
- invoice , err := fetchInvoice (v , invoices )
475
+ invoice , err := fetchInvoice (v , invoices , nil , false )
472
476
if err != nil {
473
477
return err
474
478
}
@@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
526
530
// characteristics for our query and returns the number of items
527
531
// we have added to our set of invoices.
528
532
accumulateInvoices := func (_ , indexValue []byte ) (bool , error ) {
529
- invoice , err := fetchInvoice (indexValue , invoices )
533
+ invoice , err := fetchInvoice (
534
+ indexValue , invoices , nil , false ,
535
+ )
530
536
if err != nil {
531
537
return false , err
532
538
}
@@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
654
660
if setIDHint != nil {
655
661
invSetID = * setIDHint
656
662
}
657
- invoice , err := fetchInvoice (invoiceNum , invoices , & invSetID )
663
+ invoice , err := fetchInvoice (
664
+ invoiceNum , invoices , []* invpkg.SetID {& invSetID }, false ,
665
+ )
658
666
if err != nil {
659
667
return err
660
668
}
@@ -676,15 +684,43 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
676
684
updatedInvoice , err = invpkg .UpdateInvoice (
677
685
payHash , updater .invoice , now , callback , updater ,
678
686
)
687
+ if err != nil {
688
+ return err
689
+ }
679
690
680
- return err
691
+ // If this is an AMP update, then limit the returned AMP state
692
+ // to only the requested set ID.
693
+ if setIDHint != nil {
694
+ filterInvoiceAMPState (updatedInvoice , & invSetID )
695
+ }
696
+
697
+ return nil
681
698
}, func () {
682
699
updatedInvoice = nil
683
700
})
684
701
685
702
return updatedInvoice , err
686
703
}
687
704
705
+ // filterInvoiceAMPState filters the AMP state of the invoice to only include
706
+ // state for the specified set IDs.
707
+ func filterInvoiceAMPState (invoice * invpkg.Invoice , setIDs ... * invpkg.SetID ) {
708
+ filteredAMPState := make (invpkg.AMPInvoiceState )
709
+
710
+ for _ , setID := range setIDs {
711
+ if setID == nil {
712
+ return
713
+ }
714
+
715
+ ampState , ok := invoice .AMPState [* setID ]
716
+ if ok {
717
+ filteredAMPState [* setID ] = ampState
718
+ }
719
+ }
720
+
721
+ invoice .AMPState = filteredAMPState
722
+ }
723
+
688
724
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
689
725
type ampHTLCsMap map [invpkg.SetID ]map [models.CircuitKey ]* invpkg.InvoiceHTLC
690
726
@@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1056
1092
// For each key found, we'll look up the actual
1057
1093
// invoice, then accumulate it into our return value.
1058
1094
invoice , err := fetchInvoice (
1059
- invoiceKey [:], invoices , setID ,
1095
+ invoiceKey [:], invoices , []* invpkg.SetID {setID },
1096
+ true ,
1060
1097
)
1061
1098
if err != nil {
1062
1099
return err
@@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1485
1522
// specified by the invoice number. If the setID fields are set, then only the
1486
1523
// HTLC information pertaining to those set IDs is returned.
1487
1524
func fetchInvoice (invoiceNum []byte , invoices kvdb.RBucket ,
1488
- setIDs ... * invpkg.SetID ) (invpkg.Invoice , error ) {
1525
+ setIDs [] * invpkg.SetID , filterAMPState bool ) (invpkg.Invoice , error ) {
1489
1526
1490
1527
invoiceBytes := invoices .Get (invoiceNum )
1491
1528
if invoiceBytes == nil {
@@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1518
1555
log .Errorf ("unable to fetch amp htlcs for inv " +
1519
1556
"%v and setIDs %v: %w" , invoiceNum , setIDs , err )
1520
1557
}
1558
+
1559
+ if filterAMPState {
1560
+ filterInvoiceAMPState (& invoice , setIDs ... )
1561
+ }
1521
1562
}
1522
1563
1523
1564
return invoice , nil
@@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
2163
2204
return nil
2164
2205
}
2165
2206
2166
- invoice , err := fetchInvoice (v , invoices )
2207
+ invoice , err := fetchInvoice (v , invoices , nil , false )
2167
2208
if err != nil {
2168
2209
return err
2169
2210
}
0 commit comments