diff --git a/.gitignore b/.gitignore index 75acbad..0a72b5a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ build/ # Idea .idea +.kotlin *.iml # Gradle diff --git a/src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt b/src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt index 774b98d..64dab61 100644 --- a/src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt +++ b/src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt @@ -15,7 +15,6 @@ */ package fr.acinq.bitcoin -import fr.acinq.bitcoin.io.ByteArrayInput import fr.acinq.bitcoin.io.ByteArrayOutput import fr.acinq.bitcoin.io.Input import fr.acinq.bitcoin.io.Output @@ -23,10 +22,10 @@ import kotlin.jvm.JvmStatic /** Simple binary tree structure containing taproot spending scripts. */ public sealed class ScriptTree { - public abstract fun write(output: Output, level: Int): Unit + public abstract fun write(output: Output, level: Int) /** - * @return the tree serialised with the format defined in BIP 371 + * @return the tree serialized with the format defined in BIP 371 */ public fun write(): ByteArray { val output = ByteArrayOutput() @@ -46,7 +45,7 @@ public sealed class ScriptTree { public constructor(script: List, leafVersion: Int) : this(Script.write(script).byteVector(), leafVersion) public constructor(script: String, leafVersion: Int) : this(ByteVector.fromHex(script), leafVersion) - override fun write(output: Output, level: Int): Unit { + override fun write(output: Output, level: Int) { output.write(level) output.write(leafVersion) BtcSerializer.writeScript(script, output) @@ -54,7 +53,7 @@ public sealed class ScriptTree { } public data class Branch(val left: ScriptTree, val right: ScriptTree) : ScriptTree() { - override fun write(output: Output, level: Int): Unit { + override fun write(output: Output, level: Int) { left.write(output, level + 1) right.write(output, level + 1) } @@ -68,7 +67,6 @@ public sealed class ScriptTree { BtcSerializer.writeScript(this.script, buffer) Crypto.taggedHash(buffer.toByteArray(), "TapLeaf") } - is Branch -> { val h1 = this.left.hash() val h2 = this.right.hash() @@ -83,6 +81,12 @@ public sealed class ScriptTree { is Branch -> this.left.findScript(script) ?: this.right.findScript(script) } + /** Return the first leaf with a matching leaf hash, if any. */ + public fun findScript(leafHash: ByteVector32): Leaf? = when (this) { + is Leaf -> if (this.hash() == leafHash) this else null + is Branch -> this.left.findScript(leafHash) ?: this.right.findScript(leafHash) + } + /** * Compute a merkle proof for the given script leaf. * This merkle proof is encoded for creating control blocks in taproot script path witnesses. @@ -128,7 +132,7 @@ public sealed class ScriptTree { public fun read(input: Input): ScriptTree { val leaves = readLeaves(input) merge(leaves) - require(leaves.size == 1) { "invalid serialised script tree" } + require(leaves.size == 1) { "invalid serialized script tree" } return leaves[0].second } } diff --git a/src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt index 970f2cf..059ccad 100644 --- a/src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt @@ -415,17 +415,28 @@ class TaprootTestsCommon { @Test fun `serialize script tree -- reference test`() { - val tree = - ScriptTree.read(ByteArrayInput(Hex.decode("02c02220736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac02c02220631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac01c0222044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac"))) + val encoded = "02c02220736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac02c02220631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac01c0222044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac" + val tree = ScriptTree.read(ByteArrayInput(Hex.decode(encoded))) + val leaves = listOf( + ScriptTree.Leaf("20736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac", 0xc0), + ScriptTree.Leaf("20631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac", 0xc0), + ScriptTree.Leaf("2044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac", 0xc0), + ) assertEquals( ScriptTree.Branch( ScriptTree.Branch( - ScriptTree.Leaf("20736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac", 0xc0), - ScriptTree.Leaf("20631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac", 0xc0), + leaves[0], + leaves[1], ), - ScriptTree.Leaf("2044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac", 0xc0) + leaves[2] ), tree ) + // We're able to find leaves in that script tree. + leaves.forEach { l -> + assertEquals(l, tree.findScript(l.script)) + assertEquals(l, tree.findScript(l.hash())) + } + assertNull(tree.findScript(ByteVector.fromHex("deadbeef"))) } @Test