Skip to content

Commit 0f3fda1

Browse files
author
masoud@anyscale.com
committed
merge dev: resolve conflicts, keep dev changes
2 parents 0127d28 + 6b4ec9d commit 0f3fda1

File tree

6 files changed

+263
-2
lines changed

6 files changed

+263
-2
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ jobs:
5555
5656
echo "Running chunk execution tests..."
5757
cargo test --test chunk_execution_test -- --nocapture
58+
59+
echo "Running chunk proof generation test (release mode for speed)..."
60+
cargo test --test chunk_proof_test --release -- --nocapture
5861
5962
- name: Run clippy (warnings only)
6063
working-directory: zkml

.github/workflows/slow-tests.yml

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
name: Slow Tests
2+
3+
# Triggers:
4+
# 1. Manual trigger via GitHub UI ("Run workflow" button)
5+
# 2. When 'run-slow-tests' label is added to a PR
6+
# 3. Nightly at 2am UTC (optional, commented out)
7+
8+
on:
9+
workflow_dispatch: # Manual trigger
10+
pull_request:
11+
types: [labeled]
12+
13+
# schedule:
14+
# - cron: '0 2 * * *' # Uncomment for nightly runs
15+
16+
env:
17+
CARGO_TERM_COLOR: always
18+
RUST_BACKTRACE: 1
19+
20+
jobs:
21+
slow-tests:
22+
name: Slow Tests (KZG Proofs)
23+
runs-on: ubuntu-latest
24+
# Only run if manually triggered OR if the 'run-slow-tests' label was added
25+
if: |
26+
github.event_name == 'workflow_dispatch' ||
27+
(github.event_name == 'pull_request' && github.event.label.name == 'run-slow-tests')
28+
29+
steps:
30+
- name: Checkout code
31+
uses: actions/checkout@v4
32+
33+
- name: Install Rust toolchain
34+
uses: dtolnay/rust-toolchain@nightly
35+
with:
36+
components: rustfmt, clippy
37+
38+
- name: Cache cargo registry
39+
uses: actions/cache@v4
40+
with:
41+
path: |
42+
~/.cargo/registry
43+
~/.cargo/git
44+
zkml/target
45+
key: ${{ runner.os }}-cargo-slow-${{ hashFiles('zkml/Cargo.lock') }}
46+
restore-keys: |
47+
${{ runner.os }}-cargo-slow-
48+
${{ runner.os }}-cargo-
49+
50+
- name: Build zkml library (release)
51+
working-directory: zkml
52+
run: cargo build --lib --release
53+
54+
- name: Run chunk proof generation test
55+
working-directory: zkml
56+
run: |
57+
echo "Running KZG proof generation test (this takes ~1-2 minutes)..."
58+
cargo test --test chunk_proof_test --release -- --nocapture
59+
60+
- name: Summary
61+
run: |
62+
echo "## Slow Tests Complete ✅" >> $GITHUB_STEP_SUMMARY
63+
echo "- KZG chunk proof generation: PASSED" >> $GITHUB_STEP_SUMMARY
64+
65+
# Note: AWS/GPU tests are excluded from CI entirely.
66+
# Run them manually: cd tests/aws && python gpu_test.py
67+

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,16 @@ Distributed Proving Simulation: PASS
342342
#### "PyTorch CUDA not available"
343343
- Install PyTorch with CUDA: `pip install torch --index-url https://download.pytorch.org/whl/cu118`
344344

345+
## CI
346+
347+
Lightweight CI runs on every PR to `main` and `dev`:
348+
- Builds zkml library (nightly Rust)
349+
- Runs `zkml/testing/` tests (~3-4 min total)
350+
- AWS/GPU tests excluded to save costs
351+
345352
## TODO: Next Steps
346353

347-
1. **Make Merkle root public**: Add root to public values so next chunk can verify it
354+
1. ~~**Make Merkle root public**: Add root to public values so next chunk can verify it~~ ✅ Done
348355
2. **Complete proof generation**: Connect chunk execution to actual proof generation
349356
3. **Ray-Rust integration**: Connect Python Ray workers to Rust proof generation
350357
4. **GPU acceleration**: Current implementation is CPU-based. GPU acceleration for proof generation requires additional work (Halo2 GPU support or custom GPU kernels)

zkml/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,8 @@ path = "testing/chunk_execution_test.rs"
9292

9393
[[test]]
9494
name = "test_merkle_root_public"
95-
path = "testing/test_merkle_root_public.rs"
95+
path = "testing/test_merkle_root_public.rs"
96+
97+
[[test]]
98+
name = "chunk_proof_test"
99+
path = "testing/chunk_proof_test.rs"

zkml/src/utils/proving_kzg.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,128 @@ pub fn time_circuit_kzg(circuit: ModelCircuit<Fr>) {
168168
println!("Verifying time: {:?}", verify_duration - proof_duration);
169169
}
170170

171+
/// Result of proving a chunk
172+
pub struct ChunkProofResult {
173+
/// The serialized proof bytes
174+
pub proof: Vec<u8>,
175+
/// Public values including Merkle root (if use_merkle=true)
176+
pub public_vals: Vec<Fr>,
177+
/// The Merkle root (last public value if use_merkle=true)
178+
pub merkle_root: Option<Fr>,
179+
/// Proving time in milliseconds
180+
pub proving_time_ms: u128,
181+
/// Verification time in milliseconds
182+
pub verify_time_ms: u128,
183+
}
184+
185+
/// Generate a KZG proof for a chunk of the model
186+
///
187+
/// # Arguments
188+
/// * `config_path` - Path to model config msgpack file
189+
/// * `input_path` - Path to input msgpack file
190+
/// * `chunk_start` - Starting layer index (inclusive)
191+
/// * `chunk_end` - Ending layer index (exclusive)
192+
/// * `use_merkle` - Whether to compute and include Merkle root in public values
193+
/// * `params_dir` - Directory for KZG params (will be created if needed)
194+
///
195+
/// # Returns
196+
/// ChunkProofResult containing proof, public values, and timing info
197+
pub fn prove_chunk_kzg(
198+
config_path: &str,
199+
input_path: &str,
200+
chunk_start: usize,
201+
chunk_end: usize,
202+
use_merkle: bool,
203+
params_dir: &str,
204+
) -> ChunkProofResult {
205+
use crate::utils::loader::load_model_msgpack;
206+
207+
let start = Instant::now();
208+
209+
// Load and configure circuit for chunk execution
210+
let config = load_model_msgpack(config_path, input_path);
211+
let mut circuit = ModelCircuit::<Fr>::generate_from_file(config_path, input_path);
212+
213+
// If using Merkle, ensure Poseidon hasher is configured
214+
if use_merkle && circuit.commit_after.is_empty() && circuit.commit_before.is_empty() {
215+
// Set commit_after to enable Poseidon hasher
216+
if let Some(layer) = config.layers.get(chunk_end.saturating_sub(1)) {
217+
if !layer.out_idxes.is_empty() {
218+
circuit.commit_after = vec![layer.out_idxes.clone()];
219+
}
220+
}
221+
}
222+
223+
// Configure for chunk execution
224+
circuit.set_chunk_config(chunk_start, chunk_end, use_merkle);
225+
226+
let degree = circuit.k as u32;
227+
let params = get_kzg_params(params_dir, degree);
228+
229+
// Generate keys
230+
let vk = keygen_vk(&params, &circuit).unwrap();
231+
let pk = keygen_pk(&params, vk, &circuit).unwrap();
232+
233+
// First run to get public values
234+
let _mock = MockProver::run(degree, &circuit, vec![vec![]]).unwrap();
235+
let public_vals = get_public_values();
236+
237+
// Extract Merkle root (last public value if use_merkle)
238+
let merkle_root = if use_merkle && !public_vals.is_empty() {
239+
Some(public_vals[public_vals.len() - 1])
240+
} else {
241+
None
242+
};
243+
244+
// Generate proof
245+
let rng = rand::thread_rng();
246+
let prove_start = Instant::now();
247+
let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]);
248+
create_proof::<
249+
KZGCommitmentScheme<Bn256>,
250+
ProverSHPLONK<'_, Bn256>,
251+
Challenge255<G1Affine>,
252+
_,
253+
Blake2bWrite<Vec<u8>, G1Affine, Challenge255<G1Affine>>,
254+
ModelCircuit<Fr>,
255+
>(
256+
&params,
257+
&pk,
258+
&[circuit],
259+
&[&[&public_vals]],
260+
rng,
261+
&mut transcript,
262+
)
263+
.unwrap();
264+
let proof = transcript.finalize();
265+
let proving_time_ms = prove_start.elapsed().as_millis();
266+
267+
// Verify the proof
268+
let verify_start = Instant::now();
269+
let strategy = SingleStrategy::new(&params);
270+
let transcript_read = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
271+
verify_kzg(&params, pk.get_vk(), strategy, &public_vals, transcript_read);
272+
let verify_time_ms = verify_start.elapsed().as_millis();
273+
274+
println!(
275+
"Chunk [{}, {}): proof generated in {}ms, verified in {}ms, {} public vals{}",
276+
chunk_start,
277+
chunk_end,
278+
proving_time_ms,
279+
verify_time_ms,
280+
public_vals.len(),
281+
if use_merkle { ", includes Merkle root" } else { "" }
282+
);
283+
284+
ChunkProofResult {
285+
proof,
286+
public_vals,
287+
merkle_root,
288+
proving_time_ms,
289+
verify_time_ms,
290+
}
291+
}
292+
171293
// Standalone verification
172294
pub fn verify_circuit_kzg(
173295
circuit: ModelCircuit<Fr>,

zkml/testing/chunk_proof_test.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//! Tests for chunk proof generation
2+
//!
3+
//! These tests verify that:
4+
//! 1. Real KZG proofs can be generated for model chunks
5+
//! 2. Proofs verify correctly
6+
//!
7+
//! Note: These are slow tests (~50s each) as they generate real cryptographic proofs.
8+
//! Run with: cargo test --test chunk_proof_test --release -- --nocapture
9+
10+
#[cfg(test)]
11+
mod tests {
12+
use zkml::utils::proving_kzg::prove_chunk_kzg;
13+
use std::fs;
14+
15+
/// Test: Generate and verify a real KZG proof for a chunk
16+
///
17+
/// This is the main test - it proves that we can generate real ZK proofs
18+
/// for a subset of model layers.
19+
#[test]
20+
fn test_chunk_proof_generation() {
21+
let config_file = "examples/mnist/model.msgpack";
22+
let input_file = "examples/mnist/inp.msgpack";
23+
24+
if !std::path::Path::new(config_file).exists() {
25+
eprintln!("Skipping test: example files not found");
26+
return;
27+
}
28+
29+
// Use unique params directory to avoid race conditions
30+
let params_dir = "./params_kzg_chunk_test";
31+
fs::create_dir_all(params_dir).ok();
32+
33+
// Generate proof for first 2 layers WITHOUT Merkle
34+
// This tests the core proof generation functionality
35+
let result = prove_chunk_kzg(
36+
config_file,
37+
input_file,
38+
0, // chunk_start
39+
2, // chunk_end
40+
false, // use_merkle = false (simpler case)
41+
params_dir,
42+
);
43+
44+
// Verify we got a valid proof
45+
assert!(!result.proof.is_empty(), "Proof should not be empty");
46+
assert!(!result.public_vals.is_empty(), "Public values should not be empty");
47+
assert!(result.merkle_root.is_none(), "Merkle root should be None when use_merkle=false");
48+
assert!(result.proving_time_ms > 0, "Should have non-zero proving time");
49+
assert!(result.verify_time_ms > 0, "Should have non-zero verify time");
50+
51+
println!("✓ Chunk proof generated and verified successfully");
52+
println!(" Proof size: {} bytes", result.proof.len());
53+
println!(" Public values: {}", result.public_vals.len());
54+
println!(" Proving time: {}ms", result.proving_time_ms);
55+
println!(" Verify time: {}ms", result.verify_time_ms);
56+
}
57+
}
58+

0 commit comments

Comments
 (0)