@@ -23,8 +23,10 @@ import fr.acinq.eclair.wire.protocol._
23
23
import grizzled .slf4j .Logging
24
24
import scodec .Attempt
25
25
import scodec .bits .ByteVector
26
+ import scodec .codecs .uint32
26
27
27
28
import scala .annotation .tailrec
29
+ import scala .concurrent .duration .{DurationLong , FiniteDuration }
28
30
import scala .util .{Failure , Success , Try }
29
31
30
32
/**
@@ -282,24 +284,28 @@ object Sphinx extends Logging {
282
284
*/
283
285
case class CannotDecryptFailurePacket (unwrapped : ByteVector )
284
286
287
+ case class HoldTime (duration : FiniteDuration , remoteNodeId : PublicKey )
288
+
289
+ case class HtlcFailure (holdTimes : Seq [HoldTime ], failure : Either [CannotDecryptFailurePacket , DecryptedFailurePacket ])
290
+
285
291
object FailurePacket {
286
292
287
293
/**
288
- * Create a failure packet that will be returned to the sender.
294
+ * Create a failure packet that needs to be wrapped before being returned to the sender.
289
295
* Each intermediate hop will add a layer of encryption and forward to the previous hop.
290
296
* Note that malicious intermediate hops may drop the packet or alter it (which breaks the mac).
291
297
*
292
298
* @param sharedSecret destination node's shared secret that was computed when the original onion for the HTLC
293
299
* was created or forwarded: see OnionPacket.create() and OnionPacket.wrap().
294
300
* @param failure failure message.
295
- * @return a failure packet that can be sent to the destination node.
301
+ * @return a failure packet that still needs to be wrapped before being sent to the destination node.
296
302
*/
297
303
def create (sharedSecret : ByteVector32 , failure : FailureMessage ): ByteVector = {
298
304
val um = generateKey(" um" , sharedSecret)
299
305
val packet = FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).encode(failure).require.toByteVector
300
306
logger.debug(s " um key: $um" )
301
307
logger.debug(s " raw error packet: ${packet.toHex}" )
302
- wrap( packet, sharedSecret)
308
+ packet
303
309
}
304
310
305
311
/**
@@ -322,25 +328,108 @@ object Sphinx extends Logging {
322
328
* it was sent by the corresponding node.
323
329
* Note that malicious nodes in the route may have altered the packet, triggering a decryption failure.
324
330
*
325
- * @param packet failure packet.
326
- * @param sharedSecrets nodes shared secrets.
331
+ * @param packet failure packet.
332
+ * @param attribution_opt attribution data for this failure packet.
333
+ * @param sharedSecrets nodes shared secrets.
327
334
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
328
335
* failure packet otherwise.
329
336
*/
330
- @ tailrec
331
- def decrypt (packet : ByteVector , sharedSecrets : Seq [SharedSecret ]): Either [CannotDecryptFailurePacket , DecryptedFailurePacket ] = {
337
+ def decrypt (packet : ByteVector , attribution_opt : Option [ByteVector ], sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): HtlcFailure = {
332
338
sharedSecrets match {
333
- case Nil => Left (CannotDecryptFailurePacket (packet))
339
+ case Nil => HtlcFailure ( Nil , Left (CannotDecryptFailurePacket (packet) ))
334
340
case ss :: tail =>
335
341
val packet1 = wrap(packet, ss.secret)
342
+ val attribution1_opt = attribution_opt.flatMap(Attribution .unwrap(_, packet1, ss.secret, hopIndex))
336
343
val um = generateKey(" um" , ss.secret)
337
- FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).decode(packet1.toBitVector) match {
338
- case Attempt .Successful (value) => Right (DecryptedFailurePacket (ss.remoteNodeId, value.value))
339
- case _ => decrypt(packet1, tail)
344
+ val HtlcFailure (downstreamHoldTimes, failure) = FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).decode(packet1.toBitVector) match {
345
+ case Attempt .Successful (value) => HtlcFailure ( Nil , Right (DecryptedFailurePacket (ss.remoteNodeId, value.value) ))
346
+ case _ => decrypt(packet1, attribution1_opt.map(_._2), tail, hopIndex + 1 )
340
347
}
348
+ HtlcFailure (attribution1_opt.map(n => HoldTime (n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil ), failure)
341
349
}
342
350
}
343
351
352
+ /**
353
+ * Attribution data is added to the failure packet and prevents a node from evading responsibility for its failures.
354
+ * Nodes that relay attribution data can prove that they are not the erring node and in case the erring node tries
355
+ * to hide, there will only be at most two nodes that can be the erring node (the last one to send attribution data
356
+ * and the one after it).
357
+ * It also adds timing data for each node on the path.
358
+ * https://github.com/lightning/bolts/pull/1044
359
+ */
360
+ object Attribution {
361
+ val maxNumHops = 20
362
+ val holdTimeLength = 4
363
+ val hmacLength = 4 // HMACs are truncated to 4 bytes to save space
364
+ val totalLength = maxNumHops * holdTimeLength + maxNumHops * (maxNumHops + 1 ) / 2 * hmacLength // = 920
365
+
366
+ private def cipher (bytes : ByteVector , sharedSecret : ByteVector32 ): ByteVector = {
367
+ val key = generateKey(" ammagext" , sharedSecret)
368
+ val stream = generateStream(key, totalLength)
369
+ bytes xor stream
370
+ }
371
+
372
+ /**
373
+ * Get the HMACs from the attribution data.
374
+ * The layout of the attribution data is as follows (using maxNumHops = 3 for conciseness):
375
+ * holdTime(0) ++ holdTime(1) ++ holdTime(2) ++
376
+ * hmacs(0)(0) ++ hmacs(0)(1) ++ hmacs(0)(2) ++
377
+ * hmacs(1)(0) ++ hmacs(1)(1) ++
378
+ * hmacs(2)(0)
379
+ *
380
+ * Where `hmac(i)(j)` is the hmac added by node `i` (counted from the node that built the attribution data),
381
+ * assuming it is `maxNumHops - 1 - i - j` hops away from the erring node.
382
+ */
383
+ private def getHmacs (bytes : ByteVector ): Seq [Seq [ByteVector ]] =
384
+ (0 until maxNumHops).map(i => (0 until (maxNumHops - i)).map(j => {
385
+ val start = maxNumHops * holdTimeLength + (maxNumHops * i - (i * (i - 1 )) / 2 + j) * hmacLength
386
+ bytes.slice(start, start + hmacLength)
387
+ }))
388
+
389
+ /**
390
+ * Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
391
+ * HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
392
+ */
393
+ private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
394
+ (minNumHop until maxNumHops).map(i => {
395
+ val y = maxNumHops - i
396
+ mac.mac(failurePacket ++
397
+ holdTimes.take(y * holdTimeLength) ++
398
+ ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(hmacLength)
399
+ })
400
+ }
401
+
402
+ /**
403
+ * Create attribution data to send with the failure packet
404
+ *
405
+ * @param failurePacket the failure packet before being wrapped
406
+ */
407
+ def create (previousAttribution_opt : Option [ByteVector ], failurePacket : ByteVector , holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
408
+ val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(totalLength))
409
+ val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
410
+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
411
+ val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
412
+ val hmacs = computeHmacs(mac, failurePacket, holdTimes, previousHmacs, 0 ) +: previousHmacs
413
+ cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
414
+ }
415
+
416
+ /**
417
+ * Unwrap one hop of attribution data
418
+ * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
419
+ */
420
+ def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
421
+ val bytes = cipher(encrypted, sharedSecret)
422
+ val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
423
+ val hmacs = getHmacs(bytes)
424
+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
425
+ if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop) == hmacs.head.drop(minNumHop)) {
426
+ val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
427
+ Some (holdTime, unwrapped)
428
+ } else {
429
+ None
430
+ }
431
+ }
432
+ }
344
433
}
345
434
346
435
/**
0 commit comments