Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 165 additions & 93 deletions src/scalar_field.nr
Original file line number Diff line number Diff line change
Expand Up @@ -24,138 +24,135 @@ pub struct ScalarField<let N: u32> {
}

// 1, 2, 3, 4
unconstrained fn get_wnaf_slices<let N: u32>(x: Field) -> ([u8; N], bool) {
let mut result: [u8; N] = [0; N];
unconstrained fn get_wnaf_slices<let N: u32>(x: Field) -> ScalarField<N> {
let mut base4_slices: [u8; N] = [0; N];
let bytes = x.to_le_bytes::<32>();

// Extract nibbles from bytes
let mut nibbles: [u8; N] = [0; N];
nibbles = extract_nibbles_from_bytes(bytes);

let skew: bool = nibbles[0] & 1 == 0;
nibbles[0] = nibbles[0] + (skew as u8);
result[N - 1] = (nibbles[0] + 15) / 2;
base4_slices[N - 1] = (nibbles[0] + 15) / 2;

for i in 1..N {
let mut nibble: u8 = nibbles[i];
result[N - 1 - i] = (nibble + 15) / 2;
base4_slices[N - 1 - i] = (nibble + 15) / 2;
if (nibble & 1 == 0) {
result[N - 1 - i] += 1;
result[N - i] -= 8;
base4_slices[N - 1 - i] += 1;
base4_slices[N - i] -= 8;
}
}
(result, skew)

ScalarField { base4_slices, skew }
}

unconstrained fn get_wnaf_slices2<let N: u32, B>(x: B) -> ([u8; N], bool)
unconstrained fn get_wnaf_slices2<let N: u32, B>(x: B) -> ScalarField<N>
where
B: BigNum,
{
let mut result: [u8; N] = [0; N];
let mut base4_slices: [u8; N] = [0; N];
let mut nibbles: [[u8; 30]; (N / 30) + 1] = [[0; 30]; (N / 30) + 1];
let x: [u128] = x.get_limbs().as_slice();
let x: [u128; _] = x.get_limbs();
for i in 0..x.len() {
let bytes = (x[i] as Field).to_le_bytes::<30>();
nibbles[i] = extract_nibbles_from_bytes(bytes);
}

let skew: bool = nibbles[0][0] & 1 == 0;
nibbles[0][0] = nibbles[0][0] + (skew as u8);
result[N - 1] = (nibbles[0][0] + 15) / 2;
base4_slices[N - 1] = (nibbles[0][0] + 15) / 2;

for i in 1..N {
let major_index = i / 30;
let minor_index = i % 30;
let mut nibble: u8 = nibbles[major_index][minor_index];
result[N - 1 - i] = (nibble + 15) / 2;
base4_slices[N - 1 - i] = (nibble + 15) / 2;
if (nibble & 1 == 0) {
result[N - 1 - i] += 1;
result[N - i] -= 8;
base4_slices[N - 1 - i] += 1;
base4_slices[N - i] -= 8;
}
}
(result, skew)

ScalarField { base4_slices, skew }
}

// unconstrained fn get_modulus_slices() -> (Field, Field) {
// let bytes = std::field::modulus_be_bytes();
// let num_bytes = (std::field::modulus_num_bits() / 8) + ((std::field::modulus_num_bits() % 8 != 0) as u64);
// let mut lo: Field = 0;
// let mut hi: Field = 0;
// for i in 0..(num_bytes / 2) {
// hi *= 256;
// hi += bytes[i] as Field;
// lo *= 256;
// lo += bytes[i + (num_bytes/2)] as Field;
// }
// if (num_bytes & 1 == 1) {
// lo *= 256;
// lo += bytes[num_bytes - 1] as Field;
// }
// (lo, hi)
// }

// unconstrained fn get_borrow_flag(lhs_lo: Field, rhs_lo: Field) -> bool {
// lhs_lo.lt(rhs_lo + 1)
// }
fn get_modulus_slices<let N: u32>() -> [u8; N] {
let mut expected_slices: [u8; N] = [0; N];

if N == 64 {
let slice: [u8; 64] = [
9, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11, 4, 0, 12,
0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, 10, 1, 15, 0, 15,
10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
];
for i in 0..N {
expected_slices[i] = slice[i];
}
} else if N == 65 {
let slice: [u8; 65] = [
8, 1, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11, 4, 0,
12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, 10, 1, 15, 0,
15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
];
for i in 0..N {
expected_slices[i] = slice[i];
}
} else if N > 65 {
// For N > 65, we need to insert zeros at the beginning
let num_zeros = N - 65;
expected_slices[0] = 8;
for i in 1..num_zeros + 1 {
expected_slices[i] = 0;
}
let slice: [u8; 65] = [
8, 1, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11, 4, 0,
12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, 10, 1, 15, 0,
15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
];
for i in num_zeros + 1..N {
expected_slices[i] = slice[i - num_zeros];
}
}
expected_slices
}

fn compare_scalar_field_to_bignum<let N: u32>(result: ScalarField<N>) {
let expected_slices: [u8; N] = get_modulus_slices::<N>();

// Lexicographic comparison: stop when we find a strictly smaller number
let mut should_continue: bool = true;
for i in 0..N {
if should_continue {
if result.base4_slices[i] < expected_slices[i] {
// Found a strictly smaller number, we can stop - this is valid
should_continue = false;
} else if result.base4_slices[i] > expected_slices[i] {
// Found a strictly larger number, this is invalid
panic(f"Reconstructed number is greater than modulus");
}
// If equal, continue to the next element (should_continue remains true)
}
}
}

impl<let N: u32> std::convert::From<Field> for ScalarField<N> {

/// Constructs an instance from a field element.
///
/// If `N >= 64`, additional checks are performed to ensure that the slice decomposition
/// accurately represents the same integral value as the input. For example, it verifies
/// that the sum of the slices is not equal to `x + modulus`.
fn from(x: Field) -> Self {
let mut result: Self = ScalarField { base4_slices: [0; N], skew: false };
let (slices, skew): ([u8; N], bool) = unsafe { get_wnaf_slices(x) };
result.base4_slices = slices;
result.skew = skew;
if (N < 64) {
let mut acc: Field = (slices[0] as Field) * 2 - 15;
for i in 1..N {
acc *= 16;
acc += (slices[i] as Field) * 2 - 15;
fn from(input: Field) -> Self {
let result = unsafe { get_wnaf_slices(input) };

if !std::runtime::is_unconstrained() {
// Enforce that limbs are all 4 bits.
for i in 0..N {
(result.base4_slices[i] as Field).assert_max_bit_size::<4>();
}

// Enforce consistency with `input`.
let reconstructed_input: Field = result.into();
assert_eq(reconstructed_input, input);
if N >= 64 {
compare_scalar_field_to_bignum(result);
}
assert(acc - skew as Field == x);
} else {
// TODO fix! this does not work because we are assuming N slices is smaller than 256 bits
// let mut lo: Field = slices[(N / 2)] as Field * 2 - 15;
// let mut hi: Field = slices[0] as Field * 2 - 15;
// let mut borrow_shift = 1;
// for i in 1..(N / 2) {
// borrow_shift *= 16;
// lo *= 16;
// lo += (slices[(N/2) + i] as Field) * 2 - 15;
// hi *= 16;
// hi += (slices[i] as Field) * 2 - 15;
// }
// if ((N & 1) == 1) {
// borrow_shift *= 16;
// lo *= 16;
// lo += (slices[N-1] as Field) * 2 - 15;
// }
// // 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593efffffff
// // 0x2833e84879b9709143e1f593f0000001
// // 0x8833e84879b9709143e1f593efffffff
// lo -= skew as Field;
// // Validate that the integer represented by (lo, hi) is smaller than the integer represented by (plo, phi)
// let (plo, phi) = unsafe {
// get_modulus_slices()
// };
// let borrow = unsafe {
// get_borrow_flag(plo, lo) as Field
// };
// let rlo = plo - lo + borrow * borrow_shift - 1; // -1 because we are checking a strict <, not <=
// let rhi = phi - hi - borrow;
// let offset = (N & 1 == 1) as u32;
// let hibits = (N / 2) * 4;
// let lobits = hibits + offset * 4 + 1; // 1 extra bit to account for borrow
// // 0x013833e84879b9709143e1f593f0000000
// // rlo.assert_max_bit_size(lobits as u32);
// // rhi.assert_max_bit_size(hibits as u32);
}
for i in 0..N {
(result.base4_slices[i] as Field).assert_max_bit_size::<4>();
}
result
}
Expand All @@ -167,6 +164,12 @@ impl<let N: u32> std::convert::Into<Field> for ScalarField<N> {
* @details use this method instead of `new` if you know x/y is on the curve
**/
fn into(self: Self) -> Field {
// TODO: This is susceptible to overflow when N is large!
if !std::runtime::is_unconstrained() {
if N >= 64 {
compare_scalar_field_to_bignum(self);
}
}
let mut acc: Field = 0;
for i in 0..N {
acc = acc * 16;
Expand Down Expand Up @@ -194,7 +197,10 @@ impl<let N: u32> ScalarField<N> {
B: BigNum,
{
x.validate_in_field();
let mut (slices, skew): ([u8; N], bool) = unsafe { get_wnaf_slices2(x) };
// Safety: separate verification
let mut slices_result = unsafe { get_wnaf_slices2(x) };
let mut slices = slices_result.base4_slices;
let skew = slices_result.skew;
for i in 0..N {
(slices[i] as Field).assert_max_bit_size::<4>();
}
Expand Down Expand Up @@ -322,6 +328,7 @@ fn handle_overflow<B: BigNum>(acc: Field, result: &mut B, limb_index: u32) -> u1
}

mod tests {
use crate::scalar_field::get_modulus_slices;
use crate::scalar_field::ScalarField;
#[test]
// test even number of nibbles
Expand All @@ -340,4 +347,69 @@ mod tests {
let scalar_field2 = scalar_field.into();
assert(val as Field == scalar_field2);
}
#[test]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some tests for compare_scalar_field_to_bignum

unconstrained fn test_get_modulus_slices() {
let modulus_slices: [u8; 64] = get_modulus_slices::<64>();
assert(
modulus_slices
== [
9, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11, 4,
0, 12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8,
10, 1, 15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
],
);
let modulus_slices2: [u8; 65] = get_modulus_slices::<65>();
assert(
modulus_slices2
== [
8, 1, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11,
4, 0, 12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8,
10, 1, 15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
],
);
let modulus_slices3: [u8; 68] = get_modulus_slices::<68>();
assert(
modulus_slices3
== [
8, 0, 0, 0, 1, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2,
13, 11, 4, 0, 12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11,
8, 4, 8, 10, 1, 15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
],
);
}
#[test(should_fail_with = "Reconstructed number is greater than modulus")]
fn test_get_modulus_slices_fail_64() {
//13th nibble is 14, which is greater than the modulus
let modulus_slices: [u8; 64] = [
9, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 14, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11, 4, 0, 12,
0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, 10, 1, 15, 0, 15,
10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
];
let mut result: ScalarField<64> = ScalarField { base4_slices: modulus_slices, skew: true };
let _ = result.into();
}
#[test(should_fail_with = "Reconstructed number is greater than modulus")]
fn test_get_modulus_slices_fail_65() {
//modulus slice represents a number that is greater than the modulus
let modulus_slices2 = [
8, 1, 8, 3, 2, 2, 7, 3, 10, 8, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11, 4, 0,
12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, 10, 1, 15, 0,
15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
];
let mut result: ScalarField<65> = ScalarField { base4_slices: modulus_slices2, skew: true };
let _ = result.into();
}

#[test(should_fail_with = "Reconstructed number is greater than modulus")]
fn test_get_modulus_slices_fail_68() {
//modulus slice represents a number that is greater than the modulus
let modulus_slices3 = [
8, 0, 0, 0, 2, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, 13, 12, 2, 8, 2, 2, 13, 11,
4, 0, 12, 0, 10, 12, 2, 14, 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, 10, 1,
15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0,
];
let mut result: ScalarField<68> = ScalarField { base4_slices: modulus_slices3, skew: true };
let _ = result.into();
}

}