diff --git a/lnwire/musig2.go b/lnwire/musig2.go index cfc753f820b..10dbc27af39 100644 --- a/lnwire/musig2.go +++ b/lnwire/musig2.go @@ -51,7 +51,7 @@ func nonceTypeEncoder(w io.Writer, val interface{}, _ *[8]byte) error { func nonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { - if v, ok := val.(*Musig2Nonce); ok { + if v, ok := val.(*Musig2Nonce); ok && l == musig2.PubNonceSize { _, err := io.ReadFull(r, v[:]) return err } diff --git a/lnwire/musig2_test.go b/lnwire/musig2_test.go new file mode 100644 index 00000000000..179a8163e33 --- /dev/null +++ b/lnwire/musig2_test.go @@ -0,0 +1,58 @@ +package lnwire + +import ( + "testing" + + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/stretchr/testify/require" +) + +func makeNonce() Musig2Nonce { + var n Musig2Nonce + for i := range musig2.PubNonceSize { + n[i] = byte(i) + } + + return n +} + +// TestMusig2NonceEncodeDecode tests that we're able to properly encode and +// decode Musig2Nonce within TLV streams. +func TestMusig2NonceEncodeDecode(t *testing.T) { + t.Parallel() + + nonce := makeNonce() + + var extraData ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&nonce)) + + var extractedNonce Musig2Nonce + _, err := extraData.ExtractRecords(&extractedNonce) + require.NoError(t, err) + + require.Equal(t, nonce, extractedNonce) +} + +// TestMusig2NonceTypeDecodeInvalidLength ensures that decoding a Musig2Nonce +// TLV with an invalid length (anything other than 66 bytes) fails with an +// error. +func TestMusig2NonceTypeDecodeInvalidLength(t *testing.T) { + t.Parallel() + + nonce := makeNonce() + + var extraData ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&nonce)) + + // Corrupt the TLV length field to simulate malformed input. + extraData[1] = musig2.PubNonceSize + 1 + + var out Musig2Nonce + _, err := extraData.ExtractRecords(&out) + require.Error(t, err) + + extraData[1] = musig2.PubNonceSize - 1 + + _, err = extraData.ExtractRecords(&out) + require.Error(t, err) +} diff --git a/lnwire/short_channel_id.go b/lnwire/short_channel_id.go index 73e37ab96fd..e26575001a1 100644 --- a/lnwire/short_channel_id.go +++ b/lnwire/short_channel_id.go @@ -92,7 +92,8 @@ func DShortChannelID(r io.Reader, val interface{}, buf *[8]byte, if v, ok := val.(*ShortChannelID); ok { var scid uint64 - err := tlv.DUint64(r, &scid, buf, 8) + // tlv.DUint64 forces the length to be 8 bytes. + err := tlv.DUint64(r, &scid, buf, l) if err != nil { return err } diff --git a/lnwire/short_channel_id_test.go b/lnwire/short_channel_id_test.go index 2916f20d17c..efc0cba40f7 100644 --- a/lnwire/short_channel_id_test.go +++ b/lnwire/short_channel_id_test.go @@ -62,3 +62,28 @@ func TestScidTypeEncodeDecode(t *testing.T) { require.Contains(t, tlvs, AliasScidRecordType) require.Equal(t, aliasScid, aliasScid2) } + +// TestScidTypeDecodeInvalidLength ensures that decoding a ShortChannelID TLV +// with an invalid length (anything other than 8 bytes) fails with an error. +func TestScidTypeDecodeInvalidLength(t *testing.T) { + t.Parallel() + + aliasScid := ShortChannelID{ + BlockHeight: 1, TxIndex: 1, TxPosition: 1, + } + + var extraData ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&aliasScid)) + + // Corrupt the TLV length field to simulate malformed input. + extraData[1] = 8 + 1 + + var out ShortChannelID + _, err := extraData.ExtractRecords(&out) + require.Error(t, err) + + extraData[1] = 8 - 1 + + _, err = extraData.ExtractRecords(&out) + require.Error(t, err) +} diff --git a/lnwire/typed_fee.go b/lnwire/typed_fee.go index 6b139f196f7..95b89d29720 100644 --- a/lnwire/typed_fee.go +++ b/lnwire/typed_fee.go @@ -41,7 +41,7 @@ func feeEncoder(w io.Writer, val interface{}, buf *[8]byte) error { // feeDecoder is a custom TLV decoder for the fee record. func feeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { v, ok := val.(*Fee) - if !ok { + if !ok || l != 8 { return tlv.NewTypeForDecodingErr(val, "lnwire.Fee", l, 8) } diff --git a/lnwire/typed_fee_test.go b/lnwire/typed_fee_test.go index a54b765ea61..f9d41e58f38 100644 --- a/lnwire/typed_fee_test.go +++ b/lnwire/typed_fee_test.go @@ -38,3 +38,28 @@ func testTypedFee(t *testing.T, fee Fee) { //nolint: thelper require.Equal(t, fee, extractedFee) } + +// TestTypedFeeTypeDecodeInvalidLength ensures that decoding a Fee TLV +// with an invalid length (anything other than 8 bytes) fails with an error. +func TestTypedFeeTypeDecodeInvalidLength(t *testing.T) { + t.Parallel() + + fee := Fee{ + BaseFee: 1, FeeRate: 1, + } + + var extraData ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&fee)) + + // Corrupt the TLV length field to simulate malformed input. + extraData[3] = 8 + 1 + + var out Fee + _, err := extraData.ExtractRecords(&out) + require.Error(t, err) + + extraData[3] = 8 - 1 + + _, err = extraData.ExtractRecords(&out) + require.Error(t, err) +} diff --git a/routing/route/route.go b/routing/route/route.go index b3e91a6f4a2..fc5ac4f8eb0 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -112,7 +112,7 @@ func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error { } func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { - if b, ok := val.(*Vertex); ok { + if b, ok := val.(*Vertex); ok && l == VertexSize { _, err := io.ReadFull(r, b[:]) return err } diff --git a/routing/route/route_test.go b/routing/route/route_test.go index 99594833fcd..960f30a55b2 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -430,3 +431,53 @@ func TestBlindedHopFee(t *testing.T) { require.Equal(t, lnwire.MilliSatoshi(0), route.HopFee(3)) require.Equal(t, lnwire.MilliSatoshi(0), route.HopFee(4)) } + +func makeVertex() Vertex { + var v Vertex + for i := range VertexSize { + v[i] = byte(i) + } + + return v +} + +// TestVertexTLVEncodeDecode tests that we're able to properly encode and decode +// Vertex within TLV streams. +func TestVertexTLVEncodeDecode(t *testing.T) { + t.Parallel() + + vertex := makeVertex() + + var extraData lnwire.ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&vertex)) + + var vertex2 Vertex + tlvs, err := extraData.ExtractRecords(&vertex2) + require.NoError(t, err) + + require.Contains(t, tlvs, tlv.Type(0)) + require.Equal(t, vertex, vertex2) +} + +// TestVertexTypeDecodeInvalidLength ensures that decoding a Vertex TLV +// with an invalid length (anything other than 33) fails with an error. +func TestVertexTypeDecodeInvalidLength(t *testing.T) { + t.Parallel() + + vertex := makeVertex() + + var extraData lnwire.ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&vertex)) + + // Corrupt the TLV length field to simulate malformed input. + extraData[1] = VertexSize + 1 + + var out Vertex + _, err := extraData.ExtractRecords(&out) + require.Error(t, err) + + extraData[1] = VertexSize - 1 + + _, err = extraData.ExtractRecords(&out) + require.Error(t, err) +} diff --git a/tlv/primitive.go b/tlv/primitive.go index fb8bb70c3b4..969d590331c 100644 --- a/tlv/primitive.go +++ b/tlv/primitive.go @@ -257,7 +257,7 @@ func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error { // DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not // a *[33]byte. func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { - if b, ok := val.(*[33]byte); ok { + if b, ok := val.(*[33]byte); ok && l == 33 { _, err := io.ReadFull(r, b[:]) return err } diff --git a/tlv/primitive_test.go b/tlv/primitive_test.go index ba84320e613..3323ad27750 100644 --- a/tlv/primitive_test.go +++ b/tlv/primitive_test.go @@ -253,3 +253,62 @@ func TestPrimitiveEncodings(t *testing.T) { prim, prim2) } } + +// TestPrimitiveWrongLength asserts that fixed-size primitive decoders fail +// with ErrTypeForDecoding when given an incorrect TLV length. +func TestPrimitiveWrongLength(t *testing.T) { + prim := primitive{ + u8: 0x01, + u16: 0x0201, + u32: 0x02000001, + u64: 0x0200000000000001, + b32: [32]byte{0x02, 0x01}, + b33: [33]byte{0x03, 0x01}, + b64: [64]byte{0x02, 0x01}, + pk: testPK, + boolean: true, + } + + type item struct { + enc fieldEncoder + dec fieldDecoder + } + + items := []item{ + {fieldEncoder{&prim.u8, tlv.EUint8}, fieldDecoder{new(byte), tlv.DUint8, 1}}, + {fieldEncoder{&prim.u16, tlv.EUint16}, fieldDecoder{new(uint16), tlv.DUint16, 2}}, + {fieldEncoder{&prim.u32, tlv.EUint32}, fieldDecoder{new(uint32), tlv.DUint32, 4}}, + {fieldEncoder{&prim.u64, tlv.EUint64}, fieldDecoder{new(uint64), tlv.DUint64, 8}}, + {fieldEncoder{&prim.b32, tlv.EBytes32}, fieldDecoder{new([32]byte), tlv.DBytes32, 32}}, + {fieldEncoder{&prim.b33, tlv.EBytes33}, fieldDecoder{new([33]byte), tlv.DBytes33, 33}}, + {fieldEncoder{&prim.b64, tlv.EBytes64}, fieldDecoder{new([64]byte), tlv.DBytes64, 64}}, + {fieldEncoder{&prim.pk, tlv.EPubKey}, fieldDecoder{new(*btcec.PublicKey), tlv.DPubKey, 33}}, + {fieldEncoder{&prim.boolean, tlv.EBool}, fieldDecoder{new(bool), tlv.DBool, 1}}, + } + + for _, it := range items { + var buf [8]byte + var b bytes.Buffer + if err := it.enc.encoder(&b, it.enc.val, &buf); err != nil { + t.Fatalf("encode %T: %v", it.enc.val, err) + } + data := b.Bytes() + + // Generate two wrong lengths: expected-1 (if >0) and expected+1. + wrongs := []uint64{it.dec.size + 1} + if it.dec.size > 0 { + wrongs = append(wrongs, it.dec.size-1) + } + + for _, l := range wrongs { + r := bytes.NewReader(data) + if err := it.dec.decoder(r, it.dec.val, &buf, l); err == nil { + t.Fatalf("decoder %T accepted wrong length %d (expected %d)", it.dec.decoder, l, it.dec.size) + } else { + if _, ok := err.(tlv.ErrTypeForDecoding); !ok { + t.Fatalf("expected ErrTypeForDecoding, got %T: %v", err, err) + } + } + } + } +}