@@ -20,8 +20,8 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
20
20
import fr .acinq .bitcoin .scalacompat .{ByteVector32 , Crypto }
21
21
import fr .acinq .eclair .wire .protocol ._
22
22
import grizzled .slf4j .Logging
23
+ import scodec .Attempt
23
24
import scodec .bits .ByteVector
24
- import scodec .{Attempt , Codec }
25
25
26
26
import scala .annotation .tailrec
27
27
import scala .collection .mutable .ArrayBuffer
@@ -343,14 +343,14 @@ object Sphinx extends Logging {
343
343
private val payloadAndPadLength = 256
344
344
private val hopPayloadLength = 9
345
345
private val maxNumHop = 27
346
- private val codec : Codec [ FatError ] = fatErrorCodec(payloadAndPadLength, hopPayloadLength, maxNumHop)
346
+ private val totalLength = 12599
347
347
348
348
def create (sharedSecret : ByteVector32 , failure : FailureMessage ): ByteVector = {
349
349
val failurePayload = FailureMessageCodecs .failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
350
350
val hopPayload = HopPayload (ErrorSource , 0 millis)
351
351
val zeroPayloads = Seq .fill(maxNumHop)(ByteVector .fill(hopPayloadLength)(0 ))
352
352
val zeroHmacs = (maxNumHop.to(1 , - 1 )).map(Seq .fill(_)(ByteVector32 .Zeroes ))
353
- val plainError = codec .encode(FatError (failurePayload, zeroPayloads, zeroHmacs)).require.bytes
353
+ val plainError = fatErrorCodec(totalLength, hopPayloadLength, maxNumHop) .encode(FatError (failurePayload, zeroPayloads, zeroHmacs)).require.bytes
354
354
wrap(plainError, sharedSecret, hopPayload).get
355
355
}
356
356
@@ -366,32 +366,32 @@ object Sphinx extends Logging {
366
366
367
367
def wrap (errorPacket : ByteVector , sharedSecret : ByteVector32 , hopPayload : HopPayload ): Try [ByteVector ] = Try {
368
368
val um = generateKey(" um" , sharedSecret)
369
- val error = codec .decode(errorPacket.bits).require.value
369
+ val error = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop) .decode(errorPacket.bits).require.value
370
370
val hopPayloads = hopPayloadCodec.encode(hopPayload).require.bytes +: error.hopPayloads.dropRight(1 )
371
- val hmacs = computeHmacs(Hmac256 (um), error.failurePayload, hopPayloads, error.hmacs, 0 ) +: error.hmacs.dropRight(1 ).map(_.drop(1 ))
372
- val newError = codec .encode(FatError (error.failurePayload, hopPayloads, hmacs)).require.bytes
371
+ val hmacs = computeHmacs(Hmac256 (um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop( 1 )) , 0 ) +: error.hmacs.dropRight(1 ).map(_.drop(1 ))
372
+ val newError = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop) .encode(FatError (error.failurePayload, hopPayloads, hmacs)).require.bytes
373
373
val key = generateKey(" ammag" , sharedSecret)
374
374
val stream = generateStream(key, newError.length.toInt)
375
375
newError xor stream
376
376
}
377
377
378
- private def unwrap (errorPacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Try [(ByteVector , HopPayload )] = Try {
378
+ def unwrap (errorPacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Try [(ByteVector , HopPayload )] = Try {
379
379
val key = generateKey(" ammag" , sharedSecret)
380
380
val stream = generateStream(key, errorPacket.length.toInt)
381
- val error = codec .decode((errorPacket xor stream).bits).require.value
381
+ val error = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop) .decode((errorPacket xor stream).bits).require.value
382
382
val um = generateKey(" um" , sharedSecret)
383
383
val shiftedHmacs = error.hmacs.tail.map(ByteVector32 .Zeroes +: _) :+ Seq (ByteVector32 .Zeroes )
384
- val hmacs = computeHmacs(Hmac256 (um), error.failurePayload, error.hopPayloads, shiftedHmacs , minNumHop)
384
+ val hmacs = computeHmacs(Hmac256 (um), error.failurePayload, error.hopPayloads, error.hmacs.tail , minNumHop)
385
385
require(hmacs == error.hmacs.head.drop(minNumHop), " Invalid HMAC" )
386
386
val shiftedHopPayloads = error.hopPayloads.tail :+ ByteVector .fill(hopPayloadLength)(0 )
387
387
val unwrapedError = FatError (error.failurePayload, shiftedHopPayloads, shiftedHmacs)
388
- (codec .encode(unwrapedError).require.bytes,
388
+ (fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop) .encode(unwrapedError).require.bytes,
389
389
hopPayloadCodec.decode(error.hopPayloads.head.bits).require.value)
390
390
}
391
391
392
392
def decrypt (errorPacket : ByteVector , sharedSecrets : Seq [(ByteVector32 , PublicKey )]): Either [InvalidFatErrorPacket , DecryptedFailurePacket ] = {
393
393
var packet = errorPacket
394
- var minNumHop = 1
394
+ var minNumHop = 0
395
395
val hopPayloads = ArrayBuffer .empty[(PublicKey , HopPayload )]
396
396
for ((sharedSecret, nodeId) <- sharedSecrets) {
397
397
unwrap(packet, sharedSecret, minNumHop) match {
@@ -403,7 +403,7 @@ object Sphinx extends Logging {
403
403
minNumHop += 1
404
404
hopPayloads += ((nodeId, hopPayload))
405
405
case FatError .ErrorSource =>
406
- val failurePayload = codec .decode(unwrapedPacket.bits).require.value.failurePayload
406
+ val failurePayload = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop) .decode(unwrapedPacket.bits).require.value.failurePayload
407
407
FailureMessageCodecs .failureOnionPayload(payloadAndPadLength).decode(failurePayload.bits) match {
408
408
case Attempt .Successful (failureMessage) =>
409
409
return Right (DecryptedFailurePacket (nodeId, failureMessage.value))
0 commit comments