@@ -26,7 +26,7 @@ import scodec.bits.ByteVector
26
26
import scodec .codecs .uint32
27
27
28
28
import scala .annotation .tailrec
29
- import scala .concurrent .duration .FiniteDuration
29
+ import scala .concurrent .duration .{ DurationLong , FiniteDuration }
30
30
import scala .util .{Failure , Success , Try }
31
31
32
32
/**
@@ -284,6 +284,10 @@ object Sphinx extends Logging {
284
284
*/
285
285
case class CannotDecryptFailurePacket (unwrapped : ByteVector )
286
286
287
+ case class HoldTime (duration : FiniteDuration , remoteNodeId : PublicKey )
288
+
289
+ case class HtlcFailure (holdTimes : Seq [HoldTime ], failure : Either [CannotDecryptFailurePacket , DecryptedFailurePacket ])
290
+
287
291
object FailurePacket {
288
292
289
293
/**
@@ -336,41 +340,64 @@ object Sphinx extends Logging {
336
340
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
337
341
* failure packet otherwise.
338
342
*/
339
- @ tailrec
340
- def decrypt (packet : ByteVector , sharedSecrets : Seq [SharedSecret ]): Either [CannotDecryptFailurePacket , DecryptedFailurePacket ] = {
343
+ def decrypt (packet : ByteVector , attribution_opt : Option [ByteVector ], sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): HtlcFailure = {
341
344
sharedSecrets match {
342
- case Nil => Left (CannotDecryptFailurePacket (packet))
345
+ case Nil => HtlcFailure ( Nil , Left (CannotDecryptFailurePacket (packet) ))
343
346
case ss :: tail =>
344
347
val packet1 = wrap(packet, ss.secret)
348
+ val next_opt = attribution_opt.flatMap(Attribution .unwrap(_, packet1, ss.secret, hopIndex))
345
349
val um = generateKey(" um" , ss.secret)
346
- FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).decode(packet1.toBitVector) match {
347
- case Attempt .Successful (value) => Right (DecryptedFailurePacket (ss.remoteNodeId, value.value))
348
- case _ => decrypt(packet1, tail)
350
+ val HtlcFailure (holdTimes, failure) = FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).decode(packet1.toBitVector) match {
351
+ case Attempt .Successful (value) => HtlcFailure ( Nil , Right (DecryptedFailurePacket (ss.remoteNodeId, value.value) ))
352
+ case _ => decrypt(packet1, next_opt.map(_._2), tail, hopIndex + 1 )
349
353
}
354
+ HtlcFailure (next_opt.map(n => HoldTime (n._1, ss.remoteNodeId) +: holdTimes).getOrElse(Nil ), failure)
350
355
}
351
356
}
352
357
353
- def attribution (previousAttribution_opt : Option [ByteVector ], reason : ByteVector , holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
354
- val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(920 ))
355
- val previousHmacs = (0 until 19 ).map(i => (1 until (20 - i)).map(j => {
356
- val start = 80 + (20 * i - (i * (i - 1 )) / 2 + j) * 4
357
- previousAttribution.slice(start, start + 4 )
358
- }))
359
- val mac = Hmac256 (generateKey(" um" , sharedSecret))
360
- val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take(19 * 4 )
361
- val hmacs = computeHmacs(mac, reason, holdTimes, previousHmacs) +: previousHmacs
362
- val key = generateKey(" ammagext" , sharedSecret)
363
- val stream = generateStream(key, 920 )
364
- (holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_)))) xor stream
365
- }
358
+ object Attribution {
359
+ private def cipher (bytes : ByteVector , sharedSecret : ByteVector32 ): ByteVector = {
360
+ val key = generateKey(" ammagext" , sharedSecret)
361
+ val stream = generateStream(key, 920 )
362
+ bytes xor stream
363
+ }
364
+
365
+ private def getHmacs (bytes : ByteVector ): Seq [Seq [ByteVector ]] =
366
+ (0 until 20 ).map(i => (0 until (20 - i)).map(j => {
367
+ val start = (20 + 20 * i - (i * (i - 1 )) / 2 + j) * 4
368
+ bytes.slice(start, start + 4 )
369
+ }))
370
+
371
+ private def computeHmacs (mac : Mac32 , reason : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
372
+ (minNumHop until 20 ).map(i => {
373
+ val y = 20 - i
374
+ mac.mac(reason ++
375
+ holdTimes.take(y * 4 ) ++
376
+ ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(4 )
377
+ })
378
+ }
379
+
380
+ def create (previousAttribution_opt : Option [ByteVector ], reason : ByteVector , holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
381
+ val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(920 ))
382
+ val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
383
+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
384
+ val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take(19 * 4 )
385
+ val hmacs = computeHmacs(mac, reason, holdTimes, previousHmacs, 0 ) +: previousHmacs
386
+ cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
387
+ }
366
388
367
- private def computeHmacs (mac : Mac32 , reason : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]]): Seq [ByteVector ] = {
368
- (0 until 20 ).map(i => {
369
- val y = 20 - i
370
- mac.mac(reason ++
371
- holdTimes.take(y * 4 ) ++
372
- ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(4 )
373
- })
389
+ def unwrap (encrypted : ByteVector , reason : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
390
+ val bytes = cipher(encrypted, sharedSecret)
391
+ val holdTime = uint32.decode(bytes.take(4 ).bits).require.value.milliseconds
392
+ val hmacs = getHmacs(bytes)
393
+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
394
+ if (computeHmacs(mac, reason, bytes.take(20 * 4 ), hmacs.drop(1 ), minNumHop) == hmacs.head.drop(minNumHop)) {
395
+ val unwraped = bytes.slice(4 , 20 * 4 ) ++ ByteVector .low(4 ) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(4 ) ++ ByteVector .concat(s)))
396
+ Some (holdTime, unwraped)
397
+ } else {
398
+ None
399
+ }
400
+ }
374
401
}
375
402
}
376
403
0 commit comments