Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lnwire/musig2.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func nonceTypeEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
func nonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {

if v, ok := val.(*Musig2Nonce); ok {
if v, ok := val.(*Musig2Nonce); ok && l == musig2.PubNonceSize {
_, err := io.ReadFull(r, v[:])
return err
}
Expand Down
58 changes: 58 additions & 0 deletions lnwire/musig2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package lnwire

import (
"testing"

"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/stretchr/testify/require"
)

func makeNonce() Musig2Nonce {
var n Musig2Nonce
for i := range musig2.PubNonceSize {
n[i] = byte(i)
}

return n
}

// TestMusig2NonceEncodeDecode tests that we're able to properly encode and
// decode Musig2Nonce within TLV streams.
func TestMusig2NonceEncodeDecode(t *testing.T) {
t.Parallel()

nonce := makeNonce()

var extraData ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&nonce))

var extractedNonce Musig2Nonce
_, err := extraData.ExtractRecords(&extractedNonce)
require.NoError(t, err)

require.Equal(t, nonce, extractedNonce)
}

// TestMusig2NonceTypeDecodeInvalidLength ensures that decoding a Musig2Nonce
// TLV with an invalid length (anything other than 66 bytes) fails with an
// error.
func TestMusig2NonceTypeDecodeInvalidLength(t *testing.T) {
t.Parallel()

nonce := makeNonce()

var extraData ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&nonce))

// Corrupt the TLV length field to simulate malformed input.
extraData[1] = musig2.PubNonceSize + 1

var out Musig2Nonce
_, err := extraData.ExtractRecords(&out)
require.Error(t, err)

extraData[1] = musig2.PubNonceSize - 1

_, err = extraData.ExtractRecords(&out)
require.Error(t, err)
}
3 changes: 2 additions & 1 deletion lnwire/short_channel_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ func DShortChannelID(r io.Reader, val interface{}, buf *[8]byte,

if v, ok := val.(*ShortChannelID); ok {
var scid uint64
err := tlv.DUint64(r, &scid, buf, 8)
// tlv.DUint64 forces the length to be 8 bytes.
err := tlv.DUint64(r, &scid, buf, l)
if err != nil {
return err
}
Expand Down
25 changes: 25 additions & 0 deletions lnwire/short_channel_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,28 @@ func TestScidTypeEncodeDecode(t *testing.T) {
require.Contains(t, tlvs, AliasScidRecordType)
require.Equal(t, aliasScid, aliasScid2)
}

// TestScidTypeDecodeInvalidLength ensures that decoding a ShortChannelID TLV
// with an invalid length (anything other than 8 bytes) fails with an error.
func TestScidTypeDecodeInvalidLength(t *testing.T) {
t.Parallel()

aliasScid := ShortChannelID{
BlockHeight: 1, TxIndex: 1, TxPosition: 1,
}

var extraData ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&aliasScid))

// Corrupt the TLV length field to simulate malformed input.
extraData[1] = 8 + 1

var out ShortChannelID
_, err := extraData.ExtractRecords(&out)
require.Error(t, err)

extraData[1] = 8 - 1

_, err = extraData.ExtractRecords(&out)
require.Error(t, err)
}
2 changes: 1 addition & 1 deletion lnwire/typed_fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func feeEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
// feeDecoder is a custom TLV decoder for the fee record.
func feeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
v, ok := val.(*Fee)
if !ok {
if !ok || l != 8 {
return tlv.NewTypeForDecodingErr(val, "lnwire.Fee", l, 8)
}

Expand Down
25 changes: 25 additions & 0 deletions lnwire/typed_fee_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,28 @@ func testTypedFee(t *testing.T, fee Fee) { //nolint: thelper

require.Equal(t, fee, extractedFee)
}

// TestTypedFeeTypeDecodeInvalidLength ensures that decoding a Fee TLV
// with an invalid length (anything other than 8 bytes) fails with an error.
func TestTypedFeeTypeDecodeInvalidLength(t *testing.T) {
t.Parallel()

fee := Fee{
BaseFee: 1, FeeRate: 1,
}

var extraData ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&fee))

// Corrupt the TLV length field to simulate malformed input.
extraData[3] = 8 + 1

var out Fee
_, err := extraData.ExtractRecords(&out)
require.Error(t, err)

extraData[3] = 8 - 1

_, err = extraData.ExtractRecords(&out)
require.Error(t, err)
}
2 changes: 1 addition & 1 deletion routing/route/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error {
}

func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*Vertex); ok {
if b, ok := val.(*Vertex); ok && l == VertexSize {
_, err := io.ReadFull(r, b[:])
return err
}
Expand Down
51 changes: 51 additions & 0 deletions routing/route/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -430,3 +431,53 @@ func TestBlindedHopFee(t *testing.T) {
require.Equal(t, lnwire.MilliSatoshi(0), route.HopFee(3))
require.Equal(t, lnwire.MilliSatoshi(0), route.HopFee(4))
}

func makeVertex() Vertex {
var v Vertex
for i := range VertexSize {
v[i] = byte(i)
}

return v
}

// TestVertexTLVEncodeDecode tests that we're able to properly encode and decode
// Vertex within TLV streams.
func TestVertexTLVEncodeDecode(t *testing.T) {
t.Parallel()

vertex := makeVertex()

var extraData lnwire.ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&vertex))

var vertex2 Vertex
tlvs, err := extraData.ExtractRecords(&vertex2)
require.NoError(t, err)

require.Contains(t, tlvs, tlv.Type(0))
require.Equal(t, vertex, vertex2)
}

// TestVertexTypeDecodeInvalidLength ensures that decoding a Vertex TLV
// with an invalid length (anything other than 33) fails with an error.
func TestVertexTypeDecodeInvalidLength(t *testing.T) {
t.Parallel()

vertex := makeVertex()

var extraData lnwire.ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&vertex))

// Corrupt the TLV length field to simulate malformed input.
extraData[1] = VertexSize + 1

var out Vertex
_, err := extraData.ExtractRecords(&out)
require.Error(t, err)

extraData[1] = VertexSize - 1

_, err = extraData.ExtractRecords(&out)
require.Error(t, err)
}
2 changes: 1 addition & 1 deletion tlv/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error {
// DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not
// a *[33]byte.
func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*[33]byte); ok {
if b, ok := val.(*[33]byte); ok && l == 33 {
_, err := io.ReadFull(r, b[:])
return err
}
Expand Down
59 changes: 59 additions & 0 deletions tlv/primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,62 @@ func TestPrimitiveEncodings(t *testing.T) {
prim, prim2)
}
}

// TestPrimitiveWrongLength asserts that fixed-size primitive decoders fail
// with ErrTypeForDecoding when given an incorrect TLV length.
func TestPrimitiveWrongLength(t *testing.T) {
prim := primitive{
u8: 0x01,
u16: 0x0201,
u32: 0x02000001,
u64: 0x0200000000000001,
b32: [32]byte{0x02, 0x01},
b33: [33]byte{0x03, 0x01},
b64: [64]byte{0x02, 0x01},
pk: testPK,
boolean: true,
}

type item struct {
enc fieldEncoder
dec fieldDecoder
}

items := []item{
{fieldEncoder{&prim.u8, tlv.EUint8}, fieldDecoder{new(byte), tlv.DUint8, 1}},
{fieldEncoder{&prim.u16, tlv.EUint16}, fieldDecoder{new(uint16), tlv.DUint16, 2}},
{fieldEncoder{&prim.u32, tlv.EUint32}, fieldDecoder{new(uint32), tlv.DUint32, 4}},
{fieldEncoder{&prim.u64, tlv.EUint64}, fieldDecoder{new(uint64), tlv.DUint64, 8}},
{fieldEncoder{&prim.b32, tlv.EBytes32}, fieldDecoder{new([32]byte), tlv.DBytes32, 32}},
{fieldEncoder{&prim.b33, tlv.EBytes33}, fieldDecoder{new([33]byte), tlv.DBytes33, 33}},
{fieldEncoder{&prim.b64, tlv.EBytes64}, fieldDecoder{new([64]byte), tlv.DBytes64, 64}},
{fieldEncoder{&prim.pk, tlv.EPubKey}, fieldDecoder{new(*btcec.PublicKey), tlv.DPubKey, 33}},
{fieldEncoder{&prim.boolean, tlv.EBool}, fieldDecoder{new(bool), tlv.DBool, 1}},
}

for _, it := range items {
var buf [8]byte
var b bytes.Buffer
if err := it.enc.encoder(&b, it.enc.val, &buf); err != nil {
t.Fatalf("encode %T: %v", it.enc.val, err)
}
data := b.Bytes()

// Generate two wrong lengths: expected-1 (if >0) and expected+1.
wrongs := []uint64{it.dec.size + 1}
if it.dec.size > 0 {
wrongs = append(wrongs, it.dec.size-1)
}

for _, l := range wrongs {
r := bytes.NewReader(data)
if err := it.dec.decoder(r, it.dec.val, &buf, l); err == nil {
t.Fatalf("decoder %T accepted wrong length %d (expected %d)", it.dec.decoder, l, it.dec.size)
} else {
if _, ok := err.(tlv.ErrTypeForDecoding); !ok {
t.Fatalf("expected ErrTypeForDecoding, got %T: %v", err, err)
}
}
}
}
}
Loading