Skip to content

Commit f26e597

Browse files
committed
m
1 parent f838ce4 commit f26e597

23 files changed

+1211
-111
lines changed

channeldb/migration/lnwire21/custom_records.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ const (
1818

1919
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
2020
// which must be greater than or equal to MinCustomRecordsTlvType.
21-
type CustomRecords map[uint64][]byte
21+
type CustomRecords map[tlv.Type][]byte
2222

2323
// NewCustomRecords creates a new CustomRecords instance from a
2424
// tlv.TypeMap.
@@ -31,7 +31,7 @@ func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
3131

3232
customRecords := make(CustomRecords, len(tlvMap))
3333
for k, v := range tlvMap {
34-
customRecords[uint64(k)] = v
34+
customRecords[k] = v
3535
}
3636

3737
// Validate the custom records.
@@ -111,7 +111,7 @@ func (c CustomRecords) ExtendRecordProducers(
111111
// slice would erroneously contain duplicate TLV types.
112112
for _, rp := range producers {
113113
record := rp.Record()
114-
recordTlvType := uint64(record.Type())
114+
recordTlvType := record.Type()
115115

116116
_, foundDuplicateTlvType := c[recordTlvType]
117117
if foundDuplicateTlvType {
@@ -185,9 +185,9 @@ func SortProducers(producers []tlv.RecordProducer) {
185185

186186
// TlvMapToRecords converts a TLV map into a slice of records.
187187
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
188-
tlvMapGeneric := make(map[uint64][]byte)
188+
tlvMapGeneric := make(map[tlv.Type][]byte)
189189
for k, v := range tlvMap {
190-
tlvMapGeneric[uint64(k)] = v
190+
tlvMapGeneric[k] = v
191191
}
192192

193193
return tlv.MapToRecords(tlvMapGeneric)

channeldb/migration32/route.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,9 @@ func deserializeHop(r io.Reader) (*Hop, error) {
473473
return h, nil
474474
}
475475

476-
tlvMap := make(map[uint64][]byte)
476+
tlvMap := make(map[tlv.Type][]byte)
477477
for i := uint32(0); i < numElements; i++ {
478-
var tlvType uint64
478+
var tlvType tlv.Type
479479
if err := ReadElements(r, &tlvType); err != nil {
480480
return nil, err
481481
}
@@ -497,7 +497,7 @@ func deserializeHop(r io.Reader) (*Hop, error) {
497497
// blobs. The split approach will cause headaches down the road as more
498498
// fields are added, which we can avoid by having a single TLV stream
499499
// for all payload fields.
500-
mppType := uint64(MPPOnionType)
500+
mppType := MPPOnionType
501501
if mppBytes, ok := tlvMap[mppType]; ok {
502502
delete(tlvMap, mppType)
503503

@@ -515,13 +515,13 @@ func deserializeHop(r io.Reader) (*Hop, error) {
515515

516516
// If encrypted data or blinding key are present, remove them from
517517
// the TLV map and parse into proper types.
518-
encryptedDataType := uint64(EncryptedDataOnionType)
518+
encryptedDataType := EncryptedDataOnionType
519519
if data, ok := tlvMap[encryptedDataType]; ok {
520520
delete(tlvMap, encryptedDataType)
521521
h.EncryptedData = data
522522
}
523523

524-
blindingType := uint64(BlindingPointOnionType)
524+
blindingType := BlindingPointOnionType
525525
if blindingPoint, ok := tlvMap[blindingType]; ok {
526526
delete(tlvMap, blindingType)
527527

@@ -532,7 +532,7 @@ func deserializeHop(r io.Reader) (*Hop, error) {
532532
}
533533
}
534534

535-
ampType := uint64(AMPOnionType)
535+
ampType := AMPOnionType
536536
if ampBytes, ok := tlvMap[ampType]; ok {
537537
delete(tlvMap, ampType)
538538

@@ -550,14 +550,14 @@ func deserializeHop(r io.Reader) (*Hop, error) {
550550

551551
// If the metadata type is present, remove it from the tlv map and
552552
// populate directly on the hop.
553-
metadataType := uint64(MetadataOnionType)
553+
metadataType := MetadataOnionType
554554
if metadata, ok := tlvMap[metadataType]; ok {
555555
delete(tlvMap, metadataType)
556556

557557
h.Metadata = metadata
558558
}
559559

560-
totalAmtMsatType := uint64(TotalAmtMsatBlindedType)
560+
totalAmtMsatType := TotalAmtMsatBlindedType
561561
if totalAmtMsat, ok := tlvMap[totalAmtMsatType]; ok {
562562
delete(tlvMap, totalAmtMsatType)
563563

channeldb/migration_01_to_11/payments.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,9 @@ func deserializeHop(r io.Reader) (*Hop, error) {
546546
return h, nil
547547
}
548548

549-
tlvMap := make(map[uint64][]byte)
549+
tlvMap := make(map[tlv.Type][]byte)
550550
for i := uint32(0); i < numElements; i++ {
551-
var tlvType uint64
551+
var tlvType tlv.Type
552552
if err := ReadElements(r, &tlvType); err != nil {
553553
return nil, err
554554
}

channeldb/payments.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,9 +1306,9 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
13061306
return h, nil
13071307
}
13081308

1309-
tlvMap := make(map[uint64][]byte)
1309+
tlvMap := make(map[tlv.Type][]byte)
13101310
for i := uint32(0); i < numElements; i++ {
1311-
var tlvType uint64
1311+
var tlvType tlv.Type
13121312
if err := ReadElements(r, &tlvType); err != nil {
13131313
return nil, err
13141314
}
@@ -1330,7 +1330,7 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
13301330
// blobs. The split approach will cause headaches down the road as more
13311331
// fields are added, which we can avoid by having a single TLV stream
13321332
// for all payload fields.
1333-
mppType := uint64(record.MPPOnionType)
1333+
mppType := record.MPPOnionType
13341334
if mppBytes, ok := tlvMap[mppType]; ok {
13351335
delete(tlvMap, mppType)
13361336

@@ -1348,13 +1348,13 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
13481348

13491349
// If encrypted data or blinding key are present, remove them from
13501350
// the TLV map and parse into proper types.
1351-
encryptedDataType := uint64(record.EncryptedDataOnionType)
1351+
encryptedDataType := record.EncryptedDataOnionType
13521352
if data, ok := tlvMap[encryptedDataType]; ok {
13531353
delete(tlvMap, encryptedDataType)
13541354
h.EncryptedData = data
13551355
}
13561356

1357-
blindingType := uint64(record.BlindingPointOnionType)
1357+
blindingType := record.BlindingPointOnionType
13581358
if blindingPoint, ok := tlvMap[blindingType]; ok {
13591359
delete(tlvMap, blindingType)
13601360

@@ -1365,7 +1365,7 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
13651365
}
13661366
}
13671367

1368-
ampType := uint64(record.AMPOnionType)
1368+
ampType := record.AMPOnionType
13691369
if ampBytes, ok := tlvMap[ampType]; ok {
13701370
delete(tlvMap, ampType)
13711371

@@ -1383,14 +1383,14 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
13831383

13841384
// If the metadata type is present, remove it from the tlv map and
13851385
// populate directly on the hop.
1386-
metadataType := uint64(record.MetadataOnionType)
1386+
metadataType := record.MetadataOnionType
13871387
if metadata, ok := tlvMap[metadataType]; ok {
13881388
delete(tlvMap, metadataType)
13891389

13901390
h.Metadata = metadata
13911391
}
13921392

1393-
totalAmtMsatType := uint64(record.TotalAmtMsatBlindedType)
1393+
totalAmtMsatType := record.TotalAmtMsatBlindedType
13941394
if totalAmtMsat, ok := tlvMap[totalAmtMsatType]; ok {
13951395
delete(tlvMap, totalAmtMsatType)
13961396

htlcswitch/hop/payload.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet {
256256
if parseResult == nil || t < record.CustomTypeStart {
257257
continue
258258
}
259-
customRecords[uint64(t)] = parseResult
259+
customRecords[t] = parseResult
260260
}
261261
return customRecords
262262
}

invoices/sql_store.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/lightningnetwork/lnd/record"
1919
"github.com/lightningnetwork/lnd/sqldb"
2020
"github.com/lightningnetwork/lnd/sqldb/sqlc"
21+
"github.com/lightningnetwork/lnd/tlv"
2122
)
2223

2324
const (
@@ -494,7 +495,7 @@ func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
494495
value = []byte{}
495496
}
496497

497-
customRecords[row.HtlcID][uint64(row.Key)] = value
498+
customRecords[row.HtlcID][tlv.Type(row.Key)] = value
498499
}
499500

500501
// Now fetch all the AMP HTLCs for this invoice or the one matching the
@@ -1566,7 +1567,7 @@ func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
15661567
if value == nil {
15671568
value = []byte{}
15681569
}
1569-
cr[row.HtlcID][uint64(row.Key)] = value
1570+
cr[row.HtlcID][tlv.Type(row.Key)] = value
15701571
}
15711572

15721573
htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))

lnrpc/invoicesrpc/htlc_modifier.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,19 @@ func (r *htlcModifier) onIntercept(
4949
return nil, err
5050
}
5151

52+
customRecs := make(map[uint64][]byte, len(req.WireCustomRecords))
53+
for t, b := range req.WireCustomRecords {
54+
customRecs[uint64(t)] = b
55+
}
56+
5257
// Send the modification request to the client.
5358
err = r.serverStream.Send(&HtlcModifyRequest{
5459
Invoice: rpcInvoice,
5560
ExitHtlcCircuitKey: rpcCircuitKey,
5661
ExitHtlcAmt: uint64(req.ExitHtlcAmt),
5762
ExitHtlcExpiry: req.ExitHtlcExpiry,
5863
CurrentHeight: req.CurrentHeight,
59-
ExitHtlcWireCustomRecords: req.WireCustomRecords,
64+
ExitHtlcWireCustomRecords: customRecs,
6065
})
6166
if err != nil {
6267
return nil, err

lnrpc/invoicesrpc/utils.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ func CreateRPCInvoice(invoice *invoices.Invoice,
111111
return nil, fmt.Errorf("unknown state %v", htlc.State)
112112
}
113113

114+
customRecs := make(map[uint64][]byte, len(htlc.CustomRecords))
115+
for t, b := range htlc.CustomRecords {
116+
customRecs[uint64(t)] = b
117+
}
118+
114119
rpcHtlc := lnrpc.InvoiceHTLC{
115120
ChanId: key.ChanID.ToUint64(),
116121
HtlcIndex: key.HtlcID,
@@ -119,7 +124,7 @@ func CreateRPCInvoice(invoice *invoices.Invoice,
119124
ExpiryHeight: int32(htlc.Expiry),
120125
AmtMsat: uint64(htlc.Amt),
121126
State: state,
122-
CustomRecords: htlc.CustomRecords,
127+
CustomRecords: customRecs,
123128
MppTotalAmtMsat: uint64(htlc.MppTotalAmt),
124129
}
125130

lnrpc/routerrpc/forward_interceptor.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/lightningnetwork/lnd/lnrpc"
1010
"github.com/lightningnetwork/lnd/lntypes"
1111
"github.com/lightningnetwork/lnd/lnwire"
12+
"github.com/lightningnetwork/lnd/tlv"
1213
"google.golang.org/grpc/codes"
1314
"google.golang.org/grpc/status"
1415
)
@@ -77,6 +78,15 @@ func (r *forwardInterceptor) onIntercept(
7778

7879
inKey := htlc.IncomingCircuit
7980

81+
customRecs := make(map[uint64][]byte, len(htlc.InOnionCustomRecords))
82+
for t, b := range htlc.InOnionCustomRecords {
83+
customRecs[uint64(t)] = b
84+
}
85+
customRecs2 := make(map[uint64][]byte, len(htlc.InWireCustomRecords))
86+
for t, b := range htlc.InWireCustomRecords {
87+
customRecs2[uint64(t)] = b
88+
}
89+
8090
// First hold the forward, then send to client.
8191
interceptionRequest := &ForwardHtlcInterceptRequest{
8292
IncomingCircuitKey: &CircuitKey{
@@ -89,10 +99,10 @@ func (r *forwardInterceptor) onIntercept(
8999
OutgoingExpiry: htlc.OutgoingExpiry,
90100
IncomingAmountMsat: uint64(htlc.IncomingAmount),
91101
IncomingExpiry: htlc.IncomingExpiry,
92-
CustomRecords: htlc.InOnionCustomRecords,
102+
CustomRecords: customRecs,
93103
OnionBlob: htlc.OnionBlob[:],
94104
AutoFailHeight: htlc.AutoFailHeight,
95-
InWireCustomRecords: htlc.InWireCustomRecords,
105+
InWireCustomRecords: customRecs2,
96106
}
97107

98108
return r.stream.Send(interceptionRequest)
@@ -139,10 +149,15 @@ func (r *forwardInterceptor) resolveFromClient(
139149
))
140150
}
141151

152+
customRecs := make(map[tlv.Type][]byte, len(in.OutWireCustomRecords))
153+
for t, b := range in.OutWireCustomRecords {
154+
customRecs[tlv.Type(t)] = b
155+
}
156+
142157
outWireCustomRecords := fn.None[lnwire.CustomRecords]()
143158
if len(in.OutWireCustomRecords) > 0 {
144159
// Validate custom records.
145-
cr := lnwire.CustomRecords(in.OutWireCustomRecords)
160+
cr := lnwire.CustomRecords(customRecs)
146161
if err := cr.Validate(); err != nil {
147162
return status.Errorf(
148163
codes.InvalidArgument,

lnwallet/channel.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4996,7 +4996,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel,
49964996
// it, store it in the custom records map so we can
49974997
// write to disk later.
49984998
sigType := htlcCustomSigType.TypeVal()
4999-
htlc.CustomRecords[uint64(sigType)] = auxSig.UnwrapOr(
4999+
htlc.CustomRecords[sigType] = auxSig.UnwrapOr(
50005000
nil,
50015001
)
50025002

0 commit comments

Comments
 (0)