Skip to content

Commit b9ff8cf

Browse files
authored
Keccakf32Memory: use compute_from instead of provide_value (#2553)
This PR changes the usage of `std::prover::provide_value` in `Keccakf32Memory` to `std::prover::compute_from`, in the hope of being able to use the JIT for this machine in #2541. ~Witgen time goes of `test_data/std/keccakf32_memory_test.asm` from 60s to 76s, probably because too many things are being evaluated.~ This is fixed by 1c15647.
1 parent 2cb6a3a commit b9ff8cf

File tree

1 file changed

+30
-134
lines changed

1 file changed

+30
-134
lines changed

std/machines/hash/keccakf32_memory.asm

+30-134
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::convert::fe;
88
use std::prelude::set_hint;
99
use std::prelude::Query;
1010
use std::prover::eval;
11-
use std::prover::provide_value;
11+
use std::prover::compute_from_multi;
1212
use std::machines::large_field::memory::Memory;
1313

1414
machine Keccakf32Memory(mem: Memory) with
@@ -578,6 +578,8 @@ machine Keccakf32Memory(mem: Memory) with
578578
});
579579

580580
// Prover function section (for witness generation).
581+
// Hints are only needed for c and a_prime, the solver is able to figure out the
582+
// rest of the witness.
581583

582584
// // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]).
583585
// for x in 0..5 {
@@ -592,49 +594,21 @@ machine Keccakf32Memory(mem: Memory) with
592594
// }
593595
// }
594596

595-
let query_c: int, int, int -> int = query |x, limb, bit_in_limb|
596-
utils::fold(
597-
5,
598-
|y| (int(eval(a[y * 10 + x * 2 + limb])) >> bit_in_limb) & 0x1,
599-
0,
600-
|acc, e| acc ^ e
601-
);
602-
603-
query |row| {
604-
let _ = array::map_enumerated(c, |i, c_i| {
597+
query |row| compute_from_multi(
598+
c, row, a,
599+
|a_fe| array::new(array::len(c), |i| {
605600
let x = i / 64;
606601
let z = i % 64;
607602
let limb = z / 32;
608603
let bit_in_limb = z % 32;
609604

610-
provide_value(c_i, row, fe(query_c(x, limb, bit_in_limb)));
611-
});
612-
};
613-
614-
// // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
615-
// for x in 0..5 {
616-
// for z in 0..64 {
617-
// row.c_prime[x][z] = xor([
618-
// row.c[x][z],
619-
// row.c[(x + 4) % 5][z],
620-
// row.c[(x + 1) % 5][(z + 63) % 64],
621-
// ]);
622-
// }
623-
// }
624-
625-
let query_c_prime: int, int -> int = query |x, z|
626-
int(eval(c[x * 64 + z])) ^
627-
int(eval(c[((x + 4) % 5) * 64 + z])) ^
628-
int(eval(c[((x + 1) % 5) * 64 + (z + 63) % 64]));
629-
630-
query |row| {
631-
let _ = array::map_enumerated(c_prime, |i, c_i| {
632-
let x = i / 64;
633-
let z = i % 64;
634-
635-
provide_value(c_i, row, fe(query_c_prime(x, z)));
636-
});
637-
};
605+
fe(utils::fold(
606+
5,
607+
|y| (int(a_fe[y * 10 + x * 2 + limb]) >> bit_in_limb) & 0x1,
608+
0,
609+
|acc, e| acc ^ e
610+
))
611+
}));
638612

639613
// // Populate A'. To avoid shifting indices, we rewrite
640614
// // A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1])
@@ -652,110 +626,32 @@ machine Keccakf32Memory(mem: Memory) with
652626
// }
653627
// }
654628

655-
let query_a_prime: int, int, int, int, int -> int = query |x, y, z, limb, bit_in_limb|
656-
((int(eval(a[y * 10 + x * 2 + limb])) >> bit_in_limb) & 0x1) ^
657-
int(eval(c[x * 64 + z])) ^
658-
int(eval(c_prime[x * 64 + z]));
659-
660-
query |row| {
661-
let _ = array::map_enumerated(a_prime, |i, a_i| {
629+
query |row| compute_from_multi(
630+
a_prime, row, a + c + c_prime,
631+
|inputs| array::new(array::len(a_prime), |i| {
662632
let y = i / 320;
663633
let x = (i / 64) % 5;
664634
let z = i % 64;
665635
let limb = z / 32;
666636
let bit_in_limb = z % 32;
667637

668-
provide_value(a_i, row, fe(query_a_prime(x, y, z, limb, bit_in_limb)));
669-
});
670-
};
638+
let a_elem = inputs[y * 10 + x * 2 + limb];
639+
let c_elem = inputs[x * 64 + z + 5 * 5 * 2];
640+
let c_prime_elem = inputs[x * 64 + z + 5 * 5 * 2 + 5 * 64];
671641

672-
// // Populate A''.P
673-
// // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
674-
// for y in 0..5 {
675-
// for x in 0..5 {
676-
// for limb in 0..U64_LIMBS {
677-
// row.a_prime_prime[y][x][limb] = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
678-
// .rev()
679-
// .fold(F::zero(), |acc, z| {
680-
// let bit = xor([
681-
// row.b(x, y, z),
682-
// andn(row.b((x + 1) % 5, y, z), row.b((x + 2) % 5, y, z)),
683-
// ]);
684-
// acc.double() + bit
685-
// });
686-
// }
687-
// }
688-
// }
689-
690-
let query_a_prime_prime: int, int, int -> int = query |x, y, limb|
691-
utils::fold(
692-
32,
693-
|z|
694-
int(eval(b(x, y, (limb + 1) * 32 - 1 - z))) ^
695-
int(eval(andn(b((x + 1) % 5, y, (limb + 1) * 32 - 1 - z),
696-
b((x + 2) % 5, y, (limb + 1) * 32 - 1 - z)))),
697-
0,
698-
|acc, e| acc * 2 + e
699-
);
700-
701-
query |row| {
702-
let _ = array::map_enumerated(a_prime_prime, |i, a_i| {
703-
let y = i / 10;
704-
let x = (i / 2) % 5;
705-
let limb = i % 2;
706-
707-
provide_value(a_i, row, fe(query_a_prime_prime(x, y, limb)));
708-
});
709-
};
710-
711-
// // For the XOR, we split A''[0, 0] to bits.
712-
// let mut val = 0; // smaller address correspond to less significant limb
713-
// for limb in 0..U64_LIMBS {
714-
// let val_limb = row.a_prime_prime[0][0][limb].as_canonical_u64();
715-
// val |= val_limb << (limb * BITS_PER_LIMB);
716-
// }
717-
// let val_bits: Vec<bool> = (0..64) // smaller address correspond to less significant bit
718-
// .scan(val, |acc, _| {
719-
// let bit = (*acc & 1) != 0;
720-
// *acc >>= 1;
721-
// Some(bit)
722-
// })
723-
// .collect();
724-
// for (i, bit) in row.a_prime_prime_0_0_bits.iter_mut().enumerate() {
725-
// *bit = F::from_bool(val_bits[i]);
726-
// }
642+
fe(((int(a_elem) >> bit_in_limb) & 0x1) ^ int(c_elem) ^ int(c_prime_elem))
643+
}));
727644

645+
// TODO: This hint is correct but not needed (the solver can figure this out).
646+
// We keep it here because it prevents the JIT solver from succeeding (because of the
647+
// use of `provide_value`), because it currently fails when compiling Rust code.
648+
// Once these issues are resolved, we can remove this hint.
728649
query |row| {
729-
let _ = array::map_enumerated(a_prime_prime_0_0_bits, |i, a_i| {
730-
let limb = i / 32;
731-
let bit_in_limb = i % 32;
732-
733-
provide_value(
734-
a_i,
735-
row,
736-
fe((int(eval(a_prime_prime[limb])) >> bit_in_limb) & 0x1)
737-
);
738-
});
650+
std::prover::provide_value(
651+
a_prime_prime_0_0_bits[0],
652+
row,
653+
fe((int(eval(a_prime_prime[0]))) & 0x1)
654+
);
739655
};
740656

741-
// // A''[0, 0] is additionally xor'd with RC.
742-
// for limb in 0..U64_LIMBS {
743-
// let rc_lo = rc_value_limb(round, limb);
744-
// row.a_prime_prime_prime_0_0_limbs[limb] =
745-
// F::from_canonical_u16(row.a_prime_prime[0][0][limb].as_canonical_u64() as u16 ^ rc_lo);
746-
// }
747-
748-
let query_a_prime_prime_prime_0_0_limbs: int, int -> int = query |round, limb|
749-
int(eval(a_prime_prime[limb])) ^
750-
((RC[round] >> (limb * 32)) & 0xffffffff);
751-
752-
query |row| {
753-
let _ = array::new(2, |limb| {
754-
provide_value(
755-
a_prime_prime_prime_0_0_limbs[limb],
756-
row,
757-
fe(query_a_prime_prime_prime_0_0_limbs(row % NUM_ROUNDS, limb)
758-
));
759-
});
760-
};
761657
}

0 commit comments

Comments
 (0)