@@ -26,6 +26,7 @@ import fr.acinq.eclair._
2626import fr .acinq .eclair .blockchain .fee .{ConfirmationTarget , FeeratePerKw }
2727import fr .acinq .eclair .transactions .CommitmentOutput ._
2828import fr .acinq .eclair .transactions .Scripts ._
29+ import fr .acinq .eclair .transactions .Transactions .InputInfo .SegwitInput
2930import fr .acinq .eclair .wire .protocol .UpdateAddHtlc
3031import scodec .bits .ByteVector
3132
@@ -102,14 +103,18 @@ object Transactions {
102103 val publicKeyScript : ByteVector = Script .write(Script .pay2tr(internalKey, Some (scriptTree)))
103104 }
104105
105- case class InputInfo (outPoint : OutPoint , txOut : TxOut , redeemScriptOrScriptTree : Either [ByteVector , ScriptTreeAndInternalKey ]) {
106- val redeemScriptOrEmptyScript : ByteVector = redeemScriptOrScriptTree.swap.getOrElse(ByteVector .empty) // TODO: use the actual script tree for taproot transactions, once we implement them
106+ sealed trait InputInfo {
107+ val outPoint : OutPoint
108+ val txOut : TxOut
107109 }
108110
109111 object InputInfo {
110- def apply (outPoint : OutPoint , txOut : TxOut , redeemScript : ByteVector ) = new InputInfo (outPoint, txOut, Left (redeemScript))
111- def apply (outPoint : OutPoint , txOut : TxOut , redeemScript : Seq [ScriptElt ]) = new InputInfo (outPoint, txOut, Left (Script .write(redeemScript)))
112- def apply (outPoint : OutPoint , txOut : TxOut , scriptTree : ScriptTreeAndInternalKey ) = new InputInfo (outPoint, txOut, Right (scriptTree))
112+ case class SegwitInput (outPoint : OutPoint , txOut : TxOut , redeemScript : ByteVector ) extends InputInfo
113+ case class TaprootInput (outPoint : OutPoint , txOut : TxOut , scriptTreeAndInternalKey : ScriptTreeAndInternalKey ) extends InputInfo
114+
115+ def apply (outPoint : OutPoint , txOut : TxOut , redeemScript : ByteVector ): SegwitInput = SegwitInput (outPoint, txOut, redeemScript)
116+ def apply (outPoint : OutPoint , txOut : TxOut , redeemScript : Seq [ScriptElt ]): SegwitInput = SegwitInput (outPoint, txOut, Script .write(redeemScript))
117+ def apply (outPoint : OutPoint , txOut : TxOut , scriptTree : ScriptTreeAndInternalKey ): TaprootInput = TaprootInput (outPoint, txOut, scriptTree)
113118 }
114119
115120 /** Owner of a given transaction (local/remote). */
@@ -138,24 +143,29 @@ object Transactions {
138143 sign(key, sighash(txOwner, commitmentFormat))
139144 }
140145
141- def sign (key : PrivateKey , sighashType : Int ): ByteVector64 = {
142- // NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the
143- // signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL.
144- val inputIndex = tx.txIn.indexWhere(_.outPoint == input.outPoint)
145- val sigDER = Transaction .signInput(tx, inputIndex, input.redeemScriptOrEmptyScript, sighashType, input.txOut.amount, SIGVERSION_WITNESS_V0 , key)
146- val sig64 = Crypto .der2compact(sigDER)
147- sig64
146+ def sign (key : PrivateKey , sighashType : Int ): ByteVector64 = input match {
147+ case _:InputInfo .TaprootInput => ByteVector64 .Zeroes
148+ case InputInfo .SegwitInput (outPoint, txOut, redeemScript) =>
149+ // NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the
150+ // signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL.
151+ val inputIndex = tx.txIn.indexWhere(_.outPoint == outPoint)
152+ val sigDER = Transaction .signInput(tx, inputIndex, redeemScript, sighashType, txOut.amount, SIGVERSION_WITNESS_V0 , key)
153+ val sig64 = Crypto .der2compact(sigDER)
154+ sig64
148155 }
149156
150- def checkSig (sig : ByteVector64 , pubKey : PublicKey , txOwner : TxOwner , commitmentFormat : CommitmentFormat ): Boolean = {
151- val sighash = this .sighash(txOwner, commitmentFormat)
152- val inputIndex = tx.txIn.indexWhere(_.outPoint == input.outPoint)
153- if (inputIndex >= 0 ) {
154- val data = Transaction .hashForSigning(tx, inputIndex, input.redeemScriptOrEmptyScript, sighash, input.txOut.amount, SIGVERSION_WITNESS_V0 )
155- Crypto .verifySignature(data, sig, pubKey)
156- } else {
157- false
158- }
157+ def checkSig (sig : ByteVector64 , pubKey : PublicKey , txOwner : TxOwner , commitmentFormat : CommitmentFormat ): Boolean = input match {
158+
159+ case _:InputInfo .TaprootInput => false
160+ case InputInfo .SegwitInput (outPoint, txOut, redeemScript) =>
161+ val sighash = this .sighash(txOwner, commitmentFormat)
162+ val inputIndex = tx.txIn.indexWhere(_.outPoint == outPoint)
163+ if (inputIndex >= 0 ) {
164+ val data = Transaction .hashForSigning(tx, inputIndex, redeemScript, sighash, txOut.amount, SIGVERSION_WITNESS_V0 )
165+ Crypto .verifySignature(data, sig, pubKey)
166+ } else {
167+ false
168+ }
159169 }
160170 }
161171
@@ -983,64 +993,86 @@ object Transactions {
983993 commitTx.copy(tx = commitTx.tx.updateWitness(0 , witness))
984994 }
985995
986- def addSigs (mainPenaltyTx : MainPenaltyTx , revocationSig : ByteVector64 ): MainPenaltyTx = {
987- val witness = Scripts .witnessToLocalDelayedWithRevocationSig(revocationSig, mainPenaltyTx.input.redeemScriptOrEmptyScript)
988- mainPenaltyTx.copy(tx = mainPenaltyTx.tx.updateWitness(0 , witness))
996+ def addSigs (mainPenaltyTx : MainPenaltyTx , revocationSig : ByteVector64 ): MainPenaltyTx = mainPenaltyTx.input match {
997+ case InputInfo .SegwitInput (_, _, redeemScript) =>
998+ val witness = Scripts .witnessToLocalDelayedWithRevocationSig(revocationSig, redeemScript)
999+ mainPenaltyTx.copy(tx = mainPenaltyTx.tx.updateWitness(0 , witness))
1000+ case _ => mainPenaltyTx
9891001 }
9901002
991- def addSigs (htlcPenaltyTx : HtlcPenaltyTx , revocationSig : ByteVector64 , revocationPubkey : PublicKey ): HtlcPenaltyTx = {
992- val witness = Scripts .witnessHtlcWithRevocationSig(revocationSig, revocationPubkey, htlcPenaltyTx.input.redeemScriptOrEmptyScript)
993- htlcPenaltyTx.copy(tx = htlcPenaltyTx.tx.updateWitness(0 , witness))
1003+ def addSigs (htlcPenaltyTx : HtlcPenaltyTx , revocationSig : ByteVector64 , revocationPubkey : PublicKey ): HtlcPenaltyTx = htlcPenaltyTx.input match {
1004+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1005+ val witness = Scripts .witnessHtlcWithRevocationSig(revocationSig, revocationPubkey, redeemScript)
1006+ htlcPenaltyTx.copy(tx = htlcPenaltyTx.tx.updateWitness(0 , witness))
1007+ case _ => htlcPenaltyTx
9941008 }
9951009
996- def addSigs (htlcSuccessTx : HtlcSuccessTx , localSig : ByteVector64 , remoteSig : ByteVector64 , paymentPreimage : ByteVector32 , commitmentFormat : CommitmentFormat ): HtlcSuccessTx = {
997- val witness = witnessHtlcSuccess(localSig, remoteSig, paymentPreimage, htlcSuccessTx.input.redeemScriptOrEmptyScript, commitmentFormat)
998- htlcSuccessTx.copy(tx = htlcSuccessTx.tx.updateWitness(0 , witness))
1010+ def addSigs (htlcSuccessTx : HtlcSuccessTx , localSig : ByteVector64 , remoteSig : ByteVector64 , paymentPreimage : ByteVector32 , commitmentFormat : CommitmentFormat ): HtlcSuccessTx = htlcSuccessTx.input match {
1011+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1012+ val witness = witnessHtlcSuccess(localSig, remoteSig, paymentPreimage, redeemScript, commitmentFormat)
1013+ htlcSuccessTx.copy(tx = htlcSuccessTx.tx.updateWitness(0 , witness))
1014+ case _ => htlcSuccessTx
9991015 }
10001016
1001- def addSigs (htlcTimeoutTx : HtlcTimeoutTx , localSig : ByteVector64 , remoteSig : ByteVector64 , commitmentFormat : CommitmentFormat ): HtlcTimeoutTx = {
1002- val witness = witnessHtlcTimeout(localSig, remoteSig, htlcTimeoutTx.input.redeemScriptOrEmptyScript, commitmentFormat)
1003- htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0 , witness))
1017+ def addSigs (htlcTimeoutTx : HtlcTimeoutTx , localSig : ByteVector64 , remoteSig : ByteVector64 , commitmentFormat : CommitmentFormat ): HtlcTimeoutTx = htlcTimeoutTx.input match {
1018+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1019+ val witness = witnessHtlcTimeout(localSig, remoteSig, redeemScript, commitmentFormat)
1020+ htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0 , witness))
1021+ case _ => htlcTimeoutTx
10041022 }
10051023
1006- def addSigs (claimHtlcSuccessTx : ClaimHtlcSuccessTx , localSig : ByteVector64 , paymentPreimage : ByteVector32 ): ClaimHtlcSuccessTx = {
1007- val witness = witnessClaimHtlcSuccessFromCommitTx(localSig, paymentPreimage, claimHtlcSuccessTx.input.redeemScriptOrEmptyScript)
1008- claimHtlcSuccessTx.copy(tx = claimHtlcSuccessTx.tx.updateWitness(0 , witness))
1024+ def addSigs (claimHtlcSuccessTx : ClaimHtlcSuccessTx , localSig : ByteVector64 , paymentPreimage : ByteVector32 ): ClaimHtlcSuccessTx = claimHtlcSuccessTx.input match {
1025+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1026+ val witness = witnessClaimHtlcSuccessFromCommitTx(localSig, paymentPreimage, redeemScript)
1027+ claimHtlcSuccessTx.copy(tx = claimHtlcSuccessTx.tx.updateWitness(0 , witness))
1028+ case _ => claimHtlcSuccessTx
10091029 }
10101030
1011- def addSigs (claimHtlcTimeoutTx : ClaimHtlcTimeoutTx , localSig : ByteVector64 ): ClaimHtlcTimeoutTx = {
1012- val witness = witnessClaimHtlcTimeoutFromCommitTx(localSig, claimHtlcTimeoutTx.input.redeemScriptOrEmptyScript)
1013- claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0 , witness))
1031+ def addSigs (claimHtlcTimeoutTx : ClaimHtlcTimeoutTx , localSig : ByteVector64 ): ClaimHtlcTimeoutTx = claimHtlcTimeoutTx.input match {
1032+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1033+ val witness = witnessClaimHtlcTimeoutFromCommitTx(localSig, redeemScript)
1034+ claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0 , witness))
1035+ case _ => claimHtlcTimeoutTx
10141036 }
10151037
10161038 def addSigs (claimP2WPKHOutputTx : ClaimP2WPKHOutputTx , localPaymentPubkey : PublicKey , localSig : ByteVector64 ): ClaimP2WPKHOutputTx = {
10171039 val witness = ScriptWitness (Seq (der(localSig), localPaymentPubkey.value))
10181040 claimP2WPKHOutputTx.copy(tx = claimP2WPKHOutputTx.tx.updateWitness(0 , witness))
10191041 }
10201042
1021- def addSigs (claimRemoteDelayedOutputTx : ClaimRemoteDelayedOutputTx , localSig : ByteVector64 ): ClaimRemoteDelayedOutputTx = {
1022- val witness = witnessClaimToRemoteDelayedFromCommitTx(localSig, claimRemoteDelayedOutputTx.input.redeemScriptOrEmptyScript)
1023- claimRemoteDelayedOutputTx.copy(tx = claimRemoteDelayedOutputTx.tx.updateWitness(0 , witness))
1043+ def addSigs (claimRemoteDelayedOutputTx : ClaimRemoteDelayedOutputTx , localSig : ByteVector64 ): ClaimRemoteDelayedOutputTx = claimRemoteDelayedOutputTx.input match {
1044+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1045+ val witness = witnessClaimToRemoteDelayedFromCommitTx(localSig, redeemScript)
1046+ claimRemoteDelayedOutputTx.copy(tx = claimRemoteDelayedOutputTx.tx.updateWitness(0 , witness))
1047+ case _ => claimRemoteDelayedOutputTx
10241048 }
10251049
1026- def addSigs (claimDelayedOutputTx : ClaimLocalDelayedOutputTx , localSig : ByteVector64 ): ClaimLocalDelayedOutputTx = {
1027- val witness = witnessToLocalDelayedAfterDelay(localSig, claimDelayedOutputTx.input.redeemScriptOrEmptyScript)
1028- claimDelayedOutputTx.copy(tx = claimDelayedOutputTx.tx.updateWitness(0 , witness))
1050+ def addSigs (claimDelayedOutputTx : ClaimLocalDelayedOutputTx , localSig : ByteVector64 ): ClaimLocalDelayedOutputTx = claimDelayedOutputTx.input match {
1051+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1052+ val witness = witnessToLocalDelayedAfterDelay(localSig, redeemScript)
1053+ claimDelayedOutputTx.copy(tx = claimDelayedOutputTx.tx.updateWitness(0 , witness))
1054+ case _ => claimDelayedOutputTx
10291055 }
10301056
1031- def addSigs (htlcDelayedTx : HtlcDelayedTx , localSig : ByteVector64 ): HtlcDelayedTx = {
1032- val witness = witnessToLocalDelayedAfterDelay(localSig, htlcDelayedTx.input.redeemScriptOrEmptyScript)
1033- htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0 , witness))
1057+ def addSigs (htlcDelayedTx : HtlcDelayedTx , localSig : ByteVector64 ): HtlcDelayedTx = htlcDelayedTx.input match {
1058+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1059+ val witness = witnessToLocalDelayedAfterDelay(localSig, redeemScript)
1060+ htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0 , witness))
1061+ case _ => htlcDelayedTx
10341062 }
10351063
1036- def addSigs (claimAnchorOutputTx : ClaimLocalAnchorOutputTx , localSig : ByteVector64 ): ClaimLocalAnchorOutputTx = {
1037- val witness = witnessAnchor(localSig, claimAnchorOutputTx.input.redeemScriptOrEmptyScript)
1038- claimAnchorOutputTx.copy(tx = claimAnchorOutputTx.tx.updateWitness(0 , witness))
1064+ def addSigs (claimAnchorOutputTx : ClaimLocalAnchorOutputTx , localSig : ByteVector64 ): ClaimLocalAnchorOutputTx = claimAnchorOutputTx.input match {
1065+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1066+ val witness = witnessAnchor(localSig, redeemScript)
1067+ claimAnchorOutputTx.copy(tx = claimAnchorOutputTx.tx.updateWitness(0 , witness))
1068+ case _ => claimAnchorOutputTx
10391069 }
10401070
1041- def addSigs (claimHtlcDelayedPenalty : ClaimHtlcDelayedOutputPenaltyTx , revocationSig : ByteVector64 ): ClaimHtlcDelayedOutputPenaltyTx = {
1042- val witness = Scripts .witnessToLocalDelayedWithRevocationSig(revocationSig, claimHtlcDelayedPenalty.input.redeemScriptOrEmptyScript)
1043- claimHtlcDelayedPenalty.copy(tx = claimHtlcDelayedPenalty.tx.updateWitness(0 , witness))
1071+ def addSigs (claimHtlcDelayedPenalty : ClaimHtlcDelayedOutputPenaltyTx , revocationSig : ByteVector64 ): ClaimHtlcDelayedOutputPenaltyTx = claimHtlcDelayedPenalty.input match {
1072+ case InputInfo .SegwitInput (_, _, redeemScript) =>
1073+ val witness = Scripts .witnessToLocalDelayedWithRevocationSig(revocationSig, redeemScript)
1074+ claimHtlcDelayedPenalty.copy(tx = claimHtlcDelayedPenalty.tx.updateWitness(0 , witness))
1075+ case _ => claimHtlcDelayedPenalty
10441076 }
10451077
10461078 def addSigs (closingTx : ClosingTx , localFundingPubkey : PublicKey , remoteFundingPubkey : PublicKey , localSig : ByteVector64 , remoteSig : ByteVector64 ): ClosingTx = {
0 commit comments