Skip to content

Commit 521ed3c

Browse files
author
masoud@anyscale.com
committed
add prove_chunk_kzg() for real ZK proof generation of model chunks
1 parent aaaed82 commit 521ed3c

File tree

3 files changed

+185
-1
lines changed

3 files changed

+185
-1
lines changed

zkml/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,8 @@ path = "testing/chunk_execution_test.rs"
9292

9393
[[test]]
9494
name = "test_merkle_root_public"
95-
path = "testing/test_merkle_root_public.rs"
95+
path = "testing/test_merkle_root_public.rs"
96+
97+
[[test]]
98+
name = "chunk_proof_test"
99+
path = "testing/chunk_proof_test.rs"

zkml/src/utils/proving_kzg.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,128 @@ pub fn time_circuit_kzg(circuit: ModelCircuit<Fr>) {
168168
println!("Verifying time: {:?}", verify_duration - proof_duration);
169169
}
170170

171+
/// Result of proving a chunk
172+
pub struct ChunkProofResult {
173+
/// The serialized proof bytes
174+
pub proof: Vec<u8>,
175+
/// Public values including Merkle root (if use_merkle=true)
176+
pub public_vals: Vec<Fr>,
177+
/// The Merkle root (last public value if use_merkle=true)
178+
pub merkle_root: Option<Fr>,
179+
/// Proving time in milliseconds
180+
pub proving_time_ms: u128,
181+
/// Verification time in milliseconds
182+
pub verify_time_ms: u128,
183+
}
184+
185+
/// Generate a KZG proof for a chunk of the model
186+
///
187+
/// # Arguments
188+
/// * `config_path` - Path to model config msgpack file
189+
/// * `input_path` - Path to input msgpack file
190+
/// * `chunk_start` - Starting layer index (inclusive)
191+
/// * `chunk_end` - Ending layer index (exclusive)
192+
/// * `use_merkle` - Whether to compute and include Merkle root in public values
193+
/// * `params_dir` - Directory for KZG params (will be created if needed)
194+
///
195+
/// # Returns
196+
/// ChunkProofResult containing proof, public values, and timing info
197+
pub fn prove_chunk_kzg(
198+
config_path: &str,
199+
input_path: &str,
200+
chunk_start: usize,
201+
chunk_end: usize,
202+
use_merkle: bool,
203+
params_dir: &str,
204+
) -> ChunkProofResult {
205+
use crate::utils::loader::load_model_msgpack;
206+
207+
let start = Instant::now();
208+
209+
// Load and configure circuit for chunk execution
210+
let config = load_model_msgpack(config_path, input_path);
211+
let mut circuit = ModelCircuit::<Fr>::generate_from_file(config_path, input_path);
212+
213+
// If using Merkle, ensure Poseidon hasher is configured
214+
if use_merkle && circuit.commit_after.is_empty() && circuit.commit_before.is_empty() {
215+
// Set commit_after to enable Poseidon hasher
216+
if let Some(layer) = config.layers.get(chunk_end.saturating_sub(1)) {
217+
if !layer.out_idxes.is_empty() {
218+
circuit.commit_after = vec![layer.out_idxes.clone()];
219+
}
220+
}
221+
}
222+
223+
// Configure for chunk execution
224+
circuit.set_chunk_config(chunk_start, chunk_end, use_merkle);
225+
226+
let degree = circuit.k as u32;
227+
let params = get_kzg_params(params_dir, degree);
228+
229+
// Generate keys
230+
let vk = keygen_vk(&params, &circuit).unwrap();
231+
let pk = keygen_pk(&params, vk, &circuit).unwrap();
232+
233+
// First run to get public values
234+
let _mock = MockProver::run(degree, &circuit, vec![vec![]]).unwrap();
235+
let public_vals = get_public_values();
236+
237+
// Extract Merkle root (last public value if use_merkle)
238+
let merkle_root = if use_merkle && !public_vals.is_empty() {
239+
Some(public_vals[public_vals.len() - 1])
240+
} else {
241+
None
242+
};
243+
244+
// Generate proof
245+
let rng = rand::thread_rng();
246+
let prove_start = Instant::now();
247+
let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]);
248+
create_proof::<
249+
KZGCommitmentScheme<Bn256>,
250+
ProverSHPLONK<'_, Bn256>,
251+
Challenge255<G1Affine>,
252+
_,
253+
Blake2bWrite<Vec<u8>, G1Affine, Challenge255<G1Affine>>,
254+
ModelCircuit<Fr>,
255+
>(
256+
&params,
257+
&pk,
258+
&[circuit],
259+
&[&[&public_vals]],
260+
rng,
261+
&mut transcript,
262+
)
263+
.unwrap();
264+
let proof = transcript.finalize();
265+
let proving_time_ms = prove_start.elapsed().as_millis();
266+
267+
// Verify the proof
268+
let verify_start = Instant::now();
269+
let strategy = SingleStrategy::new(&params);
270+
let transcript_read = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
271+
verify_kzg(&params, pk.get_vk(), strategy, &public_vals, transcript_read);
272+
let verify_time_ms = verify_start.elapsed().as_millis();
273+
274+
println!(
275+
"Chunk [{}, {}): proof generated in {}ms, verified in {}ms, {} public vals{}",
276+
chunk_start,
277+
chunk_end,
278+
proving_time_ms,
279+
verify_time_ms,
280+
public_vals.len(),
281+
if use_merkle { ", includes Merkle root" } else { "" }
282+
);
283+
284+
ChunkProofResult {
285+
proof,
286+
public_vals,
287+
merkle_root,
288+
proving_time_ms,
289+
verify_time_ms,
290+
}
291+
}
292+
171293
// Standalone verification
172294
pub fn verify_circuit_kzg(
173295
circuit: ModelCircuit<Fr>,

zkml/testing/chunk_proof_test.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//! Tests for chunk proof generation
2+
//!
3+
//! These tests verify that:
4+
//! 1. Real KZG proofs can be generated for model chunks
5+
//! 2. Proofs verify correctly
6+
//!
7+
//! Note: These are slow tests (~50s each) as they generate real cryptographic proofs.
8+
//! Run with: cargo test --test chunk_proof_test --release -- --nocapture
9+
10+
#[cfg(test)]
11+
mod tests {
12+
use zkml::utils::proving_kzg::prove_chunk_kzg;
13+
use std::fs;
14+
15+
/// Test: Generate and verify a real KZG proof for a chunk
16+
///
17+
/// This is the main test - it proves that we can generate real ZK proofs
18+
/// for a subset of model layers.
19+
#[test]
20+
fn test_chunk_proof_generation() {
21+
let config_file = "examples/mnist/model.msgpack";
22+
let input_file = "examples/mnist/inp.msgpack";
23+
24+
if !std::path::Path::new(config_file).exists() {
25+
eprintln!("Skipping test: example files not found");
26+
return;
27+
}
28+
29+
// Use unique params directory to avoid race conditions
30+
let params_dir = "./params_kzg_chunk_test";
31+
fs::create_dir_all(params_dir).ok();
32+
33+
// Generate proof for first 2 layers WITHOUT Merkle
34+
// This tests the core proof generation functionality
35+
let result = prove_chunk_kzg(
36+
config_file,
37+
input_file,
38+
0, // chunk_start
39+
2, // chunk_end
40+
false, // use_merkle = false (simpler case)
41+
params_dir,
42+
);
43+
44+
// Verify we got a valid proof
45+
assert!(!result.proof.is_empty(), "Proof should not be empty");
46+
assert!(!result.public_vals.is_empty(), "Public values should not be empty");
47+
assert!(result.merkle_root.is_none(), "Merkle root should be None when use_merkle=false");
48+
assert!(result.proving_time_ms > 0, "Should have non-zero proving time");
49+
assert!(result.verify_time_ms > 0, "Should have non-zero verify time");
50+
51+
println!("✓ Chunk proof generated and verified successfully");
52+
println!(" Proof size: {} bytes", result.proof.len());
53+
println!(" Public values: {}", result.public_vals.len());
54+
println!(" Proving time: {}ms", result.proving_time_ms);
55+
println!(" Verify time: {}ms", result.verify_time_ms);
56+
}
57+
}
58+

0 commit comments

Comments
 (0)