|
| 1 | +""" |
| 2 | +Python wrapper for the Rust prove_chunk CLI binary. |
| 3 | +
|
| 4 | +This module provides a Python interface to the zkml proof generation system, |
| 5 | +allowing Ray workers to generate ZK proofs for model chunks. |
| 6 | +""" |
| 7 | + |
| 8 | +import json |
| 9 | +import os |
| 10 | +import subprocess |
| 11 | +import tempfile |
| 12 | +from dataclasses import dataclass |
| 13 | +from pathlib import Path |
| 14 | +from typing import Optional |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class ProofResult: |
| 19 | + """Result from proof generation.""" |
| 20 | + chunk_start: int |
| 21 | + chunk_end: int |
| 22 | + use_merkle: bool |
| 23 | + prev_merkle_root: Optional[str] |
| 24 | + merkle_root: Optional[str] |
| 25 | + proving_time_ms: int |
| 26 | + verify_time_ms: int |
| 27 | + proof_size_bytes: int |
| 28 | + public_vals_count: int |
| 29 | + proof_path: str |
| 30 | + public_vals_path: str |
| 31 | + output_dir: str |
| 32 | + |
| 33 | + |
| 34 | +def find_prove_chunk_binary() -> str: |
| 35 | + """Find the prove_chunk binary, checking common locations.""" |
| 36 | + # Check common locations relative to this file |
| 37 | + this_dir = Path(__file__).parent |
| 38 | + candidates = [ |
| 39 | + this_dir.parent / "zkml" / "target" / "release" / "prove_chunk", |
| 40 | + this_dir.parent / "zkml" / "target" / "debug" / "prove_chunk", |
| 41 | + Path("zkml") / "target" / "release" / "prove_chunk", |
| 42 | + Path("zkml") / "target" / "debug" / "prove_chunk", |
| 43 | + ] |
| 44 | + |
| 45 | + for candidate in candidates: |
| 46 | + if candidate.exists(): |
| 47 | + return str(candidate.resolve()) |
| 48 | + |
| 49 | + # Try PATH |
| 50 | + try: |
| 51 | + result = subprocess.run( |
| 52 | + ["which", "prove_chunk"], |
| 53 | + capture_output=True, |
| 54 | + text=True, |
| 55 | + check=True |
| 56 | + ) |
| 57 | + return result.stdout.strip() |
| 58 | + except subprocess.CalledProcessError: |
| 59 | + pass |
| 60 | + |
| 61 | + raise FileNotFoundError( |
| 62 | + "Could not find prove_chunk binary. " |
| 63 | + "Build it with: cd zkml && cargo build --bin prove_chunk --release" |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +def prove_chunk( |
| 68 | + config_path: str, |
| 69 | + input_path: str, |
| 70 | + chunk_start: int, |
| 71 | + chunk_end: int, |
| 72 | + use_merkle: bool = False, |
| 73 | + prev_merkle_root: Optional[str] = None, |
| 74 | + params_dir: str = "./params_kzg", |
| 75 | + output_dir: Optional[str] = None, |
| 76 | + binary_path: Optional[str] = None, |
| 77 | +) -> ProofResult: |
| 78 | + """ |
| 79 | + Generate a ZK proof for a model chunk using the Rust prover. |
| 80 | + |
| 81 | + Args: |
| 82 | + config_path: Path to model config/weights (msgpack) |
| 83 | + input_path: Path to input data (msgpack) |
| 84 | + chunk_start: Start layer index (inclusive) |
| 85 | + chunk_end: End layer index (exclusive) |
| 86 | + use_merkle: Enable Merkle tree for intermediate values |
| 87 | + prev_merkle_root: Previous chunk's Merkle root (hex string) |
| 88 | + params_dir: Directory for KZG params |
| 89 | + output_dir: Output directory for proof files (temp dir if None) |
| 90 | + binary_path: Path to prove_chunk binary (auto-detect if None) |
| 91 | + |
| 92 | + Returns: |
| 93 | + ProofResult with proof metadata and file paths |
| 94 | + |
| 95 | + Raises: |
| 96 | + FileNotFoundError: If binary or input files not found |
| 97 | + subprocess.CalledProcessError: If proof generation fails |
| 98 | + ValueError: If result.json is invalid |
| 99 | + """ |
| 100 | + # Find binary |
| 101 | + if binary_path is None: |
| 102 | + binary_path = find_prove_chunk_binary() |
| 103 | + |
| 104 | + # Validate inputs exist |
| 105 | + if not os.path.exists(config_path): |
| 106 | + raise FileNotFoundError(f"Config file not found: {config_path}") |
| 107 | + if not os.path.exists(input_path): |
| 108 | + raise FileNotFoundError(f"Input file not found: {input_path}") |
| 109 | + |
| 110 | + # Create output directory |
| 111 | + if output_dir is None: |
| 112 | + output_dir = tempfile.mkdtemp(prefix="zkml_proof_") |
| 113 | + else: |
| 114 | + os.makedirs(output_dir, exist_ok=True) |
| 115 | + |
| 116 | + # Ensure params directory exists |
| 117 | + os.makedirs(params_dir, exist_ok=True) |
| 118 | + |
| 119 | + # Build command |
| 120 | + cmd = [ |
| 121 | + binary_path, |
| 122 | + "--config", config_path, |
| 123 | + "--input", input_path, |
| 124 | + "--start", str(chunk_start), |
| 125 | + "--end", str(chunk_end), |
| 126 | + "--params-dir", params_dir, |
| 127 | + "--output-dir", output_dir, |
| 128 | + ] |
| 129 | + |
| 130 | + if use_merkle: |
| 131 | + cmd.append("--use-merkle") |
| 132 | + |
| 133 | + if prev_merkle_root is not None: |
| 134 | + cmd.extend(["--prev-root", prev_merkle_root]) |
| 135 | + |
| 136 | + # Run prover |
| 137 | + result = subprocess.run( |
| 138 | + cmd, |
| 139 | + capture_output=True, |
| 140 | + text=True, |
| 141 | + ) |
| 142 | + |
| 143 | + if result.returncode != 0: |
| 144 | + raise subprocess.CalledProcessError( |
| 145 | + result.returncode, |
| 146 | + cmd, |
| 147 | + output=result.stdout, |
| 148 | + stderr=result.stderr, |
| 149 | + ) |
| 150 | + |
| 151 | + # Parse result |
| 152 | + result_path = os.path.join(output_dir, "result.json") |
| 153 | + if not os.path.exists(result_path): |
| 154 | + raise ValueError(f"result.json not found in {output_dir}") |
| 155 | + |
| 156 | + with open(result_path, "r") as f: |
| 157 | + data = json.load(f) |
| 158 | + |
| 159 | + return ProofResult( |
| 160 | + chunk_start=data["chunk_start"], |
| 161 | + chunk_end=data["chunk_end"], |
| 162 | + use_merkle=data["use_merkle"], |
| 163 | + prev_merkle_root=data.get("prev_merkle_root"), |
| 164 | + merkle_root=data.get("merkle_root"), |
| 165 | + proving_time_ms=data["proving_time_ms"], |
| 166 | + verify_time_ms=data["verify_time_ms"], |
| 167 | + proof_size_bytes=data["proof_size_bytes"], |
| 168 | + public_vals_count=data["public_vals_count"], |
| 169 | + proof_path=os.path.join(output_dir, "proof.bin"), |
| 170 | + public_vals_path=os.path.join(output_dir, "public_vals.bin"), |
| 171 | + output_dir=output_dir, |
| 172 | + ) |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + # Simple test |
| 177 | + import sys |
| 178 | + |
| 179 | + if len(sys.argv) < 3: |
| 180 | + print("Usage: python rust_prover.py <config.msgpack> <input.msgpack>") |
| 181 | + print("\nExample:") |
| 182 | + print(" python rust_prover.py zkml/examples/mnist/model.msgpack zkml/examples/mnist/inp.msgpack") |
| 183 | + sys.exit(1) |
| 184 | + |
| 185 | + config = sys.argv[1] |
| 186 | + inp = sys.argv[2] |
| 187 | + |
| 188 | + print(f"Testing prove_chunk with config={config}, input={inp}") |
| 189 | + |
| 190 | + try: |
| 191 | + result = prove_chunk( |
| 192 | + config_path=config, |
| 193 | + input_path=inp, |
| 194 | + chunk_start=0, |
| 195 | + chunk_end=2, |
| 196 | + use_merkle=False, |
| 197 | + ) |
| 198 | + print(f"\nSuccess!") |
| 199 | + print(f" Chunk: [{result.chunk_start}, {result.chunk_end})") |
| 200 | + print(f" Proving time: {result.proving_time_ms}ms") |
| 201 | + print(f" Verify time: {result.verify_time_ms}ms") |
| 202 | + print(f" Proof size: {result.proof_size_bytes} bytes") |
| 203 | + print(f" Public values: {result.public_vals_count}") |
| 204 | + print(f" Output dir: {result.output_dir}") |
| 205 | + except Exception as e: |
| 206 | + print(f"Error: {e}") |
| 207 | + sys.exit(1) |
| 208 | + |
0 commit comments