Skip to content

Commit 4e614ce

Browse files
committed
multi: decode zero-length onion message payloads
Since the onion message payload can be zero-length, we need to decode it correctly. This commit adds a boolean flag to the HopPayload Decode that tells whether the payload is an onion message payload or not. If it is, the payload is decoded as a tlv payload also if the first byte is 0x00. sphinx_test: Add zero-length payload om test
1 parent fcae597 commit 4e614ce

File tree

3 files changed

+132
-40
lines changed

3 files changed

+132
-40
lines changed

payload.go

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ type HopPayload struct {
4444
// HMAC is an HMAC computed over the entire per-hop payload that also
4545
// includes the higher-level (optional) associated data bytes.
4646
HMAC [HMACSize]byte
47+
48+
// TLVPayloadGuaranteed is set to true if the payload is guaranteed to
49+
// be a TLVPayload. E.g. in the case of an onion message.
50+
TLVPayloadGuaranteed bool
4751
}
4852

4953
// NewTLVHopPayload creates a new TLV encoded HopPayload. The payload will be
@@ -99,36 +103,13 @@ func (hp *HopPayload) Decode(r io.Reader) error {
99103
return err
100104
}
101105

102-
var (
103-
legacyPayload = isLegacyPayloadByte(peekByte[0])
104-
payloadSize uint16
105-
)
106-
107-
if legacyPayload {
108-
payloadSize = legacyPayloadSize()
109-
hp.Type = PayloadLegacy
110-
} else {
111-
payloadSize, err = tlvPayloadSize(bufReader)
112-
if err != nil {
113-
return err
114-
}
115-
116-
hp.Type = PayloadTLV
106+
// If the HopPayload is guaranteed to be a TLV payload, we can skip the
107+
// check for the legacy payload byte.
108+
if !hp.TLVPayloadGuaranteed && isLegacyPayloadByte(peekByte[0]) {
109+
return decodeLegacyHopPayload(hp, bufReader)
117110
}
118111

119-
// Now that we know the payload size, we'll create a new buffer to
120-
// read it out in full.
121-
//
122-
// TODO(roasbeef): can avoid all these copies
123-
hp.Payload = make([]byte, payloadSize)
124-
if _, err := io.ReadFull(bufReader, hp.Payload[:]); err != nil {
125-
return err
126-
}
127-
if _, err := io.ReadFull(bufReader, hp.HMAC[:]); err != nil {
128-
return err
129-
}
130-
131-
return nil
112+
return decodeTLVHopPayload(hp, bufReader)
132113
}
133114

134115
// HopData attempts to extract a set of forwarding instructions from the target
@@ -146,6 +127,42 @@ func (hp *HopPayload) HopData() (*HopData, error) {
146127
return nil, nil
147128
}
148129

130+
// readPayloadAndHMAC reads the payload and HMAC from the reader into the
131+
// HopPayload.
132+
func readPayloadAndHMAC(hp *HopPayload, r io.Reader, payloadSize uint16) error {
133+
// Now that we know the payload size, we'll create a new buffer to read
134+
// it out in full.
135+
hp.Payload = make([]byte, payloadSize)
136+
if _, err := io.ReadFull(r, hp.Payload[:]); err != nil {
137+
return err
138+
}
139+
if _, err := io.ReadFull(r, hp.HMAC[:]); err != nil {
140+
return err
141+
}
142+
143+
return nil
144+
}
145+
146+
// decodeTLVHopPayload decodes a TLV hop payload from the passed reader.
147+
func decodeTLVHopPayload(hp *HopPayload, r io.Reader) error {
148+
payloadSize, err := tlvPayloadSize(r)
149+
if err != nil {
150+
return err
151+
}
152+
153+
hp.Type = PayloadTLV
154+
155+
return readPayloadAndHMAC(hp, r, payloadSize)
156+
}
157+
158+
// decodeLegacyHopPayload decodes a legacy hop payload from the passed reader.
159+
func decodeLegacyHopPayload(hp *HopPayload, r io.Reader) error {
160+
payloadSize := legacyPayloadSize()
161+
hp.Type = PayloadLegacy
162+
163+
return readPayloadAndHMAC(hp, r, payloadSize)
164+
}
165+
149166
// tlvPayloadSize uses the passed reader to extract the payload length encoded
150167
// as a var-int.
151168
func tlvPayloadSize(r io.Reader) (uint16, error) {
@@ -314,8 +331,12 @@ func legacyNumBytes() int {
314331
return LegacyHopDataSize
315332
}
316333

317-
// isLegacyPayload returns true if the given byte is equal to the 0x00 byte
318-
// which indicates that the payload should be decoded as a legacy payload.
334+
// isLegacyPayloadByte determines if the first byte of a hop payload indicates
335+
// that it is a legacy payload. The first byte of a legacy payload will always
336+
// be 0x00, as this is the realm. For TLV payloads, the first byte is a
337+
// var-int encoding the length of the payload. A TLV stream can be empty, in
338+
// which case its length is 0, which is also encoded as a 0x00 byte. This
339+
// creates an ambiguity between a legacy payload and an empty TLV payload.
319340
func isLegacyPayloadByte(b byte) bool {
320341
return b == 0x00
321342
}

sphinx.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ func (r *Router) Stop() {
510510
// processOnionCfg is a set of config values that can be used to modify how an
511511
// onion is processed.
512512
type processOnionCfg struct {
513-
blindingPoint *btcec.PublicKey
513+
blindingPoint *btcec.PublicKey
514+
isOnionMessage bool
514515
}
515516

516517
// ProcessOnionOpt defines the signature of a function option that can be used
@@ -525,6 +526,14 @@ func WithBlindingPoint(point *btcec.PublicKey) ProcessOnionOpt {
525526
}
526527
}
527528

529+
// WithIsOnionMessage is a functional option that signals that the onion packet
530+
// being processed is from onion message.
531+
func WithIsOnionMessage() ProcessOnionOpt {
532+
return func(cfg *processOnionCfg) {
533+
cfg.isOnionMessage = true
534+
}
535+
}
536+
528537
// ProcessOnionPacket processes an incoming onion packet which has been forward
529538
// to the target Sphinx router. If the encoded ephemeral key isn't on the
530539
// target Elliptic Curve, then the packet is rejected. Similarly, if the
@@ -560,7 +569,9 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte,
560569
// Continue to optimistically process this packet, deferring replay
561570
// protection until the end to reduce the penalty of multiple IO
562571
// operations.
563-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
572+
packet, err := processOnionPacket(
573+
onionPkt, &sharedSecret, assocData, cfg.isOnionMessage,
574+
)
564575
if err != nil {
565576
return nil, err
566577
}
@@ -594,7 +605,9 @@ func (r *Router) ReconstructOnionPacket(onionPkt *OnionPacket, assocData []byte,
594605
return nil, err
595606
}
596607

597-
return processOnionPacket(onionPkt, &sharedSecret, assocData)
608+
return processOnionPacket(
609+
onionPkt, &sharedSecret, assocData, cfg.isOnionMessage,
610+
)
598611
}
599612

600613
// DecryptBlindedHopData uses the router's private key to decrypt data encrypted
@@ -625,7 +638,8 @@ func (r *Router) OnionPublicKey() *btcec.PublicKey {
625638
// packet. This function returns the next inner onion packet layer, along with
626639
// the hop data extracted from the outer onion packet.
627640
func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
628-
assocData []byte) (*OnionPacket, *HopPayload, error) {
641+
assocData []byte, isOnionMessage bool) (*OnionPacket, *HopPayload,
642+
error) {
629643

630644
dhKey := onionPkt.EphemeralKey
631645
routeInfo := onionPkt.RoutingInfo
@@ -660,8 +674,9 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
660674
// With the MAC checked, and the payload decrypted, we can now parse
661675
// out the payload so we can derive the specified forwarding
662676
// instructions.
663-
var hopPayload HopPayload
664-
if err := hopPayload.Decode(bytes.NewReader(hopInfo[:])); err != nil {
677+
hopPayload := HopPayload{TLVPayloadGuaranteed: isOnionMessage}
678+
err := hopPayload.Decode(bytes.NewReader(hopInfo[:]))
679+
if err != nil {
665680
return nil, nil, err
666681
}
667682

@@ -683,7 +698,7 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
683698
// packets. The processed packets returned from this method should only be used
684699
// if the packet was not flagged as a replayed packet.
685700
func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
686-
assocData []byte) (*ProcessedPacket, error) {
701+
assocData []byte, isOnionMessage bool) (*ProcessedPacket, error) {
687702

688703
// First, we'll unwrap an initial layer of the onion packet. Typically,
689704
// we'll only have a single layer to unwrap, However, if the sender has
@@ -693,7 +708,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
693708
// they can properly check the HMAC and unwrap a layer for their
694709
// handoff hop.
695710
innerPkt, outerHopPayload, err := unwrapPacket(
696-
onionPkt, sharedSecret, assocData,
711+
onionPkt, sharedSecret, assocData, isOnionMessage,
697712
)
698713
if err != nil {
699714
return nil, err
@@ -703,7 +718,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
703718
// However if the uncovered 'nextMac' is all zeroes, then this
704719
// indicates that we're the final hop in the route.
705720
var action ProcessCode = MoreHops
706-
if bytes.Compare(zeroHMAC[:], outerHopPayload.HMAC[:]) == 0 {
721+
if bytes.Equal(zeroHMAC[:], outerHopPayload.HMAC[:]) {
707722
action = ExitNode
708723
}
709724

@@ -794,7 +809,9 @@ func (t *Tx) ProcessOnionPacket(seqNum uint16, onionPkt *OnionPacket,
794809
// Continue to optimistically process this packet, deferring replay
795810
// protection until the end to reduce the penalty of multiple IO
796811
// operations.
797-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
812+
packet, err := processOnionPacket(
813+
onionPkt, &sharedSecret, assocData, cfg.isOnionMessage,
814+
)
798815
if err != nil {
799816
return err
800817
}

sphinx_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,60 @@ func TestTLVPayloadMessagePacket(t *testing.T) {
288288
hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes()))
289289
}
290290

291+
// TestProcessOnionMessageZeroLengthPayload tests that we can properly process an
292+
// onion message that has a zero-length payload.
293+
func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
294+
t.Parallel()
295+
296+
// First, create a router that will be the destination of the onion
297+
// message.
298+
privKey, err := btcec.NewPrivateKey()
299+
require.NoError(t, err)
300+
301+
router := NewRouter(&PrivKeyECDH{privKey}, NewMemoryReplayLog())
302+
err = router.Start()
303+
require.NoError(t, err)
304+
defer router.Stop()
305+
306+
// Next, create a session key for the onion packet.
307+
sessionKey, err := btcec.NewPrivateKey()
308+
require.NoError(t, err)
309+
310+
// We'll create a simple one-hop path.
311+
path := &PaymentPath{
312+
{
313+
NodePub: *privKey.PubKey(),
314+
},
315+
}
316+
317+
// The hop payload will be an empty TLV payload.
318+
payload, err := NewTLVHopPayload(nil)
319+
require.NoError(t, err)
320+
path[0].HopPayload = payload
321+
322+
// Now, create the onion packet.
323+
onionPacket, err := NewOnionPacket(
324+
path, sessionKey, nil, DeterministicPacketFiller,
325+
)
326+
require.NoError(t, err)
327+
328+
// We'll now process the packet, making sure to indicate that this is
329+
// an onion message.
330+
processedPacket, err := router.ProcessOnionPacket(
331+
onionPacket, nil, 0, WithIsOnionMessage(),
332+
)
333+
require.NoError(t, err)
334+
335+
// The packet should be decoded as an exit node.
336+
require.EqualValues(t, ExitNode, processedPacket.Action)
337+
338+
// The payload should be of type TLV.
339+
require.Equal(t, PayloadTLV, processedPacket.Payload.Type)
340+
341+
// And the payload should be empty.
342+
require.Empty(t, processedPacket.Payload.Payload)
343+
}
344+
291345
func TestSphinxCorrectness(t *testing.T) {
292346
nodes, _, hopDatas, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
293347
if err != nil {

0 commit comments

Comments
 (0)