Skip to content

Commit 5d467fd

Browse files
committed
multi: decode zero-length onion message payloads
Since the onion message payload can be zero-length, we need to decode it correctly. This commit adds a boolean flag to the HopPayload Decode that tells whether the payload is an onion message payload or not. If it is, the payload is decoded as a tlv payload also if the first byte is 0x00.
1 parent 255c9b4 commit 5d467fd

File tree

2 files changed

+77
-39
lines changed

2 files changed

+77
-39
lines changed

payload.go

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ func (hp *HopPayload) Encode(w io.Writer) error {
8787
}
8888

8989
// Decode unpacks an encoded HopPayload from the passed reader into the target
90-
// HopPayload.
91-
func (hp *HopPayload) Decode(r io.Reader) error {
90+
// HopPayload. The isMessage boolean should be set to true if we're parsing a
91+
// payload that is known to be for an onion message.
92+
func (hp *HopPayload) Decode(r io.Reader, isMessage bool) error {
9293
bufReader := bufio.NewReader(r)
9394

9495
// In order to properly parse the payload, we'll need to check the
@@ -99,36 +100,16 @@ func (hp *HopPayload) Decode(r io.Reader) error {
99100
return err
100101
}
101102

102-
var (
103-
legacyPayload = isLegacyPayloadByte(peekByte[0])
104-
payloadSize uint16
105-
)
106-
107-
if legacyPayload {
108-
payloadSize = legacyPayloadSize()
109-
hp.Type = PayloadLegacy
110-
} else {
111-
payloadSize, err = tlvPayloadSize(bufReader)
112-
if err != nil {
113-
return err
114-
}
115-
116-
hp.Type = PayloadTLV
103+
// Per BOLT 7, onion messages MUST use the TLV format.
104+
if isMessage {
105+
return decodeTLVHopPayload(hp, bufReader)
117106
}
118107

119-
// Now that we know the payload size, we'll create a new buffer to
120-
// read it out in full.
121-
//
122-
// TODO(roasbeef): can avoid all these copies
123-
hp.Payload = make([]byte, payloadSize)
124-
if _, err := io.ReadFull(bufReader, hp.Payload[:]); err != nil {
125-
return err
126-
}
127-
if _, err := io.ReadFull(bufReader, hp.HMAC[:]); err != nil {
128-
return err
108+
if isLegacyPayloadByte(peekByte[0]) {
109+
return decodeLegacyHopPayload(hp, bufReader)
129110
}
130111

131-
return nil
112+
return decodeTLVHopPayload(hp, bufReader)
132113
}
133114

134115
// HopData attempts to extract a set of forwarding instructions from the target
@@ -146,6 +127,42 @@ func (hp *HopPayload) HopData() (*HopData, error) {
146127
return nil, nil
147128
}
148129

130+
// readPayloadAndHMAC reads the payload and HMAC from the reader into the
131+
// HopPayload.
132+
func readPayloadAndHMAC(hp *HopPayload, r io.Reader, payloadSize uint16) error {
133+
// Now that we know the payload size, we'll create a new buffer to read
134+
// it out in full.
135+
hp.Payload = make([]byte, payloadSize)
136+
if _, err := io.ReadFull(r, hp.Payload[:]); err != nil {
137+
return err
138+
}
139+
if _, err := io.ReadFull(r, hp.HMAC[:]); err != nil {
140+
return err
141+
}
142+
143+
return nil
144+
}
145+
146+
// decodeTLVHopPayload decodes a TLV hop payload from the passed reader.
147+
func decodeTLVHopPayload(hp *HopPayload, r io.Reader) error {
148+
payloadSize, err := tlvPayloadSize(r)
149+
if err != nil {
150+
return err
151+
}
152+
153+
hp.Type = PayloadTLV
154+
155+
return readPayloadAndHMAC(hp, r, payloadSize)
156+
}
157+
158+
// decodeLegacyHopPayload decodes a legacy hop payload from the passed reader.
159+
func decodeLegacyHopPayload(hp *HopPayload, r io.Reader) error {
160+
payloadSize := legacyPayloadSize()
161+
hp.Type = PayloadLegacy
162+
163+
return readPayloadAndHMAC(hp, r, payloadSize)
164+
}
165+
149166
// tlvPayloadSize uses the passed reader to extract the payload length encoded
150167
// as a var-int.
151168
func tlvPayloadSize(r io.Reader) (uint16, error) {
@@ -314,8 +331,12 @@ func legacyNumBytes() int {
314331
return LegacyHopDataSize
315332
}
316333

317-
// isLegacyPayload returns true if the given byte is equal to the 0x00 byte
318-
// which indicates that the payload should be decoded as a legacy payload.
334+
// isLegacyPayloadByte determines if the first byte of a hop payload indicates
335+
// that it is a legacy payload. The first byte of a legacy payload will always
336+
// be 0x00, as this is the realm. For TLV payloads, the first byte is a
337+
// var-int encoding the length of the payload. A TLV stream can be empty, in
338+
// which case its length is 0, which is also encoded as a 0x00 byte. This
339+
// creates an ambiguity between a legacy payload and an empty TLV payload.
319340
func isLegacyPayloadByte(b byte) bool {
320341
return b == 0x00
321342
}

sphinx.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ func (r *Router) Stop() {
510510
// processOnionCfg is a set of config values that can be used to modify how an
511511
// onion is processed.
512512
type processOnionCfg struct {
513-
blindingPoint *btcec.PublicKey
513+
blindingPoint *btcec.PublicKey
514+
isOnionMessage bool
514515
}
515516

516517
// ProcessOnionOpt defines the signature of a function option that can be used
@@ -525,6 +526,14 @@ func WithBlindingPoint(point *btcec.PublicKey) ProcessOnionOpt {
525526
}
526527
}
527528

529+
// WithIsOnionMessage is a functional option that signals that the onion packet
530+
// being processed is an onion message.
531+
func WithIsOnionMessage() ProcessOnionOpt {
532+
return func(cfg *processOnionCfg) {
533+
cfg.isOnionMessage = true
534+
}
535+
}
536+
528537
// ProcessOnionPacket processes an incoming onion packet which has been forward
529538
// to the target Sphinx router. If the encoded ephemeral key isn't on the
530539
// target Elliptic Curve, then the packet is rejected. Similarly, if the
@@ -560,7 +569,9 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte,
560569
// Continue to optimistically process this packet, deferring replay
561570
// protection until the end to reduce the penalty of multiple IO
562571
// operations.
563-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
572+
packet, err := processOnionPacket(
573+
onionPkt, &sharedSecret, assocData, cfg.isOnionMessage,
574+
)
564575
if err != nil {
565576
return nil, err
566577
}
@@ -594,7 +605,9 @@ func (r *Router) ReconstructOnionPacket(onionPkt *OnionPacket, assocData []byte,
594605
return nil, err
595606
}
596607

597-
return processOnionPacket(onionPkt, &sharedSecret, assocData)
608+
return processOnionPacket(
609+
onionPkt, &sharedSecret, assocData, cfg.isOnionMessage,
610+
)
598611
}
599612

600613
// DecryptBlindedHopData uses the router's private key to decrypt data encrypted
@@ -625,7 +638,8 @@ func (r *Router) OnionPublicKey() *btcec.PublicKey {
625638
// packet. This function returns the next inner onion packet layer, along with
626639
// the hop data extracted from the outer onion packet.
627640
func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
628-
assocData []byte) (*OnionPacket, *HopPayload, error) {
641+
assocData []byte, isOnionMessage bool) (*OnionPacket, *HopPayload,
642+
error) {
629643

630644
dhKey := onionPkt.EphemeralKey
631645
routeInfo := onionPkt.RoutingInfo
@@ -661,7 +675,8 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
661675
// out the payload so we can derive the specified forwarding
662676
// instructions.
663677
var hopPayload HopPayload
664-
if err := hopPayload.Decode(bytes.NewReader(hopInfo[:])); err != nil {
678+
err := hopPayload.Decode(bytes.NewReader(hopInfo[:]), isOnionMessage)
679+
if err != nil {
665680
return nil, nil, err
666681
}
667682

@@ -683,7 +698,7 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
683698
// packets. The processed packets returned from this method should only be used
684699
// if the packet was not flagged as a replayed packet.
685700
func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
686-
assocData []byte) (*ProcessedPacket, error) {
701+
assocData []byte, isOnionMessage bool) (*ProcessedPacket, error) {
687702

688703
// First, we'll unwrap an initial layer of the onion packet. Typically,
689704
// we'll only have a single layer to unwrap, However, if the sender has
@@ -693,7 +708,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
693708
// they can properly check the HMAC and unwrap a layer for their
694709
// handoff hop.
695710
innerPkt, outerHopPayload, err := unwrapPacket(
696-
onionPkt, sharedSecret, assocData,
711+
onionPkt, sharedSecret, assocData, isOnionMessage,
697712
)
698713
if err != nil {
699714
return nil, err
@@ -703,7 +718,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
703718
// However if the uncovered 'nextMac' is all zeroes, then this
704719
// indicates that we're the final hop in the route.
705720
var action ProcessCode = MoreHops
706-
if bytes.Compare(zeroHMAC[:], outerHopPayload.HMAC[:]) == 0 {
721+
if bytes.Equal(zeroHMAC[:], outerHopPayload.HMAC[:]) {
707722
action = ExitNode
708723
}
709724

@@ -794,7 +809,9 @@ func (t *Tx) ProcessOnionPacket(seqNum uint16, onionPkt *OnionPacket,
794809
// Continue to optimistically process this packet, deferring replay
795810
// protection until the end to reduce the penalty of multiple IO
796811
// operations.
797-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
812+
packet, err := processOnionPacket(
813+
onionPkt, &sharedSecret, assocData, cfg.isOnionMessage,
814+
)
798815
if err != nil {
799816
return err
800817
}

0 commit comments

Comments
 (0)