Skip to content

Commit e37a6b2

Browse files
author
masoud@anyscale.com
committed
feat: Ray-Rust integration for distributed proving
- Add prove_chunk CLI binary (zkml/src/bin/prove_chunk.rs) - Add Python wrapper (python/rust_prover.py) - Update simple_distributed.py with real prover support - Add clap and hex dependencies Refs #9, #12
1 parent 9d23a0d commit e37a6b2

File tree

5 files changed

+576
-50
lines changed

5 files changed

+576
-50
lines changed

python/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Python utilities for distributed-zkml
2+

python/rust_prover.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""
2+
Python wrapper for the Rust prove_chunk CLI binary.
3+
4+
This module provides a Python interface to the zkml proof generation system,
5+
allowing Ray workers to generate ZK proofs for model chunks.
6+
"""
7+
8+
import json
9+
import os
10+
import subprocess
11+
import tempfile
12+
from dataclasses import dataclass
13+
from pathlib import Path
14+
from typing import Optional
15+
16+
17+
@dataclass
18+
class ProofResult:
19+
"""Result from proof generation."""
20+
chunk_start: int
21+
chunk_end: int
22+
use_merkle: bool
23+
prev_merkle_root: Optional[str]
24+
merkle_root: Optional[str]
25+
proving_time_ms: int
26+
verify_time_ms: int
27+
proof_size_bytes: int
28+
public_vals_count: int
29+
proof_path: str
30+
public_vals_path: str
31+
output_dir: str
32+
33+
34+
def find_prove_chunk_binary() -> str:
35+
"""Find the prove_chunk binary, checking common locations."""
36+
# Check common locations relative to this file
37+
this_dir = Path(__file__).parent
38+
candidates = [
39+
this_dir.parent / "zkml" / "target" / "release" / "prove_chunk",
40+
this_dir.parent / "zkml" / "target" / "debug" / "prove_chunk",
41+
Path("zkml") / "target" / "release" / "prove_chunk",
42+
Path("zkml") / "target" / "debug" / "prove_chunk",
43+
]
44+
45+
for candidate in candidates:
46+
if candidate.exists():
47+
return str(candidate.resolve())
48+
49+
# Try PATH
50+
try:
51+
result = subprocess.run(
52+
["which", "prove_chunk"],
53+
capture_output=True,
54+
text=True,
55+
check=True
56+
)
57+
return result.stdout.strip()
58+
except subprocess.CalledProcessError:
59+
pass
60+
61+
raise FileNotFoundError(
62+
"Could not find prove_chunk binary. "
63+
"Build it with: cd zkml && cargo build --bin prove_chunk --release"
64+
)
65+
66+
67+
def prove_chunk(
68+
config_path: str,
69+
input_path: str,
70+
chunk_start: int,
71+
chunk_end: int,
72+
use_merkle: bool = False,
73+
prev_merkle_root: Optional[str] = None,
74+
params_dir: str = "./params_kzg",
75+
output_dir: Optional[str] = None,
76+
binary_path: Optional[str] = None,
77+
) -> ProofResult:
78+
"""
79+
Generate a ZK proof for a model chunk using the Rust prover.
80+
81+
Args:
82+
config_path: Path to model config/weights (msgpack)
83+
input_path: Path to input data (msgpack)
84+
chunk_start: Start layer index (inclusive)
85+
chunk_end: End layer index (exclusive)
86+
use_merkle: Enable Merkle tree for intermediate values
87+
prev_merkle_root: Previous chunk's Merkle root (hex string)
88+
params_dir: Directory for KZG params
89+
output_dir: Output directory for proof files (temp dir if None)
90+
binary_path: Path to prove_chunk binary (auto-detect if None)
91+
92+
Returns:
93+
ProofResult with proof metadata and file paths
94+
95+
Raises:
96+
FileNotFoundError: If binary or input files not found
97+
subprocess.CalledProcessError: If proof generation fails
98+
ValueError: If result.json is invalid
99+
"""
100+
# Find binary
101+
if binary_path is None:
102+
binary_path = find_prove_chunk_binary()
103+
104+
# Validate inputs exist
105+
if not os.path.exists(config_path):
106+
raise FileNotFoundError(f"Config file not found: {config_path}")
107+
if not os.path.exists(input_path):
108+
raise FileNotFoundError(f"Input file not found: {input_path}")
109+
110+
# Create output directory
111+
if output_dir is None:
112+
output_dir = tempfile.mkdtemp(prefix="zkml_proof_")
113+
else:
114+
os.makedirs(output_dir, exist_ok=True)
115+
116+
# Ensure params directory exists
117+
os.makedirs(params_dir, exist_ok=True)
118+
119+
# Build command
120+
cmd = [
121+
binary_path,
122+
"--config", config_path,
123+
"--input", input_path,
124+
"--start", str(chunk_start),
125+
"--end", str(chunk_end),
126+
"--params-dir", params_dir,
127+
"--output-dir", output_dir,
128+
]
129+
130+
if use_merkle:
131+
cmd.append("--use-merkle")
132+
133+
if prev_merkle_root is not None:
134+
cmd.extend(["--prev-root", prev_merkle_root])
135+
136+
# Run prover
137+
result = subprocess.run(
138+
cmd,
139+
capture_output=True,
140+
text=True,
141+
)
142+
143+
if result.returncode != 0:
144+
raise subprocess.CalledProcessError(
145+
result.returncode,
146+
cmd,
147+
output=result.stdout,
148+
stderr=result.stderr,
149+
)
150+
151+
# Parse result
152+
result_path = os.path.join(output_dir, "result.json")
153+
if not os.path.exists(result_path):
154+
raise ValueError(f"result.json not found in {output_dir}")
155+
156+
with open(result_path, "r") as f:
157+
data = json.load(f)
158+
159+
return ProofResult(
160+
chunk_start=data["chunk_start"],
161+
chunk_end=data["chunk_end"],
162+
use_merkle=data["use_merkle"],
163+
prev_merkle_root=data.get("prev_merkle_root"),
164+
merkle_root=data.get("merkle_root"),
165+
proving_time_ms=data["proving_time_ms"],
166+
verify_time_ms=data["verify_time_ms"],
167+
proof_size_bytes=data["proof_size_bytes"],
168+
public_vals_count=data["public_vals_count"],
169+
proof_path=os.path.join(output_dir, "proof.bin"),
170+
public_vals_path=os.path.join(output_dir, "public_vals.bin"),
171+
output_dir=output_dir,
172+
)
173+
174+
175+
if __name__ == "__main__":
176+
# Simple test
177+
import sys
178+
179+
if len(sys.argv) < 3:
180+
print("Usage: python rust_prover.py <config.msgpack> <input.msgpack>")
181+
print("\nExample:")
182+
print(" python rust_prover.py zkml/examples/mnist/model.msgpack zkml/examples/mnist/inp.msgpack")
183+
sys.exit(1)
184+
185+
config = sys.argv[1]
186+
inp = sys.argv[2]
187+
188+
print(f"Testing prove_chunk with config={config}, input={inp}")
189+
190+
try:
191+
result = prove_chunk(
192+
config_path=config,
193+
input_path=inp,
194+
chunk_start=0,
195+
chunk_end=2,
196+
use_merkle=False,
197+
)
198+
print(f"\nSuccess!")
199+
print(f" Chunk: [{result.chunk_start}, {result.chunk_end})")
200+
print(f" Proving time: {result.proving_time_ms}ms")
201+
print(f" Verify time: {result.verify_time_ms}ms")
202+
print(f" Proof size: {result.proof_size_bytes} bytes")
203+
print(f" Public values: {result.public_vals_count}")
204+
print(f" Output dir: {result.output_dir}")
205+
except Exception as e:
206+
print(f"Error: {e}")
207+
sys.exit(1)
208+

0 commit comments

Comments
 (0)