Skip to content

Commit 4d99060

Browse files
authored
Cleanly separate Circom1 and Circom2 traits (arkworks-rs#60)
1 parent 170b10f commit 4d99060

File tree

3 files changed

+67
-80
lines changed

3 files changed

+67
-80
lines changed

src/witness/circom.rs

+41-46
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ pub struct Wasm(Instance);
77
pub trait CircomBase {
88
fn init(&self, sanity_check: bool) -> Result<()>;
99
fn func(&self, name: &str) -> &Function;
10-
fn get_ptr_witness_buffer(&self) -> Result<u32>;
11-
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
1210
fn get_n_vars(&self) -> Result<u32>;
11+
fn get_u32(&self, name: &str) -> Result<u32>;
12+
// Only exists natively in Circom2, hardcoded for Circom
13+
fn get_version(&self) -> Result<u32>;
14+
}
15+
16+
pub trait Circom1 {
17+
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
18+
fn get_fr_len(&self) -> Result<u32>;
1319
fn get_signal_offset32(
1420
&self,
1521
p_sig_offset: u32,
@@ -18,13 +24,6 @@ pub trait CircomBase {
1824
hash_lsb: u32,
1925
) -> Result<()>;
2026
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>;
21-
fn get_u32(&self, name: &str) -> Result<u32>;
22-
// Only exists natively in Circom2, hardcoded for Circom
23-
fn get_version(&self) -> Result<u32>;
24-
}
25-
26-
pub trait Circom {
27-
fn get_fr_len(&self) -> Result<u32>;
2827
fn get_ptr_raw_prime(&self) -> Result<u32>;
2928
}
3029

@@ -38,14 +37,46 @@ pub trait Circom2 {
3837
fn get_witness_size(&self) -> Result<u32>;
3938
}
4039

41-
impl Circom for Wasm {
40+
impl Circom1 for Wasm {
4241
fn get_fr_len(&self) -> Result<u32> {
4342
self.get_u32("getFrLen")
4443
}
4544

4645
fn get_ptr_raw_prime(&self) -> Result<u32> {
4746
self.get_u32("getPRawPrime")
4847
}
48+
49+
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
50+
let func = self.func("getPWitness");
51+
let res = func.call(&[w.into()])?;
52+
53+
Ok(res[0].unwrap_i32() as u32)
54+
}
55+
56+
fn get_signal_offset32(
57+
&self,
58+
p_sig_offset: u32,
59+
component: u32,
60+
hash_msb: u32,
61+
hash_lsb: u32,
62+
) -> Result<()> {
63+
let func = self.func("getSignalOffset32");
64+
func.call(&[
65+
p_sig_offset.into(),
66+
component.into(),
67+
hash_msb.into(),
68+
hash_lsb.into(),
69+
])?;
70+
71+
Ok(())
72+
}
73+
74+
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
75+
let func = self.func("setSignal");
76+
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
77+
78+
Ok(())
79+
}
4980
}
5081

5182
#[cfg(feature = "circom-2")]
@@ -96,46 +127,10 @@ impl CircomBase for Wasm {
96127
Ok(())
97128
}
98129

99-
fn get_ptr_witness_buffer(&self) -> Result<u32> {
100-
self.get_u32("getWitnessBuffer")
101-
}
102-
103-
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
104-
let func = self.func("getPWitness");
105-
let res = func.call(&[w.into()])?;
106-
107-
Ok(res[0].unwrap_i32() as u32)
108-
}
109-
110130
fn get_n_vars(&self) -> Result<u32> {
111131
self.get_u32("getNVars")
112132
}
113133

114-
fn get_signal_offset32(
115-
&self,
116-
p_sig_offset: u32,
117-
component: u32,
118-
hash_msb: u32,
119-
hash_lsb: u32,
120-
) -> Result<()> {
121-
let func = self.func("getSignalOffset32");
122-
func.call(&[
123-
p_sig_offset.into(),
124-
component.into(),
125-
hash_msb.into(),
126-
hash_lsb.into(),
127-
])?;
128-
129-
Ok(())
130-
}
131-
132-
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
133-
let func = self.func("setSignal");
134-
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
135-
136-
Ok(())
137-
}
138-
139134
// Default to version 1 if it isn't explicitly defined
140135
fn get_version(&self) -> Result<u32> {
141136
match self.0.exports.get_function("getVersion") {

src/witness/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub(super) use circom::{CircomBase, Wasm};
1010
#[cfg(feature = "circom-2")]
1111
pub(super) use circom::Circom2;
1212

13-
pub(super) use circom::Circom;
13+
pub(super) use circom::Circom1;
1414

1515
use fnv::FnvHasher;
1616
use std::hash::Hasher;

src/witness/witness_calculator.rs

+25-33
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,22 @@ use super::{fnv, CircomBase, SafeMemory, Wasm};
22
use color_eyre::Result;
33
use num_bigint::BigInt;
44
use num_traits::Zero;
5-
use std::cell::Cell;
65
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};
76

87
#[cfg(feature = "circom-2")]
98
use num::ToPrimitive;
109

10+
use super::Circom1;
1111
#[cfg(feature = "circom-2")]
1212
use super::Circom2;
1313

14-
use super::Circom;
15-
1614
#[derive(Clone, Debug)]
1715
pub struct WitnessCalculator {
1816
pub instance: Wasm,
19-
pub memory: SafeMemory,
17+
pub memory: Option<SafeMemory>,
2018
pub n64: u32,
2119
pub circom_version: u32,
20+
pub prime: BigInt,
2221
}
2322

2423
// Error type to signal end of execution.
@@ -92,9 +91,8 @@ impl WitnessCalculator {
9291

9392
// Circom 2 feature flag with version 2
9493
#[cfg(feature = "circom-2")]
95-
fn new_circom2(instance: Wasm, memory: Memory, version: u32) -> Result<WitnessCalculator> {
94+
fn new_circom2(instance: Wasm, version: u32) -> Result<WitnessCalculator> {
9695
let n32 = instance.get_field_num_len32()?;
97-
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
9896
instance.get_raw_prime()?;
9997
let mut arr = vec![0; n32 as usize];
10098
for i in 0..n32 {
@@ -104,13 +102,13 @@ impl WitnessCalculator {
104102
let prime = from_array32(arr);
105103

106104
let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
107-
safe_memory.prime = prime;
108105

109106
Ok(WitnessCalculator {
110107
instance,
111-
memory: safe_memory,
108+
memory: None,
112109
n64,
113110
circom_version: version,
111+
prime,
114112
})
115113
}
116114

@@ -122,13 +120,14 @@ impl WitnessCalculator {
122120
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
123121

124122
let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
125-
safe_memory.prime = prime;
123+
safe_memory.prime = prime.clone();
126124

127125
Ok(WitnessCalculator {
128126
instance,
129-
memory: safe_memory,
127+
memory: Some(safe_memory),
130128
n64,
131129
circom_version: version,
130+
prime,
132131
})
133132
}
134133

@@ -142,7 +141,7 @@ impl WitnessCalculator {
142141
cfg_if::cfg_if! {
143142
if #[cfg(feature = "circom-2")] {
144143
match version {
145-
2 => new_circom2(instance, memory, version),
144+
2 => new_circom2(instance, version),
146145
1 => new_circom1(instance, memory, version),
147146
_ => panic!("Unknown Circom version")
148147
}
@@ -180,9 +179,9 @@ impl WitnessCalculator {
180179
) -> Result<Vec<BigInt>> {
181180
self.instance.init(sanity_check)?;
182181

183-
let old_mem_free_pos = self.memory.free_pos();
184-
let p_sig_offset = self.memory.alloc_u32();
185-
let p_fr = self.memory.alloc_fr();
182+
let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos();
183+
let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32();
184+
let p_fr = self.memory.as_mut().unwrap().alloc_fr();
186185

187186
// allocate the inputs
188187
for (name, values) in inputs.into_iter() {
@@ -191,10 +190,17 @@ impl WitnessCalculator {
191190
self.instance
192191
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
193192

194-
let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
193+
let sig_offset = self
194+
.memory
195+
.as_ref()
196+
.unwrap()
197+
.read_u32(p_sig_offset as usize) as usize;
195198

196199
for (i, value) in values.into_iter().enumerate() {
197-
self.memory.write_fr(p_fr as usize, &value)?;
200+
self.memory
201+
.as_mut()
202+
.unwrap()
203+
.write_fr(p_fr as usize, &value)?;
198204
self.instance
199205
.set_signal(0, 0, (sig_offset + i) as u32, p_fr)?;
200206
}
@@ -205,11 +211,11 @@ impl WitnessCalculator {
205211
let n_vars = self.instance.get_n_vars()?;
206212
for i in 0..n_vars {
207213
let ptr = self.instance.get_ptr_witness(i)? as usize;
208-
let el = self.memory.read_fr(ptr)?;
214+
let el = self.memory.as_ref().unwrap().read_fr(ptr)?;
209215
w.push(el);
210216
}
211217

212-
self.memory.set_free_pos(old_mem_free_pos);
218+
self.memory.as_mut().unwrap().set_free_pos(old_mem_free_pos);
213219

214220
Ok(w)
215221
}
@@ -283,20 +289,6 @@ impl WitnessCalculator {
283289

284290
Ok(witness)
285291
}
286-
287-
pub fn get_witness_buffer(&self) -> Result<Vec<u8>> {
288-
let ptr = self.instance.get_ptr_witness_buffer()? as usize;
289-
290-
let view = self.memory.memory.view::<u8>();
291-
292-
let len = self.instance.get_n_vars()? * self.n64 * 8;
293-
let arr = view[ptr..ptr + len as usize]
294-
.iter()
295-
.map(Cell::get)
296-
.collect::<Vec<_>>();
297-
298-
Ok(arr)
299-
}
300292
}
301293

302294
// callback hooks for debugging
@@ -463,7 +455,7 @@ mod tests {
463455
fn run_test(case: TestCase) {
464456
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
465457
assert_eq!(
466-
wtns.memory.prime.to_str_radix(16),
458+
wtns.prime.to_str_radix(16),
467459
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
468460
);
469461
assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);

0 commit comments

Comments
 (0)