Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,10 @@ object Sphinx extends Logging {
/**
* The downstream failure could not be decrypted.
*
* @param unwrapped encrypted failure packet after unwrapping using our shared secrets.
* @param unwrapped encrypted failure packet after unwrapping using our shared secrets.
* @param attribution_opt attribution data after unwrapping using our shared secrets
*/
case class CannotDecryptFailurePacket(unwrapped: ByteVector)
case class CannotDecryptFailurePacket(unwrapped: ByteVector, attribution_opt: Option[ByteVector])

case class HoldTime(duration: FiniteDuration, remoteNodeId: PublicKey)

Expand Down Expand Up @@ -336,7 +337,7 @@ object Sphinx extends Logging {
*/
def decrypt(packet: ByteVector, attribution_opt: Option[ByteVector], sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): HtlcFailure = {
sharedSecrets match {
case Nil => HtlcFailure(Nil, Left(CannotDecryptFailurePacket(packet)))
case Nil => HtlcFailure(Nil, Left(CannotDecryptFailurePacket(packet, attribution_opt)))
case ss :: tail =>
val packet1 = wrap(packet, ss.secret)
val attribution1_opt = attribution_opt.flatMap(Attribution.unwrap(_, packet1, ss.secret, hopIndex))
Expand Down Expand Up @@ -432,17 +433,20 @@ object Sphinx extends Logging {
}
}

case class UnwrappedAttribution(holdTimes: List[HoldTime], remaining_opt: Option[ByteVector])

/**
* Decrypt the hold times from the attribution data of a fulfilled HTLC
*/
def fulfillHoldTimes(attribution: ByteVector, sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): List[HoldTime] = {
def fulfillHoldTimes(attribution: ByteVector, sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): UnwrappedAttribution = {
sharedSecrets match {
case Nil => Nil
case Nil => UnwrappedAttribution(Nil, Some(attribution))
case ss :: tail =>
unwrap(attribution, ByteVector.empty, ss.secret, hopIndex) match {
case Some((holdTime, nextAttribution)) =>
HoldTime(holdTime, ss.remoteNodeId) :: fulfillHoldTimes(nextAttribution, tail, hopIndex + 1)
case None => Nil
val UnwrappedAttribution(holdTimes, remaining_opt) = fulfillHoldTimes(nextAttribution, tail, hopIndex + 1)
UnwrappedAttribution(HoldTime(holdTime, ss.remoteNodeId) :: holdTimes, remaining_opt)
case None => UnwrappedAttribution(Nil, None)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
rs.getByteVector32FromHex("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
Seq(part))
Seq(part),
None)
}
sentByParentId + (parentId -> sent)
}.values.toSeq.sortBy(_.timestamp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging {
rs.getByteVector32("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVector("recipient_node_id")),
Seq(part))
Seq(part),
None)
}
sentByParentId + (parentId -> sent)
}.values.toSeq.sortBy(_.timestamp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ sealed trait PaymentEvent {
/**
* A payment was successfully sent and fulfilled.
*
* @param id id of the whole payment attempt (if using multi-part, there will be multiple parts, each with
* a different id).
* @param paymentHash payment hash.
* @param paymentPreimage payment preimage (proof of payment).
* @param recipientAmount amount that has been received by the final recipient.
* @param recipientNodeId id of the final recipient.
* @param parts child payments (actual outgoing HTLCs).
* @param id id of the whole payment attempt (if using multi-part, there will be multiple parts,
* each with a different id).
* @param paymentHash payment hash.
* @param paymentPreimage payment preimage (proof of payment).
* @param recipientAmount amount that has been received by the final recipient.
* @param recipientNodeId id of the final recipient.
* @param parts child payments (actual outgoing HTLCs).
* @param remainingAttribution_opt for relayed trampoline payments, the attribution data that needs to be sent upstream
*/
case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, recipientAmount: MilliSatoshi, recipientNodeId: PublicKey, parts: Seq[PaymentSent.PartialPayment]) extends PaymentEvent {
case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, recipientAmount: MilliSatoshi, recipientNodeId: PublicKey, parts: Seq[PaymentSent.PartialPayment], remainingAttribution_opt: Option[ByteVector]) extends PaymentEvent {
require(parts.nonEmpty, "must have at least one payment part")
val amountWithFees: MilliSatoshi = parts.map(_.amountWithFees).sum
val feesPaid: MilliSatoshi = amountWithFees - recipientAmount // overall fees for this payment
Expand Down Expand Up @@ -151,7 +152,7 @@ case class LocalFailure(amount: MilliSatoshi, route: Seq[Hop], t: Throwable) ext
case class RemoteFailure(amount: MilliSatoshi, route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure

/** A remote node failed the payment but we couldn't decrypt the failure (e.g. a malicious node tampered with the message). */
case class UnreadableRemoteFailure(amount: MilliSatoshi, route: Seq[Hop], failurePacket: ByteVector, holdTimes: Seq[HoldTime]) extends PaymentFailure
case class UnreadableRemoteFailure(amount: MilliSatoshi, route: Seq[Hop], e: Sphinx.CannotDecryptFailurePacket, holdTimes: Seq[HoldTime]) extends PaymentFailure

object PaymentFailure {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Alias, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, InitFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32}
import scodec.bits.ByteVector

import java.util.UUID
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -374,11 +375,11 @@ class NodeRelay private(nodeParams: NodeParams,
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
// this is the fulfill that arrives from downstream channels
case WrappedPreimageReceived(PreimageReceived(_, paymentPreimage)) =>
case WrappedPreimageReceived(PreimageReceived(_, paymentPreimage, attribution_opt)) =>
if (!fulfilledUpstream) {
// We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream).
context.log.debug("got preimage from downstream")
fulfillPayment(upstream, paymentPreimage)
fulfillPayment(upstream, paymentPreimage, attribution_opt)
sending(upstream, recipient, walletNodeId_opt, recipientFeatures_opt, nextPayload, startedAt, fulfilledUpstream = true)
} else {
// we don't want to fulfill multiple times
Expand Down Expand Up @@ -491,16 +492,15 @@ class NodeRelay private(nodeParams: NodeParams,
upstream.received.foreach(r => rejectHtlc(r.add.id, r.add.channelId, upstream.amountIn, r.receivedAt, failure))
}

private def fulfillPayment(upstream: Upstream.Hot.Trampoline, paymentPreimage: ByteVector32): Unit = upstream.received.foreach(r => {
// TODO: process downstream attribution data
val cmd = CMD_FULFILL_HTLC(r.add.id, paymentPreimage, None, Some(r.receivedAt), commit = true)
private def fulfillPayment(upstream: Upstream.Hot.Trampoline, paymentPreimage: ByteVector32, downstreamAttribution_opt: Option[ByteVector]): Unit = upstream.received.foreach(r => {
val cmd = CMD_FULFILL_HTLC(r.add.id, paymentPreimage, downstreamAttribution_opt, Some(r.receivedAt), commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, r.add.channelId, cmd)
})

private def success(upstream: Upstream.Hot.Trampoline, fulfilledUpstream: Boolean, paymentSent: PaymentSent): Unit = {
// We may have already fulfilled upstream, but we can now emit an accurate relayed event and clean-up resources.
if (!fulfilledUpstream) {
fulfillPayment(upstream, paymentSent.paymentPreimage)
fulfillPayment(upstream, paymentSent.paymentPreimage, paymentSent.remainingAttribution_opt)
}
val incoming = upstream.received.map(r => PaymentRelayed.IncomingPart(r.add.amountMsat, r.add.channelId, r.receivedAt))
val outgoing = paymentSent.parts.map(part => PaymentRelayed.OutgoingPart(part.amountWithFees, part.toChannelId, part.timestamp))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ object OnTheFlyFunding {
// In the trampoline case, we currently ignore downstream failures: we should add dedicated failures to
// the BOLTs to better handle those cases.
Sphinx.FailurePacket.decrypt(f.packet, f.attribution_opt, onionSharedSecrets).failure match {
case Left(Sphinx.CannotDecryptFailurePacket(_)) =>
case Left(Sphinx.CannotDecryptFailurePacket(_, _)) =>
log.warning("couldn't decrypt downstream on-the-fly funding failure")
case Right(f) =>
log.warning("downstream on-the-fly funding failure: {}", f.failureMessage.message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
val feesPaid = 0.msat // fees are unknown since we lost the reference to the payment
nodeParams.db.payments.getOutgoingPayment(id) match {
case Some(p) =>
nodeParams.db.payments.updateOutgoingPayment(PaymentSent(p.parentId, fulfilledHtlc.paymentHash, paymentPreimage, p.recipientAmount, p.recipientNodeId, PaymentSent.PartialPayment(id, fulfilledHtlc.amountMsat, feesPaid, fulfilledHtlc.channelId, None) :: Nil))
nodeParams.db.payments.updateOutgoingPayment(PaymentSent(p.parentId, fulfilledHtlc.paymentHash, paymentPreimage, p.recipientAmount, p.recipientNodeId, PaymentSent.PartialPayment(id, fulfilledHtlc.amountMsat, feesPaid, fulfilledHtlc.channelId, None) :: Nil, None))
// If all downstream HTLCs are now resolved, we can emit the payment event.
val payments = nodeParams.db.payments.listOutgoingPayments(p.parentId)
if (!payments.exists(p => p.status == OutgoingPaymentStatus.Pending)) {
val succeeded = payments.collect {
case OutgoingPayment(id, _, _, _, _, amount, _, _, _, _, _, OutgoingPaymentStatus.Succeeded(_, feesPaid, _, completedAt)) =>
PaymentSent.PartialPayment(id, amount, feesPaid, ByteVector32.Zeroes, None, completedAt)
}
val sent = PaymentSent(p.parentId, fulfilledHtlc.paymentHash, paymentPreimage, p.recipientAmount, p.recipientNodeId, succeeded)
val sent = PaymentSent(p.parentId, fulfilledHtlc.paymentHash, paymentPreimage, p.recipientAmount, p.recipientNodeId, succeeded, None)
log.info(s"payment id=${sent.id} paymentHash=${sent.paymentHash} successfully sent (amount=${sent.recipientAmount})")
context.system.eventStream.publish(sent)
}
Expand All @@ -196,7 +196,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
val dummyFinalAmount = fulfilledHtlc.amountMsat
val dummyNodeId = nodeParams.nodeId
nodeParams.db.payments.addOutgoingPayment(OutgoingPayment(id, id, None, fulfilledHtlc.paymentHash, PaymentType.Standard, fulfilledHtlc.amountMsat, dummyFinalAmount, dummyNodeId, TimestampMilli.now(), None, None, OutgoingPaymentStatus.Pending))
nodeParams.db.payments.updateOutgoingPayment(PaymentSent(id, fulfilledHtlc.paymentHash, paymentPreimage, dummyFinalAmount, dummyNodeId, PaymentSent.PartialPayment(id, fulfilledHtlc.amountMsat, feesPaid, fulfilledHtlc.channelId, None) :: Nil))
nodeParams.db.payments.updateOutgoingPayment(PaymentSent(id, fulfilledHtlc.paymentHash, paymentPreimage, dummyFinalAmount, dummyNodeId, PaymentSent.PartialPayment(id, fulfilledHtlc.amountMsat, feesPaid, fulfilledHtlc.channelId, None) :: Nil, None))
}
// There can never be more than one pending downstream HTLC for a given local origin (a multi-part payment is
// instead spread across multiple local origins) so we can now forget this origin.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.{FSMDiagnosticActorLogging, Logs, MilliSatoshiLong, NodeParams, TimestampMilli}
import scodec.bits.ByteVector

import java.util.UUID
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -118,7 +119,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
case Event(ps: PaymentSent, d: PaymentProgress) =>
require(ps.parts.length == 1, "child payment must contain only one part")
// As soon as we get the preimage we can consider that the whole payment succeeded (we have a proof of payment).
gotoSucceededOrStop(PaymentSucceeded(d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id))
gotoSucceededOrStop(PaymentSucceeded(d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id, ps.remainingAttribution_opt))
}

when(PAYMENT_IN_PROGRESS) {
Expand All @@ -144,7 +145,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
require(ps.parts.length == 1, "child payment must contain only one part")
// As soon as we get the preimage we can consider that the whole payment succeeded (we have a proof of payment).
Metrics.PaymentAttempt.withTag(Tags.MultiPart, value = true).record(d.request.maxAttempts - d.remainingAttempts)
gotoSucceededOrStop(PaymentSucceeded(d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id))
gotoSucceededOrStop(PaymentSucceeded(d.request, ps.paymentPreimage, ps.parts, d.pending.keySet - ps.parts.head.id, ps.remainingAttribution_opt))
}

when(PAYMENT_ABORTED) {
Expand All @@ -162,7 +163,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
case Event(ps: PaymentSent, d: PaymentAborted) =>
require(ps.parts.length == 1, "child payment must contain only one part")
log.warning(s"payment recipient fulfilled incomplete multi-part payment (id=${ps.parts.head.id})")
gotoSucceededOrStop(PaymentSucceeded(d.request, ps.paymentPreimage, ps.parts, d.pending - ps.parts.head.id))
gotoSucceededOrStop(PaymentSucceeded(d.request, ps.paymentPreimage, ps.parts, d.pending - ps.parts.head.id, ps.remainingAttribution_opt))

case Event(_: RouteResponse, _) => stay()
case Event(_: PaymentRouteNotFound, _) => stay()
Expand All @@ -174,7 +175,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
val parts = d.parts ++ ps.parts
val pending = d.pending - ps.parts.head.id
if (pending.isEmpty) {
myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, parts)))
myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, parts, d.remainingAttribution_opt)))
} else {
stay() using d.copy(parts = parts, pending = pending)
}
Expand All @@ -185,7 +186,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
log.warning(s"payment succeeded but partial payment failed (id=${pf.id})")
val pending = d.pending - pf.id
if (pending.isEmpty) {
myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, d.parts)))
myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, d.parts, d.remainingAttribution_opt)))
} else {
stay() using d.copy(pending = pending)
}
Expand All @@ -212,10 +213,10 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,

private def gotoSucceededOrStop(d: PaymentSucceeded): State = {
if (publishPreimage) {
d.request.replyTo ! PreimageReceived(paymentHash, d.preimage)
d.request.replyTo ! PreimageReceived(paymentHash, d.preimage, d.remainingAttribution_opt)
}
if (d.pending.isEmpty) {
myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, d.parts)))
myStop(d.request, Right(cfg.createPaymentSent(d.request.recipient, d.preimage, d.parts, d.remainingAttribution_opt)))
} else
goto(PAYMENT_SUCCEEDED) using d
}
Expand Down Expand Up @@ -310,7 +311,7 @@ object MultiPartPaymentLifecycle {
* The payment FSM will wait for all child payments to settle before emitting payment events, but the preimage will be
* shared as soon as it's received to unblock other actors that may need it.
*/
case class PreimageReceived(paymentHash: ByteVector32, paymentPreimage: ByteVector32)
case class PreimageReceived(paymentHash: ByteVector32, paymentPreimage: ByteVector32, remainingAttribution_opt: Option[ByteVector])

// @formatter:off
sealed trait State
Expand Down Expand Up @@ -367,7 +368,7 @@ object MultiPartPaymentLifecycle {
* @param parts fulfilled child payments.
* @param pending pending child payments (we are waiting for them to be fulfilled downstream).
*/
case class PaymentSucceeded(request: SendMultiPartPayment, preimage: ByteVector32, parts: Seq[PartialPayment], pending: Set[UUID]) extends Data
case class PaymentSucceeded(request: SendMultiPartPayment, preimage: ByteVector32, parts: Seq[PartialPayment], pending: Set[UUID], remainingAttribution_opt: Option[ByteVector]) extends Data

private def createRouteRequest(replyTo: ActorRef, nodeParams: NodeParams, routeParams: RouteParams, d: PaymentProgress, cfg: SendPaymentConfig): RouteRequest = {
RouteRequest(replyTo.toTyped, nodeParams.nodeId, d.request.recipient, routeParams, d.ignore, allowMultiPart = true, d.pending.values.toSeq, Some(cfg.paymentContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import fr.acinq.eclair.payment.send.PaymentError._
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams}
import scodec.bits.ByteVector

import java.util.UUID

Expand Down Expand Up @@ -335,7 +336,7 @@ object PaymentInitiator {
case _ => PaymentType.Standard
}

def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts)
def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment], remainingAttribution_opt: Option[ByteVector]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts, remainingAttribution_opt)
}

}
Loading