Skip to content

Commit 1e54913

Browse files
committed
test vector
1 parent 3ffbf6d commit 1e54913

File tree

7 files changed

+70
-25
lines changed

7 files changed

+70
-25
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
2020
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
2121
import fr.acinq.eclair.wire.protocol._
2222
import grizzled.slf4j.Logging
23+
import scodec.Attempt
2324
import scodec.bits.ByteVector
24-
import scodec.{Attempt, Codec}
2525

2626
import scala.annotation.tailrec
2727
import scala.collection.mutable.ArrayBuffer
@@ -343,14 +343,14 @@ object Sphinx extends Logging {
343343
private val payloadAndPadLength = 256
344344
private val hopPayloadLength = 9
345345
private val maxNumHop = 27
346-
private val codec: Codec[FatError] = fatErrorCodec(payloadAndPadLength, hopPayloadLength, maxNumHop)
346+
private val totalLength = 12599
347347

348348
def create(sharedSecret: ByteVector32, failure: FailureMessage): ByteVector = {
349349
val failurePayload = FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
350350
val hopPayload = HopPayload(ErrorSource, 0 millis)
351351
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.fill(hopPayloadLength)(0))
352352
val zeroHmacs = (maxNumHop.to(1, -1)).map(Seq.fill(_)(ByteVector32.Zeroes))
353-
val plainError = codec.encode(FatError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
353+
val plainError = fatErrorCodec(totalLength, hopPayloadLength, maxNumHop).encode(FatError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
354354
wrap(plainError, sharedSecret, hopPayload).get
355355
}
356356

@@ -366,32 +366,32 @@ object Sphinx extends Logging {
366366

367367
def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, hopPayload: HopPayload): Try[ByteVector] = Try {
368368
val um = generateKey("um", sharedSecret)
369-
val error = codec.decode(errorPacket.bits).require.value
369+
val error = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits).require.value
370370
val hopPayloads = hopPayloadCodec.encode(hopPayload).require.bytes +: error.hopPayloads.dropRight(1)
371-
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs, 0) +: error.hmacs.dropRight(1).map(_.drop(1))
372-
val newError = codec.encode(FatError(error.failurePayload, hopPayloads, hmacs)).require.bytes
371+
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop(1)), 0) +: error.hmacs.dropRight(1).map(_.drop(1))
372+
val newError = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(FatError(error.failurePayload, hopPayloads, hmacs)).require.bytes
373373
val key = generateKey("ammag", sharedSecret)
374374
val stream = generateStream(key, newError.length.toInt)
375375
newError xor stream
376376
}
377377

378-
private def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
378+
def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
379379
val key = generateKey("ammag", sharedSecret)
380380
val stream = generateStream(key, errorPacket.length.toInt)
381-
val error = codec.decode((errorPacket xor stream).bits).require.value
381+
val error = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode((errorPacket xor stream).bits).require.value
382382
val um = generateKey("um", sharedSecret)
383383
val shiftedHmacs = error.hmacs.tail.map(ByteVector32.Zeroes +: _) :+ Seq(ByteVector32.Zeroes)
384-
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, error.hopPayloads, shiftedHmacs, minNumHop)
384+
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, error.hopPayloads, error.hmacs.tail, minNumHop)
385385
require(hmacs == error.hmacs.head.drop(minNumHop), "Invalid HMAC")
386386
val shiftedHopPayloads = error.hopPayloads.tail :+ ByteVector.fill(hopPayloadLength)(0)
387387
val unwrapedError = FatError(error.failurePayload, shiftedHopPayloads, shiftedHmacs)
388-
(codec.encode(unwrapedError).require.bytes,
388+
(fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(unwrapedError).require.bytes,
389389
hopPayloadCodec.decode(error.hopPayloads.head.bits).require.value)
390390
}
391391

392392
def decrypt(errorPacket: ByteVector, sharedSecrets: Seq[(ByteVector32, PublicKey)]): Either[InvalidFatErrorPacket, DecryptedFailurePacket] = {
393393
var packet = errorPacket
394-
var minNumHop = 1
394+
var minNumHop = 0
395395
val hopPayloads = ArrayBuffer.empty[(PublicKey, HopPayload)]
396396
for ((sharedSecret, nodeId) <- sharedSecrets) {
397397
unwrap(packet, sharedSecret, minNumHop) match {
@@ -403,7 +403,7 @@ object Sphinx extends Logging {
403403
minNumHop += 1
404404
hopPayloads += ((nodeId, hopPayload))
405405
case FatError.ErrorSource =>
406-
val failurePayload = codec.decode(unwrapedPacket.bits).require.value.failurePayload
406+
val failurePayload = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(unwrapedPacket.bits).require.value.failurePayload
407407
FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).decode(failurePayload.bits) match {
408408
case Attempt.Successful(failureMessage) =>
409409
return Right(DecryptedFailurePacket(nodeId, failureMessage.value))

eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/FatError.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ object FatError {
5252
.xmap(pair => pair._1 +: pair._2, seq => (seq.head, seq.tail))
5353
}
5454

55-
def fatErrorCodec(payloadAndPadLength: Int = 256, hopPayloadLength: Int = 9, maxHop: Int = 27): Codec[FatError] = (
56-
("failure_payload" | bytes(payloadAndPadLength + 4)) ::
57-
("hop_payloads" | listOfN(provide(maxHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
58-
("hmacs" | hmacsCodec(maxHop))).as[FatError].complete
55+
def fatErrorCodec(totalLength: Int, hopPayloadLength: Int, maxNumHop: Int): Codec[FatError] = {
56+
val metadataLength = maxNumHop * hopPayloadLength + (maxNumHop * (maxNumHop + 1)) / 2 * 32
57+
(("failure_payload" | bytes(totalLength - metadataLength)) ::
58+
("hop_payloads" | listOfN(provide(maxNumHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
59+
("hmacs" | hmacsCodec(maxNumHop))).as[FatError].complete}
5960
}

eclair-core/src/test/resources/fat_error.json

Lines changed: 25 additions & 0 deletions
Large diffs are not rendered by default.

eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedRoute, BlindedRouteDe
2222
import fr.acinq.eclair.wire.protocol
2323
import fr.acinq.eclair.wire.protocol._
2424
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomBytes, randomKey}
25+
import org.json4s.JsonAST._
26+
import org.json4s.jackson.JsonMethods
2527
import org.scalatest.funsuite.AnyFunSuite
2628
import scodec.bits._
2729

30+
import java.io.File
2831
import scala.concurrent.duration.DurationInt
32+
import scala.io.Source
2933
import scala.util.Success
3034

3135
/**
@@ -444,6 +448,21 @@ class SphinxSpec extends AnyFunSuite {
444448
assert(decryptionError == expected)
445449
}
446450

451+
test("fat error test vector") {
452+
val src = Source.fromFile(new File(getClass.getResource(s"/fat_error.json").getFile))
453+
try {
454+
val testVector = JsonMethods.parse(src.mkString).asInstanceOf[JObject].values
455+
val hops = testVector("hops").asInstanceOf[List[Map[String, String]]]
456+
val sharedSecrets = hops.map(hop => ByteVector32(ByteVector.fromValidHex(hop("sharedSecret"))))
457+
val encryptedMessages = hops.map(hop => ByteVector.fromValidHex(hop("encryptedMessage")))
458+
val nodeIds = (1 to 5).map(_ => randomKey().publicKey)
459+
//println(FatErrorPacket.unwrap(encryptedMessages(0), sharedSecrets(0), 0))
460+
//println(FatErrorPacket.decrypt(encryptedMessages.head, sharedSecrets.zip(nodeIds)))
461+
} finally {
462+
src.close()
463+
}
464+
}
465+
447466
test("create blinded route (reference test vector)") {
448467
val alice = PrivateKey(hex"4141414141414141414141414141414141414141414141414141414141414141")
449468
val bob = PrivateKey(hex"4242424242424242424242424242424242424242424242424242424242424242")

eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
378378
val failures = Seq(
379379
LocalFailure(finalAmount, Nil, ChannelUnavailable(randomBytes32())),
380380
RemoteFailure(finalAmount, Nil, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48))))),
381-
UnreadableRemoteFailure(finalAmount, Nil)
381+
UnreadableRemoteFailure(finalAmount, Nil, ???)
382382
)
383383
val extraEdges1 = Seq(
384384
BasicEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), BasicEdge(b, c, ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48)),
@@ -412,14 +412,14 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
412412
childPayFsm.expectMsgType[SendPaymentToRoute]
413413

414414
val (failedId1, failedRoute1) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head
415-
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops))))
415+
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops, ???))))
416416
router.expectMsgType[RouteRequest]
417417
router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ad :: hop_de :: Nil))))
418418
childPayFsm.expectMsgType[SendPaymentToRoute]
419419

420420
assert(!payFsm.stateData.asInstanceOf[PaymentProgress].pending.contains(failedId1))
421421
val (failedId2, failedRoute2) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head
422-
val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(UnreadableRemoteFailure(failedRoute2.amount, failedRoute2.hops))))
422+
val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(UnreadableRemoteFailure(failedRoute2.amount, failedRoute2.hops, ???))))
423423
assert(result.failures.length >= 3)
424424
assert(result.failures.contains(LocalFailure(finalAmount, Nil, RetryExhausted)))
425425

@@ -508,7 +508,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
508508
childPayFsm.expectMsgType[SendPaymentToRoute]
509509

510510
val (failedId1, failedRoute1) :: (failedId2, failedRoute2) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq
511-
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops))))
511+
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops, ???))))
512512
router.expectMsgType[RouteRequest]
513513

514514
val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(RemoteFailure(failedRoute2.amount, failedRoute2.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout)))))
@@ -526,7 +526,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
526526
childPayFsm.expectMsgType[SendPaymentToRoute]
527527

528528
val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq
529-
childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops))))
529+
childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops, ???))))
530530
router.expectMsgType[RouteRequest]
531531

532532
val result = fulfillPendingPayments(f, 1)

eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
327327
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)))) // unparsable message
328328

329329
// we allow 2 tries, so we send a 2nd request to the router
330-
assert(sender.expectMsgType[PaymentFailed].failures == UnreadableRemoteFailure(route.amount, route.hops) :: UnreadableRemoteFailure(route.amount, route.hops) :: Nil)
330+
assert(sender.expectMsgType[PaymentFailed].failures == UnreadableRemoteFailure(route.amount, route.hops, ???) :: UnreadableRemoteFailure(route.amount, route.hops, ???) :: Nil)
331331
awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) // after last attempt the payment is failed
332332

333333
val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics]
@@ -794,8 +794,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
794794
(RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(c, UnknownNextPeer)), Set.empty, Set(ChannelDesc(scid_cd, c, d))),
795795
(RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, update_bc))), Set.empty, Set.empty),
796796
// unreadable remote failures -> blacklist all nodes except our direct peer and the final recipient
797-
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil), Set.empty, Set.empty),
798-
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, null) :: Nil), Set(c, d), Set.empty)
797+
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil, ???), Set.empty, Set.empty),
798+
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, null) :: Nil, ???), Set(c, d), Set.empty)
799799
)
800800

801801
for ((failure, expectedNodes, expectedChannels) <- testCases) {

eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
615615
val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef]
616616
router.expectMessageType[RouteRequest]
617617

618-
val failures = RemoteFailure(outgoingAmount, Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, FinalIncorrectHtlcAmount(42 msat))) :: UnreadableRemoteFailure(outgoingAmount, Nil) :: Nil
618+
val failures = RemoteFailure(outgoingAmount, Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, FinalIncorrectHtlcAmount(42 msat))) :: UnreadableRemoteFailure(outgoingAmount, Nil, ???) :: Nil
619619
payFSM ! PaymentFailed(relayId, incomingMultiPart.head.add.paymentHash, failures)
620620

621621
incomingMultiPart.foreach { p =>

0 commit comments

Comments
 (0)