diff --git a/Cargo.lock b/Cargo.lock index f4830856c..922698236 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "autocfg" @@ -120,9 +120,9 @@ checksum = "847495c209977a90e8aad588b959d0ca9f5dc228096d29a6bd3defd53f35eaec" [[package]] name = "block-buffer" -version = "0.11.0-rc.2" +version = "0.11.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "939c0e62efa052fb0b2db2c0f7c479ad32e364c192c3aab605a7641de265a1a7" +checksum = "3fd016a0ddc7cb13661bf5576073ce07330a693f8608a1320b4e20561cc12cdc" dependencies = [ "hybrid-array", ] @@ -195,7 +195,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" dependencies = [ "ciborium-io", - "half", + "half 2.4.1", ] [[package]] @@ -225,9 +225,9 @@ checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "const-oid" -version = "0.10.0-rc.2" +version = "0.10.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a0d96d207edbe5135e55038e79ab9ad6d75ba83b14cdf62326ce5b12bc46ab5" +checksum = "68ff6be19477a1bd5441f382916a89bc2a0b2c35db6d41e0f6e8538bf6d6463f" [[package]] name = "cpufeatures" @@ -415,9 +415,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "ff" @@ -464,6 +464,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -512,9 +518,9 @@ dependencies = [ [[package]] name = "hybrid-array" -version = "0.2.0-rc.11" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5a41e5b0754cae5aaf7915f1df1147ba8d316fc6e019cfcc00fbaba96d5e030" +checksum = "45a9a965bb102c1c891fb017c09a05c965186b1265a207640f323ddd009f9deb" dependencies = [ "typenum", "zeroize", @@ -569,6 +575,7 @@ dependencies = [ "num-bigint", "num-traits", "once_cell", + "powdr-riscv-runtime", "proptest", "rand_core", "serdect", @@ -594,15 +601,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "linux-raw-sys" @@ -796,6 +803,22 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +[[package]] +name = "powdr-riscv-runtime" +version = "0.1.0-alpha.2" +source = "git+https://github.com/powdr-labs/powdr.git?tag=v0.1.1#699b74ac5b032113270a2f419ef192bdb7fc0857" +dependencies = [ + "getrandom", + "powdr-riscv-syscalls", + "serde", + "serde_cbor", +] + +[[package]] +name = "powdr-riscv-syscalls" +version = "0.1.1" +source = "git+https://github.com/powdr-labs/powdr.git?tag=v0.1.1#699b74ac5b032113270a2f419ef192bdb7fc0857" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -819,9 +842,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -928,9 +951,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -967,9 +990,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" dependencies = [ "bitflags", "errno", @@ -1022,18 +1045,28 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -1042,9 +1075,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -1135,9 +1168,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1152,9 +1185,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", diff --git a/k256/Cargo.toml b/k256/Cargo.toml index d0ebc81f8..721f1d163 100644 --- a/k256/Cargo.toml +++ b/k256/Cargo.toml @@ -20,20 +20,33 @@ rust-version = "1.81" [dependencies] cfg-if = "1.0" -elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features = ["sec1"] } +elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features = [ + "sec1", +] } # optional dependencies once_cell = { version = "1.20", optional = true, default-features = false } -ecdsa-core = { version = "=0.17.0-pre.9", package = "ecdsa", optional = true, default-features = false, features = ["der"] } +ecdsa-core = { version = "=0.17.0-pre.9", package = "ecdsa", optional = true, default-features = false, features = [ + "der", +] } hex-literal = { version = "0.4", optional = true } serdect = { version = "0.3.0-rc.0", optional = true, default-features = false } sha2 = { version = "=0.11.0-pre.4", optional = true, default-features = false } signature = { version = "=2.3.0-pre.4", optional = true } +[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dependencies] +powdr-riscv-runtime = { git = "https://github.com/powdr-labs/powdr.git", tag = "v0.1.1", features = [ + "std", + "getrandom", + "allow_fake_rand", +] } + [dev-dependencies] blobby = "0.3" criterion = "0.5" -ecdsa-core = { version = "=0.17.0-pre.9", package = "ecdsa", default-features = false, features = ["dev"] } +ecdsa-core = { version = "=0.17.0-pre.9", package = "ecdsa", default-features = false, features = [ + "dev", +] } hex = "0.4.3" hex-literal = "0.4" num-bigint = "0.4" @@ -43,7 +56,14 @@ rand_core = { version = "0.6", features = ["getrandom"] } sha3 = { version = "=0.11.0-pre.4", default-features = false } [features] -default = ["arithmetic", "ecdsa", "pkcs8", "precomputed-tables", "schnorr", "std"] +default = [ + "arithmetic", + "ecdsa", + "pkcs8", + "precomputed-tables", + "schnorr", + "std", +] alloc = ["ecdsa-core?/alloc", "elliptic-curve/alloc"] std = ["alloc", "ecdsa-core?/std", "elliptic-curve/std", "once_cell?/std"] diff --git a/k256/src/arithmetic/field.rs b/k256/src/arithmetic/field.rs index 9cbe924f8..43390b0d5 100644 --- a/k256/src/arithmetic/field.rs +++ b/k256/src/arithmetic/field.rs @@ -5,7 +5,10 @@ use cfg_if::cfg_if; cfg_if! { - if #[cfg(target_pointer_width = "32")] { + + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + mod field_8x32; + } else if #[cfg(target_pointer_width = "32")] { mod field_10x26; } else if #[cfg(target_pointer_width = "64")] { mod field_5x52; @@ -20,7 +23,9 @@ cfg_if! { use field_impl::FieldElementImpl; } else { cfg_if! { - if #[cfg(target_pointer_width = "32")] { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + use field_8x32::FieldElement8x32 as FieldElementImpl; + } else if #[cfg(target_pointer_width = "32")] { use field_10x26::FieldElement10x26 as FieldElementImpl; } else if #[cfg(target_pointer_width = "64")] { use field_5x52::FieldElement5x52 as FieldElementImpl; @@ -99,11 +104,31 @@ impl FieldElement { FieldElementImpl::from_bytes(bytes).map(Self) } + /// Attempts to parse the given byte array as an SEC1-encoded field element (in little-endian!). + /// Does not check the result for being in the correct range. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub(crate) fn from_bytes_unchecked_le(bytes: &[u8; 32]) -> Self { + Self(FieldElementImpl::from_bytes_unchecked_le(bytes)) + } + /// Convert a `u64` to a field element. pub const fn from_u64(w: u64) -> Self { Self(FieldElementImpl::from_u64(w)) } + /// Returns the SEC1 encoding (in little-endian!) of this field element. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub fn to_bytes_le(self) -> FieldBytes { + self.0.normalize().to_bytes_le() + } + + /// Convert a `i64` to a field element. + /// Returned value may be only weakly normalized. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub const fn from_i64(w: i64) -> Self { + Self(FieldElementImpl::from_i64(w)) + } + /// Returns the SEC1 encoding of this field element. pub fn to_bytes(self) -> FieldBytes { self.0.normalize().to_bytes() @@ -140,6 +165,14 @@ impl FieldElement { /// Returns 2*self. /// Doubles the magnitude. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub fn double(&self) -> Self { + self.mul_single(2) + } + + /// Returns 2*self. + /// Doubles the magnitude. + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] pub fn double(&self) -> Self { Self(self.0.add(&(self.0))) } @@ -361,6 +394,13 @@ impl From for FieldElement { } } +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +impl From for FieldElement { + fn from(k: i64) -> Self { + Self(FieldElementImpl::from_i64(k)) + } +} + impl PartialEq for FieldElement { fn eq(&self, other: &Self) -> bool { self.0.ct_eq(&(other.0)).into() diff --git a/k256/src/arithmetic/field/field_10x26.rs b/k256/src/arithmetic/field/field_10x26.rs index 6ea525a0f..a7d715065 100644 --- a/k256/src/arithmetic/field/field_10x26.rs +++ b/k256/src/arithmetic/field/field_10x26.rs @@ -68,6 +68,86 @@ impl FieldElement10x26 { Self([w0, w1, w2, w3, w4, w5, w6, w7, w8, w9]) } + /// Attempts to parse the given byte array as an SEC1-encoded field element (but little endian!). + /// Does not check the result for being in the correct range. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub(crate) fn from_bytes_unchecked_le(bytes: &[u8; 32]) -> Self { + // TODO: original conversion code is cheaper in some cases, not sure why. + let w0 = u32::from_le_bytes( + bytes[0..4] + .try_into() + .expect("Conversion should have worked"), + ) & 0x03ffffff; + // let w0 = (bytes[0] as u32) + // | ((bytes[1] as u32) << 8) + // | ((bytes[2] as u32) << 16) + // | (((bytes[3] & 0x3) as u32) << 24); + // let w1 = (u32::from_le_bytes(bytes[3..7].try_into().unwrap()) >> 2) & 0x03ffffff; + let w1 = (((bytes[3] >> 2) as u32) & 0x3f) + | ((bytes[4] as u32) << 6) + | ((bytes[5] as u32) << 14) + | (((bytes[6] & 0xf) as u32) << 22); + // let w2 = (u32::from_le_bytes(bytes[6..10].try_into().unwrap()) >> 4) & 0x03ffffff; + let w2 = (((bytes[6] >> 4) as u32) & 0xf) + | ((bytes[7] as u32) << 4) + | ((bytes[8] as u32) << 12) + | (((bytes[9] & 0x3f) as u32) << 20); + let w3 = (u32::from_le_bytes( + bytes[9..13] + .try_into() + .expect("Conversion should have worked"), + ) >> 6) + & 0x03ffffff; + // let w3 = (((bytes[9] >> 6) as u32) & 0x3) + // | ((bytes[10] as u32) << 2) + // | ((bytes[11] as u32) << 10) + // | ((bytes[12] as u32) << 18); + let w4 = u32::from_le_bytes( + bytes[13..17] + .try_into() + .expect("Conversion should have worked"), + ) & 0x03ffffff; + // let w4 = (bytes[13] as u32) + // | ((bytes[14] as u32) << 8) + // | ((bytes[15] as u32) << 16) + // | (((bytes[16] & 0x3) as u32) << 24); + // let w5 = (u32::from_le_bytes(bytes[16..20].try_into().unwrap()) >> 2) & 0x03ffffff; + let w5 = (((bytes[16] >> 2) as u32) & 0x3f) + | ((bytes[17] as u32) << 6) + | ((bytes[18] as u32) << 14) + | (((bytes[19] & 0xf) as u32) << 22); + // let w6 = (u32::from_le_bytes(bytes[19..23].try_into().unwrap()) >> 4) & 0x03ffffff; + let w6 = (((bytes[19] >> 4) as u32) & 0xf) + | ((bytes[20] as u32) << 4) + | ((bytes[21] as u32) << 12) + | (((bytes[22] & 0x3f) as u32) << 20); + let w7 = (u32::from_le_bytes( + bytes[22..26] + .try_into() + .expect("Conversion should have worked"), + ) >> 6) + & 0x03ffffff; + // let w7 = (((bytes[22] >> 6) as u32) & 0x3) + // | ((bytes[23] as u32) << 2) + // | ((bytes[24] as u32) << 10) + // | ((bytes[25] as u32) << 18); + let w8 = u32::from_le_bytes( + bytes[26..30] + .try_into() + .expect("Conversion should have worked"), + ) & 0x03ffffff; + // let w8 = (bytes[26] as u32) + // | ((bytes[27] as u32) << 8) + // | ((bytes[28] as u32) << 16) + // | (((bytes[29] & 0x3) as u32) << 24); + // let w9 = (u32::from_le_bytes(bytes[9..13].try_into().unwrap()) >> 6) & 0x03ffffff; + let w9 = (((bytes[29] >> 2) as u32) & 0x3f) + | ((bytes[30] as u32) << 6) + | ((bytes[31] as u32) << 14); + + Self([w0, w1, w2, w3, w4, w5, w6, w7, w8, w9]) + } + /// Attempts to parse the given byte array as an SEC1-encoded field element. /// /// Returns None if the byte array does not contain a big-endian integer in the range @@ -125,6 +205,45 @@ impl FieldElement10x26 { r } + /// Returns the SEC1 encoding of this field element (in little-endian!). + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub fn to_bytes_le(self) -> FieldBytes { + let mut r = FieldBytes::default(); + r[0] = self.0[0] as u8; + r[1] = (self.0[0] >> 8) as u8; + r[2] = (self.0[0] >> 16) as u8; + r[3] = ((self.0[1] as u8 & 0x3fu8) << 2) | ((self.0[0] >> 24) as u8 & 0x3u8); + r[4] = (self.0[1] >> 6) as u8; + r[5] = (self.0[1] >> 14) as u8; + r[6] = ((self.0[2] as u8 & 0xfu8) << 4) | ((self.0[1] >> 22) as u8 & 0xfu8); + r[7] = (self.0[2] >> 4) as u8; + r[8] = (self.0[2] >> 12) as u8; + r[9] = ((self.0[3] as u8 & 0x3u8) << 6) | ((self.0[2] >> 20) as u8 & 0x3fu8); + r[10] = (self.0[3] >> 2) as u8; + r[11] = (self.0[3] >> 10) as u8; + r[12] = (self.0[3] >> 18) as u8; + r[13] = self.0[4] as u8; + r[14] = (self.0[4] >> 8) as u8; + r[15] = (self.0[4] >> 16) as u8; + r[16] = ((self.0[5] as u8 & 0x3fu8) << 2) | ((self.0[4] >> 24) as u8 & 0x3u8); + r[17] = (self.0[5] >> 6) as u8; + r[18] = (self.0[5] >> 14) as u8; + r[19] = ((self.0[6] as u8 & 0xfu8) << 4) | ((self.0[5] >> 22) as u8 & 0xfu8); + r[20] = (self.0[6] >> 4) as u8; + r[21] = (self.0[6] >> 12) as u8; + r[22] = ((self.0[7] as u8 & 0x3u8) << 6) | ((self.0[6] >> 20) as u8 & 0x3fu8); + r[23] = (self.0[7] >> 2) as u8; + r[24] = (self.0[7] >> 10) as u8; + r[25] = (self.0[7] >> 18) as u8; + r[26] = self.0[8] as u8; + r[27] = (self.0[8] >> 8) as u8; + r[28] = (self.0[8] >> 16) as u8; + r[29] = ((self.0[9] as u8 & 0x3Fu8) << 2) | ((self.0[8] >> 24) as u8 & 0x3); + r[30] = (self.0[9] >> 6) as u8; + r[31] = (self.0[9] >> 14) as u8; + r + } + /// Adds `x * (2^256 - modulus)`. fn add_modulus_correction(&self, x: u32) -> Self { // add (2^256 - modulus) * x to the first limb diff --git a/k256/src/arithmetic/field/field_8x32.rs b/k256/src/arithmetic/field/field_8x32.rs new file mode 100644 index 000000000..3cd040c4a --- /dev/null +++ b/k256/src/arithmetic/field/field_8x32.rs @@ -0,0 +1,328 @@ +//! Field element modulo the curve internal modulus using 32-bit limbs. +#![allow(unsafe_code)] + +// This file is the same as the original below except for minor snippets. +// https://github.com/risc0/RustCrypto-elliptic-curves/blob/risczero/k256/src/arithmetic/field/field_8x32_risc0.rs + +use crate::FieldBytes; +use elliptic_curve::{ + bigint::{ArrayEncoding, Integer, Limb, Zero, U256}, + subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}, + zeroize::Zeroize, +}; +use powdr_riscv_runtime::arith::modmul_256_u32_le; + +/// Base field characteristic for secp256k1 as an 8x32 big integer, least to most significant. +const MODULUS: U256 = + U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F"); + +/// Low two words of 2^256 - MODULUS, used for correcting the value after addition mod 2^256. +const MODULUS_CORRECTION: U256 = U256::ZERO.wrapping_sub(&MODULUS); + +/// Scalars modulo SECP256k1 modulus (2^256 - 2^32 - 2^9 - 2^8 - 2^7 - 2^6 - 2^4 - 1). +/// Uses 8 32-bit limbs (little-endian) and acceleration support from the RISC Zero rv32im impl. +/// Unlike the 10x26 and 8x52 implementations, the values in this implementation are always +/// fully reduced and normalized as there is no extra room in the representation. +/// +/// NOTE: This implementation will only run inside the RISC Zero / powdrVM guests. +/// As a result, the requirements for constant-timeness are different +/// than on a physical platform. +#[derive(Clone, Copy, Debug)] +pub struct FieldElement8x32(pub(crate) U256); + +impl FieldElement8x32 { + /// Zero element. + pub const ZERO: Self = Self(U256::ZERO); + + /// Multiplicative identity. + pub const ONE: Self = Self(U256::ONE); + + /// Attempts to parse the given byte array as an SEC1-encoded field element. + /// Does not check the result for being in the correct range. + pub(crate) const fn from_bytes_unchecked(bytes: &[u8; 32]) -> Self { + Self(U256::from_be_slice(bytes.as_slice())) + } + + /// Attempts to parse the given byte array as an SEC1-encoded field element (but little endian!). + /// Does not check the result for being in the correct range. + pub(crate) fn from_bytes_unchecked_le(bytes: &[u8; 32]) -> Self { + Self(U256::from_le_slice(bytes.as_slice())) + } + + /// Attempts to parse the given byte array as an SEC1-encoded field element. + /// + /// Returns None if the byte array does not contain a big-endian integer in the range + /// [0, p). + pub fn from_bytes(bytes: &FieldBytes) -> CtOption { + let res = Self::from_bytes_unchecked(bytes.as_ref()); + let overflow = res.get_overflow(); + + CtOption::new(res, !overflow) + } + + pub const fn from_u64(val: u64) -> Self { + let w0 = val as u32; + let w1 = (val >> 32) as u32; + Self(U256::from_words([w0, w1, 0, 0, 0, 0, 0, 0])) + } + + pub const fn from_i64(val: i64) -> Self { + // Compute val_abs = |val| + let val_mask = val >> 63; + let val_abs = ((val + val_mask) ^ val_mask) as u64; + + Self::from_u64(val_abs).negate_const() + } + + /// Returns the SEC1 encoding of this field element. + pub fn to_bytes(self) -> FieldBytes { + self.0.to_be_byte_array() + } + + pub fn to_bytes_le(self) -> FieldBytes { + self.0.to_le_byte_array() + } + + /// Checks if the field element is greater or equal to the modulus. + fn get_overflow(&self) -> Choice { + let (_, carry) = self.0.adc(&MODULUS_CORRECTION, Limb(0)); + Choice::from(carry.0 as u8) + } + + /// Brings the field element's magnitude to 1, but does not necessarily normalize it. + /// + /// NOTE: In RISC Zero, this is a no-op since weak normalization is not an operation that + /// needs to be performed between calls to arithmetic routines. + #[inline(always)] + pub const fn normalize_weak(&self) -> Self { + Self(self.0) + } + + /// Returns the fully normalized and canonical representation of the value. + #[inline(always)] + pub fn normalize(&self) -> Self { + // When the prover is cooperative, the value is always normalized. + assert!(!bool::from(self.get_overflow())); + *self + } + + /// Checks if the field element becomes zero if normalized. + pub fn normalizes_to_zero(&self) -> Choice { + self.0.ct_eq(&U256::ZERO) | self.0.ct_eq(&MODULUS) + } + + /// Determine if this `FieldElement8x32` is zero. + /// + /// # Returns + /// + /// If zero, return `Choice(1)`. Otherwise, return `Choice(0)`. + pub fn is_zero(&self) -> Choice { + self.0.is_zero() + } + + /// Determine if this `FieldElement8x32` is odd in the SEC1 sense: `self mod 2 == 1`. + /// + /// Value must be normalized before calling is_odd. + /// + /// # Returns + /// + /// If odd, return `Choice(1)`. Otherwise, return `Choice(0)`. + pub fn is_odd(&self) -> Choice { + self.0.is_odd() + } + + #[cfg(debug_assertions)] + pub const fn max_magnitude() -> u32 { + // Results as always reduced, so this implementation does not need to track magnitude. + u32::MAX + } + + /// Returns -self. + const fn negate_const(&self) -> Self { + let (s, borrow) = MODULUS.sbb(&self.0, Limb(0)); + assert!(borrow.0 == 0); + Self(s) + } + + /// Returns -self. + pub fn negate(&self, _magnitude: u32) -> Self { + self.mul(&Self::ONE.negate_const()) + } + + /// Returns self + rhs mod p. + /// Sums the magnitudes. + pub fn add(&self, rhs: &Self) -> Self { + let self_limbs = self.0.as_limbs(); + let rhs_limbs = rhs.0.as_limbs(); + + // Carrying addition of self and rhs, with the overflow correction added in. + let (a0, carry0) = self_limbs[0].adc(rhs_limbs[0], MODULUS_CORRECTION.as_limbs()[0]); + let (a1, carry1) = self_limbs[1].adc( + rhs_limbs[1], + carry0.wrapping_add(MODULUS_CORRECTION.as_limbs()[1]), + ); + let (a2, carry2) = self_limbs[2].adc(rhs_limbs[2], carry1); + let (a3, carry3) = self_limbs[3].adc(rhs_limbs[3], carry2); + let (a4, carry4) = self_limbs[4].adc(rhs_limbs[4], carry3); + let (a5, carry5) = self_limbs[5].adc(rhs_limbs[5], carry4); + let (a6, carry6) = self_limbs[6].adc(rhs_limbs[6], carry5); + let (a7, carry7) = self_limbs[7].adc(rhs_limbs[7], carry6); + let a = U256::from([a0, a1, a2, a3, a4, a5, a6, a7]); + + // If the inputs are not in the range [0, p), then then carry7 may be greater than 1, + // indicating more than one overflow occurred. In this case, the code below will not + // correct the value. If the host is cooperative, this should never happen. + assert!(carry7.0 <= 1); + + // If a carry occured, then the correction was already added and the result is correct. + // If a carry did not occur, the correction needs to be removed. Result will be in [0, p). + // Wrap and unwrap to prevent the compiler interpreting this as a boolean, potentially + // introducing non-constant time code. + let mask = 1 - Choice::from(carry7.0 as u8).unwrap_u8(); + let c0 = MODULUS_CORRECTION.as_words()[0] * (mask as u32); + let c1 = MODULUS_CORRECTION.as_words()[1] * (mask as u32); + let correction = U256::from_words([c0, c1, 0, 0, 0, 0, 0, 0]); + + // The correction value was either already added to a, or is 0, so this sub will not + // underflow. + Self(a.wrapping_sub(&correction)) + } + + /// Returns self * rhs mod p + pub fn mul(&self, rhs: &Self) -> Self { + // powdr machine is 32 bits, so U256 = Uint<8> + Self(U256::from_words( + modmul_256_u32_le(self.0.to_words(), rhs.0.to_words(), MODULUS.to_words()), // the remainder + )) + } + + /// Multiplies by a single-limb integer. + pub fn mul_single(&self, rhs: u32) -> Self { + // powdr machine is 32 bits, so U256 = Uint<8> + Self(U256::from_words( + modmul_256_u32_le( + self.0.to_words(), + [rhs, 0, 0, 0, 0, 0, 0, 0], + MODULUS.to_words(), + ), // the remainder + )) + } + + /// Returns self * self + pub fn square(&self) -> Self { + // powdr machine is 32 bits, so U256 = Uint<8> + Self(U256::from_words( + modmul_256_u32_le(self.0.to_words(), self.0.to_words(), MODULUS.to_words()), // the remainder + )) + } +} + +impl Default for FieldElement8x32 { + fn default() -> Self { + Self::ZERO + } +} + +impl ConditionallySelectable for FieldElement8x32 { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self(U256::conditional_select(&a.0, &b.0, choice)) + } +} + +impl ConstantTimeEq for FieldElement8x32 { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl Zeroize for FieldElement8x32 { + fn zeroize(&mut self) { + self.0.zeroize(); + } +} + +#[cfg(test)] +mod tests { + use super::FieldElement8x32 as F; + use hex_literal::hex; + + const VAL_A: F = F::from_bytes_unchecked(&hex!( + "EC08EAC2CBCEFE58E61038DCA45BA2B4A56BDF05A3595EBEE1BCFC488889C1CF" + )); + const VAL_B: F = F::from_bytes_unchecked(&hex!( + "9FC3E90D2FAD03C8669F437A26374FA694CA76A7913C5E016322EBAA5C7616C5" + )); + + extern crate alloc; + + fn as_hex(&elem: &F) -> alloc::string::String { + // Call normalize here simply to assert that the value is normalized. + ::hex::encode_upper(elem.normalize().to_bytes()) + } + + #[test] + fn add() { + let expected: F = F::from_bytes_unchecked(&hex!( + "8BCCD3CFFB7C02214CAF7C56CA92F25B3A3655AD3495BCC044DFE7F3E4FFDC65" + )); + assert_eq!(as_hex(&VAL_A.add(&VAL_B)), as_hex(&expected)); + } + + // Tests the other "code path" returning the reduced or non-reduced result. + #[test] + fn add_negated() { + let expected: F = F::from_bytes_unchecked(&hex!( + "74332C300483FDDEB35083A9356D0DA4C5C9AA52CB6A433FBB20180B1B001FCA" + )); + assert_eq!( + as_hex(&VAL_A.negate(0).add(&VAL_B.negate(0))), + as_hex(&expected) + ); + } + + #[test] + fn negate() { + let expected: F = F::from_bytes_unchecked(&hex!( + "13F7153D343101A719EFC7235BA45D4B5A9420FA5CA6A1411E4303B677763A60" + )); + assert_eq!(as_hex(&VAL_A.negate(0)), as_hex(&expected)); + assert_eq!(as_hex(&VAL_A.add(&VAL_A.negate(0))), as_hex(&F::ZERO)); + } + + #[test] + fn mul() { + let expected: F = F::from_bytes_unchecked(&hex!( + "26B936E25A89EBAF821A46DC6BD8A0B1F0ED329412FA75FADF9A494D6F0EB4DB" + )); + assert_eq!(as_hex(&VAL_A.mul(&VAL_B)), as_hex(&expected)); + } + + #[test] + fn mul_zero() { + assert_eq!(as_hex(&VAL_A.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&VAL_B.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ZERO.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ONE.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ONE.negate(0).mul(&F::ZERO)), as_hex(&F::ZERO)); + } + + #[test] + fn mul_one() { + assert_eq!(as_hex(&VAL_A.mul(&F::ONE)), as_hex(&VAL_A)); + assert_eq!(as_hex(&VAL_B.mul(&F::ONE)), as_hex(&VAL_B)); + assert_eq!(as_hex(&F::ZERO.mul(&F::ONE)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ONE.mul(&F::ONE)), as_hex(&F::ONE)); + assert_eq!( + as_hex(&F::ONE.negate(0).mul(&F::ONE)), + as_hex(&F::ONE.negate(0)) + ); + } + + #[test] + fn square() { + let expected: F = F::from_bytes_unchecked(&hex!( + "111671376746955B968F48A94AFBACD243EA840AAE13EF85BC39AAE9552D8EDA" + )); + assert_eq!(as_hex(&VAL_A.square()), as_hex(&expected)); + } +} diff --git a/k256/src/arithmetic/field/field_impl.rs b/k256/src/arithmetic/field/field_impl.rs index 6c7820b1a..1adc4f2a3 100644 --- a/k256/src/arithmetic/field/field_impl.rs +++ b/k256/src/arithmetic/field/field_impl.rs @@ -8,11 +8,19 @@ use elliptic_curve::{ zeroize::Zeroize, }; -#[cfg(target_pointer_width = "32")] -use super::field_10x26::FieldElement10x26 as FieldElementUnsafeImpl; - -#[cfg(target_pointer_width = "64")] -use super::field_5x52::FieldElement5x52 as FieldElementUnsafeImpl; +use cfg_if::cfg_if; + +cfg_if! { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + use super::field_8x32::FieldElement8x32 as FieldElementUnsafeImpl; + } else if #[cfg(target_pointer_width = "32")] { + use super::field_10x26::FieldElement10x26 as FieldElementUnsafeImpl; + } else if #[cfg(target_pointer_width = "64")] { + use super::field_5x52::FieldElement5x52 as FieldElementUnsafeImpl; + } else { + compile_error!("unsupported target word size (i.e. target_pointer_width)"); + } +} #[derive(Clone, Copy, Debug)] pub struct FieldElementImpl { @@ -56,6 +64,9 @@ impl FieldElementImpl { debug_assert!(magnitude <= FieldElementUnsafeImpl::max_magnitude()); Self { value: *value, + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + magnitude: 1, + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] magnitude, normalized: false, } @@ -66,10 +77,23 @@ impl FieldElementImpl { Self::new_normalized(&value) } + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub(crate) fn from_bytes_unchecked_le(bytes: &[u8; 32]) -> Self { + let value = FieldElementUnsafeImpl::from_bytes_unchecked_le(bytes); + Self::new_normalized(&value) + } + pub(crate) const fn from_u64(val: u64) -> Self { Self::new_normalized(&FieldElementUnsafeImpl::from_u64(val)) } + /// Convert a `i64` to a field element. + /// Returned value may be only weakly normalized. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub(crate) const fn from_i64(w: i64) -> Self { + Self::new_weak_normalized(&FieldElementUnsafeImpl::from_i64(w)) + } + pub fn from_bytes(bytes: &FieldBytes) -> CtOption { let value = FieldElementUnsafeImpl::from_bytes(bytes); CtOption::map(value, |x| Self::new_normalized(&x)) @@ -80,6 +104,12 @@ impl FieldElementImpl { self.value.to_bytes() } + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub fn to_bytes_le(self) -> FieldBytes { + debug_assert!(self.normalized); + self.value.to_bytes_le() + } + pub fn normalize_weak(&self) -> Self { Self::new_weak_normalized(&self.value.normalize_weak()) } diff --git a/k256/src/arithmetic/hash2curve.rs b/k256/src/arithmetic/hash2curve.rs index 598748b82..fe820ff94 100644 --- a/k256/src/arithmetic/hash2curve.rs +++ b/k256/src/arithmetic/hash2curve.rs @@ -415,7 +415,13 @@ mod tests { Scalar(reduced_scalar) }; - proptest!(ProptestConfig::with_cases(1000), |(b0 in ANY, b1 in ANY, b2 in ANY, b3 in ANY, b4 in ANY, b5 in ANY)| { + proptest!(ProptestConfig::with_cases( + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + 1 + } else { + 1000 + } + ), |(b0 in ANY, b1 in ANY, b2 in ANY, b3 in ANY, b4 in ANY, b5 in ANY)| { let mut data = Array::default(); data[..8].copy_from_slice(&b0.to_be_bytes()); data[8..16].copy_from_slice(&b1.to_be_bytes()); diff --git a/k256/src/arithmetic/mul.rs b/k256/src/arithmetic/mul.rs index 162229f74..198e0fe8f 100644 --- a/k256/src/arithmetic/mul.rs +++ b/k256/src/arithmetic/mul.rs @@ -49,17 +49,36 @@ use core::ops::{Mul, MulAssign}; use elliptic_curve::{ ops::{LinearCombination, MulByGenerator}, scalar::IsHigh, - subtle::{Choice, ConditionallySelectable, ConstantTimeEq}, + subtle::{Choice, ConditionallySelectable}, }; +#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] +use elliptic_curve::subtle::ConstantTimeEq; + #[cfg(feature = "precomputed-tables")] use once_cell::sync::Lazy; -/// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]` +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +#[repr(align(1024))] +#[derive(Copy, Clone, Default)] +struct LookupTable([ProjectivePoint; 9]); + +#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] #[derive(Copy, Clone, Default)] struct LookupTable([ProjectivePoint; 8]); impl From<&ProjectivePoint> for LookupTable { + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + fn from(p: &ProjectivePoint) -> Self { + let mut points = [*p; 9]; + points[0] = ProjectivePoint::IDENTITY; + for j in 1..8 { + points[j + 1] = p + &points[j]; + } + LookupTable(points) + } + + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] fn from(p: &ProjectivePoint) -> Self { let mut points = [*p; 8]; for j in 0..7 { @@ -79,19 +98,35 @@ impl LookupTable { let xmask = x >> 7; let xabs = (x + xmask) ^ xmask; - // Get an array element in constant time - let mut t = ProjectivePoint::IDENTITY; - for j in 1..9 { - let c = (xabs as u8).ct_eq(&(j as u8)); - t.conditional_assign(&self.0[j - 1], c); - } - // Now t == |x| * p. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + // All paged-in memory is constant time to access in RISC Zero. + // LookupTable fits in 864 bytes, which is less than the page size of 1024. Adding the + // repr(align(1024)) attribute above ensure the struct is placed on a page boundary and + // so all accesses within the table will result in the same paging behavior. + let value = self.0[xabs as usize]; - let neg_mask = Choice::from((xmask & 1) as u8); - t.conditional_assign(&-t, neg_mask); - // Now t == x * p. + let neg_mask = Choice::from((xmask & 1) as u8); - t + ProjectivePoint::conditional_select(&value, &-value, neg_mask) + } + + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] + { + // Get an array element in constant time + let mut t = ProjectivePoint::IDENTITY; + for j in 1..9 { + let c = (xabs as u8).ct_eq(&(j as u8)); + t.conditional_assign(&self.0[j - 1], c); + } + // Now t == |x| * p. + + let neg_mask = Choice::from((xmask & 1) as u8); + t.conditional_assign(&-t, neg_mask); + // Now t == x * p. + + t + } } } diff --git a/k256/src/arithmetic/projective.rs b/k256/src/arithmetic/projective.rs index cf05fa16d..b2492fd9d 100644 --- a/k256/src/arithmetic/projective.rs +++ b/k256/src/arithmetic/projective.rs @@ -25,6 +25,9 @@ use elliptic_curve::{ #[cfg(feature = "alloc")] use alloc::vec::Vec; +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use powdr_riscv_runtime::ec::{add_u8_le, double_u8_le}; + #[rustfmt::skip] const ENDOMORPHISM_BETA: FieldElement = FieldElement::from_bytes_unchecked(&[ 0x7a, 0xe9, 0x6a, 0x2b, 0x65, 0x7c, 0x07, 0x10, @@ -94,6 +97,30 @@ impl ProjectivePoint { /// Returns `self + other`. fn add(&self, other: &ProjectivePoint) -> ProjectivePoint { + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + // call when the values are normalized, into powdr ec operations + if self.z == FieldElement::ONE && other.z == FieldElement::ONE { + // z being ONE means value is not identity + let self_x: [u8; 32] = self.x.to_bytes_le().into(); + let self_y: [u8; 32] = self.y.to_bytes_le().into(); + let other_x: [u8; 32] = other.x.to_bytes_le().into(); + let other_y: [u8; 32] = other.y.to_bytes_le().into(); + + let (res_x, res_y) = add_u8_le(self_x, self_y, other_x, other_y); + let mut res = *self; + res.x = FieldElement::from_bytes_unchecked_le(&res_x); + res.y = FieldElement::from_bytes_unchecked_le(&res_y); + return res; + } + + if self.is_identity().into() { + return *other; + } else if other.is_identity().into() { + return *self; + } + } + // We implement the complete addition formula from Renes-Costello-Batina 2015 // (https://eprint.iacr.org/2015/1060 Algorithm 7). @@ -108,36 +135,83 @@ impl ProjectivePoint { let yz_pairs = ((self.y + &self.z) * &(other.y + &other.z)) + &n_yy_zz; let xz_pairs = ((self.x + &self.z) * &(other.x + &other.z)) + &n_xx_zz; - let bzz = zz.mul_single(CURVE_EQUATION_B_SINGLE); - let bzz3 = (bzz.double() + &bzz).normalize_weak(); + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + // The following are from risc0, but in practice, powdr ec_add should have captured most cases + let bzz3 = zz.mul_single(CURVE_EQUATION_B_SINGLE * 3); - let yy_m_bzz3 = yy + &bzz3.negate(1); - let yy_p_bzz3 = yy + &bzz3; + let yy_m_bzz3 = yy + &bzz3.negate(1); + let yy_p_bzz3 = yy + &bzz3; - let byz = &yz_pairs - .mul_single(CURVE_EQUATION_B_SINGLE) - .normalize_weak(); - let byz3 = (byz.double() + byz).normalize_weak(); + let byz3 = &yz_pairs.mul_single(CURVE_EQUATION_B_SINGLE * 3); - let xx3 = xx.double() + &xx; - let bxx9 = (xx3.double() + &xx3) - .normalize_weak() - .mul_single(CURVE_EQUATION_B_SINGLE) - .normalize_weak(); + let xx3 = xx.mul_single(3); + let bxx9 = xx3.mul_single(CURVE_EQUATION_B_SINGLE * 3); - let new_x = ((xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1)).normalize_weak(); // m1 - let new_y = ((yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs)).normalize_weak(); - let new_z = ((yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs)).normalize_weak(); + let new_x = (xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1); // m1 + let new_y = (yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs); + let new_z = (yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs); - ProjectivePoint { - x: new_x, - y: new_y, - z: new_z, + ProjectivePoint { + x: new_x, + y: new_y, + z: new_z, + } + // end of risc0 block + } + + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] + { + let bzz = zz.mul_single(CURVE_EQUATION_B_SINGLE); + let bzz3 = (bzz.double() + &bzz).normalize_weak(); + + let yy_m_bzz3 = yy + &bzz3.negate(1); + let yy_p_bzz3 = yy + &bzz3; + + let byz = &yz_pairs + .mul_single(CURVE_EQUATION_B_SINGLE) + .normalize_weak(); + let byz3 = (byz.double() + byz).normalize_weak(); + + let xx3 = xx.double() + &xx; + let bxx9 = (xx3.double() + &xx3) + .normalize_weak() + .mul_single(CURVE_EQUATION_B_SINGLE) + .normalize_weak(); + + let new_x = ((xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1)).normalize_weak(); // m1 + let new_y = ((yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs)).normalize_weak(); + let new_z = ((yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs)).normalize_weak(); + + ProjectivePoint { + x: new_x, + y: new_y, + z: new_z, + } } } /// Returns `self + other`. fn add_mixed(&self, other: &AffinePoint) -> ProjectivePoint { + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + if other.is_identity().into() { + return *self; + } else if self.z == FieldElement::ONE { + // z being ONE means value is not identity + let self_x: [u8; 32] = self.x.to_bytes_le().into(); + let self_y: [u8; 32] = self.y.to_bytes_le().into(); + let other_x: [u8; 32] = other.x.to_bytes_le().into(); + let other_y: [u8; 32] = other.y.to_bytes_le().into(); + + let (res_x, res_y) = add_u8_le(self_x, self_y, other_x, other_y); + let mut res = *self; + res.x = FieldElement::from_bytes_unchecked_le(&res_x); + res.y = FieldElement::from_bytes_unchecked_le(&res_y); + return res; + } + } + // We implement the complete addition formula from Renes-Costello-Batina 2015 // (https://eprint.iacr.org/2015/1060 Algorithm 8). @@ -147,35 +221,82 @@ impl ProjectivePoint { let yz_pairs = (other.y * &self.z) + &self.y; let xz_pairs = (other.x * &self.z) + &self.x; - let bzz = &self.z.mul_single(CURVE_EQUATION_B_SINGLE); - let bzz3 = (bzz.double() + bzz).normalize_weak(); - - let yy_m_bzz3 = yy + &bzz3.negate(1); - let yy_p_bzz3 = yy + &bzz3; - - let byz = &yz_pairs - .mul_single(CURVE_EQUATION_B_SINGLE) - .normalize_weak(); - let byz3 = (byz.double() + byz).normalize_weak(); - - let xx3 = xx.double() + &xx; - let bxx9 = &(xx3.double() + &xx3) - .normalize_weak() - .mul_single(CURVE_EQUATION_B_SINGLE) - .normalize_weak(); + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + // The following are from risc0, but in practice, powdr ec_add should have captured most cases + // Same as below, but using mul_single instead of repeated addition to get small + // multiplications and normalize_weak is removed. + let bzz3 = self.z.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let yy_m_bzz3 = yy + &bzz3.negate(1); + let yy_p_bzz3 = yy + &bzz3; + + let n_byz3 = + &yz_pairs.mul(&FieldElement::from_i64(CURVE_EQUATION_B_SINGLE as i64 * -3)); + + let xx3 = xx.mul_single(3); + let bxx9 = xx3.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let mut ret = ProjectivePoint { + x: (xy_pairs * &yy_m_bzz3) + &(n_byz3 * &xz_pairs), + y: (yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs), + z: (yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs), + }; + ret.conditional_assign(self, other.is_identity()); + ret + // end of risc0 block + } - let mut ret = ProjectivePoint { - x: ((xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1)).normalize_weak(), - y: ((yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs)).normalize_weak(), - z: ((yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs)).normalize_weak(), - }; - ret.conditional_assign(self, other.is_identity()); - ret + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] + { + let bzz = &self.z.mul_single(CURVE_EQUATION_B_SINGLE); + let bzz3 = (bzz.double() + bzz).normalize_weak(); + + let yy_m_bzz3 = yy + &bzz3.negate(1); + let yy_p_bzz3 = yy + &bzz3; + + let byz = &yz_pairs + .mul_single(CURVE_EQUATION_B_SINGLE) + .normalize_weak(); + let byz3 = (byz.double() + byz).normalize_weak(); + + let xx3 = xx.double() + &xx; + let bxx9 = &(xx3.double() + &xx3) + .normalize_weak() + .mul_single(CURVE_EQUATION_B_SINGLE) + .normalize_weak(); + + let mut ret = ProjectivePoint { + x: ((xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1)).normalize_weak(), + y: ((yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs)).normalize_weak(), + z: ((yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs)).normalize_weak(), + }; + ret.conditional_assign(self, other.is_identity()); + ret + } } /// Doubles this point. #[inline] pub fn double(&self) -> ProjectivePoint { + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + if self.z == FieldElement::ONE { + // z being ONE means value is not identity + let self_x: [u8; 32] = self.x.to_bytes_le().into(); + let self_y: [u8; 32] = self.y.to_bytes_le().into(); + let (res_x, res_y) = double_u8_le(self_x, self_y); + let mut res = *self; + res.x = FieldElement::from_bytes_unchecked_le(&res_x); + res.y = FieldElement::from_bytes_unchecked_le(&res_y); + return res; + } + + if self.is_identity().into() { + return *self; + } + } + // We implement the complete addition formula from Renes-Costello-Batina 2015 // (https://eprint.iacr.org/2015/1060 Algorithm 9). @@ -183,27 +304,52 @@ impl ProjectivePoint { let zz = self.z.square(); let xy2 = (self.x * &self.y).double(); - let bzz = &zz.mul_single(CURVE_EQUATION_B_SINGLE); - let bzz3 = (bzz.double() + bzz).normalize_weak(); - let bzz9 = (bzz3.double() + &bzz3).normalize_weak(); - - let yy_m_bzz9 = yy + &bzz9.negate(1); - let yy_p_bzz3 = yy + &bzz3; - - let yy_zz = yy * &zz; - let yy_zz8 = yy_zz.double().double().double(); - let t = (yy_zz8.double() + &yy_zz8) - .normalize_weak() - .mul_single(CURVE_EQUATION_B_SINGLE); + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + { + // The following are from risc0, but in practice, powdr ec_add should have captured most cases + // Same as below, but using mul_single instead of repeated addition to get small + // multiplications and normalize_weak is removed. + let bzz3 = zz.mul_single(CURVE_EQUATION_B_SINGLE * 3); + let n_bzz9 = zz.mul(&FieldElement::from_i64(CURVE_EQUATION_B_SINGLE as i64 * -9)); + + let yy_m_bzz9 = yy + &n_bzz9; + let yy_p_bzz3 = yy + &bzz3; + + let yy_zz = yy * &zz; + let t = yy_zz.mul_single(CURVE_EQUATION_B_SINGLE * 24); + + ProjectivePoint { + x: xy2 * &yy_m_bzz9, + y: ((yy_m_bzz9 * &yy_p_bzz3) + &t), + z: ((yy * &self.y) * &self.z).mul_single(8), + } + // end of risc0 block + } - ProjectivePoint { - x: xy2 * &yy_m_bzz9, - y: ((yy_m_bzz9 * &yy_p_bzz3) + &t).normalize_weak(), - z: ((yy * &self.y) * &self.z) - .double() - .double() - .double() - .normalize_weak(), + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] + { + let bzz = &zz.mul_single(CURVE_EQUATION_B_SINGLE); + let bzz3 = (bzz.double() + bzz).normalize_weak(); + let bzz9 = (bzz3.double() + &bzz3).normalize_weak(); + + let yy_m_bzz9 = yy + &bzz9.negate(1); + let yy_p_bzz3 = yy + &bzz3; + + let yy_zz = yy * &zz; + let yy_zz8 = yy_zz.double().double().double(); + let t = (yy_zz8.double() + &yy_zz8) + .normalize_weak() + .mul_single(CURVE_EQUATION_B_SINGLE); + + ProjectivePoint { + x: xy2 * &yy_m_bzz9, + y: ((yy_m_bzz9 * &yy_p_bzz3) + &t).normalize_weak(), + z: ((yy * &self.y) * &self.z) + .double() + .double() + .double() + .normalize_weak(), + } } } diff --git a/k256/src/arithmetic/scalar.rs b/k256/src/arithmetic/scalar.rs index dad9288e5..13ae536f7 100644 --- a/k256/src/arithmetic/scalar.rs +++ b/k256/src/arithmetic/scalar.rs @@ -25,6 +25,9 @@ use elliptic_curve::{ Curve, ScalarPrimitive, }; +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use powdr_riscv_runtime::arith::modmul_256_u32_le; + #[cfg(feature = "bits")] use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits}; @@ -108,6 +111,19 @@ impl Scalar { } /// Modulo multiplies two scalars. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + pub fn mul(&self, rhs: &Scalar) -> Scalar { + let result = Self(U256::from_words( + modmul_256_u32_le(self.0.to_words(), rhs.0.to_words(), ORDER.to_words()), // the remainder + )); + // our asm doesn't guarantee full modulus reduction, so it's asserted at the rust level for soundness + // the honest prover should provide a remainder that's reduced already, so this assertion should never fail + assert!(bool::from(result.0.ct_lt(&ORDER))); + result + } + + /// Modulo multiplies two scalars. + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] pub fn mul(&self, rhs: &Scalar) -> Scalar { WideScalar::mul_wide(self, rhs).reduce() } @@ -117,6 +133,17 @@ impl Scalar { self.mul(self) } + /* + * Steve wrote this function but it's unused. + * Commenting out to stop clippy but keeping the code if needed. + #[inline(always)] + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + fn normalize(&self) -> Self { + assert!(bool::from(self.0.ct_lt(&ORDER))); + *self + } + */ + /// Right shifts the scalar. /// /// Note: not constant-time with respect to the `shift` parameter. @@ -419,6 +446,12 @@ impl Invert for Scalar { self.invert() } + #[allow(non_snake_case)] + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + fn invert_vartime(&self) -> CtOption { + self.invert() + } + /// Fast variable-time inversion using Stein's algorithm. /// /// Returns none if the scalar is zero. @@ -431,6 +464,7 @@ impl Invert for Scalar { /// variable-time operation can potentially leak secrets through /// sidechannels. #[allow(non_snake_case)] + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] fn invert_vartime(&self) -> CtOption { let mut u = *self; let mut v = Self::from_uint_unchecked(Secp256k1::ORDER); diff --git a/k256/src/lib.rs b/k256/src/lib.rs index 435045630..68e81f500 100644 --- a/k256/src/lib.rs +++ b/k256/src/lib.rs @@ -6,7 +6,6 @@ html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg" )] #![allow(clippy::needless_range_loop)] -#![forbid(unsafe_code)] #![warn( clippy::mod_module_files, clippy::unwrap_used,