diff --git a/constraint/gkr.go b/constraint/gkr.go new file mode 100644 index 0000000000..f295f82de4 --- /dev/null +++ b/constraint/gkr.go @@ -0,0 +1,80 @@ +package constraint + +type ( + // GkrClaimSource identifies an incoming evaluation claim for a wire. + // Level is the level that produced the claim. + // OutgoingClaimIndex selects which of that level's outgoing evaluation points is referenced; + // always 0 for SumcheckLevels, 0..M-1 for SkipLevels with M inherited evaluation points. + // The initial verifier challenge is represented as {Level: len(schedule), OutgoingClaimIndex: 0}. + GkrClaimSource struct { + Level int `json:"level"` + OutgoingClaimIndex int `json:"outgoingClaimIndex"` + } + + // GkrClaimGroup represents a set of wires sharing identical claim sources. + // finalEvalProof index = pos(wire, srcLevel) * NbOutgoingEvalPoints(srcLevel) + ClaimSources[claimI].OutgoingClaimIndex, + // where pos(wire, srcLevel) is the wire's position in srcLevel's UniqueGateInputs list. + GkrClaimGroup struct { + Wires []int `json:"wires"` + ClaimSources []GkrClaimSource `json:"claimSources"` + } + + // GkrProvingLevel is a single level in the proving schedule. + GkrProvingLevel interface { + NbOutgoingEvalPoints() int + // NbClaims returns the total number of claims at this level. + NbClaims() int + ClaimGroups() []GkrClaimGroup + // FinalEvalProofIndex returns where to find the evaluationPointI'th evaluation claim for the wireI'th input wire to the layer, + // in the layer's final evaluation proof. + FinalEvalProofIndex(wireI, evaluationPointI int) int + } + + // GkrSkipLevel represents a level where zerocheck is skipped. + // Claims propagate through at their existing evaluation points. + GkrSkipLevel GkrClaimGroup + + // GkrSumcheckLevel represents a level where one or more zerochecks are batched + // together in a single sumcheck. Each GkrClaimGroup within may have different + // claim sources (sumcheck-level batching), or the same source (enabling + // zerocheck-level batching with shared eq tables). + GkrSumcheckLevel []GkrClaimGroup + + // GkrProvingSchedule is a sequence of levels defining how to prove a GKR circuit. + GkrProvingSchedule []GkrProvingLevel +) + +func (g GkrClaimGroup) NbClaims() int { return len(g.Wires) * len(g.ClaimSources) } + +func (l GkrSumcheckLevel) NbOutgoingEvalPoints() int { return 1 } +func (l GkrSumcheckLevel) NbClaims() int { + n := 0 + for _, g := range l { + n += len(g.Wires) * len(g.ClaimSources) + } + return n +} +func (l GkrSumcheckLevel) ClaimGroups() []GkrClaimGroup { return l } +func (l GkrSumcheckLevel) FinalEvalProofIndex(wireI, _ int) int { return wireI } + +func (l GkrSkipLevel) NbOutgoingEvalPoints() int { return len(l.ClaimSources) } +func (l GkrSkipLevel) NbClaims() int { + return GkrClaimGroup(l).NbClaims() +} +func (l GkrSkipLevel) ClaimGroups() []GkrClaimGroup { return []GkrClaimGroup{GkrClaimGroup(l)} } +func (l GkrSkipLevel) FinalEvalProofIndex(wireI, evaluationPointI int) int { + return wireI*l.NbOutgoingEvalPoints() + evaluationPointI +} + +// BindGkrFinalEvalProof binds the non-input-wire entries of finalEvalProof into the transcript. +// Input-wire evaluations are fully determined by the public assignment (and by evaluation points +// already committed to the transcript), so hashing them contributes nothing to Fiat-Shamir security. +// uniqueGateInputs is the deduplicated list of gate-input wire indices for the level in the same +// order as the finalEvalProof entries (i.e. the order returned by UniqueGateInputs). +func BindGkrFinalEvalProof[F any](transcript interface{ Bind(...F) }, finalEvalProof []F, uniqueGateInputs []int, isInput func(wireI int) bool, level GkrProvingLevel) { + for i, inputWireI := range uniqueGateInputs { + if !isInput(inputWireI) { + transcript.Bind(finalEvalProof[level.FinalEvalProofIndex(i, 0):level.FinalEvalProofIndex(i+1, 0)]...) + } + } +} diff --git a/constraint/marshal.go b/constraint/marshal.go index d642f62775..bce070b5bb 100644 --- a/constraint/marshal.go +++ b/constraint/marshal.go @@ -398,6 +398,9 @@ func getTagSet() cbor.TagSet { addType(reflect.TypeOf(BlueprintBatchInverse[U32]{})) addType(reflect.TypeOf(BlueprintBatchInverse[U64]{})) + addType(reflect.TypeOf(GkrSkipLevel{})) + addType(reflect.TypeOf(GkrSumcheckLevel{})) + // Add types registered by external packages (e.g., GKR blueprints) // These use explicit tag numbers to ensure stability regardless of init() order for _, rt := range registeredBlueprintTypes { diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 71f7de9122..6178eeb6ab 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -10,7 +10,7 @@ import ( "runtime" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/std/gkrapi/gkr" ) @@ -119,7 +119,7 @@ func Register(f gkr.GateFunction, nbIn int, options ...RegisterOption) error { } for _, curve := range s.curves { - compiled, err := gkrtypes.CompileGateFunction(f, nbIn, curve.ScalarField()) + compiled, err := gkrcore.CompileGateFunction(f, nbIn, curve.ScalarField()) if err != nil { return err } diff --git a/internal/generator/backend/template/gkr/blueprint.go.tmpl b/internal/generator/backend/template/gkr/blueprint.go.tmpl index 3b148b4ee2..7c5492d172 100644 --- a/internal/generator/backend/template/gkr/blueprint.go.tmpl +++ b/internal/generator/backend/template/gkr/blueprint.go.tmpl @@ -8,10 +8,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "{{ .FieldPackagePath }}" "{{ .FieldPackagePath }}/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) func init() { @@ -27,7 +26,7 @@ type circuitEvaluator struct { // BlueprintSolve is a {{.FieldID}}-specific blueprint for solving GKR circuit instances. type BlueprintSolve struct { // Circuit structure (serialized) - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit NbInstances uint32 // Not serialized - recreated lazily at solve time @@ -41,10 +40,13 @@ type BlueprintSolve struct { // Ensures BlueprintSolve implements BlueprintStateful var _ constraint.BlueprintStateful[constraint.U64] = (*BlueprintSolve)(nil) + // Equal returns true if the serialized fields of two BlueprintSolve are equal. // Used for testing serialization round-trips. func (b *BlueprintSolve) Equal(other constraint.BlueprintComparable) bool { - if other == nil { return false } + if other == nil { + return false + } o, ok := other.(*BlueprintSolve) if !ok { return false @@ -107,7 +109,7 @@ func (b *BlueprintSolve) Solve(s constraint.Solver[constraint.U64], inst constra if w.IsInput() { val, delta := s.Read(calldata) calldata = calldata[delta:] - // Copy directly from constraint.U64 to fr.Element (both in Montgomery form) + // Copy directly from constraint.U64 to {{ .ElementType }} (both in Montgomery form) copy(b.assignments[wI][instanceI][:], val[:]) } else { // Get evaluator for this wire from the circuit evaluator @@ -123,7 +125,7 @@ func (b *BlueprintSolve) Solve(s constraint.Solver[constraint.U64], inst constra } } - // Set output wires (copy fr.Element to U64 in Montgomery form) + // Set output wires (copy {{ .ElementType }} to U64 in Montgomery form) for outI, outWI := range b.outputWires { var val constraint.U64 copy(val[:], b.assignments[outWI][instanceI][:]) @@ -150,9 +152,9 @@ func (b *BlueprintSolve) NbConstraints() int { // NbOutputs implements Blueprint func (b *BlueprintSolve) NbOutputs(inst constraint.Instruction) int { - if b.outputWires == nil { - b.outputWires = b.Circuit.Outputs() - } + if b.outputWires == nil { + b.outputWires = b.Circuit.Outputs() + } return len(b.outputWires) } @@ -194,6 +196,7 @@ func (b *BlueprintSolve) UpdateInstructionTree(inst constraint.Instruction, tree type BlueprintProve struct { SolveBlueprintID constraint.BlueprintID SolveBlueprint *BlueprintSolve `cbor:"-"` // not serialized, set at compile time + Schedule constraint.GkrProvingSchedule HashName string lock sync.Mutex @@ -201,9 +204,12 @@ type BlueprintProve struct { // Ensures BlueprintProve implements BlueprintSolvable var _ constraint.BlueprintSolvable[constraint.U64] = (*BlueprintProve)(nil) + // Equal returns true if the serialized fields of two BlueprintProve are equal. func (b *BlueprintProve) Equal(other constraint.BlueprintComparable) bool { - if other == nil { return false } + if other == nil { + return false + } o, ok := other.(*BlueprintProve) if !ok { return false @@ -243,28 +249,27 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra } } + // Create hasher and write base challenges + hsh := hash.NewHash(b.HashName + "_{{.FieldID}}") + // Read initial challenges from instruction calldata (parse dynamically, no metadata) // Format: [0]=totalSize, [1...]=challenge linear expressions - insBytes := make([][]byte, 0) // first challenges calldata := inst.Calldata[1:] // skip size prefix for len(calldata) != 0 { val, delta := s.Read(calldata) calldata = calldata[delta:] - // Copy directly from constraint.U64 to fr.Element (both in Montgomery form) + // Copy directly from constraint.U64 to {{ .ElementType }} (both in Montgomery form) var challenge {{ .ElementType }} copy(challenge[:], val[:]) - insBytes = append(insBytes, challenge.Marshal()) + challengeBytes := challenge.Bytes() + hsh.Write(challengeBytes[:]) } - // Create Fiat-Shamir settings - hsh := hash.NewHash(b.HashName + "_{{.FieldID}}") - fsSettings := fiatshamir.WithHash(hsh, insBytes...) - // Call the {{.FieldID}}-specific Prove function (assignments already WireAssignment type) - proof, err := Prove(solveBlueprint.Circuit, assignments, fsSettings) + proof, err := Prove(solveBlueprint.Circuit, b.Schedule, assignments, hsh) if err != nil { - return fmt.Errorf("{{toLower .FieldID}} prove failed: %w", err) + return fmt.Errorf("{{.FieldID}} prove failed: %w", err) } for i, elem := range proof.flatten() { @@ -292,7 +297,7 @@ func (b *BlueprintProve) proofSize() int { } nbPaddedInstances := ecc.NextPowerOfTwo(uint64(b.SolveBlueprint.NbInstances)) logNbInstances := bits.TrailingZeros64(nbPaddedInstances) - return b.SolveBlueprint.Circuit.ProofSize(logNbInstances) + return b.SolveBlueprint.Circuit.ProofSize(b.Schedule, logNbInstances) } // NbOutputs implements Blueprint @@ -344,9 +349,12 @@ type BlueprintGetAssignment struct { // Ensures BlueprintGetAssignment implements BlueprintSolvable var _ constraint.BlueprintSolvable[constraint.U64] = (*BlueprintGetAssignment)(nil) + // Equal returns true if the serialized fields of two BlueprintGetAssignment are equal. func (b *BlueprintGetAssignment) Equal(other constraint.BlueprintComparable) bool { - if other == nil { return false } + if other == nil { + return false + } o, ok := other.(*BlueprintGetAssignment) if !ok { return false @@ -418,7 +426,7 @@ func (b *BlueprintGetAssignment) UpdateInstructionTree(inst constraint.Instructi } // NewBlueprints creates and registers all GKR blueprints for {{.FieldID}} -func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compiler constraint.CustomizableSystem) gkrtypes.Blueprints { +func NewBlueprints(circuit gkrcore.SerializableCircuit, schedule constraint.GkrProvingSchedule, hashName string, compiler constraint.CustomizableSystem) gkrcore.Blueprints { // Create and register solve blueprint solve := &BlueprintSolve{Circuit: circuit} solveID := compiler.AddBlueprint(solve) @@ -427,6 +435,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil prove := &BlueprintProve{ SolveBlueprintID: solveID, SolveBlueprint: solve, + Schedule: schedule, HashName: hashName, } proveID := compiler.AddBlueprint(prove) @@ -437,7 +446,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil } getAssignmentID := compiler.AddBlueprint(getAssignment) - return gkrtypes.Blueprints{ + return gkrcore.Blueprints{ SolveID: solveID, Solve: solve, ProveID: proveID, diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index bcf85cecdf..9dfd6ce8cb 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -1,657 +1,557 @@ import ( "errors" "fmt" + "hash" "iter" + "sync" + "{{.FieldPackagePath}}" "{{.FieldPackagePath}}/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" - "strconv" - "sync" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) // Type aliases for bytecode-based GKR types type ( - Wire = gkrtypes.SerializableWire - Circuit = gkrtypes.SerializableCircuit + Wire = gkrcore.SerializableWire + Circuit = gkrcore.SerializableCircuit ) // The goal is to prove/verify evaluations of many instances of the same circuit - -// WireAssignment is assignment of values to the same wire across many instances of the circuit +// WireAssignment is the assignment of values to the same wire across many instances of the circuit type WireAssignment []polynomial.MultiLin type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all wᵢ and evaluation point xᵢ in the level. +// Its purpose is to batch the checking of multiple wire evaluations at evaluation points. type zeroCheckLazyClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]{{ .ElementType }} // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []{{ .ElementType }} // yᵢ = w(xᵢ), allegedly - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff {{ .ElementType }} // the coefficient used to fold claims, conventionally 0 if there is only one claim + resources *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -// foldedSum returns ∑ᵢ aⁱ yᵢ -func (e *zeroCheckLazyClaims) foldedSum(a {{ .ElementType }}) {{ .ElementType }} { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) + return e.resources.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return e.manager.circuit[e.wireI].ZeroCheckDegree() -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. (c is foldingCoeff) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, foldingCoeff, purportedValue {{ .ElementType }}, uniqueInputEvaluations []{{ .ElementType }}) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &foldingCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - wire := e.manager.circuit[e.wireI] - - // the w(...) term - var gateEvaluation {{ .ElementType }} - if wire.IsInput() { // just compute w(r) - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool) - } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } + return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) +} + +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is checked equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +// The verifier checks consistency by evaluating gateᵥ(inputEvals...) and confirming +// that the full sum matches purportedValue. +func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, uniqueInputEvaluations []{{ .ElementType }}) error { + e.resources.outgoingEvalPoints[e.levelI] = [][]{{ .ElementType }}{r} + level := e.resources.schedule[e.levelI] + gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) + + var claimedEvals polynomial.Polynomial + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.resources.circuit[wI] + + var gateEval {{ .ElementType }} + if wire.IsInput() { + gateEval = e.resources.assignment[wI].Evaluate(r, &e.resources.memPool) + if !gateInputEvals[levelWireI][0].Equal(&gateEval) { + return errors.New("incompatible evaluations") + } + } else { + evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) + for _, v := range gateInputEvals[levelWireI] { + evaluator.pushInput(v) + } + gateEval.Set(evaluator.evaluate()) + } - evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) - for _, uniqueI := range injectionLeftInv { // map from all to unique - evaluator.pushInput(uniqueInputEvaluations[uniqueI]) + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term {{ .ElementType }} + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } + levelWireI++ } - - gateEvaluation.Set(evaluator.evaluate()) } - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil + if total := claimedEvals.Eval(&e.foldingCoeff); !total.Equal(&purportedValue) { + return errors.New("incompatible evaluations") } - return errors.New("incompatible evaluations") + return nil } // zeroCheckClaims is a claim for sumcheck (prover side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the proving of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. +// Each wire has its own eq table with the batching coefficients baked in. type zeroCheckClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]{{ .ElementType }} // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []{{ .ElementType }} // yᵢ = w(xᵢ) - manager *claimsManager - - input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) - - eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) - - gateEvaluatorPool *gateEvaluatorPool + levelI int + resources *resources + input []polynomial.MultiLin // UniqueGateInputs order + inputIndices [][]int // [wireInLevel][gateInputJ] → index in input + eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points + gateEvaluatorPools []*gateEvaluatorPool } -func (c *zeroCheckClaims) getWire() Wire { - return c.manager.circuit[c.wireI] -} - -// fold the multiple claims into one claim using a random combination (foldingCoeff or c). -// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim -// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and -// i iterates over the claims. -// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). -// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form -// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, -// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. -// The output of fold is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. -func (c *zeroCheckClaims) fold(foldingCoeff {{ .ElementType }}) polynomial.Polynomial { -varsNum := c.varsNum() - eqLength := 1 << varsNum -claimsNum := c.claimsNum() -// initialize the eq tables ( E ) -c.eq = c.manager.memPool.Make(eqLength) - -c.eq[0].SetOne() -c.eq.Eq(c.evaluationPoints[0]) - -// E := eq(x₀, -) -newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) -aI := foldingCoeff - -// E += cⁱ eq(xᵢ, -) -for k := 1; k < claimsNum; k++ { -newEq[0].Set(&aI) - -c.eqAcc(c.eq, newEq,c.evaluationPoints[k]) - -if k+1 < claimsNum { -aI.Mul(&aI, &foldingCoeff) -} -} - -c.manager.memPool.Dump(newEq) - -return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e. -// m <- eq(q, -). -// e <- e + m -func (c *zeroCheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{ .ElementType }}) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() -} - - -// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). -// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). -// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) computeGJ() polynomial.Polynomial { - - wire := c.getWire() - degGJ := wire.ZeroCheckDegree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.input) - - // Both E and wᵢ (the input wires and the eq table) are multilinear, thus - // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. - // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. - ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube - ml[0] = c.eq - copy(ml[1:], c.input) - - sumSize := len(c.eq) / 2; // the range of h, over which we sum - - // Perf-TODO: Collate once at claim "folding" time and not again. then, even folding can be done in one operation every time "next" is called - - gJ := make([]{{ .ElementType }}, degGJ) +func (c *zeroCheckClaims) varsNum() int { + return c.resources.nbVars +} + +// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). +// By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + + // Both eqs and input are multilinear, thus linear in Xⱼ. + // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly + // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree. + // Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}] + ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs) + copy(ml, c.eqs) + copy(ml[nbWires:], c.input) + + sumSize := len(c.eqs[0]) / 2 + + p := make([]{{ .ElementType }}, degree) var mu sync.Mutex - computeAll := func(start, end int) { // compute method to allow parallelization across instances + computeAll := func(start, end int) { var step {{ .ElementType }} - evaluator := c.gateEvaluatorPool.get() - defer c.gateEvaluatorPool.put(evaluator) + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() - res := make([]{{ .ElementType }}, degGJ) + res := make([]{{ .ElementType }}, degree) // evaluations of ml, laid out as: // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), // ... - // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) - mlEvals := make([]{{ .ElementType }}, degGJ*len(ml)) - - for h := start; h < end; h++ { // h counts across instances + // ml[0](degree, h...), ml[1](degree, h...), ..., ml[len(ml)-1](degree, h...) + mlEvals := make([]{{ .ElementType }}, degree*len(ml)) + for h := start; h < end; h++ { evalAt1Index := sumSize + h for k := range ml { - // d = 0 - mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. - step.Sub(&mlEvals[k], &ml[k][h])// step = ml[k](1) - ml[k](0) - for d := 1; d < degGJ; d++ { + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1, taken directly from the table + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degree; d++ { mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - eIndex := 0 // index for where the current eq term is + eIndex := 0 // start of the current row's eq evaluations nextEIndex := len(ml) - for d := range degGJ { - // Push gate inputs - for i := range nbGateIn { - evaluator.pushInput(mlEvals[eIndex+1+i]) + for d := range degree { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &mlEvals[eIndex+w]) + res[d].Add(&res[d], summand) // collect contributions into the sum from start to end } - summand := evaluator.evaluate() - summand.Mul(summand, &mlEvals[eIndex]) - res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := range gJ { - gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + for i := range p { + p[i].Add(&p[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if sumSize < minBlockSize { - // no parallelization computeAll(0, sumSize) } else { - c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } - return gJ + return p } -// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. -// Thus, j <- j+1 and rⱼ = challenge. -func (c *zeroCheckClaims) next(challenge {{ .ElementType }}) polynomial.Polynomial { +// roundFold folds all input and eq polynomials at the verifier challenge r. +// After this call, j ← j+1 and rⱼ = r. +func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) { const minBlockSize = 512 - n := len(c.eq) / 2 + n := len(c.eqs[0]) / 2 if n < minBlockSize { - // no parallelization for i := range c.input { - c.input[i].Fold(challenge) + c.input[i].Fold(r) + } + for i := range c.eqs { + c.eqs[i].Fold(r) } - c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.input)) + wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { - wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + wgs[i] = c.resources.workers.Submit(n, c.input[i].FoldParallel(r), minBlockSize) + } + for i := range c.eqs { + wgs[len(c.input)+i] = c.resources.workers.Submit(n, c.eqs[i].FoldParallel(r), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } } - - return c.computeGJ() } -func (c *zeroCheckClaims) varsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *zeroCheckClaims) claimsNum() int { - return len(c.claimedEvaluations) -} - -// proveFinalEval provides the values wᵢ(r₁, ..., rₙ) +// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). func (c *zeroCheckClaims) proveFinalEval(r []{{ .ElementType }}) []{{ .ElementType }} { - //defer the proof, return list of claims - - injection, _ := c.manager.circuit.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers. - evaluations := make([]{{ .ElementType }}, len(injection)) - for i, gateInputI := range injection { - wI := c.input[gateInputI] - wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. - c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0]) - evaluations[i] = wI[0] + c.resources.outgoingEvalPoints[c.levelI] = [][]{{ .ElementType }}{r} + evaluations := make([]{{ .ElementType }}, len(c.input)) + for i := range c.input { + c.input[i].Fold(r[len(r)-1]) + evaluations[i] = c.input[i][0] + } + for i := range c.input { + c.resources.memPool.Dump(c.input[i]) + } + for i := range c.eqs { + c.resources.memPool.Dump(c.eqs[i]) + } + for _, pool := range c.gateEvaluatorPools { + pool.dumpAll() } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - c.gateEvaluatorPool.dumpAll() - return evaluations } -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool - circuit Circuit -} +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- m[0] · eq(q, -). +// e <- e + m +func (r *resources) eqAcc(e, m polynomial.MultiLin, q []{{ .ElementType }}) { + n := len(q) -func newClaimsManager(circuit Circuit, assignment WireAssignment, o settings) (manager claimsManager) { - manager.assignment = assignment - manager.claims = make([]*zeroCheckLazyClaims, len(circuit)) - manager.memPool = o.pool - manager.workers = o.workers - manager.circuit = circuit + // At the end of each iteration, m(h₁, ..., hₙ) = m[0] · eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // 1-based in comments: q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - for i := range circuit { - manager.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]{{ .ElementType }}, 0, circuit[i].NbClaims()), - claimedEvaluations: manager.memPool.Make(circuit[i].NbClaims()), - manager: &manager, + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + } else { + r.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + }, 1024).Wait() } } - return -} - -func (m *claimsManager) add(wire int, evaluationPoint []{{ .ElementType }}, evaluation {{ .ElementType }}) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) + r.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() } -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] +type resources struct { + // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. + // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). + // SumcheckLevels produce one point (k=0). SkipLevels pass on all their evaluation points. + outgoingEvalPoints [][][]{{ .ElementType }} + nbVars int + assignment WireAssignment + memPool polynomial.Pool + workers *utils.WorkerPool + circuit Circuit + schedule constraint.GkrProvingSchedule + transcript transcript + uniqueInputIndices [][]int // uniqueInputIndices[wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from } -func (m *claimsManager) getClaim(wireI int) *zeroCheckClaims { - lazy := m.claims[wireI] - wire := m.circuit[wireI] - res := &zeroCheckClaims{ - wireI: wireI, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])} - } else { - res.input = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied +func newResources(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (resources, error) { + nbVars := assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1<= 2 { + foldingCoeff = r.transcript.getChallenge() } - if o.pool == nil { - pool := polynomial.NewPool(c.MemoryRequirements(nbInstances)...) - o.pool = &pool + uniqueInputs, inputIndices := r.circuit.InputMapping(level) + input := make([]polynomial.MultiLin, len(uniqueInputs)) + for i, inW := range uniqueInputs { + input[i] = r.memPool.Clone(r.assignment[inW]) } - if o.workers == nil { - o.workers = utils.NewWorkerPool() + nbWires := 0 + for _, group := range level.ClaimGroups() { + nbWires += len(group.Wires) } - if transcriptSettings.Transcript == nil { - challengeNames := ChallengeNames(c, o.nbVars, transcriptSettings.Prefix) - o.transcript = fiatshamir.NewTranscript(transcriptSettings.Hash, challengeNames...) - for i := range transcriptSettings.BaseChallenges { - if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges[i]); err != nil { - return o, err + pools := make([]*gateEvaluatorPool, nbWires) + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := r.circuit[wI] + gate := wire.Gate.Evaluate + if wire.IsInput() { + gate = gkrcore.IdentityBytecode() } + pools[levelWireI] = newGateEvaluatorPool(gate, len(inputIndices[levelWireI]), &r.memPool) + levelWireI++ } - } else { - o.transcript, o.transcriptPrefix = transcriptSettings.Transcript, transcriptSettings.Prefix } - return o, err -} + eqLength := 1 << r.nbVars + eqs := make([]polynomial.MultiLin, nbWires) + var alpha {{ .ElementType }} + alpha.SetOne() + levelWireI = 0 + for _, group := range level.ClaimGroups() { + nbSources := len(group.ClaimSources) + + groupEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + groupEq[0].Set(&alpha) + groupEq.Eq(r.outgoingEvalPoints[group.ClaimSources[0].Level][group.ClaimSources[0].OutgoingClaimIndex]) + + if nbSources > 1 { + newEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + aI := alpha + for k := 1; k < nbSources; k++ { + aI.Mul(&aI, &foldingCoeff) + newEq[0].Set(&aI) + r.eqAcc(groupEq, newEq, r.outgoingEvalPoints[group.ClaimSources[k].Level][group.ClaimSources[k].OutgoingClaimIndex]) + } + r.memPool.Dump(newEq) + } -func ChallengeNames(c Circuit, logNbInstances int, prefix string) []string { + var stride {{ .ElementType }} + stride.Set(&foldingCoeff) + for range nbSources - 1 { + stride.Mul(&stride, &foldingCoeff) + } - // Pre-compute the size TODO: Consider not doing this and just grow the list by appending - size := logNbInstances // first challenge + eqs[levelWireI] = groupEq + levelWireI++ + alpha.Mul(&alpha, &stride) - for i := range c { - if c[i].NoProof() { // no proof, no challenge - continue - } - if c[i].NbClaims() > 1 { //fold the claims - size++ + for w := 1; w < len(group.Wires); w++ { + eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength)) + r.workers.Submit(eqLength, func(start, end int) { + for i := start; i < end; i++ { + eqs[levelWireI][i].Mul(&eqs[levelWireI-1][i], &stride) + } + }, 512).Wait() + levelWireI++ + alpha.Mul(&alpha, &stride) } - size += logNbInstances // full run of sumcheck on logNbInstances variables } - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) + claims := &zeroCheckClaims{ + levelI: levelI, + resources: r, + input: input, + inputIndices: inputIndices, + eqs: eqs, + gateEvaluatorPools: pools, } + return sumcheckProve(claims, &r.transcript) +} - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + var foldingCoeff {{ .ElementType }} + if nbClaims >= 2 { + foldingCoeff = r.transcript.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } + initialChallengeI := len(r.schedule) + claimedEvals := make(polynomial.Polynomial, 0, level.NbClaims()) - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + if src.Level == initialChallengeI { + claimedEvals = append(claimedEvals, r.assignment[wI].Evaluate(r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], &r.memPool)) + } else { + claimedEvals = append(claimedEvals, proof[src.Level].finalEvalProof[r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex)]) + } + } } } - return challenges -} -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} + claimedSum := claimedEvals.Eval(&foldingCoeff) -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{ .ElementType }}, error) { - res := make([]{{ .ElementType }}, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err != nil { - return nil, err - } else if err = res[i].SetBytesCanonical(bytes); err != nil { - return nil, err - } + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + resources: r, + levelI: levelI, } - return res, nil + return sumcheckVerify(lazyClaims, proof[levelI], claimedSum, r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), &r.transcript) } // Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) +func Prove(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (Proof, error) { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return nil, err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) + proof := make(Proof, len(schedule)) - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []{{ .ElementType }} - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err + // Derive the initial challenge point + firstChallenge := make([]{{ .ElementType }}, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]{{ .ElementType }}{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(i) - if wire.NoProof() { // input wires with one claim only - proof[i] = sumcheckProof{ - partialSumPolys: []polynomial.Polynomial{}, - finalEvalProof: []{{ .ElementType }}{}, - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + proof[levelI] = r.proveSkipLevel(levelI) } else { - if proof[i], err = sumcheckProve( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - baseChallenge = make([][]byte, len(proof[i].finalEvalProof)) - for j := range proof[i].finalEvalProof { - baseChallenge[j] = proof[i].finalEvalProof[j].Marshal() - } + proof[levelI] = r.proveSumcheckLevel(levelI) } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return proof, nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) +// Verify the consistency of the claimed output with the claimed input. +// Unlike in Prove, the assignment argument need not be complete. +func Verify(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, hasher hash.Hash) error { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return err } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) + defer r.workers.Stop() - var firstChallenge []{{ .ElementType }} - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err + // Derive the initial challenge point + firstChallenge := make([]{{ .ElementType }}, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]{{ .ElementType }}{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - if len(claim.evaluationPoints) == 0 || len(claim.claimedEvaluations) == 0 { - return errors.New("missing input wire claim") - } - evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheckVerify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.finalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.finalEvalProof[j].Marshal() - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + err = r.verifySkipLevel(levelI, proof) } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + err = r.verifySumcheckLevel(levelI, proof) + } + if err != nil { + return fmt.Errorf("level %d: %v", levelI, err) } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return nil } @@ -711,15 +611,15 @@ func iterateElems(elems []{{ .ElementType }}, counter *int, yield func(int, *{{ func (p Proof) flatten() iter.Seq2[int, *{{ .ElementType }}] { return func(yield func(int, *{{ .ElementType }}) bool) { - var counter int + var counter int for i := range p { for _, poly := range p[i].partialSumPolys { if !iterateElems(poly, &counter, yield) { - return + return } } if !iterateElems(p[i].finalEvalProof, &counter, yield) { - return + return } } } @@ -729,14 +629,14 @@ func (p Proof) flatten() iter.Seq2[int, *{{ .ElementType }}] { // It manages the stack internally and handles input buffering, making it easy to // evaluate the same gate multiple times with different inputs. type gateEvaluator struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode vars []{{ .ElementType }} nbIn int // number of inputs expected } // newGateEvaluator creates an evaluator for the given compiled gate. // The stack is preloaded with constants and ready for evaluation. -func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { +func newGateEvaluator(gate gkrcore.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { e := gateEvaluator{ gate: gate, nbIn: nbIn, @@ -744,7 +644,7 @@ func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*poly if len(elementPool) > 0 { e.vars = elementPool[0].Make(gate.NbConstants() + nbIn + len(gate.Instructions)) } else { - e.vars = make([]{{.ElementType}}, gate.NbConstants()+nbIn+len(gate.Instructions)) + e.vars = make([]{{ .ElementType }}, gate.NbConstants()+nbIn+len(gate.Instructions)) } e.vars = e.vars[:gate.NbConstants()] for i, constVal := range gate.Constants { @@ -780,28 +680,28 @@ func (e *gateEvaluator) evaluate(top ...{{ .ElementType }}) *{{ .ElementType }} // Use switch instead of function pointer for better inlining switch inst.Op { - case gkrtypes.OpAdd: + case gkrcore.OpAdd: dst.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Add(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpMul: + case gkrcore.OpMul: dst.Mul(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Mul(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpSub: + case gkrcore.OpSub: dst.Sub(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Sub(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpNeg: + case gkrcore.OpNeg: dst.Neg(&e.vars[inst.Inputs[0]]) - case gkrtypes.OpMulAcc: + case gkrcore.OpMulAcc: var prod {{ .ElementType }} prod.Mul(&e.vars[inst.Inputs[1]], &e.vars[inst.Inputs[2]]) dst.Add(&e.vars[inst.Inputs[0]], &prod) - case gkrtypes.OpSumExp17: + case gkrcore.OpSumExp17: // result = (x[0] + x[1] + x[2])^17 var sum {{ .ElementType }} sum.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) @@ -827,19 +727,19 @@ func (e *gateEvaluator) evaluate(top ...{{ .ElementType }}) *{{ .ElementType }} // gateEvaluatorPool manages a pool of gate evaluators for a specific gate type // All evaluators share the same underlying polynomial.Pool for element slices type gateEvaluatorPool struct { - gate gkrtypes.GateBytecode - nbIn int - lock sync.Mutex - available map[*gateEvaluator]struct{} - elementPool *polynomial.Pool + gate gkrcore.GateBytecode + nbIn int + lock sync.Mutex + available map[*gateEvaluator]struct{} + elementPool *polynomial.Pool } -func newGateEvaluatorPool(gate gkrtypes.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { +func newGateEvaluatorPool(gate gkrcore.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { gep := &gateEvaluatorPool{ - gate: gate, - nbIn: nbIn, - elementPool: elementPool, - available: make(map[*gateEvaluator]struct{}), + gate: gate, + nbIn: nbIn, + elementPool: elementPool, + available: make(map[*gateEvaluator]struct{}), } return gep } @@ -862,7 +762,7 @@ func (gep *gateEvaluatorPool) put(e *gateEvaluator) { gep.lock.Lock() defer gep.lock.Unlock() - // Return evaluator to pool (it keeps its vars slice from polynomial pool) + // Return evaluator to pool (it keeps its vars slice from the polynomial pool) gep.available[e] = struct{}{} } diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index 729c28948c..5977893dbf 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -1,25 +1,24 @@ import ( - "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/mimc" - "{{.FieldPackagePath}}/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/ecc" - gcUtils "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" - "github.com/stretchr/testify/assert" "fmt" "hash" "os" - "strconv" - "testing" "path/filepath" "reflect" + "testing" "time" + + "github.com/consensys/gnark-crypto/ecc" + "{{ .FieldPackagePath }}" + "{{ .FieldPackagePath }}/mimc" + "{{ .FieldPackagePath }}/polynomial" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" + "github.com/consensys/gnark/internal/gkr/gkrtesting" + "github.com/stretchr/testify/assert" ) -var cache = gkrtesting.NewCache(ecc.{{.FieldID}}.ScalarField()) +var cache = gkrtesting.NewCache(ecc.{{ .FieldID }}.ScalarField()) func TestNoGateTwoInstances(t *testing.T) { // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case @@ -62,36 +61,164 @@ func TestMimc(t *testing.T) { test(t, gkrtesting.MiMCCircuit(93)) } -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - // Construct SerializableCircuit directly, bypassing CompileCircuit - // which would reset NbUniqueOutputs based on actual topology - circuit := gkrtypes.SerializableCircuit{ +func TestPoseidon2(t *testing.T) { + test(t, gkrtesting.Poseidon2Circuit(4, 2)) +} + +// testSumcheckLevel exercises proveSumcheckLevel/verifySumcheckLevel for a single sumcheck level. +func testSumcheckLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) + + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]{{ .ElementType }}, 2) + {{ .FieldPackageName }}.Vector(assignment[i]).MustSetRandom() + } + + assignment.Complete(sCircuit) + + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]{{ .ElementType }}{ {one} } + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSumcheckLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySumcheckLevel(0, proof)) +} + +func TestSumcheckLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3,4 = mul(0,1). All gates are independent outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + } + // Each level has an initial challenge at index 1 (len(schedule) = 1). + // GkrClaimSource{Level:1} is the initial-challenge sentinel. + tests := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single wire", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + }, + { + name: "two groups", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + {Wires: []int{3}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + }, { - NbUniqueOutputs: 2, - Gate: gkrtypes.SerializableGate{Degree: 1}, + name: "one group with two wires", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4, 3}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, }, + { + name: "mixed: single + multi-wire group", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + {Wires: []int{3, 2}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testSumcheckLevel(t, circuit, tc.level) + }) } +} + +// testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. +func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) - assignment := WireAssignment{[]{{ .ElementType }}{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := gcUtils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(0, []{{ .ElementType }}{three}, five) - manager.add(0, []{{ .ElementType }}{four}, six) - return &manager + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]{{ .ElementType }}, 2) + {{ .FieldPackageName }}.Vector(assignment[i]).MustSetRandom() } - transcriptGen := newMessageCounterGenerator(4, 1) + assignment.Complete(sCircuit) - proof, err := sumcheckProve(claimsManagerGen().getClaim(0), fiatshamir.WithHash(transcriptGen(), nil)) + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]{{ .ElementType }}{ {one} } + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) - err = sumcheckVerify(claimsManagerGen().getLazyClaim(0), proof, fiatshamir.WithHash(transcriptGen(), nil)) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSkipLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySkipLevel(0, proof)) +} + +func TestSkipLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3 = identity(0); wire 4 = add(0,1). All degree-1 outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Add2, Inputs: []int{0, 1}}, + } + + // Single-claim cases: one inherited evaluation point (OutgoingClaimIndex always 0). + singleClaim := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single input wire", + level: constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + { + name: "single identity gate", + level: constraint.GkrSkipLevel{Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + { + name: "add gate", + level: constraint.GkrSkipLevel{Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + { + name: "two identity gates one group", + level: constraint.GkrSkipLevel{Wires: []int{2, 3}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + } + for _, tc := range singleClaim { + t.Run(tc.name, func(t *testing.T) { + testSkipLevel(t, circuit, tc.level) + }) + } } var one, two, three, four, five, six {{ .ElementType }} @@ -105,35 +232,24 @@ func init() { six.Double(&three) } -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances +func test(t *testing.T, circuit gkrcore.RawCircuit) { + testWithSchedule(t, circuit, nil) } -func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { - sCircuit := cache.Compile(t, circuit) - ins := circuit.Inputs() +func testWithSchedule(t *testing.T, circuit gkrcore.RawCircuit, schedule constraint.GkrProvingSchedule) { + gCircuit, sCircuit := cache.Compile(t, circuit) + if schedule == nil { + var err error + schedule, err = gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + } + ins := gCircuit.Inputs() insAssignment := make(WireAssignment, len(ins)) - maxSize := 1 << getLogMaxInstances(t) + maxSize := 1 << gkrtesting.GetLogMaxInstances(t) for i := range ins { insAssignment[i] = make([]{{ .ElementType }}, maxSize) - fr.Vector(insAssignment[i]).MustSetRandom() + {{ .FieldPackageName }}.Vector(insAssignment[i]).MustSetRandom() } fullAssignment := make(WireAssignment, len(circuit)) @@ -144,51 +260,33 @@ func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { fullAssignment.Complete(sCircuit) - t.Log("Selected inputs for test") - - proof, err := Prove(sCircuit, fullAssignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(sCircuit, schedule, fullAssignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") - if proof.isEmpty() { // special case for TestNoGate: - continue // there's no way to make a trivial proof fail - } - - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(0, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(0, 1)) assert.NotNil(t, err, "bad proof accepted") } - -} - -func (p Proof) isEmpty() bool { - for i := range p { - if len(p[i].finalEvalProof) != 0 { - return false - } - for j := range p[i].partialSumPolys { - if len(p[i].partialSumPolys[j]) != 0 { - return false - } - } - } - return true } func testNoGate(t *testing.T, inputAssignments ...[]{{ .ElementType }}) { - c := cache.Compile(t, gkrtesting.NoGateCircuit()) + _, c := cache.Compile(t, gkrtesting.NoGateCircuit()) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(t, err) assignment := WireAssignment{0: inputAssignments[0]} - proof, err := Prove(c, assignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(c, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(c, assignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(c, schedule, assignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") } @@ -196,7 +294,7 @@ func generateTestProver(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + proof, err := Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) assert.NoError(t, err) assert.NoError(t, proofEquals(testCase.Proof, proof)) } @@ -206,17 +304,29 @@ func generateTestVerifier(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, newMessageCounter(2, 0)) assert.NotNil(t, err, "bad proof accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0].Add(&testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered output accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[0][0].Add(&testCase.InOutAssignment[0][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered input accepted") } } func TestGkrVectors(t *testing.T) { - const testDirPath = "../test_vectors/" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) @@ -260,12 +370,15 @@ func proofEquals(expected Proof, seen Proof) error { func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") - c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + _, c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(b, err) in0 := make([]{{ .ElementType }}, nbInstances) in1 := make([]{{ .ElementType }}, nbInstances) - {{.FieldPackageName}}.Vector(in0).MustSetRandom() - {{.FieldPackageName}}.Vector(in1).MustSetRandom() + {{ .FieldPackageName }}.Vector(in0).MustSetRandom() + {{ .FieldPackageName }}.Vector(in1).MustSetRandom() fmt.Println("evaluating circuit") start := time.Now().UnixMicro() @@ -276,12 +389,30 @@ func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { //b.ResetTimer() fmt.Println("constructing proof") start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + _, err = Prove(c, schedule, assignment, mimc.NewMiMC()) proved := time.Now().UnixMicro() - start fmt.Println("proved in", proved, "μs") assert.NoError(b, err) } +// TestSingleMulGateExplicitSchedule tests a single mul gate with an explicit single-step schedule, +// equivalent to the default but constructed manually to exercise the schedule path. +func TestSingleMulGateExplicitSchedule(t *testing.T) { + circuit := gkrtesting.SingleMulGateCircuit() + _, sCircuit := cache.Compile(t, circuit) + + // Wire 2 is the mul gate output (inputs: 0, 1). + // Explicit schedule: one GkrProvingLevel for wire 2. + // GkrClaimSource{Level:1} is the initial-challenge sentinel (len(schedule)=1). + schedule := constraint.GkrProvingSchedule{ + constraint.GkrSumcheckLevel{ + {Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{ {Level: 1} }}, + }, + } + testWithSchedule(t, circuit, schedule) + _ = sCircuit +} + func BenchmarkGkrMimc19(b *testing.B) { benchmarkGkrMiMC(b, 1<<19, 91) } @@ -290,4 +421,4 @@ func BenchmarkGkrMimc17(b *testing.B) { benchmarkGkrMiMC(b, 1<<17, 91) } -{{template "gkrTestVectors" .}} +{{ template "gkrTestVectors" .}} diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl index 49f2d3eb7c..7cbc8ed2f3 100644 --- a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl @@ -3,11 +3,11 @@ import ( "fmt" "github.com/consensys/bavard" "github.com/consensys/gnark-crypto/ecc" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" "hash" "os" "path/filepath" @@ -56,10 +56,8 @@ func run(absPath string) error { return err } - transcriptSetting := fiatshamir.WithHash(testCase.Hash) - var proof Proof - proof, err = Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + proof, err = Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) if err != nil { return err } @@ -81,7 +79,7 @@ func run(absPath string) error { return err } - err = Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, proof, testCase.Hash) if err != nil { return err } @@ -91,7 +89,7 @@ func run(absPath string) error { return err } - err = Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, proof, newMessageCounter(2, 0)) if err == nil { return fmt.Errorf("bad proof accepted") } @@ -116,7 +114,7 @@ func toPrintableProof(proof Proof) (gkrtesting.PrintableProof, error) { return res, nil } -func elementToInterface(x *{{.ElementType}}) interface{} { +func elementToInterface(x *{{ .ElementType }}) interface{} { if i := x.BigInt(nil); i != nil { return i } @@ -132,7 +130,7 @@ func elementSliceToInterfaceSlice(x interface{}) []interface{} { res := make([]interface{}, X.Len()) for i := range res { - xI := X.Index(i).Interface().({{.ElementType}}) + xI := X.Index(i).Interface().({{ .ElementType }}) res[i] = elementToInterface(&xI) } return res @@ -153,4 +151,4 @@ func elementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { return res } -{{template "gkrTestVectors" .}} \ No newline at end of file +{{ template "gkrTestVectors" .}} diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl index e0778641bd..92370ece26 100644 --- a/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl @@ -1,15 +1,13 @@ -{{define "gkrTestVectors"}} - -{{$CheckOutputCorrectness := true}} +{{ define "gkrTestVectors" }} func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { proof := make(Proof, len(printable)) for i := range printable { - finalEvalProof := []{{.ElementType}}(nil) + finalEvalProof := []{{ .ElementType }}(nil) if printable[i].FinalEvalProof != nil { finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]{{.ElementType}}, finalEvalSlice.Len()) + finalEvalProof = make([]{{ .ElementType }}, finalEvalSlice.Len()) for k := range finalEvalProof { if _, err := {{ setElement "finalEvalProof[k]" "finalEvalSlice.Index(k).Interface()" .ElementType}}; err != nil { return nil, err @@ -32,12 +30,13 @@ func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { } type TestCase struct { - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit Hash hash.Hash Proof Proof FullAssignment WireAssignment InOutAssignment WireAssignment - {{if .GenerateTestVectors}}Info gkrtesting.TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it{{end}} + Schedule constraint.GkrProvingSchedule + {{ if .GenerateTestVectors }}Info gkrtesting.TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it{{ end }} } var testCases = make(map[string]*TestCase) @@ -69,6 +68,20 @@ func newTestCase(path string) (*TestCase, error) { if proof, err = unmarshalProof(info.Proof); err != nil { return nil, err } + var schedule constraint.GkrProvingSchedule + if schedule, err = info.Schedule.ToProvingSchedule(); err != nil { + return nil, err + } + if schedule == nil { + if schedule, err = gkrcore.DefaultProvingSchedule(circuit); err != nil { + return nil, err + } + } + + outputSet := make(map[int]bool, len(circuit)) + for _, o := range circuit.Outputs() { + outputSet[o] = true + } fullAssignment := make(WireAssignment, len(circuit)) inOutAssignment := make(WireAssignment, len(circuit)) @@ -82,7 +95,7 @@ func newTestCase(path string) (*TestCase, error) { } assignmentRaw = info.Input[inI] inI++ - } else if circuit[i].IsOutput() { + } else if outputSet[i] { if outI == len(info.Output) { return nil, fmt.Errorf("fewer output in vector than in circuit") } @@ -103,7 +116,7 @@ func newTestCase(path string) (*TestCase, error) { fullAssignment.Complete(circuit) for i := range circuit { - if circuit[i].IsOutput() { + if outputSet[i] { if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil { return nil, fmt.Errorf("assignment mismatch: %v", err) } @@ -116,6 +129,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, + Schedule: schedule, {{ if .GenerateTestVectors }} Info: info, {{ end }} } @@ -124,7 +138,7 @@ func newTestCase(path string) (*TestCase, error) { return tCase, nil } -{{end}} +{{ end }} {{- define "setElement element value elementType"}} {{- if eq .elementType "fr.Element"}} setElement(&{{.element}}, {{.value}}) @@ -132,4 +146,4 @@ func newTestCase(path string) (*TestCase, error) { {{- else}} {{print "\"UNEXPECTED TYPE" .elementType "\""}} {{- end}} -{{- end}} \ No newline at end of file +{{- end}} diff --git a/internal/generator/backend/template/gkr/sumcheck.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.go.tmpl index 41ff58b325..bd5c2c2c28 100644 --- a/internal/generator/backend/template/gkr/sumcheck.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.go.tmpl @@ -1,163 +1,109 @@ import ( "errors" + "hash" + "{{.FieldPackagePath}}" "{{.FieldPackagePath}}/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" ) -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. +// This does not make use of parallelism and represents polynomials as lists of coefficients. + +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.Hash + bound bool // whether Bind was called since the last getChallenge +} + +// Bind writes field elements to the transcript as bindings for the next challenge. +func (t *transcript) Bind(elements ...{{ .ElementType }}) { + if len(elements) == 0 { + return + } + for i := range elements { + bytes := elements[i].Bytes() + t.h.Write(bytes[:]) + } + t.bound = true +} + +// getChallenge binds optional elements, then squeezes a challenge from the current hash state. +// If no bindings were added since the last squeeze, a separator byte is written first +// to advance the state and prevent repeated values. +func (t *transcript) getChallenge(bindings ...{{ .ElementType }}) {{ .ElementType }} { + t.Bind(bindings...) + if !t.bound { + t.h.Write([]byte{0}) + } + t.bound = false + var res {{ .ElementType }} + res.SetBytes(t.h.Sum(nil)) + return res +} // sumcheckClaims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type sumcheckClaims interface { - fold(a {{.ElementType}}) polynomial.Polynomial // fold into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - next({{.ElementType}}) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + roundPolynomial() polynomial.Polynomial // compute gⱼ polynomial for current round + roundFold(r {{ .ElementType }}) // fold inputs and eq at challenge r varsNum() int // number of variables - claimsNum() int // number of claims - proveFinalEval(r []{{.ElementType}}) []{{.ElementType}} // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + proveFinalEval(r []{{ .ElementType }}) []{{ .ElementType }} // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // sumcheckLazyClaims is the sumcheckClaims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(a {{.ElementType}}) {{.ElementType}} // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(r []{{.ElementType}}, foldingCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof []{{.ElementType}}) error + varsNum() int // varsNum = n + degree(i int) int // degree of the total claim in the i'th variable + verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, proof []{{ .ElementType }}) error } // sumcheckProof of a multi-statement. type sumcheckProof struct { partialSumPolys []polynomial.Polynomial - finalEvalProof []{{.ElementType}} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []{{.ElementType}}, remainingChallengeNames *[]string) ({{.ElementType}}, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return {{.ElementType}}{}, err - } - } - var res {{.ElementType}} - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err + finalEvalProof []{{ .ElementType }} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } -// sumcheckProve create a non-interactive proof -func sumcheckProve(claims sumcheckClaims, transcriptSettings fiatshamir.Settings) (sumcheckProof, error) { - - var proof sumcheckProof - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var foldingCoeff {{.ElementType}} - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - +// sumcheckProve creates a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (proveLevel). +// Pattern: roundPolynomial, [roundFold, roundPolynomial]*, proveFinalEval. +func sumcheckProve(claims sumcheckClaims, t *transcript) sumcheckProof { varsNum := claims.varsNum() - proof.partialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.partialSumPolys[0] = claims.fold(foldingCoeff) - challenges := make([]{{.ElementType}}, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.partialSumPolys[j+1] = claims.next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.partialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err + proof := sumcheckProof{partialSumPolys: make([]polynomial.Polynomial, varsNum)} + proof.partialSumPolys[0] = claims.roundPolynomial() + challenges := make([]{{ .ElementType }}, varsNum) + + for j := range varsNum - 1 { + challenges[j] = t.getChallenge(proof.partialSumPolys[j]...) + claims.roundFold(challenges[j]) + proof.partialSumPolys[j+1] = claims.roundPolynomial() } + challenges[varsNum-1] = t.getChallenge(proof.partialSumPolys[varsNum-1]...) proof.finalEvalProof = claims.proveFinalEval(challenges) - - return proof, nil + return proof } -func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff {{.ElementType}} +// sumcheckVerify verifies a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (verifyLevel). +// claimedSum is the expected sum; degree is the polynomial's degree in each variable. +func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum {{ .ElementType }}, degree int, t *transcript) error { + r := make([]{{ .ElementType }}, claims.varsNum()) - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]{{.ElementType}}, claims.varsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum for j := range claims.varsNum() { - if len(proof.partialSumPolys[j]) != claims.degree(j) { + if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) - //Prepare for the next iteration - if r[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.degree(j) + 1)]) + r[j] = t.getChallenge(proof.partialSumPolys[j]...) + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) gJR = gJCoeffs.Eval(&r[j]) } - return claims.verifyFinalEval(r, foldingCoeff, gJR, proof.finalEvalProof) + return claims.verifyFinalEval(r, gJR, proof.finalEvalProof) } diff --git a/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl index 472898669c..54a5c00b85 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl @@ -4,41 +4,36 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) proveFinalEval(r []{{.ElementType}}) []{{.ElementType}} { +func (c *singleMultilinClaim) proveFinalEval(r []{{ .ElementType }}) []{{ .ElementType }} { return nil // verifier can compute the final eval itself } -func (c singleMultilinClaim) varsNum() int { +func (c *singleMultilinClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -func (c singleMultilinClaim) claimsNum() int { - return 1 -} - func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { sum := g[len(g)/2] for i := len(g)/2 + 1; i < len(g); i++ { sum.Add(&sum, &g[i]) } - return []{{.ElementType}}{sum} + return []{{ .ElementType }}{sum} } -func (c singleMultilinClaim) fold({{.ElementType}}) polynomial.Polynomial { +func (c *singleMultilinClaim) roundPolynomial() polynomial.Polynomial { return sumForX1One(c.g) } -func (c *singleMultilinClaim) next(r {{.ElementType}}) polynomial.Polynomial { +func (c *singleMultilinClaim) roundFold(r {{ .ElementType }}) { c.g.Fold(r) - return sumForX1One(c.g) } type singleMultilinLazyClaim struct { g polynomial.MultiLin - claimedSum {{.ElementType}} + claimedSum {{ .ElementType }} } -func (c singleMultilinLazyClaim) verifyFinalEval(r []{{.ElementType}}, _ {{.ElementType}}, purportedValue {{.ElementType}}, proof []{{.ElementType}}) error { +func (c singleMultilinLazyClaim) verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, proof []{{ .ElementType }}) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil @@ -46,15 +41,7 @@ func (c singleMultilinLazyClaim) verifyFinalEval(r []{{.ElementType}}, _ {{.Elem return fmt.Errorf("mismatch") } -func (c singleMultilinLazyClaim) foldedSum(_ {{.ElementType}}) {{.ElementType}} { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) claimsNum() int { +func (c singleMultilinLazyClaim) degree(int) int { return 1 } @@ -62,4 +49,4 @@ func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -{{ end }} \ No newline at end of file +{{ end }} diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl index 72e0f76326..146a67727d 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -1,11 +1,10 @@ import ( "fmt" - "{{.FieldPackagePath}}/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "{{ .FieldPackagePath }}/polynomial" "github.com/stretchr/testify/assert" "hash" {{ if not .GenerateTestVectors}} - "{{.FieldPackagePath}}" + "{{ .FieldPackagePath }}" "math/bits" {{ end }} "strings" @@ -19,11 +18,9 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } claim := singleMultilinClaim{g: poly.Clone()} + t := transcript{h: hashGenerator()} - proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } + proof := sumcheckProve(&claim, &t) var sb strings.Builder for _, p := range proof.partialSumPolys { @@ -39,13 +36,15 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + t = transcript{h: hashGenerator()} + if err := sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t); err != nil { return err } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + t = transcript{h: hashGenerator()} + if sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t) == nil { return fmt.Errorf("bad proof accepted") } return nil @@ -82,4 +81,4 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { {{ if not .GenerateTestVectors }} {{ template "sumcheckTestDefs" .}} -{{ end }} \ No newline at end of file +{{ end }} diff --git a/internal/generator/backend/template/gkr/sumcheck.test.vectors.gen.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.vectors.gen.go.tmpl index d47ce5bc3f..de9055e2f4 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.vectors.gen.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.vectors.gen.go.tmpl @@ -1,7 +1,6 @@ import ( "encoding/json" "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" "github.com/consensys/gnark/internal/gkr/gkrtesting" @@ -30,11 +29,9 @@ func runMultilin(testCaseInfo *sumcheckTestCaseInfo) error { return err } - proof, err := sumcheckProve( - &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) - if err != nil { - return err - } + claim := singleMultilinClaim{poly} + t := transcript{h: hsh} + proof := sumcheckProve(&claim, &t) testCaseInfo.Proof = sumcheckToPrintableProof(proof) // Verification @@ -48,12 +45,20 @@ func runMultilin(testCaseInfo *sumcheckTestCaseInfo) error { return err } - if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { + if hsh, err = hashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + t = transcript{h: hsh} + if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, claimedSum, 1, &t); err != nil { return fmt.Errorf("proof rejected: %v", err) } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) - if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { + if hsh, err = hashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + t = transcript{h: hsh} + if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, claimedSum, 1, &t); err == nil { return fmt.Errorf("bad proof accepted") } @@ -138,4 +143,4 @@ func sumcheckToPrintableProof(proof sumcheckProof) (printable SumcheckPrintableP return } -{{ template "sumcheckTestDefs" .}} \ No newline at end of file +{{ template "sumcheckTestDefs" .}} diff --git a/internal/gkr/bls12-377/blueprint.go b/internal/gkr/bls12-377/blueprint.go index 67b337ea1e..f5d004f0a1 100644 --- a/internal/gkr/bls12-377/blueprint.go +++ b/internal/gkr/bls12-377/blueprint.go @@ -15,10 +15,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) func init() { @@ -34,7 +33,7 @@ type circuitEvaluator struct { // BlueprintSolve is a BLS12_377-specific blueprint for solving GKR circuit instances. type BlueprintSolve struct { // Circuit structure (serialized) - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit NbInstances uint32 // Not serialized - recreated lazily at solve time @@ -204,6 +203,7 @@ func (b *BlueprintSolve) UpdateInstructionTree(inst constraint.Instruction, tree type BlueprintProve struct { SolveBlueprintID constraint.BlueprintID SolveBlueprint *BlueprintSolve `cbor:"-"` // not serialized, set at compile time + Schedule constraint.GkrProvingSchedule HashName string lock sync.Mutex @@ -256,9 +256,11 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra } } + // Create hasher and write base challenges + hsh := hash.NewHash(b.HashName + "_BLS12_377") + // Read initial challenges from instruction calldata (parse dynamically, no metadata) // Format: [0]=totalSize, [1...]=challenge linear expressions - insBytes := make([][]byte, 0) // first challenges calldata := inst.Calldata[1:] // skip size prefix for len(calldata) != 0 { val, delta := s.Read(calldata) @@ -267,17 +269,14 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra // Copy directly from constraint.U64 to fr.Element (both in Montgomery form) var challenge fr.Element copy(challenge[:], val[:]) - insBytes = append(insBytes, challenge.Marshal()) + challengeBytes := challenge.Bytes() + hsh.Write(challengeBytes[:]) } - // Create Fiat-Shamir settings - hsh := hash.NewHash(b.HashName + "_BLS12_377") - fsSettings := fiatshamir.WithHash(hsh, insBytes...) - // Call the BLS12_377-specific Prove function (assignments already WireAssignment type) - proof, err := Prove(solveBlueprint.Circuit, assignments, fsSettings) + proof, err := Prove(solveBlueprint.Circuit, b.Schedule, assignments, hsh) if err != nil { - return fmt.Errorf("bls12_377 prove failed: %w", err) + return fmt.Errorf("BLS12_377 prove failed: %w", err) } for i, elem := range proof.flatten() { @@ -305,7 +304,7 @@ func (b *BlueprintProve) proofSize() int { } nbPaddedInstances := ecc.NextPowerOfTwo(uint64(b.SolveBlueprint.NbInstances)) logNbInstances := bits.TrailingZeros64(nbPaddedInstances) - return b.SolveBlueprint.Circuit.ProofSize(logNbInstances) + return b.SolveBlueprint.Circuit.ProofSize(b.Schedule, logNbInstances) } // NbOutputs implements Blueprint @@ -434,7 +433,7 @@ func (b *BlueprintGetAssignment) UpdateInstructionTree(inst constraint.Instructi } // NewBlueprints creates and registers all GKR blueprints for BLS12_377 -func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compiler constraint.CustomizableSystem) gkrtypes.Blueprints { +func NewBlueprints(circuit gkrcore.SerializableCircuit, schedule constraint.GkrProvingSchedule, hashName string, compiler constraint.CustomizableSystem) gkrcore.Blueprints { // Create and register solve blueprint solve := &BlueprintSolve{Circuit: circuit} solveID := compiler.AddBlueprint(solve) @@ -443,6 +442,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil prove := &BlueprintProve{ SolveBlueprintID: solveID, SolveBlueprint: solve, + Schedule: schedule, HashName: hashName, } proveID := compiler.AddBlueprint(prove) @@ -453,7 +453,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil } getAssignmentID := compiler.AddBlueprint(getAssignment) - return gkrtypes.Blueprints{ + return gkrcore.Blueprints{ SolveID: solveID, Solve: solve, ProveID: proveID, diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index e9cc5a10c9..07dd53baca 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -8,655 +8,557 @@ package gkr import ( "errors" "fmt" + "hash" "iter" - "strconv" "sync" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) // Type aliases for bytecode-based GKR types type ( - Wire = gkrtypes.SerializableWire - Circuit = gkrtypes.SerializableCircuit + Wire = gkrcore.SerializableWire + Circuit = gkrcore.SerializableCircuit ) // The goal is to prove/verify evaluations of many instances of the same circuit -// WireAssignment is assignment of values to the same wire across many instances of the circuit +// WireAssignment is the assignment of values to the same wire across many instances of the circuit type WireAssignment []polynomial.MultiLin type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all wᵢ and evaluation point xᵢ in the level. +// Its purpose is to batch the checking of multiple wire evaluations at evaluation points. type zeroCheckLazyClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff fr.Element // the coefficient used to fold claims, conventionally 0 if there is only one claim + resources *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -// foldedSum returns ∑ᵢ aⁱ yᵢ -func (e *zeroCheckLazyClaims) foldedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) + return e.resources.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return e.manager.circuit[e.wireI].ZeroCheckDegree() -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. (c is foldingCoeff) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, foldingCoeff, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &foldingCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - wire := e.manager.circuit[e.wireI] - - // the w(...) term - var gateEvaluation fr.Element - if wire.IsInput() { // just compute w(r) - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool) - } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } + return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) +} + +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is checked equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +// The verifier checks consistency by evaluating gateᵥ(inputEvals...) and confirming +// that the full sum matches purportedValue. +func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { + e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} + level := e.resources.schedule[e.levelI] + gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) + + var claimedEvals polynomial.Polynomial + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.resources.circuit[wI] + + var gateEval fr.Element + if wire.IsInput() { + gateEval = e.resources.assignment[wI].Evaluate(r, &e.resources.memPool) + if !gateInputEvals[levelWireI][0].Equal(&gateEval) { + return errors.New("incompatible evaluations") + } + } else { + evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) + for _, v := range gateInputEvals[levelWireI] { + evaluator.pushInput(v) + } + gateEval.Set(evaluator.evaluate()) + } - evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) - for _, uniqueI := range injectionLeftInv { // map from all to unique - evaluator.pushInput(uniqueInputEvaluations[uniqueI]) + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } + levelWireI++ } - - gateEvaluation.Set(evaluator.evaluate()) } - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil + if total := claimedEvals.Eval(&e.foldingCoeff); !total.Equal(&purportedValue) { + return errors.New("incompatible evaluations") } - return errors.New("incompatible evaluations") + return nil } // zeroCheckClaims is a claim for sumcheck (prover side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the proving of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. +// Each wire has its own eq table with the batching coefficients baked in. type zeroCheckClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ) - manager *claimsManager - - input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) - - eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) - - gateEvaluatorPool *gateEvaluatorPool -} - -func (c *zeroCheckClaims) getWire() Wire { - return c.manager.circuit[c.wireI] -} - -// fold the multiple claims into one claim using a random combination (foldingCoeff or c). -// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim -// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and -// i iterates over the claims. -// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). -// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form -// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, -// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. -// The output of fold is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. -func (c *zeroCheckClaims) fold(foldingCoeff fr.Element) polynomial.Polynomial { - varsNum := c.varsNum() - eqLength := 1 << varsNum - claimsNum := c.claimsNum() - // initialize the eq tables ( E ) - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - // E := eq(x₀, -) - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := foldingCoeff - - // E += cⁱ eq(xᵢ, -) - for k := 1; k < claimsNum; k++ { - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - if k+1 < claimsNum { - aI.Mul(&aI, &foldingCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e. -// m <- eq(q, -). -// e <- e + m -func (c *zeroCheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() + levelI int + resources *resources + input []polynomial.MultiLin // UniqueGateInputs order + inputIndices [][]int // [wireInLevel][gateInputJ] → index in input + eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points + gateEvaluatorPools []*gateEvaluatorPool } -// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). -// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). -// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) computeGJ() polynomial.Polynomial { - - wire := c.getWire() - degGJ := wire.ZeroCheckDegree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.input) - - // Both E and wᵢ (the input wires and the eq table) are multilinear, thus - // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. - // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. - ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube - ml[0] = c.eq - copy(ml[1:], c.input) - - sumSize := len(c.eq) / 2 // the range of h, over which we sum - - // Perf-TODO: Collate once at claim "folding" time and not again. then, even folding can be done in one operation every time "next" is called - - gJ := make([]fr.Element, degGJ) +func (c *zeroCheckClaims) varsNum() int { + return c.resources.nbVars +} + +// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). +// By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + + // Both eqs and input are multilinear, thus linear in Xⱼ. + // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly + // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree. + // Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}] + ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs) + copy(ml, c.eqs) + copy(ml[nbWires:], c.input) + + sumSize := len(c.eqs[0]) / 2 + + p := make([]fr.Element, degree) var mu sync.Mutex - computeAll := func(start, end int) { // compute method to allow parallelization across instances + computeAll := func(start, end int) { var step fr.Element - evaluator := c.gateEvaluatorPool.get() - defer c.gateEvaluatorPool.put(evaluator) + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() - res := make([]fr.Element, degGJ) + res := make([]fr.Element, degree) // evaluations of ml, laid out as: // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), // ... - // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) - mlEvals := make([]fr.Element, degGJ*len(ml)) - - for h := start; h < end; h++ { // h counts across instances + // ml[0](degree, h...), ml[1](degree, h...), ..., ml[len(ml)-1](degree, h...) + mlEvals := make([]fr.Element, degree*len(ml)) + for h := start; h < end; h++ { evalAt1Index := sumSize + h for k := range ml { - // d = 0 - mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1, taken directly from the table step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) - for d := 1; d < degGJ; d++ { + for d := 1; d < degree; d++ { mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - eIndex := 0 // index for where the current eq term is + eIndex := 0 // start of the current row's eq evaluations nextEIndex := len(ml) - for d := range degGJ { - // Push gate inputs - for i := range nbGateIn { - evaluator.pushInput(mlEvals[eIndex+1+i]) + for d := range degree { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &mlEvals[eIndex+w]) + res[d].Add(&res[d], summand) // collect contributions into the sum from start to end } - summand := evaluator.evaluate() - summand.Mul(summand, &mlEvals[eIndex]) - res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := range gJ { - gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + for i := range p { + p[i].Add(&p[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if sumSize < minBlockSize { - // no parallelization computeAll(0, sumSize) } else { - c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } - return gJ + return p } -// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. -// Thus, j <- j+1 and rⱼ = challenge. -func (c *zeroCheckClaims) next(challenge fr.Element) polynomial.Polynomial { +// roundFold folds all input and eq polynomials at the verifier challenge r. +// After this call, j ← j+1 and rⱼ = r. +func (c *zeroCheckClaims) roundFold(r fr.Element) { const minBlockSize = 512 - n := len(c.eq) / 2 + n := len(c.eqs[0]) / 2 if n < minBlockSize { - // no parallelization for i := range c.input { - c.input[i].Fold(challenge) + c.input[i].Fold(r) + } + for i := range c.eqs { + c.eqs[i].Fold(r) } - c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.input)) + wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { - wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + wgs[i] = c.resources.workers.Submit(n, c.input[i].FoldParallel(r), minBlockSize) + } + for i := range c.eqs { + wgs[len(c.input)+i] = c.resources.workers.Submit(n, c.eqs[i].FoldParallel(r), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } } - - return c.computeGJ() -} - -func (c *zeroCheckClaims) varsNum() int { - return len(c.evaluationPoints[0]) } -func (c *zeroCheckClaims) claimsNum() int { - return len(c.claimedEvaluations) -} - -// proveFinalEval provides the values wᵢ(r₁, ..., rₙ) +// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). func (c *zeroCheckClaims) proveFinalEval(r []fr.Element) []fr.Element { - //defer the proof, return list of claims - - injection, _ := c.manager.circuit.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers. - evaluations := make([]fr.Element, len(injection)) - for i, gateInputI := range injection { - wI := c.input[gateInputI] - wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. - c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0]) - evaluations[i] = wI[0] + c.resources.outgoingEvalPoints[c.levelI] = [][]fr.Element{r} + evaluations := make([]fr.Element, len(c.input)) + for i := range c.input { + c.input[i].Fold(r[len(r)-1]) + evaluations[i] = c.input[i][0] + } + for i := range c.input { + c.resources.memPool.Dump(c.input[i]) + } + for i := range c.eqs { + c.resources.memPool.Dump(c.eqs[i]) + } + for _, pool := range c.gateEvaluatorPools { + pool.dumpAll() } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - c.gateEvaluatorPool.dumpAll() - return evaluations } -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool - circuit Circuit -} +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- m[0] · eq(q, -). +// e <- e + m +func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) -func newClaimsManager(circuit Circuit, assignment WireAssignment, o settings) (manager claimsManager) { - manager.assignment = assignment - manager.claims = make([]*zeroCheckLazyClaims, len(circuit)) - manager.memPool = o.pool - manager.workers = o.workers - manager.circuit = circuit + // At the end of each iteration, m(h₁, ..., hₙ) = m[0] · eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // 1-based in comments: q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - for i := range circuit { - manager.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]fr.Element, 0, circuit[i].NbClaims()), - claimedEvaluations: manager.memPool.Make(circuit[i].NbClaims()), - manager: &manager, + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + } else { + r.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + }, 1024).Wait() } } - return -} - -func (m *claimsManager) add(wire int, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) + r.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() } -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] +type resources struct { + // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. + // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). + // SumcheckLevels produce one point (k=0). SkipLevels pass on all their evaluation points. + outgoingEvalPoints [][][]fr.Element + nbVars int + assignment WireAssignment + memPool polynomial.Pool + workers *utils.WorkerPool + circuit Circuit + schedule constraint.GkrProvingSchedule + transcript transcript + uniqueInputIndices [][]int // uniqueInputIndices[wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from } -func (m *claimsManager) getClaim(wireI int) *zeroCheckClaims { - lazy := m.claims[wireI] - wire := m.circuit[wireI] - res := &zeroCheckClaims{ - wireI: wireI, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])} - } else { - res.input = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied +func newResources(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (resources, error) { + nbVars := assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1<= 2 { + foldingCoeff = r.transcript.getChallenge() } -} -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) + uniqueInputs, inputIndices := r.circuit.InputMapping(level) + input := make([]polynomial.MultiLin, len(uniqueInputs)) + for i, inW := range uniqueInputs { + input[i] = r.memPool.Clone(r.assignment[inW]) } - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { + newEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + aI := alpha + for k := 1; k < nbSources; k++ { + aI.Mul(&aI, &foldingCoeff) + newEq[0].Set(&aI) + r.eqAcc(groupEq, newEq, r.outgoingEvalPoints[group.ClaimSources[k].Level][group.ClaimSources[k].OutgoingClaimIndex]) + } + r.memPool.Dump(newEq) + } -func ChallengeNames(c Circuit, logNbInstances int, prefix string) []string { + var stride fr.Element + stride.Set(&foldingCoeff) + for range nbSources - 1 { + stride.Mul(&stride, &foldingCoeff) + } - // Pre-compute the size TODO: Consider not doing this and just grow the list by appending - size := logNbInstances // first challenge + eqs[levelWireI] = groupEq + levelWireI++ + alpha.Mul(&alpha, &stride) - for i := range c { - if c[i].NoProof() { // no proof, no challenge - continue - } - if c[i].NbClaims() > 1 { //fold the claims - size++ + for w := 1; w < len(group.Wires); w++ { + eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength)) + r.workers.Submit(eqLength, func(start, end int) { + for i := start; i < end; i++ { + eqs[levelWireI][i].Mul(&eqs[levelWireI-1][i], &stride) + } + }, 512).Wait() + levelWireI++ + alpha.Mul(&alpha, &stride) } - size += logNbInstances // full run of sumcheck on logNbInstances variables } - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) + claims := &zeroCheckClaims{ + levelI: levelI, + resources: r, + input: input, + inputIndices: inputIndices, + eqs: eqs, + gateEvaluatorPools: pools, } + return sumcheckProve(claims, &r.transcript) +} - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + var foldingCoeff fr.Element + if nbClaims >= 2 { + foldingCoeff = r.transcript.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } + initialChallengeI := len(r.schedule) + claimedEvals := make(polynomial.Polynomial, 0, level.NbClaims()) - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + if src.Level == initialChallengeI { + claimedEvals = append(claimedEvals, r.assignment[wI].Evaluate(r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], &r.memPool)) + } else { + claimedEvals = append(claimedEvals, proof[src.Level].finalEvalProof[r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex)]) + } + } } } - return challenges -} -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} + claimedSum := claimedEvals.Eval(&foldingCoeff) -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err != nil { - return nil, err - } else if err = res[i].SetBytesCanonical(bytes); err != nil { - return nil, err - } + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + resources: r, + levelI: levelI, } - return res, nil + return sumcheckVerify(lazyClaims, proof[levelI], claimedSum, r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), &r.transcript) } // Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) +func Prove(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (Proof, error) { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return nil, err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) + proof := make(Proof, len(schedule)) - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(i) - if wire.NoProof() { // input wires with one claim only - proof[i] = sumcheckProof{ - partialSumPolys: []polynomial.Polynomial{}, - finalEvalProof: []fr.Element{}, - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + proof[levelI] = r.proveSkipLevel(levelI) } else { - if proof[i], err = sumcheckProve( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - baseChallenge = make([][]byte, len(proof[i].finalEvalProof)) - for j := range proof[i].finalEvalProof { - baseChallenge[j] = proof[i].finalEvalProof[j].Marshal() - } + proof[levelI] = r.proveSumcheckLevel(levelI) } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return proof, nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) +// Verify the consistency of the claimed output with the claimed input. +// Unlike in Prove, the assignment argument need not be complete. +func Verify(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, hasher hash.Hash) error { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - if len(claim.evaluationPoints) == 0 || len(claim.claimedEvaluations) == 0 { - return errors.New("missing input wire claim") - } - evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheckVerify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.finalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.finalEvalProof[j].Marshal() - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + err = r.verifySkipLevel(levelI, proof) } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + err = r.verifySumcheckLevel(levelI, proof) + } + if err != nil { + return fmt.Errorf("level %d: %v", levelI, err) } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return nil } @@ -734,14 +636,14 @@ func (p Proof) flatten() iter.Seq2[int, *fr.Element] { // It manages the stack internally and handles input buffering, making it easy to // evaluate the same gate multiple times with different inputs. type gateEvaluator struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode vars []fr.Element nbIn int // number of inputs expected } // newGateEvaluator creates an evaluator for the given compiled gate. // The stack is preloaded with constants and ready for evaluation. -func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { +func newGateEvaluator(gate gkrcore.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { e := gateEvaluator{ gate: gate, nbIn: nbIn, @@ -785,28 +687,28 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // Use switch instead of function pointer for better inlining switch inst.Op { - case gkrtypes.OpAdd: + case gkrcore.OpAdd: dst.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Add(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpMul: + case gkrcore.OpMul: dst.Mul(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Mul(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpSub: + case gkrcore.OpSub: dst.Sub(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Sub(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpNeg: + case gkrcore.OpNeg: dst.Neg(&e.vars[inst.Inputs[0]]) - case gkrtypes.OpMulAcc: + case gkrcore.OpMulAcc: var prod fr.Element prod.Mul(&e.vars[inst.Inputs[1]], &e.vars[inst.Inputs[2]]) dst.Add(&e.vars[inst.Inputs[0]], &prod) - case gkrtypes.OpSumExp17: + case gkrcore.OpSumExp17: // result = (x[0] + x[1] + x[2])^17 var sum fr.Element sum.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) @@ -832,14 +734,14 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // gateEvaluatorPool manages a pool of gate evaluators for a specific gate type // All evaluators share the same underlying polynomial.Pool for element slices type gateEvaluatorPool struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode nbIn int lock sync.Mutex available map[*gateEvaluator]struct{} elementPool *polynomial.Pool } -func newGateEvaluatorPool(gate gkrtypes.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { +func newGateEvaluatorPool(gate gkrcore.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { gep := &gateEvaluatorPool{ gate: gate, nbIn: nbIn, @@ -867,7 +769,7 @@ func (gep *gateEvaluatorPool) put(e *gateEvaluator) { gep.lock.Lock() defer gep.lock.Unlock() - // Return evaluator to pool (it keeps its vars slice from polynomial pool) + // Return evaluator to pool (it keeps its vars slice from the polynomial pool) gep.available[e] = struct{}{} } diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go index 95b45cc654..57f5c1e664 100644 --- a/internal/gkr/bls12-377/gkr_test.go +++ b/internal/gkr/bls12-377/gkr_test.go @@ -11,7 +11,6 @@ import ( "os" "path/filepath" "reflect" - "strconv" "testing" "time" @@ -19,10 +18,9 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - gcUtils "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/stretchr/testify/assert" ) @@ -69,36 +67,164 @@ func TestMimc(t *testing.T) { test(t, gkrtesting.MiMCCircuit(93)) } -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - // Construct SerializableCircuit directly, bypassing CompileCircuit - // which would reset NbUniqueOutputs based on actual topology - circuit := gkrtypes.SerializableCircuit{ +func TestPoseidon2(t *testing.T) { + test(t, gkrtesting.Poseidon2Circuit(4, 2)) +} + +// testSumcheckLevel exercises proveSumcheckLevel/verifySumcheckLevel for a single sumcheck level. +func testSumcheckLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) + + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + + assignment.Complete(sCircuit) + + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSumcheckLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySumcheckLevel(0, proof)) +} + +func TestSumcheckLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3,4 = mul(0,1). All gates are independent outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + } + // Each level has an initial challenge at index 1 (len(schedule) = 1). + // GkrClaimSource{Level:1} is the initial-challenge sentinel. + tests := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single wire", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "two groups", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "one group with two wires", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, { - NbUniqueOutputs: 2, - Gate: gkrtypes.SerializableGate{Degree: 1}, + name: "mixed: single + multi-wire group", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3, 2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, }, } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testSumcheckLevel(t, circuit, tc.level) + }) + } +} + +// testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. +func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) - assignment := WireAssignment{[]fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := gcUtils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(0, []fr.Element{three}, five) - manager.add(0, []fr.Element{four}, six) - return &manager + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() } - transcriptGen := newMessageCounterGenerator(4, 1) + assignment.Complete(sCircuit) - proof, err := sumcheckProve(claimsManagerGen().getClaim(0), fiatshamir.WithHash(transcriptGen(), nil)) + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) - err = sumcheckVerify(claimsManagerGen().getLazyClaim(0), proof, fiatshamir.WithHash(transcriptGen(), nil)) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSkipLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySkipLevel(0, proof)) +} + +func TestSkipLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3 = identity(0); wire 4 = add(0,1). All degree-1 outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Add2, Inputs: []int{0, 1}}, + } + + // Single-claim cases: one inherited evaluation point (OutgoingClaimIndex always 0). + singleClaim := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single input wire", + level: constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "single identity gate", + level: constraint.GkrSkipLevel{Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "add gate", + level: constraint.GkrSkipLevel{Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "two identity gates one group", + level: constraint.GkrSkipLevel{Wires: []int{2, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + for _, tc := range singleClaim { + t.Run(tc.name, func(t *testing.T) { + testSkipLevel(t, circuit, tc.level) + }) + } } var one, two, three, four, five, six fr.Element @@ -112,31 +238,20 @@ func init() { six.Double(&three) } -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances +func test(t *testing.T, circuit gkrcore.RawCircuit) { + testWithSchedule(t, circuit, nil) } -func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { - sCircuit := cache.Compile(t, circuit) - ins := circuit.Inputs() +func testWithSchedule(t *testing.T, circuit gkrcore.RawCircuit, schedule constraint.GkrProvingSchedule) { + gCircuit, sCircuit := cache.Compile(t, circuit) + if schedule == nil { + var err error + schedule, err = gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + } + ins := gCircuit.Inputs() insAssignment := make(WireAssignment, len(ins)) - maxSize := 1 << getLogMaxInstances(t) + maxSize := 1 << gkrtesting.GetLogMaxInstances(t) for i := range ins { insAssignment[i] = make([]fr.Element, maxSize) @@ -151,51 +266,33 @@ func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { fullAssignment.Complete(sCircuit) - t.Log("Selected inputs for test") - - proof, err := Prove(sCircuit, fullAssignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(sCircuit, schedule, fullAssignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") - if proof.isEmpty() { // special case for TestNoGate: - continue // there's no way to make a trivial proof fail - } - - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(0, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(0, 1)) assert.NotNil(t, err, "bad proof accepted") } - -} - -func (p Proof) isEmpty() bool { - for i := range p { - if len(p[i].finalEvalProof) != 0 { - return false - } - for j := range p[i].partialSumPolys { - if len(p[i].partialSumPolys[j]) != 0 { - return false - } - } - } - return true } func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := cache.Compile(t, gkrtesting.NoGateCircuit()) + _, c := cache.Compile(t, gkrtesting.NoGateCircuit()) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(t, err) assignment := WireAssignment{0: inputAssignments[0]} - proof, err := Prove(c, assignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(c, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(c, assignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(c, schedule, assignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") } @@ -203,7 +300,7 @@ func generateTestProver(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + proof, err := Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) assert.NoError(t, err) assert.NoError(t, proofEquals(testCase.Proof, proof)) } @@ -213,17 +310,29 @@ func generateTestVerifier(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, newMessageCounter(2, 0)) assert.NotNil(t, err, "bad proof accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0].Add(&testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered output accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[0][0].Add(&testCase.InOutAssignment[0][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered input accepted") } } func TestGkrVectors(t *testing.T) { - const testDirPath = "../test_vectors/" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) @@ -267,7 +376,10 @@ func proofEquals(expected Proof, seen Proof) error { func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") - c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + _, c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(b, err) in0 := make([]fr.Element, nbInstances) in1 := make([]fr.Element, nbInstances) @@ -283,12 +395,30 @@ func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { //b.ResetTimer() fmt.Println("constructing proof") start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + _, err = Prove(c, schedule, assignment, mimc.NewMiMC()) proved := time.Now().UnixMicro() - start fmt.Println("proved in", proved, "μs") assert.NoError(b, err) } +// TestSingleMulGateExplicitSchedule tests a single mul gate with an explicit single-step schedule, +// equivalent to the default but constructed manually to exercise the schedule path. +func TestSingleMulGateExplicitSchedule(t *testing.T) { + circuit := gkrtesting.SingleMulGateCircuit() + _, sCircuit := cache.Compile(t, circuit) + + // Wire 2 is the mul gate output (inputs: 0, 1). + // Explicit schedule: one GkrProvingLevel for wire 2. + // GkrClaimSource{Level:1} is the initial-challenge sentinel (len(schedule)=1). + schedule := constraint.GkrProvingSchedule{ + constraint.GkrSumcheckLevel{ + {Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + testWithSchedule(t, circuit, schedule) + _ = sCircuit +} + func BenchmarkGkrMimc19(b *testing.B) { benchmarkGkrMiMC(b, 1<<19, 91) } @@ -327,11 +457,12 @@ func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { } type TestCase struct { - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit Hash hash.Hash Proof Proof FullAssignment WireAssignment InOutAssignment WireAssignment + Schedule constraint.GkrProvingSchedule } var testCases = make(map[string]*TestCase) @@ -362,6 +493,20 @@ func newTestCase(path string) (*TestCase, error) { if proof, err = unmarshalProof(info.Proof); err != nil { return nil, err } + var schedule constraint.GkrProvingSchedule + if schedule, err = info.Schedule.ToProvingSchedule(); err != nil { + return nil, err + } + if schedule == nil { + if schedule, err = gkrcore.DefaultProvingSchedule(circuit); err != nil { + return nil, err + } + } + + outputSet := make(map[int]bool, len(circuit)) + for _, o := range circuit.Outputs() { + outputSet[o] = true + } fullAssignment := make(WireAssignment, len(circuit)) inOutAssignment := make(WireAssignment, len(circuit)) @@ -375,7 +520,7 @@ func newTestCase(path string) (*TestCase, error) { } assignmentRaw = info.Input[inI] inI++ - } else if circuit[i].IsOutput() { + } else if outputSet[i] { if outI == len(info.Output) { return nil, fmt.Errorf("fewer output in vector than in circuit") } @@ -396,7 +541,7 @@ func newTestCase(path string) (*TestCase, error) { fullAssignment.Complete(circuit) for i := range circuit { - if circuit[i].IsOutput() { + if outputSet[i] { if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil { return nil, fmt.Errorf("assignment mismatch: %v", err) } @@ -409,6 +554,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, + Schedule: schedule, } testCases[path] = tCase diff --git a/internal/gkr/bls12-377/sumcheck.go b/internal/gkr/bls12-377/sumcheck.go index b05e319932..c1700e3db3 100644 --- a/internal/gkr/bls12-377/sumcheck.go +++ b/internal/gkr/bls12-377/sumcheck.go @@ -7,33 +7,62 @@ package gkr import ( "errors" - "strconv" + "hash" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" ) -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. +// This does not make use of parallelism and represents polynomials as lists of coefficients. + +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.Hash + bound bool // whether Bind was called since the last getChallenge +} + +// Bind writes field elements to the transcript as bindings for the next challenge. +func (t *transcript) Bind(elements ...fr.Element) { + if len(elements) == 0 { + return + } + for i := range elements { + bytes := elements[i].Bytes() + t.h.Write(bytes[:]) + } + t.bound = true +} + +// getChallenge binds optional elements, then squeezes a challenge from the current hash state. +// If no bindings were added since the last squeeze, a separator byte is written first +// to advance the state and prevent repeated values. +func (t *transcript) getChallenge(bindings ...fr.Element) fr.Element { + t.Bind(bindings...) + if !t.bound { + t.h.Write([]byte{0}) + } + t.bound = false + var res fr.Element + res.SetBytes(t.h.Sum(nil)) + return res +} // sumcheckClaims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type sumcheckClaims interface { - fold(a fr.Element) polynomial.Polynomial // fold into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + roundPolynomial() polynomial.Polynomial // compute gⱼ polynomial for current round + roundFold(r fr.Element) // fold inputs and eq at challenge r varsNum() int // number of variables - claimsNum() int // number of claims proveFinalEval(r []fr.Element) []fr.Element // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // sumcheckLazyClaims is the sumcheckClaims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(a fr.Element) fr.Element // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(r []fr.Element, foldingCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error + varsNum() int // varsNum = n + degree(i int) int // degree of the total claim in the i'th variable + verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } // sumcheckProof of a multi-statement. @@ -42,130 +71,46 @@ type sumcheckProof struct { finalEvalProof []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// sumcheckProve create a non-interactive proof -func sumcheckProve(claims sumcheckClaims, transcriptSettings fiatshamir.Settings) (sumcheckProof, error) { - - var proof sumcheckProof - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var foldingCoeff fr.Element - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - +// sumcheckProve creates a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (proveLevel). +// Pattern: roundPolynomial, [roundFold, roundPolynomial]*, proveFinalEval. +func sumcheckProve(claims sumcheckClaims, t *transcript) sumcheckProof { varsNum := claims.varsNum() - proof.partialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.partialSumPolys[0] = claims.fold(foldingCoeff) + proof := sumcheckProof{partialSumPolys: make([]polynomial.Polynomial, varsNum)} + proof.partialSumPolys[0] = claims.roundPolynomial() challenges := make([]fr.Element, varsNum) - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.partialSumPolys[j+1] = claims.next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.partialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err + for j := range varsNum - 1 { + challenges[j] = t.getChallenge(proof.partialSumPolys[j]...) + claims.roundFold(challenges[j]) + proof.partialSumPolys[j+1] = claims.roundPolynomial() } + challenges[varsNum-1] = t.getChallenge(proof.partialSumPolys[varsNum-1]...) proof.finalEvalProof = claims.proveFinalEval(challenges) - - return proof, nil + return proof } -func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff fr.Element - - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - +// sumcheckVerify verifies a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (verifyLevel). +// claimedSum is the expected sum; degree is the polynomial's degree in each variable. +func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum fr.Element, degree int, t *transcript) error { r := make([]fr.Element, claims.varsNum()) - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum for j := range claims.varsNum() { - if len(proof.partialSumPolys[j]) != claims.degree(j) { + if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) - //Prepare for the next iteration - if r[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.degree(j) + 1)]) + r[j] = t.getChallenge(proof.partialSumPolys[j]...) + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) gJR = gJCoeffs.Eval(&r[j]) } - return claims.verifyFinalEval(r, foldingCoeff, gJR, proof.finalEvalProof) + return claims.verifyFinalEval(r, gJR, proof.finalEvalProof) } diff --git a/internal/gkr/bls12-377/sumcheck_test.go b/internal/gkr/bls12-377/sumcheck_test.go index f4b8d524b3..395a75bff3 100644 --- a/internal/gkr/bls12-377/sumcheck_test.go +++ b/internal/gkr/bls12-377/sumcheck_test.go @@ -10,7 +10,6 @@ import ( "hash" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/stretchr/testify/assert" "math/bits" @@ -28,11 +27,9 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } claim := singleMultilinClaim{g: poly.Clone()} + t := transcript{h: hashGenerator()} - proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } + proof := sumcheckProve(&claim, &t) var sb strings.Builder for _, p := range proof.partialSumPolys { @@ -48,13 +45,15 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + t = transcript{h: hashGenerator()} + if err := sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t); err != nil { return err } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + t = transcript{h: hashGenerator()} + if sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t) == nil { return fmt.Errorf("bad proof accepted") } return nil @@ -93,18 +92,14 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { +func (c *singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } -func (c singleMultilinClaim) varsNum() int { +func (c *singleMultilinClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -func (c singleMultilinClaim) claimsNum() int { - return 1 -} - func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { sum := g[len(g)/2] for i := len(g)/2 + 1; i < len(g); i++ { @@ -113,13 +108,12 @@ func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { return []fr.Element{sum} } -func (c singleMultilinClaim) fold(fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundPolynomial() polynomial.Polynomial { return sumForX1One(c.g) } -func (c *singleMultilinClaim) next(r fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundFold(r fr.Element) { c.g.Fold(r) - return sumForX1One(c.g) } type singleMultilinLazyClaim struct { @@ -127,7 +121,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, purportedValue fr.Element, proof []fr.Element) error { +func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil @@ -135,15 +129,7 @@ func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, p return fmt.Errorf("mismatch") } -func (c singleMultilinLazyClaim) foldedSum(_ fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) claimsNum() int { +func (c singleMultilinLazyClaim) degree(int) int { return 1 } diff --git a/internal/gkr/bls12-381/blueprint.go b/internal/gkr/bls12-381/blueprint.go index 6ab93347a9..2257585bf5 100644 --- a/internal/gkr/bls12-381/blueprint.go +++ b/internal/gkr/bls12-381/blueprint.go @@ -15,10 +15,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) func init() { @@ -34,7 +33,7 @@ type circuitEvaluator struct { // BlueprintSolve is a BLS12_381-specific blueprint for solving GKR circuit instances. type BlueprintSolve struct { // Circuit structure (serialized) - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit NbInstances uint32 // Not serialized - recreated lazily at solve time @@ -204,6 +203,7 @@ func (b *BlueprintSolve) UpdateInstructionTree(inst constraint.Instruction, tree type BlueprintProve struct { SolveBlueprintID constraint.BlueprintID SolveBlueprint *BlueprintSolve `cbor:"-"` // not serialized, set at compile time + Schedule constraint.GkrProvingSchedule HashName string lock sync.Mutex @@ -256,9 +256,11 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra } } + // Create hasher and write base challenges + hsh := hash.NewHash(b.HashName + "_BLS12_381") + // Read initial challenges from instruction calldata (parse dynamically, no metadata) // Format: [0]=totalSize, [1...]=challenge linear expressions - insBytes := make([][]byte, 0) // first challenges calldata := inst.Calldata[1:] // skip size prefix for len(calldata) != 0 { val, delta := s.Read(calldata) @@ -267,17 +269,14 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra // Copy directly from constraint.U64 to fr.Element (both in Montgomery form) var challenge fr.Element copy(challenge[:], val[:]) - insBytes = append(insBytes, challenge.Marshal()) + challengeBytes := challenge.Bytes() + hsh.Write(challengeBytes[:]) } - // Create Fiat-Shamir settings - hsh := hash.NewHash(b.HashName + "_BLS12_381") - fsSettings := fiatshamir.WithHash(hsh, insBytes...) - // Call the BLS12_381-specific Prove function (assignments already WireAssignment type) - proof, err := Prove(solveBlueprint.Circuit, assignments, fsSettings) + proof, err := Prove(solveBlueprint.Circuit, b.Schedule, assignments, hsh) if err != nil { - return fmt.Errorf("bls12_381 prove failed: %w", err) + return fmt.Errorf("BLS12_381 prove failed: %w", err) } for i, elem := range proof.flatten() { @@ -305,7 +304,7 @@ func (b *BlueprintProve) proofSize() int { } nbPaddedInstances := ecc.NextPowerOfTwo(uint64(b.SolveBlueprint.NbInstances)) logNbInstances := bits.TrailingZeros64(nbPaddedInstances) - return b.SolveBlueprint.Circuit.ProofSize(logNbInstances) + return b.SolveBlueprint.Circuit.ProofSize(b.Schedule, logNbInstances) } // NbOutputs implements Blueprint @@ -434,7 +433,7 @@ func (b *BlueprintGetAssignment) UpdateInstructionTree(inst constraint.Instructi } // NewBlueprints creates and registers all GKR blueprints for BLS12_381 -func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compiler constraint.CustomizableSystem) gkrtypes.Blueprints { +func NewBlueprints(circuit gkrcore.SerializableCircuit, schedule constraint.GkrProvingSchedule, hashName string, compiler constraint.CustomizableSystem) gkrcore.Blueprints { // Create and register solve blueprint solve := &BlueprintSolve{Circuit: circuit} solveID := compiler.AddBlueprint(solve) @@ -443,6 +442,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil prove := &BlueprintProve{ SolveBlueprintID: solveID, SolveBlueprint: solve, + Schedule: schedule, HashName: hashName, } proveID := compiler.AddBlueprint(prove) @@ -453,7 +453,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil } getAssignmentID := compiler.AddBlueprint(getAssignment) - return gkrtypes.Blueprints{ + return gkrcore.Blueprints{ SolveID: solveID, Solve: solve, ProveID: proveID, diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index dad2117fc3..bb240ddab0 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -8,655 +8,557 @@ package gkr import ( "errors" "fmt" + "hash" "iter" - "strconv" "sync" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) // Type aliases for bytecode-based GKR types type ( - Wire = gkrtypes.SerializableWire - Circuit = gkrtypes.SerializableCircuit + Wire = gkrcore.SerializableWire + Circuit = gkrcore.SerializableCircuit ) // The goal is to prove/verify evaluations of many instances of the same circuit -// WireAssignment is assignment of values to the same wire across many instances of the circuit +// WireAssignment is the assignment of values to the same wire across many instances of the circuit type WireAssignment []polynomial.MultiLin type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all wᵢ and evaluation point xᵢ in the level. +// Its purpose is to batch the checking of multiple wire evaluations at evaluation points. type zeroCheckLazyClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff fr.Element // the coefficient used to fold claims, conventionally 0 if there is only one claim + resources *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -// foldedSum returns ∑ᵢ aⁱ yᵢ -func (e *zeroCheckLazyClaims) foldedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) + return e.resources.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return e.manager.circuit[e.wireI].ZeroCheckDegree() -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. (c is foldingCoeff) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, foldingCoeff, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &foldingCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - wire := e.manager.circuit[e.wireI] - - // the w(...) term - var gateEvaluation fr.Element - if wire.IsInput() { // just compute w(r) - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool) - } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } + return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) +} + +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is checked equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +// The verifier checks consistency by evaluating gateᵥ(inputEvals...) and confirming +// that the full sum matches purportedValue. +func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { + e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} + level := e.resources.schedule[e.levelI] + gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) + + var claimedEvals polynomial.Polynomial + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.resources.circuit[wI] + + var gateEval fr.Element + if wire.IsInput() { + gateEval = e.resources.assignment[wI].Evaluate(r, &e.resources.memPool) + if !gateInputEvals[levelWireI][0].Equal(&gateEval) { + return errors.New("incompatible evaluations") + } + } else { + evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) + for _, v := range gateInputEvals[levelWireI] { + evaluator.pushInput(v) + } + gateEval.Set(evaluator.evaluate()) + } - evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) - for _, uniqueI := range injectionLeftInv { // map from all to unique - evaluator.pushInput(uniqueInputEvaluations[uniqueI]) + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } + levelWireI++ } - - gateEvaluation.Set(evaluator.evaluate()) } - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil + if total := claimedEvals.Eval(&e.foldingCoeff); !total.Equal(&purportedValue) { + return errors.New("incompatible evaluations") } - return errors.New("incompatible evaluations") + return nil } // zeroCheckClaims is a claim for sumcheck (prover side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the proving of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. +// Each wire has its own eq table with the batching coefficients baked in. type zeroCheckClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ) - manager *claimsManager - - input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) - - eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) - - gateEvaluatorPool *gateEvaluatorPool -} - -func (c *zeroCheckClaims) getWire() Wire { - return c.manager.circuit[c.wireI] -} - -// fold the multiple claims into one claim using a random combination (foldingCoeff or c). -// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim -// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and -// i iterates over the claims. -// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). -// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form -// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, -// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. -// The output of fold is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. -func (c *zeroCheckClaims) fold(foldingCoeff fr.Element) polynomial.Polynomial { - varsNum := c.varsNum() - eqLength := 1 << varsNum - claimsNum := c.claimsNum() - // initialize the eq tables ( E ) - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - // E := eq(x₀, -) - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := foldingCoeff - - // E += cⁱ eq(xᵢ, -) - for k := 1; k < claimsNum; k++ { - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - if k+1 < claimsNum { - aI.Mul(&aI, &foldingCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e. -// m <- eq(q, -). -// e <- e + m -func (c *zeroCheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() + levelI int + resources *resources + input []polynomial.MultiLin // UniqueGateInputs order + inputIndices [][]int // [wireInLevel][gateInputJ] → index in input + eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points + gateEvaluatorPools []*gateEvaluatorPool } -// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). -// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). -// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) computeGJ() polynomial.Polynomial { - - wire := c.getWire() - degGJ := wire.ZeroCheckDegree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.input) - - // Both E and wᵢ (the input wires and the eq table) are multilinear, thus - // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. - // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. - ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube - ml[0] = c.eq - copy(ml[1:], c.input) - - sumSize := len(c.eq) / 2 // the range of h, over which we sum - - // Perf-TODO: Collate once at claim "folding" time and not again. then, even folding can be done in one operation every time "next" is called - - gJ := make([]fr.Element, degGJ) +func (c *zeroCheckClaims) varsNum() int { + return c.resources.nbVars +} + +// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). +// By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + + // Both eqs and input are multilinear, thus linear in Xⱼ. + // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly + // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree. + // Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}] + ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs) + copy(ml, c.eqs) + copy(ml[nbWires:], c.input) + + sumSize := len(c.eqs[0]) / 2 + + p := make([]fr.Element, degree) var mu sync.Mutex - computeAll := func(start, end int) { // compute method to allow parallelization across instances + computeAll := func(start, end int) { var step fr.Element - evaluator := c.gateEvaluatorPool.get() - defer c.gateEvaluatorPool.put(evaluator) + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() - res := make([]fr.Element, degGJ) + res := make([]fr.Element, degree) // evaluations of ml, laid out as: // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), // ... - // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) - mlEvals := make([]fr.Element, degGJ*len(ml)) - - for h := start; h < end; h++ { // h counts across instances + // ml[0](degree, h...), ml[1](degree, h...), ..., ml[len(ml)-1](degree, h...) + mlEvals := make([]fr.Element, degree*len(ml)) + for h := start; h < end; h++ { evalAt1Index := sumSize + h for k := range ml { - // d = 0 - mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1, taken directly from the table step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) - for d := 1; d < degGJ; d++ { + for d := 1; d < degree; d++ { mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - eIndex := 0 // index for where the current eq term is + eIndex := 0 // start of the current row's eq evaluations nextEIndex := len(ml) - for d := range degGJ { - // Push gate inputs - for i := range nbGateIn { - evaluator.pushInput(mlEvals[eIndex+1+i]) + for d := range degree { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &mlEvals[eIndex+w]) + res[d].Add(&res[d], summand) // collect contributions into the sum from start to end } - summand := evaluator.evaluate() - summand.Mul(summand, &mlEvals[eIndex]) - res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := range gJ { - gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + for i := range p { + p[i].Add(&p[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if sumSize < minBlockSize { - // no parallelization computeAll(0, sumSize) } else { - c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } - return gJ + return p } -// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. -// Thus, j <- j+1 and rⱼ = challenge. -func (c *zeroCheckClaims) next(challenge fr.Element) polynomial.Polynomial { +// roundFold folds all input and eq polynomials at the verifier challenge r. +// After this call, j ← j+1 and rⱼ = r. +func (c *zeroCheckClaims) roundFold(r fr.Element) { const minBlockSize = 512 - n := len(c.eq) / 2 + n := len(c.eqs[0]) / 2 if n < minBlockSize { - // no parallelization for i := range c.input { - c.input[i].Fold(challenge) + c.input[i].Fold(r) + } + for i := range c.eqs { + c.eqs[i].Fold(r) } - c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.input)) + wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { - wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + wgs[i] = c.resources.workers.Submit(n, c.input[i].FoldParallel(r), minBlockSize) + } + for i := range c.eqs { + wgs[len(c.input)+i] = c.resources.workers.Submit(n, c.eqs[i].FoldParallel(r), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } } - - return c.computeGJ() -} - -func (c *zeroCheckClaims) varsNum() int { - return len(c.evaluationPoints[0]) } -func (c *zeroCheckClaims) claimsNum() int { - return len(c.claimedEvaluations) -} - -// proveFinalEval provides the values wᵢ(r₁, ..., rₙ) +// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). func (c *zeroCheckClaims) proveFinalEval(r []fr.Element) []fr.Element { - //defer the proof, return list of claims - - injection, _ := c.manager.circuit.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers. - evaluations := make([]fr.Element, len(injection)) - for i, gateInputI := range injection { - wI := c.input[gateInputI] - wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. - c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0]) - evaluations[i] = wI[0] + c.resources.outgoingEvalPoints[c.levelI] = [][]fr.Element{r} + evaluations := make([]fr.Element, len(c.input)) + for i := range c.input { + c.input[i].Fold(r[len(r)-1]) + evaluations[i] = c.input[i][0] + } + for i := range c.input { + c.resources.memPool.Dump(c.input[i]) + } + for i := range c.eqs { + c.resources.memPool.Dump(c.eqs[i]) + } + for _, pool := range c.gateEvaluatorPools { + pool.dumpAll() } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - c.gateEvaluatorPool.dumpAll() - return evaluations } -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool - circuit Circuit -} +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- m[0] · eq(q, -). +// e <- e + m +func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) -func newClaimsManager(circuit Circuit, assignment WireAssignment, o settings) (manager claimsManager) { - manager.assignment = assignment - manager.claims = make([]*zeroCheckLazyClaims, len(circuit)) - manager.memPool = o.pool - manager.workers = o.workers - manager.circuit = circuit + // At the end of each iteration, m(h₁, ..., hₙ) = m[0] · eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // 1-based in comments: q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - for i := range circuit { - manager.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]fr.Element, 0, circuit[i].NbClaims()), - claimedEvaluations: manager.memPool.Make(circuit[i].NbClaims()), - manager: &manager, + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + } else { + r.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + }, 1024).Wait() } } - return -} - -func (m *claimsManager) add(wire int, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) + r.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() } -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] +type resources struct { + // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. + // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). + // SumcheckLevels produce one point (k=0). SkipLevels pass on all their evaluation points. + outgoingEvalPoints [][][]fr.Element + nbVars int + assignment WireAssignment + memPool polynomial.Pool + workers *utils.WorkerPool + circuit Circuit + schedule constraint.GkrProvingSchedule + transcript transcript + uniqueInputIndices [][]int // uniqueInputIndices[wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from } -func (m *claimsManager) getClaim(wireI int) *zeroCheckClaims { - lazy := m.claims[wireI] - wire := m.circuit[wireI] - res := &zeroCheckClaims{ - wireI: wireI, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])} - } else { - res.input = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied +func newResources(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (resources, error) { + nbVars := assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1<= 2 { + foldingCoeff = r.transcript.getChallenge() } -} -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) + uniqueInputs, inputIndices := r.circuit.InputMapping(level) + input := make([]polynomial.MultiLin, len(uniqueInputs)) + for i, inW := range uniqueInputs { + input[i] = r.memPool.Clone(r.assignment[inW]) } - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { + newEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + aI := alpha + for k := 1; k < nbSources; k++ { + aI.Mul(&aI, &foldingCoeff) + newEq[0].Set(&aI) + r.eqAcc(groupEq, newEq, r.outgoingEvalPoints[group.ClaimSources[k].Level][group.ClaimSources[k].OutgoingClaimIndex]) + } + r.memPool.Dump(newEq) + } -func ChallengeNames(c Circuit, logNbInstances int, prefix string) []string { + var stride fr.Element + stride.Set(&foldingCoeff) + for range nbSources - 1 { + stride.Mul(&stride, &foldingCoeff) + } - // Pre-compute the size TODO: Consider not doing this and just grow the list by appending - size := logNbInstances // first challenge + eqs[levelWireI] = groupEq + levelWireI++ + alpha.Mul(&alpha, &stride) - for i := range c { - if c[i].NoProof() { // no proof, no challenge - continue - } - if c[i].NbClaims() > 1 { //fold the claims - size++ + for w := 1; w < len(group.Wires); w++ { + eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength)) + r.workers.Submit(eqLength, func(start, end int) { + for i := start; i < end; i++ { + eqs[levelWireI][i].Mul(&eqs[levelWireI-1][i], &stride) + } + }, 512).Wait() + levelWireI++ + alpha.Mul(&alpha, &stride) } - size += logNbInstances // full run of sumcheck on logNbInstances variables } - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) + claims := &zeroCheckClaims{ + levelI: levelI, + resources: r, + input: input, + inputIndices: inputIndices, + eqs: eqs, + gateEvaluatorPools: pools, } + return sumcheckProve(claims, &r.transcript) +} - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + var foldingCoeff fr.Element + if nbClaims >= 2 { + foldingCoeff = r.transcript.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } + initialChallengeI := len(r.schedule) + claimedEvals := make(polynomial.Polynomial, 0, level.NbClaims()) - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + if src.Level == initialChallengeI { + claimedEvals = append(claimedEvals, r.assignment[wI].Evaluate(r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], &r.memPool)) + } else { + claimedEvals = append(claimedEvals, proof[src.Level].finalEvalProof[r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex)]) + } + } } } - return challenges -} -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} + claimedSum := claimedEvals.Eval(&foldingCoeff) -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err != nil { - return nil, err - } else if err = res[i].SetBytesCanonical(bytes); err != nil { - return nil, err - } + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + resources: r, + levelI: levelI, } - return res, nil + return sumcheckVerify(lazyClaims, proof[levelI], claimedSum, r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), &r.transcript) } // Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) +func Prove(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (Proof, error) { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return nil, err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) + proof := make(Proof, len(schedule)) - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(i) - if wire.NoProof() { // input wires with one claim only - proof[i] = sumcheckProof{ - partialSumPolys: []polynomial.Polynomial{}, - finalEvalProof: []fr.Element{}, - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + proof[levelI] = r.proveSkipLevel(levelI) } else { - if proof[i], err = sumcheckProve( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - baseChallenge = make([][]byte, len(proof[i].finalEvalProof)) - for j := range proof[i].finalEvalProof { - baseChallenge[j] = proof[i].finalEvalProof[j].Marshal() - } + proof[levelI] = r.proveSumcheckLevel(levelI) } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return proof, nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) +// Verify the consistency of the claimed output with the claimed input. +// Unlike in Prove, the assignment argument need not be complete. +func Verify(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, hasher hash.Hash) error { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - if len(claim.evaluationPoints) == 0 || len(claim.claimedEvaluations) == 0 { - return errors.New("missing input wire claim") - } - evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheckVerify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.finalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.finalEvalProof[j].Marshal() - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + err = r.verifySkipLevel(levelI, proof) } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + err = r.verifySumcheckLevel(levelI, proof) + } + if err != nil { + return fmt.Errorf("level %d: %v", levelI, err) } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return nil } @@ -734,14 +636,14 @@ func (p Proof) flatten() iter.Seq2[int, *fr.Element] { // It manages the stack internally and handles input buffering, making it easy to // evaluate the same gate multiple times with different inputs. type gateEvaluator struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode vars []fr.Element nbIn int // number of inputs expected } // newGateEvaluator creates an evaluator for the given compiled gate. // The stack is preloaded with constants and ready for evaluation. -func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { +func newGateEvaluator(gate gkrcore.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { e := gateEvaluator{ gate: gate, nbIn: nbIn, @@ -785,28 +687,28 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // Use switch instead of function pointer for better inlining switch inst.Op { - case gkrtypes.OpAdd: + case gkrcore.OpAdd: dst.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Add(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpMul: + case gkrcore.OpMul: dst.Mul(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Mul(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpSub: + case gkrcore.OpSub: dst.Sub(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Sub(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpNeg: + case gkrcore.OpNeg: dst.Neg(&e.vars[inst.Inputs[0]]) - case gkrtypes.OpMulAcc: + case gkrcore.OpMulAcc: var prod fr.Element prod.Mul(&e.vars[inst.Inputs[1]], &e.vars[inst.Inputs[2]]) dst.Add(&e.vars[inst.Inputs[0]], &prod) - case gkrtypes.OpSumExp17: + case gkrcore.OpSumExp17: // result = (x[0] + x[1] + x[2])^17 var sum fr.Element sum.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) @@ -832,14 +734,14 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // gateEvaluatorPool manages a pool of gate evaluators for a specific gate type // All evaluators share the same underlying polynomial.Pool for element slices type gateEvaluatorPool struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode nbIn int lock sync.Mutex available map[*gateEvaluator]struct{} elementPool *polynomial.Pool } -func newGateEvaluatorPool(gate gkrtypes.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { +func newGateEvaluatorPool(gate gkrcore.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { gep := &gateEvaluatorPool{ gate: gate, nbIn: nbIn, @@ -867,7 +769,7 @@ func (gep *gateEvaluatorPool) put(e *gateEvaluator) { gep.lock.Lock() defer gep.lock.Unlock() - // Return evaluator to pool (it keeps its vars slice from polynomial pool) + // Return evaluator to pool (it keeps its vars slice from the polynomial pool) gep.available[e] = struct{}{} } diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go index 012683657a..5f9fa833f8 100644 --- a/internal/gkr/bls12-381/gkr_test.go +++ b/internal/gkr/bls12-381/gkr_test.go @@ -11,7 +11,6 @@ import ( "os" "path/filepath" "reflect" - "strconv" "testing" "time" @@ -19,10 +18,9 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - gcUtils "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/stretchr/testify/assert" ) @@ -69,36 +67,164 @@ func TestMimc(t *testing.T) { test(t, gkrtesting.MiMCCircuit(93)) } -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - // Construct SerializableCircuit directly, bypassing CompileCircuit - // which would reset NbUniqueOutputs based on actual topology - circuit := gkrtypes.SerializableCircuit{ +func TestPoseidon2(t *testing.T) { + test(t, gkrtesting.Poseidon2Circuit(4, 2)) +} + +// testSumcheckLevel exercises proveSumcheckLevel/verifySumcheckLevel for a single sumcheck level. +func testSumcheckLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) + + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + + assignment.Complete(sCircuit) + + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSumcheckLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySumcheckLevel(0, proof)) +} + +func TestSumcheckLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3,4 = mul(0,1). All gates are independent outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + } + // Each level has an initial challenge at index 1 (len(schedule) = 1). + // GkrClaimSource{Level:1} is the initial-challenge sentinel. + tests := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single wire", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "two groups", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "one group with two wires", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, { - NbUniqueOutputs: 2, - Gate: gkrtypes.SerializableGate{Degree: 1}, + name: "mixed: single + multi-wire group", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3, 2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, }, } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testSumcheckLevel(t, circuit, tc.level) + }) + } +} + +// testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. +func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) - assignment := WireAssignment{[]fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := gcUtils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(0, []fr.Element{three}, five) - manager.add(0, []fr.Element{four}, six) - return &manager + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() } - transcriptGen := newMessageCounterGenerator(4, 1) + assignment.Complete(sCircuit) - proof, err := sumcheckProve(claimsManagerGen().getClaim(0), fiatshamir.WithHash(transcriptGen(), nil)) + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) - err = sumcheckVerify(claimsManagerGen().getLazyClaim(0), proof, fiatshamir.WithHash(transcriptGen(), nil)) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSkipLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySkipLevel(0, proof)) +} + +func TestSkipLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3 = identity(0); wire 4 = add(0,1). All degree-1 outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Add2, Inputs: []int{0, 1}}, + } + + // Single-claim cases: one inherited evaluation point (OutgoingClaimIndex always 0). + singleClaim := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single input wire", + level: constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "single identity gate", + level: constraint.GkrSkipLevel{Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "add gate", + level: constraint.GkrSkipLevel{Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "two identity gates one group", + level: constraint.GkrSkipLevel{Wires: []int{2, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + for _, tc := range singleClaim { + t.Run(tc.name, func(t *testing.T) { + testSkipLevel(t, circuit, tc.level) + }) + } } var one, two, three, four, five, six fr.Element @@ -112,31 +238,20 @@ func init() { six.Double(&three) } -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances +func test(t *testing.T, circuit gkrcore.RawCircuit) { + testWithSchedule(t, circuit, nil) } -func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { - sCircuit := cache.Compile(t, circuit) - ins := circuit.Inputs() +func testWithSchedule(t *testing.T, circuit gkrcore.RawCircuit, schedule constraint.GkrProvingSchedule) { + gCircuit, sCircuit := cache.Compile(t, circuit) + if schedule == nil { + var err error + schedule, err = gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + } + ins := gCircuit.Inputs() insAssignment := make(WireAssignment, len(ins)) - maxSize := 1 << getLogMaxInstances(t) + maxSize := 1 << gkrtesting.GetLogMaxInstances(t) for i := range ins { insAssignment[i] = make([]fr.Element, maxSize) @@ -151,51 +266,33 @@ func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { fullAssignment.Complete(sCircuit) - t.Log("Selected inputs for test") - - proof, err := Prove(sCircuit, fullAssignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(sCircuit, schedule, fullAssignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") - if proof.isEmpty() { // special case for TestNoGate: - continue // there's no way to make a trivial proof fail - } - - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(0, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(0, 1)) assert.NotNil(t, err, "bad proof accepted") } - -} - -func (p Proof) isEmpty() bool { - for i := range p { - if len(p[i].finalEvalProof) != 0 { - return false - } - for j := range p[i].partialSumPolys { - if len(p[i].partialSumPolys[j]) != 0 { - return false - } - } - } - return true } func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := cache.Compile(t, gkrtesting.NoGateCircuit()) + _, c := cache.Compile(t, gkrtesting.NoGateCircuit()) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(t, err) assignment := WireAssignment{0: inputAssignments[0]} - proof, err := Prove(c, assignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(c, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(c, assignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(c, schedule, assignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") } @@ -203,7 +300,7 @@ func generateTestProver(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + proof, err := Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) assert.NoError(t, err) assert.NoError(t, proofEquals(testCase.Proof, proof)) } @@ -213,17 +310,29 @@ func generateTestVerifier(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, newMessageCounter(2, 0)) assert.NotNil(t, err, "bad proof accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0].Add(&testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered output accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[0][0].Add(&testCase.InOutAssignment[0][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered input accepted") } } func TestGkrVectors(t *testing.T) { - const testDirPath = "../test_vectors/" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) @@ -267,7 +376,10 @@ func proofEquals(expected Proof, seen Proof) error { func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") - c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + _, c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(b, err) in0 := make([]fr.Element, nbInstances) in1 := make([]fr.Element, nbInstances) @@ -283,12 +395,30 @@ func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { //b.ResetTimer() fmt.Println("constructing proof") start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + _, err = Prove(c, schedule, assignment, mimc.NewMiMC()) proved := time.Now().UnixMicro() - start fmt.Println("proved in", proved, "μs") assert.NoError(b, err) } +// TestSingleMulGateExplicitSchedule tests a single mul gate with an explicit single-step schedule, +// equivalent to the default but constructed manually to exercise the schedule path. +func TestSingleMulGateExplicitSchedule(t *testing.T) { + circuit := gkrtesting.SingleMulGateCircuit() + _, sCircuit := cache.Compile(t, circuit) + + // Wire 2 is the mul gate output (inputs: 0, 1). + // Explicit schedule: one GkrProvingLevel for wire 2. + // GkrClaimSource{Level:1} is the initial-challenge sentinel (len(schedule)=1). + schedule := constraint.GkrProvingSchedule{ + constraint.GkrSumcheckLevel{ + {Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + testWithSchedule(t, circuit, schedule) + _ = sCircuit +} + func BenchmarkGkrMimc19(b *testing.B) { benchmarkGkrMiMC(b, 1<<19, 91) } @@ -327,11 +457,12 @@ func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { } type TestCase struct { - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit Hash hash.Hash Proof Proof FullAssignment WireAssignment InOutAssignment WireAssignment + Schedule constraint.GkrProvingSchedule } var testCases = make(map[string]*TestCase) @@ -362,6 +493,20 @@ func newTestCase(path string) (*TestCase, error) { if proof, err = unmarshalProof(info.Proof); err != nil { return nil, err } + var schedule constraint.GkrProvingSchedule + if schedule, err = info.Schedule.ToProvingSchedule(); err != nil { + return nil, err + } + if schedule == nil { + if schedule, err = gkrcore.DefaultProvingSchedule(circuit); err != nil { + return nil, err + } + } + + outputSet := make(map[int]bool, len(circuit)) + for _, o := range circuit.Outputs() { + outputSet[o] = true + } fullAssignment := make(WireAssignment, len(circuit)) inOutAssignment := make(WireAssignment, len(circuit)) @@ -375,7 +520,7 @@ func newTestCase(path string) (*TestCase, error) { } assignmentRaw = info.Input[inI] inI++ - } else if circuit[i].IsOutput() { + } else if outputSet[i] { if outI == len(info.Output) { return nil, fmt.Errorf("fewer output in vector than in circuit") } @@ -396,7 +541,7 @@ func newTestCase(path string) (*TestCase, error) { fullAssignment.Complete(circuit) for i := range circuit { - if circuit[i].IsOutput() { + if outputSet[i] { if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil { return nil, fmt.Errorf("assignment mismatch: %v", err) } @@ -409,6 +554,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, + Schedule: schedule, } testCases[path] = tCase diff --git a/internal/gkr/bls12-381/sumcheck.go b/internal/gkr/bls12-381/sumcheck.go index 4ff812a8ce..0b6196a393 100644 --- a/internal/gkr/bls12-381/sumcheck.go +++ b/internal/gkr/bls12-381/sumcheck.go @@ -7,33 +7,62 @@ package gkr import ( "errors" - "strconv" + "hash" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" ) -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. +// This does not make use of parallelism and represents polynomials as lists of coefficients. + +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.Hash + bound bool // whether Bind was called since the last getChallenge +} + +// Bind writes field elements to the transcript as bindings for the next challenge. +func (t *transcript) Bind(elements ...fr.Element) { + if len(elements) == 0 { + return + } + for i := range elements { + bytes := elements[i].Bytes() + t.h.Write(bytes[:]) + } + t.bound = true +} + +// getChallenge binds optional elements, then squeezes a challenge from the current hash state. +// If no bindings were added since the last squeeze, a separator byte is written first +// to advance the state and prevent repeated values. +func (t *transcript) getChallenge(bindings ...fr.Element) fr.Element { + t.Bind(bindings...) + if !t.bound { + t.h.Write([]byte{0}) + } + t.bound = false + var res fr.Element + res.SetBytes(t.h.Sum(nil)) + return res +} // sumcheckClaims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type sumcheckClaims interface { - fold(a fr.Element) polynomial.Polynomial // fold into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + roundPolynomial() polynomial.Polynomial // compute gⱼ polynomial for current round + roundFold(r fr.Element) // fold inputs and eq at challenge r varsNum() int // number of variables - claimsNum() int // number of claims proveFinalEval(r []fr.Element) []fr.Element // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // sumcheckLazyClaims is the sumcheckClaims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(a fr.Element) fr.Element // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(r []fr.Element, foldingCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error + varsNum() int // varsNum = n + degree(i int) int // degree of the total claim in the i'th variable + verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } // sumcheckProof of a multi-statement. @@ -42,130 +71,46 @@ type sumcheckProof struct { finalEvalProof []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// sumcheckProve create a non-interactive proof -func sumcheckProve(claims sumcheckClaims, transcriptSettings fiatshamir.Settings) (sumcheckProof, error) { - - var proof sumcheckProof - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var foldingCoeff fr.Element - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - +// sumcheckProve creates a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (proveLevel). +// Pattern: roundPolynomial, [roundFold, roundPolynomial]*, proveFinalEval. +func sumcheckProve(claims sumcheckClaims, t *transcript) sumcheckProof { varsNum := claims.varsNum() - proof.partialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.partialSumPolys[0] = claims.fold(foldingCoeff) + proof := sumcheckProof{partialSumPolys: make([]polynomial.Polynomial, varsNum)} + proof.partialSumPolys[0] = claims.roundPolynomial() challenges := make([]fr.Element, varsNum) - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.partialSumPolys[j+1] = claims.next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.partialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err + for j := range varsNum - 1 { + challenges[j] = t.getChallenge(proof.partialSumPolys[j]...) + claims.roundFold(challenges[j]) + proof.partialSumPolys[j+1] = claims.roundPolynomial() } + challenges[varsNum-1] = t.getChallenge(proof.partialSumPolys[varsNum-1]...) proof.finalEvalProof = claims.proveFinalEval(challenges) - - return proof, nil + return proof } -func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff fr.Element - - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - +// sumcheckVerify verifies a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (verifyLevel). +// claimedSum is the expected sum; degree is the polynomial's degree in each variable. +func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum fr.Element, degree int, t *transcript) error { r := make([]fr.Element, claims.varsNum()) - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum for j := range claims.varsNum() { - if len(proof.partialSumPolys[j]) != claims.degree(j) { + if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) - //Prepare for the next iteration - if r[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.degree(j) + 1)]) + r[j] = t.getChallenge(proof.partialSumPolys[j]...) + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) gJR = gJCoeffs.Eval(&r[j]) } - return claims.verifyFinalEval(r, foldingCoeff, gJR, proof.finalEvalProof) + return claims.verifyFinalEval(r, gJR, proof.finalEvalProof) } diff --git a/internal/gkr/bls12-381/sumcheck_test.go b/internal/gkr/bls12-381/sumcheck_test.go index 70574588a9..5f8a12fc5a 100644 --- a/internal/gkr/bls12-381/sumcheck_test.go +++ b/internal/gkr/bls12-381/sumcheck_test.go @@ -10,7 +10,6 @@ import ( "hash" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/stretchr/testify/assert" "math/bits" @@ -28,11 +27,9 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } claim := singleMultilinClaim{g: poly.Clone()} + t := transcript{h: hashGenerator()} - proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } + proof := sumcheckProve(&claim, &t) var sb strings.Builder for _, p := range proof.partialSumPolys { @@ -48,13 +45,15 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + t = transcript{h: hashGenerator()} + if err := sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t); err != nil { return err } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + t = transcript{h: hashGenerator()} + if sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t) == nil { return fmt.Errorf("bad proof accepted") } return nil @@ -93,18 +92,14 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { +func (c *singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } -func (c singleMultilinClaim) varsNum() int { +func (c *singleMultilinClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -func (c singleMultilinClaim) claimsNum() int { - return 1 -} - func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { sum := g[len(g)/2] for i := len(g)/2 + 1; i < len(g); i++ { @@ -113,13 +108,12 @@ func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { return []fr.Element{sum} } -func (c singleMultilinClaim) fold(fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundPolynomial() polynomial.Polynomial { return sumForX1One(c.g) } -func (c *singleMultilinClaim) next(r fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundFold(r fr.Element) { c.g.Fold(r) - return sumForX1One(c.g) } type singleMultilinLazyClaim struct { @@ -127,7 +121,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, purportedValue fr.Element, proof []fr.Element) error { +func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil @@ -135,15 +129,7 @@ func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, p return fmt.Errorf("mismatch") } -func (c singleMultilinLazyClaim) foldedSum(_ fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) claimsNum() int { +func (c singleMultilinLazyClaim) degree(int) int { return 1 } diff --git a/internal/gkr/bn254/blueprint.go b/internal/gkr/bn254/blueprint.go index eefb3dee8b..a9110c5b30 100644 --- a/internal/gkr/bn254/blueprint.go +++ b/internal/gkr/bn254/blueprint.go @@ -15,10 +15,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) func init() { @@ -34,7 +33,7 @@ type circuitEvaluator struct { // BlueprintSolve is a BN254-specific blueprint for solving GKR circuit instances. type BlueprintSolve struct { // Circuit structure (serialized) - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit NbInstances uint32 // Not serialized - recreated lazily at solve time @@ -204,6 +203,7 @@ func (b *BlueprintSolve) UpdateInstructionTree(inst constraint.Instruction, tree type BlueprintProve struct { SolveBlueprintID constraint.BlueprintID SolveBlueprint *BlueprintSolve `cbor:"-"` // not serialized, set at compile time + Schedule constraint.GkrProvingSchedule HashName string lock sync.Mutex @@ -256,9 +256,11 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra } } + // Create hasher and write base challenges + hsh := hash.NewHash(b.HashName + "_BN254") + // Read initial challenges from instruction calldata (parse dynamically, no metadata) // Format: [0]=totalSize, [1...]=challenge linear expressions - insBytes := make([][]byte, 0) // first challenges calldata := inst.Calldata[1:] // skip size prefix for len(calldata) != 0 { val, delta := s.Read(calldata) @@ -267,17 +269,14 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra // Copy directly from constraint.U64 to fr.Element (both in Montgomery form) var challenge fr.Element copy(challenge[:], val[:]) - insBytes = append(insBytes, challenge.Marshal()) + challengeBytes := challenge.Bytes() + hsh.Write(challengeBytes[:]) } - // Create Fiat-Shamir settings - hsh := hash.NewHash(b.HashName + "_BN254") - fsSettings := fiatshamir.WithHash(hsh, insBytes...) - // Call the BN254-specific Prove function (assignments already WireAssignment type) - proof, err := Prove(solveBlueprint.Circuit, assignments, fsSettings) + proof, err := Prove(solveBlueprint.Circuit, b.Schedule, assignments, hsh) if err != nil { - return fmt.Errorf("bn254 prove failed: %w", err) + return fmt.Errorf("BN254 prove failed: %w", err) } for i, elem := range proof.flatten() { @@ -305,7 +304,7 @@ func (b *BlueprintProve) proofSize() int { } nbPaddedInstances := ecc.NextPowerOfTwo(uint64(b.SolveBlueprint.NbInstances)) logNbInstances := bits.TrailingZeros64(nbPaddedInstances) - return b.SolveBlueprint.Circuit.ProofSize(logNbInstances) + return b.SolveBlueprint.Circuit.ProofSize(b.Schedule, logNbInstances) } // NbOutputs implements Blueprint @@ -434,7 +433,7 @@ func (b *BlueprintGetAssignment) UpdateInstructionTree(inst constraint.Instructi } // NewBlueprints creates and registers all GKR blueprints for BN254 -func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compiler constraint.CustomizableSystem) gkrtypes.Blueprints { +func NewBlueprints(circuit gkrcore.SerializableCircuit, schedule constraint.GkrProvingSchedule, hashName string, compiler constraint.CustomizableSystem) gkrcore.Blueprints { // Create and register solve blueprint solve := &BlueprintSolve{Circuit: circuit} solveID := compiler.AddBlueprint(solve) @@ -443,6 +442,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil prove := &BlueprintProve{ SolveBlueprintID: solveID, SolveBlueprint: solve, + Schedule: schedule, HashName: hashName, } proveID := compiler.AddBlueprint(prove) @@ -453,7 +453,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil } getAssignmentID := compiler.AddBlueprint(getAssignment) - return gkrtypes.Blueprints{ + return gkrcore.Blueprints{ SolveID: solveID, Solve: solve, ProveID: proveID, diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 1658dced5d..617b104f8e 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -8,655 +8,557 @@ package gkr import ( "errors" "fmt" + "hash" "iter" - "strconv" "sync" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) // Type aliases for bytecode-based GKR types type ( - Wire = gkrtypes.SerializableWire - Circuit = gkrtypes.SerializableCircuit + Wire = gkrcore.SerializableWire + Circuit = gkrcore.SerializableCircuit ) // The goal is to prove/verify evaluations of many instances of the same circuit -// WireAssignment is assignment of values to the same wire across many instances of the circuit +// WireAssignment is the assignment of values to the same wire across many instances of the circuit type WireAssignment []polynomial.MultiLin type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all wᵢ and evaluation point xᵢ in the level. +// Its purpose is to batch the checking of multiple wire evaluations at evaluation points. type zeroCheckLazyClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff fr.Element // the coefficient used to fold claims, conventionally 0 if there is only one claim + resources *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -// foldedSum returns ∑ᵢ aⁱ yᵢ -func (e *zeroCheckLazyClaims) foldedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) + return e.resources.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return e.manager.circuit[e.wireI].ZeroCheckDegree() -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. (c is foldingCoeff) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, foldingCoeff, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &foldingCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - wire := e.manager.circuit[e.wireI] - - // the w(...) term - var gateEvaluation fr.Element - if wire.IsInput() { // just compute w(r) - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool) - } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } + return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) +} + +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is checked equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +// The verifier checks consistency by evaluating gateᵥ(inputEvals...) and confirming +// that the full sum matches purportedValue. +func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { + e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} + level := e.resources.schedule[e.levelI] + gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) + + var claimedEvals polynomial.Polynomial + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.resources.circuit[wI] + + var gateEval fr.Element + if wire.IsInput() { + gateEval = e.resources.assignment[wI].Evaluate(r, &e.resources.memPool) + if !gateInputEvals[levelWireI][0].Equal(&gateEval) { + return errors.New("incompatible evaluations") + } + } else { + evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) + for _, v := range gateInputEvals[levelWireI] { + evaluator.pushInput(v) + } + gateEval.Set(evaluator.evaluate()) + } - evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) - for _, uniqueI := range injectionLeftInv { // map from all to unique - evaluator.pushInput(uniqueInputEvaluations[uniqueI]) + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } + levelWireI++ } - - gateEvaluation.Set(evaluator.evaluate()) } - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil + if total := claimedEvals.Eval(&e.foldingCoeff); !total.Equal(&purportedValue) { + return errors.New("incompatible evaluations") } - return errors.New("incompatible evaluations") + return nil } // zeroCheckClaims is a claim for sumcheck (prover side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the proving of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. +// Each wire has its own eq table with the batching coefficients baked in. type zeroCheckClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ) - manager *claimsManager - - input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) - - eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) - - gateEvaluatorPool *gateEvaluatorPool -} - -func (c *zeroCheckClaims) getWire() Wire { - return c.manager.circuit[c.wireI] -} - -// fold the multiple claims into one claim using a random combination (foldingCoeff or c). -// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim -// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and -// i iterates over the claims. -// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). -// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form -// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, -// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. -// The output of fold is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. -func (c *zeroCheckClaims) fold(foldingCoeff fr.Element) polynomial.Polynomial { - varsNum := c.varsNum() - eqLength := 1 << varsNum - claimsNum := c.claimsNum() - // initialize the eq tables ( E ) - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - // E := eq(x₀, -) - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := foldingCoeff - - // E += cⁱ eq(xᵢ, -) - for k := 1; k < claimsNum; k++ { - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - if k+1 < claimsNum { - aI.Mul(&aI, &foldingCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e. -// m <- eq(q, -). -// e <- e + m -func (c *zeroCheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() + levelI int + resources *resources + input []polynomial.MultiLin // UniqueGateInputs order + inputIndices [][]int // [wireInLevel][gateInputJ] → index in input + eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points + gateEvaluatorPools []*gateEvaluatorPool } -// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). -// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). -// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) computeGJ() polynomial.Polynomial { - - wire := c.getWire() - degGJ := wire.ZeroCheckDegree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.input) - - // Both E and wᵢ (the input wires and the eq table) are multilinear, thus - // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. - // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. - ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube - ml[0] = c.eq - copy(ml[1:], c.input) - - sumSize := len(c.eq) / 2 // the range of h, over which we sum - - // Perf-TODO: Collate once at claim "folding" time and not again. then, even folding can be done in one operation every time "next" is called - - gJ := make([]fr.Element, degGJ) +func (c *zeroCheckClaims) varsNum() int { + return c.resources.nbVars +} + +// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). +// By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + + // Both eqs and input are multilinear, thus linear in Xⱼ. + // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly + // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree. + // Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}] + ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs) + copy(ml, c.eqs) + copy(ml[nbWires:], c.input) + + sumSize := len(c.eqs[0]) / 2 + + p := make([]fr.Element, degree) var mu sync.Mutex - computeAll := func(start, end int) { // compute method to allow parallelization across instances + computeAll := func(start, end int) { var step fr.Element - evaluator := c.gateEvaluatorPool.get() - defer c.gateEvaluatorPool.put(evaluator) + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() - res := make([]fr.Element, degGJ) + res := make([]fr.Element, degree) // evaluations of ml, laid out as: // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), // ... - // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) - mlEvals := make([]fr.Element, degGJ*len(ml)) - - for h := start; h < end; h++ { // h counts across instances + // ml[0](degree, h...), ml[1](degree, h...), ..., ml[len(ml)-1](degree, h...) + mlEvals := make([]fr.Element, degree*len(ml)) + for h := start; h < end; h++ { evalAt1Index := sumSize + h for k := range ml { - // d = 0 - mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1, taken directly from the table step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) - for d := 1; d < degGJ; d++ { + for d := 1; d < degree; d++ { mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - eIndex := 0 // index for where the current eq term is + eIndex := 0 // start of the current row's eq evaluations nextEIndex := len(ml) - for d := range degGJ { - // Push gate inputs - for i := range nbGateIn { - evaluator.pushInput(mlEvals[eIndex+1+i]) + for d := range degree { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &mlEvals[eIndex+w]) + res[d].Add(&res[d], summand) // collect contributions into the sum from start to end } - summand := evaluator.evaluate() - summand.Mul(summand, &mlEvals[eIndex]) - res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := range gJ { - gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + for i := range p { + p[i].Add(&p[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if sumSize < minBlockSize { - // no parallelization computeAll(0, sumSize) } else { - c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } - return gJ + return p } -// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. -// Thus, j <- j+1 and rⱼ = challenge. -func (c *zeroCheckClaims) next(challenge fr.Element) polynomial.Polynomial { +// roundFold folds all input and eq polynomials at the verifier challenge r. +// After this call, j ← j+1 and rⱼ = r. +func (c *zeroCheckClaims) roundFold(r fr.Element) { const minBlockSize = 512 - n := len(c.eq) / 2 + n := len(c.eqs[0]) / 2 if n < minBlockSize { - // no parallelization for i := range c.input { - c.input[i].Fold(challenge) + c.input[i].Fold(r) + } + for i := range c.eqs { + c.eqs[i].Fold(r) } - c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.input)) + wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { - wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + wgs[i] = c.resources.workers.Submit(n, c.input[i].FoldParallel(r), minBlockSize) + } + for i := range c.eqs { + wgs[len(c.input)+i] = c.resources.workers.Submit(n, c.eqs[i].FoldParallel(r), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } } - - return c.computeGJ() -} - -func (c *zeroCheckClaims) varsNum() int { - return len(c.evaluationPoints[0]) } -func (c *zeroCheckClaims) claimsNum() int { - return len(c.claimedEvaluations) -} - -// proveFinalEval provides the values wᵢ(r₁, ..., rₙ) +// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). func (c *zeroCheckClaims) proveFinalEval(r []fr.Element) []fr.Element { - //defer the proof, return list of claims - - injection, _ := c.manager.circuit.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers. - evaluations := make([]fr.Element, len(injection)) - for i, gateInputI := range injection { - wI := c.input[gateInputI] - wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. - c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0]) - evaluations[i] = wI[0] + c.resources.outgoingEvalPoints[c.levelI] = [][]fr.Element{r} + evaluations := make([]fr.Element, len(c.input)) + for i := range c.input { + c.input[i].Fold(r[len(r)-1]) + evaluations[i] = c.input[i][0] + } + for i := range c.input { + c.resources.memPool.Dump(c.input[i]) + } + for i := range c.eqs { + c.resources.memPool.Dump(c.eqs[i]) + } + for _, pool := range c.gateEvaluatorPools { + pool.dumpAll() } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - c.gateEvaluatorPool.dumpAll() - return evaluations } -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool - circuit Circuit -} +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- m[0] · eq(q, -). +// e <- e + m +func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) -func newClaimsManager(circuit Circuit, assignment WireAssignment, o settings) (manager claimsManager) { - manager.assignment = assignment - manager.claims = make([]*zeroCheckLazyClaims, len(circuit)) - manager.memPool = o.pool - manager.workers = o.workers - manager.circuit = circuit + // At the end of each iteration, m(h₁, ..., hₙ) = m[0] · eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // 1-based in comments: q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - for i := range circuit { - manager.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]fr.Element, 0, circuit[i].NbClaims()), - claimedEvaluations: manager.memPool.Make(circuit[i].NbClaims()), - manager: &manager, + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + } else { + r.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + }, 1024).Wait() } } - return -} - -func (m *claimsManager) add(wire int, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) + r.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() } -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] +type resources struct { + // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. + // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). + // SumcheckLevels produce one point (k=0). SkipLevels pass on all their evaluation points. + outgoingEvalPoints [][][]fr.Element + nbVars int + assignment WireAssignment + memPool polynomial.Pool + workers *utils.WorkerPool + circuit Circuit + schedule constraint.GkrProvingSchedule + transcript transcript + uniqueInputIndices [][]int // uniqueInputIndices[wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from } -func (m *claimsManager) getClaim(wireI int) *zeroCheckClaims { - lazy := m.claims[wireI] - wire := m.circuit[wireI] - res := &zeroCheckClaims{ - wireI: wireI, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])} - } else { - res.input = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied +func newResources(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (resources, error) { + nbVars := assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1<= 2 { + foldingCoeff = r.transcript.getChallenge() } -} -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) + uniqueInputs, inputIndices := r.circuit.InputMapping(level) + input := make([]polynomial.MultiLin, len(uniqueInputs)) + for i, inW := range uniqueInputs { + input[i] = r.memPool.Clone(r.assignment[inW]) } - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { + newEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + aI := alpha + for k := 1; k < nbSources; k++ { + aI.Mul(&aI, &foldingCoeff) + newEq[0].Set(&aI) + r.eqAcc(groupEq, newEq, r.outgoingEvalPoints[group.ClaimSources[k].Level][group.ClaimSources[k].OutgoingClaimIndex]) + } + r.memPool.Dump(newEq) + } -func ChallengeNames(c Circuit, logNbInstances int, prefix string) []string { + var stride fr.Element + stride.Set(&foldingCoeff) + for range nbSources - 1 { + stride.Mul(&stride, &foldingCoeff) + } - // Pre-compute the size TODO: Consider not doing this and just grow the list by appending - size := logNbInstances // first challenge + eqs[levelWireI] = groupEq + levelWireI++ + alpha.Mul(&alpha, &stride) - for i := range c { - if c[i].NoProof() { // no proof, no challenge - continue - } - if c[i].NbClaims() > 1 { //fold the claims - size++ + for w := 1; w < len(group.Wires); w++ { + eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength)) + r.workers.Submit(eqLength, func(start, end int) { + for i := start; i < end; i++ { + eqs[levelWireI][i].Mul(&eqs[levelWireI-1][i], &stride) + } + }, 512).Wait() + levelWireI++ + alpha.Mul(&alpha, &stride) } - size += logNbInstances // full run of sumcheck on logNbInstances variables } - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) + claims := &zeroCheckClaims{ + levelI: levelI, + resources: r, + input: input, + inputIndices: inputIndices, + eqs: eqs, + gateEvaluatorPools: pools, } + return sumcheckProve(claims, &r.transcript) +} - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + var foldingCoeff fr.Element + if nbClaims >= 2 { + foldingCoeff = r.transcript.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } + initialChallengeI := len(r.schedule) + claimedEvals := make(polynomial.Polynomial, 0, level.NbClaims()) - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + if src.Level == initialChallengeI { + claimedEvals = append(claimedEvals, r.assignment[wI].Evaluate(r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], &r.memPool)) + } else { + claimedEvals = append(claimedEvals, proof[src.Level].finalEvalProof[r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex)]) + } + } } } - return challenges -} -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} + claimedSum := claimedEvals.Eval(&foldingCoeff) -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err != nil { - return nil, err - } else if err = res[i].SetBytesCanonical(bytes); err != nil { - return nil, err - } + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + resources: r, + levelI: levelI, } - return res, nil + return sumcheckVerify(lazyClaims, proof[levelI], claimedSum, r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), &r.transcript) } // Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) +func Prove(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (Proof, error) { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return nil, err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) + proof := make(Proof, len(schedule)) - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(i) - if wire.NoProof() { // input wires with one claim only - proof[i] = sumcheckProof{ - partialSumPolys: []polynomial.Polynomial{}, - finalEvalProof: []fr.Element{}, - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + proof[levelI] = r.proveSkipLevel(levelI) } else { - if proof[i], err = sumcheckProve( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - baseChallenge = make([][]byte, len(proof[i].finalEvalProof)) - for j := range proof[i].finalEvalProof { - baseChallenge[j] = proof[i].finalEvalProof[j].Marshal() - } + proof[levelI] = r.proveSumcheckLevel(levelI) } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return proof, nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) +// Verify the consistency of the claimed output with the claimed input. +// Unlike in Prove, the assignment argument need not be complete. +func Verify(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, hasher hash.Hash) error { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - if len(claim.evaluationPoints) == 0 || len(claim.claimedEvaluations) == 0 { - return errors.New("missing input wire claim") - } - evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheckVerify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.finalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.finalEvalProof[j].Marshal() - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + err = r.verifySkipLevel(levelI, proof) } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + err = r.verifySumcheckLevel(levelI, proof) + } + if err != nil { + return fmt.Errorf("level %d: %v", levelI, err) } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return nil } @@ -734,14 +636,14 @@ func (p Proof) flatten() iter.Seq2[int, *fr.Element] { // It manages the stack internally and handles input buffering, making it easy to // evaluate the same gate multiple times with different inputs. type gateEvaluator struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode vars []fr.Element nbIn int // number of inputs expected } // newGateEvaluator creates an evaluator for the given compiled gate. // The stack is preloaded with constants and ready for evaluation. -func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { +func newGateEvaluator(gate gkrcore.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { e := gateEvaluator{ gate: gate, nbIn: nbIn, @@ -785,28 +687,28 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // Use switch instead of function pointer for better inlining switch inst.Op { - case gkrtypes.OpAdd: + case gkrcore.OpAdd: dst.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Add(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpMul: + case gkrcore.OpMul: dst.Mul(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Mul(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpSub: + case gkrcore.OpSub: dst.Sub(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Sub(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpNeg: + case gkrcore.OpNeg: dst.Neg(&e.vars[inst.Inputs[0]]) - case gkrtypes.OpMulAcc: + case gkrcore.OpMulAcc: var prod fr.Element prod.Mul(&e.vars[inst.Inputs[1]], &e.vars[inst.Inputs[2]]) dst.Add(&e.vars[inst.Inputs[0]], &prod) - case gkrtypes.OpSumExp17: + case gkrcore.OpSumExp17: // result = (x[0] + x[1] + x[2])^17 var sum fr.Element sum.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) @@ -832,14 +734,14 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // gateEvaluatorPool manages a pool of gate evaluators for a specific gate type // All evaluators share the same underlying polynomial.Pool for element slices type gateEvaluatorPool struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode nbIn int lock sync.Mutex available map[*gateEvaluator]struct{} elementPool *polynomial.Pool } -func newGateEvaluatorPool(gate gkrtypes.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { +func newGateEvaluatorPool(gate gkrcore.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { gep := &gateEvaluatorPool{ gate: gate, nbIn: nbIn, @@ -867,7 +769,7 @@ func (gep *gateEvaluatorPool) put(e *gateEvaluator) { gep.lock.Lock() defer gep.lock.Unlock() - // Return evaluator to pool (it keeps its vars slice from polynomial pool) + // Return evaluator to pool (it keeps its vars slice from the polynomial pool) gep.available[e] = struct{}{} } diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go index 6f6f0b5432..3d00b93430 100644 --- a/internal/gkr/bn254/gkr_test.go +++ b/internal/gkr/bn254/gkr_test.go @@ -11,7 +11,6 @@ import ( "os" "path/filepath" "reflect" - "strconv" "testing" "time" @@ -19,10 +18,9 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - gcUtils "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/stretchr/testify/assert" ) @@ -69,36 +67,164 @@ func TestMimc(t *testing.T) { test(t, gkrtesting.MiMCCircuit(93)) } -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - // Construct SerializableCircuit directly, bypassing CompileCircuit - // which would reset NbUniqueOutputs based on actual topology - circuit := gkrtypes.SerializableCircuit{ +func TestPoseidon2(t *testing.T) { + test(t, gkrtesting.Poseidon2Circuit(4, 2)) +} + +// testSumcheckLevel exercises proveSumcheckLevel/verifySumcheckLevel for a single sumcheck level. +func testSumcheckLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) + + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + + assignment.Complete(sCircuit) + + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSumcheckLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySumcheckLevel(0, proof)) +} + +func TestSumcheckLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3,4 = mul(0,1). All gates are independent outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + } + // Each level has an initial challenge at index 1 (len(schedule) = 1). + // GkrClaimSource{Level:1} is the initial-challenge sentinel. + tests := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single wire", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "two groups", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "one group with two wires", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, { - NbUniqueOutputs: 2, - Gate: gkrtypes.SerializableGate{Degree: 1}, + name: "mixed: single + multi-wire group", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3, 2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, }, } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testSumcheckLevel(t, circuit, tc.level) + }) + } +} + +// testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. +func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) - assignment := WireAssignment{[]fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := gcUtils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(0, []fr.Element{three}, five) - manager.add(0, []fr.Element{four}, six) - return &manager + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() } - transcriptGen := newMessageCounterGenerator(4, 1) + assignment.Complete(sCircuit) - proof, err := sumcheckProve(claimsManagerGen().getClaim(0), fiatshamir.WithHash(transcriptGen(), nil)) + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) - err = sumcheckVerify(claimsManagerGen().getLazyClaim(0), proof, fiatshamir.WithHash(transcriptGen(), nil)) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSkipLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySkipLevel(0, proof)) +} + +func TestSkipLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3 = identity(0); wire 4 = add(0,1). All degree-1 outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Add2, Inputs: []int{0, 1}}, + } + + // Single-claim cases: one inherited evaluation point (OutgoingClaimIndex always 0). + singleClaim := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single input wire", + level: constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "single identity gate", + level: constraint.GkrSkipLevel{Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "add gate", + level: constraint.GkrSkipLevel{Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "two identity gates one group", + level: constraint.GkrSkipLevel{Wires: []int{2, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + for _, tc := range singleClaim { + t.Run(tc.name, func(t *testing.T) { + testSkipLevel(t, circuit, tc.level) + }) + } } var one, two, three, four, five, six fr.Element @@ -112,31 +238,20 @@ func init() { six.Double(&three) } -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances +func test(t *testing.T, circuit gkrcore.RawCircuit) { + testWithSchedule(t, circuit, nil) } -func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { - sCircuit := cache.Compile(t, circuit) - ins := circuit.Inputs() +func testWithSchedule(t *testing.T, circuit gkrcore.RawCircuit, schedule constraint.GkrProvingSchedule) { + gCircuit, sCircuit := cache.Compile(t, circuit) + if schedule == nil { + var err error + schedule, err = gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + } + ins := gCircuit.Inputs() insAssignment := make(WireAssignment, len(ins)) - maxSize := 1 << getLogMaxInstances(t) + maxSize := 1 << gkrtesting.GetLogMaxInstances(t) for i := range ins { insAssignment[i] = make([]fr.Element, maxSize) @@ -151,51 +266,33 @@ func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { fullAssignment.Complete(sCircuit) - t.Log("Selected inputs for test") - - proof, err := Prove(sCircuit, fullAssignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(sCircuit, schedule, fullAssignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") - if proof.isEmpty() { // special case for TestNoGate: - continue // there's no way to make a trivial proof fail - } - - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(0, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(0, 1)) assert.NotNil(t, err, "bad proof accepted") } - -} - -func (p Proof) isEmpty() bool { - for i := range p { - if len(p[i].finalEvalProof) != 0 { - return false - } - for j := range p[i].partialSumPolys { - if len(p[i].partialSumPolys[j]) != 0 { - return false - } - } - } - return true } func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := cache.Compile(t, gkrtesting.NoGateCircuit()) + _, c := cache.Compile(t, gkrtesting.NoGateCircuit()) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(t, err) assignment := WireAssignment{0: inputAssignments[0]} - proof, err := Prove(c, assignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(c, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(c, assignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(c, schedule, assignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") } @@ -203,7 +300,7 @@ func generateTestProver(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + proof, err := Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) assert.NoError(t, err) assert.NoError(t, proofEquals(testCase.Proof, proof)) } @@ -213,17 +310,29 @@ func generateTestVerifier(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, newMessageCounter(2, 0)) assert.NotNil(t, err, "bad proof accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0].Add(&testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered output accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[0][0].Add(&testCase.InOutAssignment[0][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered input accepted") } } func TestGkrVectors(t *testing.T) { - const testDirPath = "../test_vectors/" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) @@ -267,7 +376,10 @@ func proofEquals(expected Proof, seen Proof) error { func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") - c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + _, c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(b, err) in0 := make([]fr.Element, nbInstances) in1 := make([]fr.Element, nbInstances) @@ -283,12 +395,30 @@ func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { //b.ResetTimer() fmt.Println("constructing proof") start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + _, err = Prove(c, schedule, assignment, mimc.NewMiMC()) proved := time.Now().UnixMicro() - start fmt.Println("proved in", proved, "μs") assert.NoError(b, err) } +// TestSingleMulGateExplicitSchedule tests a single mul gate with an explicit single-step schedule, +// equivalent to the default but constructed manually to exercise the schedule path. +func TestSingleMulGateExplicitSchedule(t *testing.T) { + circuit := gkrtesting.SingleMulGateCircuit() + _, sCircuit := cache.Compile(t, circuit) + + // Wire 2 is the mul gate output (inputs: 0, 1). + // Explicit schedule: one GkrProvingLevel for wire 2. + // GkrClaimSource{Level:1} is the initial-challenge sentinel (len(schedule)=1). + schedule := constraint.GkrProvingSchedule{ + constraint.GkrSumcheckLevel{ + {Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + testWithSchedule(t, circuit, schedule) + _ = sCircuit +} + func BenchmarkGkrMimc19(b *testing.B) { benchmarkGkrMiMC(b, 1<<19, 91) } @@ -327,11 +457,12 @@ func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { } type TestCase struct { - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit Hash hash.Hash Proof Proof FullAssignment WireAssignment InOutAssignment WireAssignment + Schedule constraint.GkrProvingSchedule } var testCases = make(map[string]*TestCase) @@ -362,6 +493,20 @@ func newTestCase(path string) (*TestCase, error) { if proof, err = unmarshalProof(info.Proof); err != nil { return nil, err } + var schedule constraint.GkrProvingSchedule + if schedule, err = info.Schedule.ToProvingSchedule(); err != nil { + return nil, err + } + if schedule == nil { + if schedule, err = gkrcore.DefaultProvingSchedule(circuit); err != nil { + return nil, err + } + } + + outputSet := make(map[int]bool, len(circuit)) + for _, o := range circuit.Outputs() { + outputSet[o] = true + } fullAssignment := make(WireAssignment, len(circuit)) inOutAssignment := make(WireAssignment, len(circuit)) @@ -375,7 +520,7 @@ func newTestCase(path string) (*TestCase, error) { } assignmentRaw = info.Input[inI] inI++ - } else if circuit[i].IsOutput() { + } else if outputSet[i] { if outI == len(info.Output) { return nil, fmt.Errorf("fewer output in vector than in circuit") } @@ -396,7 +541,7 @@ func newTestCase(path string) (*TestCase, error) { fullAssignment.Complete(circuit) for i := range circuit { - if circuit[i].IsOutput() { + if outputSet[i] { if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil { return nil, fmt.Errorf("assignment mismatch: %v", err) } @@ -409,6 +554,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, + Schedule: schedule, } testCases[path] = tCase diff --git a/internal/gkr/bn254/sumcheck.go b/internal/gkr/bn254/sumcheck.go index 4c9dc2ca3c..be4d3bf5c4 100644 --- a/internal/gkr/bn254/sumcheck.go +++ b/internal/gkr/bn254/sumcheck.go @@ -7,33 +7,62 @@ package gkr import ( "errors" - "strconv" + "hash" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" ) -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. +// This does not make use of parallelism and represents polynomials as lists of coefficients. + +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.Hash + bound bool // whether Bind was called since the last getChallenge +} + +// Bind writes field elements to the transcript as bindings for the next challenge. +func (t *transcript) Bind(elements ...fr.Element) { + if len(elements) == 0 { + return + } + for i := range elements { + bytes := elements[i].Bytes() + t.h.Write(bytes[:]) + } + t.bound = true +} + +// getChallenge binds optional elements, then squeezes a challenge from the current hash state. +// If no bindings were added since the last squeeze, a separator byte is written first +// to advance the state and prevent repeated values. +func (t *transcript) getChallenge(bindings ...fr.Element) fr.Element { + t.Bind(bindings...) + if !t.bound { + t.h.Write([]byte{0}) + } + t.bound = false + var res fr.Element + res.SetBytes(t.h.Sum(nil)) + return res +} // sumcheckClaims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type sumcheckClaims interface { - fold(a fr.Element) polynomial.Polynomial // fold into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + roundPolynomial() polynomial.Polynomial // compute gⱼ polynomial for current round + roundFold(r fr.Element) // fold inputs and eq at challenge r varsNum() int // number of variables - claimsNum() int // number of claims proveFinalEval(r []fr.Element) []fr.Element // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // sumcheckLazyClaims is the sumcheckClaims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(a fr.Element) fr.Element // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(r []fr.Element, foldingCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error + varsNum() int // varsNum = n + degree(i int) int // degree of the total claim in the i'th variable + verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } // sumcheckProof of a multi-statement. @@ -42,130 +71,46 @@ type sumcheckProof struct { finalEvalProof []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// sumcheckProve create a non-interactive proof -func sumcheckProve(claims sumcheckClaims, transcriptSettings fiatshamir.Settings) (sumcheckProof, error) { - - var proof sumcheckProof - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var foldingCoeff fr.Element - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - +// sumcheckProve creates a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (proveLevel). +// Pattern: roundPolynomial, [roundFold, roundPolynomial]*, proveFinalEval. +func sumcheckProve(claims sumcheckClaims, t *transcript) sumcheckProof { varsNum := claims.varsNum() - proof.partialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.partialSumPolys[0] = claims.fold(foldingCoeff) + proof := sumcheckProof{partialSumPolys: make([]polynomial.Polynomial, varsNum)} + proof.partialSumPolys[0] = claims.roundPolynomial() challenges := make([]fr.Element, varsNum) - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.partialSumPolys[j+1] = claims.next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.partialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err + for j := range varsNum - 1 { + challenges[j] = t.getChallenge(proof.partialSumPolys[j]...) + claims.roundFold(challenges[j]) + proof.partialSumPolys[j+1] = claims.roundPolynomial() } + challenges[varsNum-1] = t.getChallenge(proof.partialSumPolys[varsNum-1]...) proof.finalEvalProof = claims.proveFinalEval(challenges) - - return proof, nil + return proof } -func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff fr.Element - - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - +// sumcheckVerify verifies a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (verifyLevel). +// claimedSum is the expected sum; degree is the polynomial's degree in each variable. +func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum fr.Element, degree int, t *transcript) error { r := make([]fr.Element, claims.varsNum()) - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum for j := range claims.varsNum() { - if len(proof.partialSumPolys[j]) != claims.degree(j) { + if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) - //Prepare for the next iteration - if r[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.degree(j) + 1)]) + r[j] = t.getChallenge(proof.partialSumPolys[j]...) + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) gJR = gJCoeffs.Eval(&r[j]) } - return claims.verifyFinalEval(r, foldingCoeff, gJR, proof.finalEvalProof) + return claims.verifyFinalEval(r, gJR, proof.finalEvalProof) } diff --git a/internal/gkr/bn254/sumcheck_test.go b/internal/gkr/bn254/sumcheck_test.go index 5f68f7c039..15cff2d307 100644 --- a/internal/gkr/bn254/sumcheck_test.go +++ b/internal/gkr/bn254/sumcheck_test.go @@ -10,7 +10,6 @@ import ( "hash" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/stretchr/testify/assert" "math/bits" @@ -28,11 +27,9 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } claim := singleMultilinClaim{g: poly.Clone()} + t := transcript{h: hashGenerator()} - proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } + proof := sumcheckProve(&claim, &t) var sb strings.Builder for _, p := range proof.partialSumPolys { @@ -48,13 +45,15 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + t = transcript{h: hashGenerator()} + if err := sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t); err != nil { return err } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + t = transcript{h: hashGenerator()} + if sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t) == nil { return fmt.Errorf("bad proof accepted") } return nil @@ -93,18 +92,14 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { +func (c *singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } -func (c singleMultilinClaim) varsNum() int { +func (c *singleMultilinClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -func (c singleMultilinClaim) claimsNum() int { - return 1 -} - func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { sum := g[len(g)/2] for i := len(g)/2 + 1; i < len(g); i++ { @@ -113,13 +108,12 @@ func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { return []fr.Element{sum} } -func (c singleMultilinClaim) fold(fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundPolynomial() polynomial.Polynomial { return sumForX1One(c.g) } -func (c *singleMultilinClaim) next(r fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundFold(r fr.Element) { c.g.Fold(r) - return sumForX1One(c.g) } type singleMultilinLazyClaim struct { @@ -127,7 +121,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, purportedValue fr.Element, proof []fr.Element) error { +func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil @@ -135,15 +129,7 @@ func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, p return fmt.Errorf("mismatch") } -func (c singleMultilinLazyClaim) foldedSum(_ fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) claimsNum() int { +func (c singleMultilinLazyClaim) degree(int) int { return 1 } diff --git a/internal/gkr/bw6-761/blueprint.go b/internal/gkr/bw6-761/blueprint.go index c760545c78..f90c542d19 100644 --- a/internal/gkr/bw6-761/blueprint.go +++ b/internal/gkr/bw6-761/blueprint.go @@ -15,10 +15,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) func init() { @@ -34,7 +33,7 @@ type circuitEvaluator struct { // BlueprintSolve is a BW6_761-specific blueprint for solving GKR circuit instances. type BlueprintSolve struct { // Circuit structure (serialized) - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit NbInstances uint32 // Not serialized - recreated lazily at solve time @@ -204,6 +203,7 @@ func (b *BlueprintSolve) UpdateInstructionTree(inst constraint.Instruction, tree type BlueprintProve struct { SolveBlueprintID constraint.BlueprintID SolveBlueprint *BlueprintSolve `cbor:"-"` // not serialized, set at compile time + Schedule constraint.GkrProvingSchedule HashName string lock sync.Mutex @@ -256,9 +256,11 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra } } + // Create hasher and write base challenges + hsh := hash.NewHash(b.HashName + "_BW6_761") + // Read initial challenges from instruction calldata (parse dynamically, no metadata) // Format: [0]=totalSize, [1...]=challenge linear expressions - insBytes := make([][]byte, 0) // first challenges calldata := inst.Calldata[1:] // skip size prefix for len(calldata) != 0 { val, delta := s.Read(calldata) @@ -267,17 +269,14 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra // Copy directly from constraint.U64 to fr.Element (both in Montgomery form) var challenge fr.Element copy(challenge[:], val[:]) - insBytes = append(insBytes, challenge.Marshal()) + challengeBytes := challenge.Bytes() + hsh.Write(challengeBytes[:]) } - // Create Fiat-Shamir settings - hsh := hash.NewHash(b.HashName + "_BW6_761") - fsSettings := fiatshamir.WithHash(hsh, insBytes...) - // Call the BW6_761-specific Prove function (assignments already WireAssignment type) - proof, err := Prove(solveBlueprint.Circuit, assignments, fsSettings) + proof, err := Prove(solveBlueprint.Circuit, b.Schedule, assignments, hsh) if err != nil { - return fmt.Errorf("bw6_761 prove failed: %w", err) + return fmt.Errorf("BW6_761 prove failed: %w", err) } for i, elem := range proof.flatten() { @@ -305,7 +304,7 @@ func (b *BlueprintProve) proofSize() int { } nbPaddedInstances := ecc.NextPowerOfTwo(uint64(b.SolveBlueprint.NbInstances)) logNbInstances := bits.TrailingZeros64(nbPaddedInstances) - return b.SolveBlueprint.Circuit.ProofSize(logNbInstances) + return b.SolveBlueprint.Circuit.ProofSize(b.Schedule, logNbInstances) } // NbOutputs implements Blueprint @@ -434,7 +433,7 @@ func (b *BlueprintGetAssignment) UpdateInstructionTree(inst constraint.Instructi } // NewBlueprints creates and registers all GKR blueprints for BW6_761 -func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compiler constraint.CustomizableSystem) gkrtypes.Blueprints { +func NewBlueprints(circuit gkrcore.SerializableCircuit, schedule constraint.GkrProvingSchedule, hashName string, compiler constraint.CustomizableSystem) gkrcore.Blueprints { // Create and register solve blueprint solve := &BlueprintSolve{Circuit: circuit} solveID := compiler.AddBlueprint(solve) @@ -443,6 +442,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil prove := &BlueprintProve{ SolveBlueprintID: solveID, SolveBlueprint: solve, + Schedule: schedule, HashName: hashName, } proveID := compiler.AddBlueprint(prove) @@ -453,7 +453,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil } getAssignmentID := compiler.AddBlueprint(getAssignment) - return gkrtypes.Blueprints{ + return gkrcore.Blueprints{ SolveID: solveID, Solve: solve, ProveID: proveID, diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index df503a96c7..8d07fe0310 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -8,655 +8,557 @@ package gkr import ( "errors" "fmt" + "hash" "iter" - "strconv" "sync" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" ) // Type aliases for bytecode-based GKR types type ( - Wire = gkrtypes.SerializableWire - Circuit = gkrtypes.SerializableCircuit + Wire = gkrcore.SerializableWire + Circuit = gkrcore.SerializableCircuit ) // The goal is to prove/verify evaluations of many instances of the same circuit -// WireAssignment is assignment of values to the same wire across many instances of the circuit +// WireAssignment is the assignment of values to the same wire across many instances of the circuit type WireAssignment []polynomial.MultiLin type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all wᵢ and evaluation point xᵢ in the level. +// Its purpose is to batch the checking of multiple wire evaluations at evaluation points. type zeroCheckLazyClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff fr.Element // the coefficient used to fold claims, conventionally 0 if there is only one claim + resources *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -// foldedSum returns ∑ᵢ aⁱ yᵢ -func (e *zeroCheckLazyClaims) foldedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) + return e.resources.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return e.manager.circuit[e.wireI].ZeroCheckDegree() -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. (c is foldingCoeff) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, foldingCoeff, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &foldingCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - wire := e.manager.circuit[e.wireI] - - // the w(...) term - var gateEvaluation fr.Element - if wire.IsInput() { // just compute w(r) - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool) - } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } + return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) +} + +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is checked equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +// The verifier checks consistency by evaluating gateᵥ(inputEvals...) and confirming +// that the full sum matches purportedValue. +func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { + e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} + level := e.resources.schedule[e.levelI] + gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) + + var claimedEvals polynomial.Polynomial + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.resources.circuit[wI] + + var gateEval fr.Element + if wire.IsInput() { + gateEval = e.resources.assignment[wI].Evaluate(r, &e.resources.memPool) + if !gateInputEvals[levelWireI][0].Equal(&gateEval) { + return errors.New("incompatible evaluations") + } + } else { + evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) + for _, v := range gateInputEvals[levelWireI] { + evaluator.pushInput(v) + } + gateEval.Set(evaluator.evaluate()) + } - evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) - for _, uniqueI := range injectionLeftInv { // map from all to unique - evaluator.pushInput(uniqueInputEvaluations[uniqueI]) + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } + levelWireI++ } - - gateEvaluation.Set(evaluator.evaluate()) } - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil + if total := claimedEvals.Eval(&e.foldingCoeff); !total.Equal(&purportedValue) { + return errors.New("incompatible evaluations") } - return errors.New("incompatible evaluations") + return nil } // zeroCheckClaims is a claim for sumcheck (prover side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the proving of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. +// Each wire has its own eq table with the batching coefficients baked in. type zeroCheckClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []fr.Element // yᵢ = w(xᵢ) - manager *claimsManager - - input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) - - eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) - - gateEvaluatorPool *gateEvaluatorPool -} - -func (c *zeroCheckClaims) getWire() Wire { - return c.manager.circuit[c.wireI] -} - -// fold the multiple claims into one claim using a random combination (foldingCoeff or c). -// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim -// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and -// i iterates over the claims. -// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). -// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form -// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, -// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. -// The output of fold is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. -func (c *zeroCheckClaims) fold(foldingCoeff fr.Element) polynomial.Polynomial { - varsNum := c.varsNum() - eqLength := 1 << varsNum - claimsNum := c.claimsNum() - // initialize the eq tables ( E ) - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - // E := eq(x₀, -) - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := foldingCoeff - - // E += cⁱ eq(xᵢ, -) - for k := 1; k < claimsNum; k++ { - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - if k+1 < claimsNum { - aI.Mul(&aI, &foldingCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e. -// m <- eq(q, -). -// e <- e + m -func (c *zeroCheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() + levelI int + resources *resources + input []polynomial.MultiLin // UniqueGateInputs order + inputIndices [][]int // [wireInLevel][gateInputJ] → index in input + eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points + gateEvaluatorPools []*gateEvaluatorPool } -// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). -// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). -// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) computeGJ() polynomial.Polynomial { - - wire := c.getWire() - degGJ := wire.ZeroCheckDegree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.input) - - // Both E and wᵢ (the input wires and the eq table) are multilinear, thus - // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. - // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. - ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube - ml[0] = c.eq - copy(ml[1:], c.input) - - sumSize := len(c.eq) / 2 // the range of h, over which we sum - - // Perf-TODO: Collate once at claim "folding" time and not again. then, even folding can be done in one operation every time "next" is called - - gJ := make([]fr.Element, degGJ) +func (c *zeroCheckClaims) varsNum() int { + return c.resources.nbVars +} + +// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). +// By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + + // Both eqs and input are multilinear, thus linear in Xⱼ. + // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly + // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree. + // Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}] + ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs) + copy(ml, c.eqs) + copy(ml[nbWires:], c.input) + + sumSize := len(c.eqs[0]) / 2 + + p := make([]fr.Element, degree) var mu sync.Mutex - computeAll := func(start, end int) { // compute method to allow parallelization across instances + computeAll := func(start, end int) { var step fr.Element - evaluator := c.gateEvaluatorPool.get() - defer c.gateEvaluatorPool.put(evaluator) + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() - res := make([]fr.Element, degGJ) + res := make([]fr.Element, degree) // evaluations of ml, laid out as: // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), // ... - // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) - mlEvals := make([]fr.Element, degGJ*len(ml)) - - for h := start; h < end; h++ { // h counts across instances + // ml[0](degree, h...), ml[1](degree, h...), ..., ml[len(ml)-1](degree, h...) + mlEvals := make([]fr.Element, degree*len(ml)) + for h := start; h < end; h++ { evalAt1Index := sumSize + h for k := range ml { - // d = 0 - mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1, taken directly from the table step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) - for d := 1; d < degGJ; d++ { + for d := 1; d < degree; d++ { mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - eIndex := 0 // index for where the current eq term is + eIndex := 0 // start of the current row's eq evaluations nextEIndex := len(ml) - for d := range degGJ { - // Push gate inputs - for i := range nbGateIn { - evaluator.pushInput(mlEvals[eIndex+1+i]) + for d := range degree { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &mlEvals[eIndex+w]) + res[d].Add(&res[d], summand) // collect contributions into the sum from start to end } - summand := evaluator.evaluate() - summand.Mul(summand, &mlEvals[eIndex]) - res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := range gJ { - gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + for i := range p { + p[i].Add(&p[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if sumSize < minBlockSize { - // no parallelization computeAll(0, sumSize) } else { - c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } - return gJ + return p } -// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. -// Thus, j <- j+1 and rⱼ = challenge. -func (c *zeroCheckClaims) next(challenge fr.Element) polynomial.Polynomial { +// roundFold folds all input and eq polynomials at the verifier challenge r. +// After this call, j ← j+1 and rⱼ = r. +func (c *zeroCheckClaims) roundFold(r fr.Element) { const minBlockSize = 512 - n := len(c.eq) / 2 + n := len(c.eqs[0]) / 2 if n < minBlockSize { - // no parallelization for i := range c.input { - c.input[i].Fold(challenge) + c.input[i].Fold(r) + } + for i := range c.eqs { + c.eqs[i].Fold(r) } - c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.input)) + wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { - wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + wgs[i] = c.resources.workers.Submit(n, c.input[i].FoldParallel(r), minBlockSize) + } + for i := range c.eqs { + wgs[len(c.input)+i] = c.resources.workers.Submit(n, c.eqs[i].FoldParallel(r), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } } - - return c.computeGJ() -} - -func (c *zeroCheckClaims) varsNum() int { - return len(c.evaluationPoints[0]) } -func (c *zeroCheckClaims) claimsNum() int { - return len(c.claimedEvaluations) -} - -// proveFinalEval provides the values wᵢ(r₁, ..., rₙ) +// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). func (c *zeroCheckClaims) proveFinalEval(r []fr.Element) []fr.Element { - //defer the proof, return list of claims - - injection, _ := c.manager.circuit.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers. - evaluations := make([]fr.Element, len(injection)) - for i, gateInputI := range injection { - wI := c.input[gateInputI] - wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. - c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0]) - evaluations[i] = wI[0] + c.resources.outgoingEvalPoints[c.levelI] = [][]fr.Element{r} + evaluations := make([]fr.Element, len(c.input)) + for i := range c.input { + c.input[i].Fold(r[len(r)-1]) + evaluations[i] = c.input[i][0] + } + for i := range c.input { + c.resources.memPool.Dump(c.input[i]) + } + for i := range c.eqs { + c.resources.memPool.Dump(c.eqs[i]) + } + for _, pool := range c.gateEvaluatorPools { + pool.dumpAll() } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - c.gateEvaluatorPool.dumpAll() - return evaluations } -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool - circuit Circuit -} +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- m[0] · eq(q, -). +// e <- e + m +func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) -func newClaimsManager(circuit Circuit, assignment WireAssignment, o settings) (manager claimsManager) { - manager.assignment = assignment - manager.claims = make([]*zeroCheckLazyClaims, len(circuit)) - manager.memPool = o.pool - manager.workers = o.workers - manager.circuit = circuit + // At the end of each iteration, m(h₁, ..., hₙ) = m[0] · eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // 1-based in comments: q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - for i := range circuit { - manager.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]fr.Element, 0, circuit[i].NbClaims()), - claimedEvaluations: manager.memPool.Make(circuit[i].NbClaims()), - manager: &manager, + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + } else { + r.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + }, 1024).Wait() } } - return -} - -func (m *claimsManager) add(wire int, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) + r.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() } -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] +type resources struct { + // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. + // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). + // SumcheckLevels produce one point (k=0). SkipLevels pass on all their evaluation points. + outgoingEvalPoints [][][]fr.Element + nbVars int + assignment WireAssignment + memPool polynomial.Pool + workers *utils.WorkerPool + circuit Circuit + schedule constraint.GkrProvingSchedule + transcript transcript + uniqueInputIndices [][]int // uniqueInputIndices[wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from } -func (m *claimsManager) getClaim(wireI int) *zeroCheckClaims { - lazy := m.claims[wireI] - wire := m.circuit[wireI] - res := &zeroCheckClaims{ - wireI: wireI, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])} - } else { - res.input = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied +func newResources(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (resources, error) { + nbVars := assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1<= 2 { + foldingCoeff = r.transcript.getChallenge() } -} -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) + uniqueInputs, inputIndices := r.circuit.InputMapping(level) + input := make([]polynomial.MultiLin, len(uniqueInputs)) + for i, inW := range uniqueInputs { + input[i] = r.memPool.Clone(r.assignment[inW]) } - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { + newEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + aI := alpha + for k := 1; k < nbSources; k++ { + aI.Mul(&aI, &foldingCoeff) + newEq[0].Set(&aI) + r.eqAcc(groupEq, newEq, r.outgoingEvalPoints[group.ClaimSources[k].Level][group.ClaimSources[k].OutgoingClaimIndex]) + } + r.memPool.Dump(newEq) + } -func ChallengeNames(c Circuit, logNbInstances int, prefix string) []string { + var stride fr.Element + stride.Set(&foldingCoeff) + for range nbSources - 1 { + stride.Mul(&stride, &foldingCoeff) + } - // Pre-compute the size TODO: Consider not doing this and just grow the list by appending - size := logNbInstances // first challenge + eqs[levelWireI] = groupEq + levelWireI++ + alpha.Mul(&alpha, &stride) - for i := range c { - if c[i].NoProof() { // no proof, no challenge - continue - } - if c[i].NbClaims() > 1 { //fold the claims - size++ + for w := 1; w < len(group.Wires); w++ { + eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength)) + r.workers.Submit(eqLength, func(start, end int) { + for i := start; i < end; i++ { + eqs[levelWireI][i].Mul(&eqs[levelWireI-1][i], &stride) + } + }, 512).Wait() + levelWireI++ + alpha.Mul(&alpha, &stride) } - size += logNbInstances // full run of sumcheck on logNbInstances variables } - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) + claims := &zeroCheckClaims{ + levelI: levelI, + resources: r, + input: input, + inputIndices: inputIndices, + eqs: eqs, + gateEvaluatorPools: pools, } + return sumcheckProve(claims, &r.transcript) +} - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + var foldingCoeff fr.Element + if nbClaims >= 2 { + foldingCoeff = r.transcript.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } + initialChallengeI := len(r.schedule) + claimedEvals := make(polynomial.Polynomial, 0, level.NbClaims()) - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + if src.Level == initialChallengeI { + claimedEvals = append(claimedEvals, r.assignment[wI].Evaluate(r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], &r.memPool)) + } else { + claimedEvals = append(claimedEvals, proof[src.Level].finalEvalProof[r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex)]) + } + } } } - return challenges -} -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} + claimedSum := claimedEvals.Eval(&foldingCoeff) -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err != nil { - return nil, err - } else if err = res[i].SetBytesCanonical(bytes); err != nil { - return nil, err - } + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + resources: r, + levelI: levelI, } - return res, nil + return sumcheckVerify(lazyClaims, proof[levelI], claimedSum, r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), &r.transcript) } // Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) +func Prove(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (Proof, error) { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return nil, err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) + proof := make(Proof, len(schedule)) - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(i) - if wire.NoProof() { // input wires with one claim only - proof[i] = sumcheckProof{ - partialSumPolys: []polynomial.Polynomial{}, - finalEvalProof: []fr.Element{}, - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + proof[levelI] = r.proveSkipLevel(levelI) } else { - if proof[i], err = sumcheckProve( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - baseChallenge = make([][]byte, len(proof[i].finalEvalProof)) - for j := range proof[i].finalEvalProof { - baseChallenge[j] = proof[i].finalEvalProof[j].Marshal() - } + proof[levelI] = r.proveSumcheckLevel(levelI) } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return proof, nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) +// Verify the consistency of the claimed output with the claimed input. +// Unlike in Prove, the assignment argument need not be complete. +func Verify(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, hasher hash.Hash) error { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err + // Derive the initial challenge point + firstChallenge := make([]fr.Element, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]fr.Element{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - if len(claim.evaluationPoints) == 0 || len(claim.claimedEvaluations) == 0 { - return errors.New("missing input wire claim") - } - evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheckVerify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.finalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.finalEvalProof[j].Marshal() - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + err = r.verifySkipLevel(levelI, proof) } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + err = r.verifySumcheckLevel(levelI, proof) + } + if err != nil { + return fmt.Errorf("level %d: %v", levelI, err) } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return nil } @@ -734,14 +636,14 @@ func (p Proof) flatten() iter.Seq2[int, *fr.Element] { // It manages the stack internally and handles input buffering, making it easy to // evaluate the same gate multiple times with different inputs. type gateEvaluator struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode vars []fr.Element nbIn int // number of inputs expected } // newGateEvaluator creates an evaluator for the given compiled gate. // The stack is preloaded with constants and ready for evaluation. -func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { +func newGateEvaluator(gate gkrcore.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { e := gateEvaluator{ gate: gate, nbIn: nbIn, @@ -785,28 +687,28 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // Use switch instead of function pointer for better inlining switch inst.Op { - case gkrtypes.OpAdd: + case gkrcore.OpAdd: dst.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Add(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpMul: + case gkrcore.OpMul: dst.Mul(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Mul(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpSub: + case gkrcore.OpSub: dst.Sub(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Sub(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpNeg: + case gkrcore.OpNeg: dst.Neg(&e.vars[inst.Inputs[0]]) - case gkrtypes.OpMulAcc: + case gkrcore.OpMulAcc: var prod fr.Element prod.Mul(&e.vars[inst.Inputs[1]], &e.vars[inst.Inputs[2]]) dst.Add(&e.vars[inst.Inputs[0]], &prod) - case gkrtypes.OpSumExp17: + case gkrcore.OpSumExp17: // result = (x[0] + x[1] + x[2])^17 var sum fr.Element sum.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) @@ -832,14 +734,14 @@ func (e *gateEvaluator) evaluate(top ...fr.Element) *fr.Element { // gateEvaluatorPool manages a pool of gate evaluators for a specific gate type // All evaluators share the same underlying polynomial.Pool for element slices type gateEvaluatorPool struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode nbIn int lock sync.Mutex available map[*gateEvaluator]struct{} elementPool *polynomial.Pool } -func newGateEvaluatorPool(gate gkrtypes.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { +func newGateEvaluatorPool(gate gkrcore.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { gep := &gateEvaluatorPool{ gate: gate, nbIn: nbIn, @@ -867,7 +769,7 @@ func (gep *gateEvaluatorPool) put(e *gateEvaluator) { gep.lock.Lock() defer gep.lock.Unlock() - // Return evaluator to pool (it keeps its vars slice from polynomial pool) + // Return evaluator to pool (it keeps its vars slice from the polynomial pool) gep.available[e] = struct{}{} } diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go index b8c4d5a422..630dfe6fd7 100644 --- a/internal/gkr/bw6-761/gkr_test.go +++ b/internal/gkr/bw6-761/gkr_test.go @@ -11,7 +11,6 @@ import ( "os" "path/filepath" "reflect" - "strconv" "testing" "time" @@ -19,10 +18,9 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - gcUtils "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/stretchr/testify/assert" ) @@ -69,36 +67,164 @@ func TestMimc(t *testing.T) { test(t, gkrtesting.MiMCCircuit(93)) } -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - // Construct SerializableCircuit directly, bypassing CompileCircuit - // which would reset NbUniqueOutputs based on actual topology - circuit := gkrtypes.SerializableCircuit{ +func TestPoseidon2(t *testing.T) { + test(t, gkrtesting.Poseidon2Circuit(4, 2)) +} + +// testSumcheckLevel exercises proveSumcheckLevel/verifySumcheckLevel for a single sumcheck level. +func testSumcheckLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) + + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + + assignment.Complete(sCircuit) + + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSumcheckLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySumcheckLevel(0, proof)) +} + +func TestSumcheckLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3,4 = mul(0,1). All gates are independent outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, + } + // Each level has an initial challenge at index 1 (len(schedule) = 1). + // GkrClaimSource{Level:1} is the initial-challenge sentinel. + tests := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single wire", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "two groups", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, + { + name: "one group with two wires", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + }, { - NbUniqueOutputs: 2, - Gate: gkrtypes.SerializableGate{Degree: 1}, + name: "mixed: single + multi-wire group", + level: constraint.GkrSumcheckLevel{ + {Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + {Wires: []int{3, 2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, }, } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testSumcheckLevel(t, circuit, tc.level) + }) + } +} + +// testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. +func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { + t.Helper() + _, sCircuit := cache.Compile(t, circuit) - assignment := WireAssignment{[]fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := gcUtils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(0, []fr.Element{three}, five) - manager.add(0, []fr.Element{four}, six) - return &manager + ins := sCircuit.Inputs() + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range ins { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() } - transcriptGen := newMessageCounterGenerator(4, 1) + assignment.Complete(sCircuit) - proof, err := sumcheckProve(claimsManagerGen().getClaim(0), fiatshamir.WithHash(transcriptGen(), nil)) + schedule := constraint.GkrProvingSchedule{level} + initEvalPoint := [][]fr.Element{{one}} + + // Prove + proveR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) - err = sumcheckVerify(claimsManagerGen().getLazyClaim(0), proof, fiatshamir.WithHash(transcriptGen(), nil)) + defer proveR.workers.Stop() + + proveR.outgoingEvalPoints[len(schedule)] = initEvalPoint + proof := Proof{proveR.proveSkipLevel(0)} + + // Verify + verifyR, err := newResources(sCircuit, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) + defer verifyR.workers.Stop() + + verifyR.outgoingEvalPoints[len(schedule)] = initEvalPoint + assert.NoError(t, verifyR.verifySkipLevel(0, proof)) +} + +func TestSkipLevel(t *testing.T) { + // Wires 0,1 = inputs; wires 2,3 = identity(0); wire 4 = add(0,1). All degree-1 outputs. + circuit := gkrcore.RawCircuit{ + {}, + {}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Add2, Inputs: []int{0, 1}}, + } + + // Single-claim cases: one inherited evaluation point (OutgoingClaimIndex always 0). + singleClaim := []struct { + name string + level constraint.GkrProvingLevel + }{ + { + name: "single input wire", + level: constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "single identity gate", + level: constraint.GkrSkipLevel{Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "add gate", + level: constraint.GkrSkipLevel{Wires: []int{4}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + { + name: "two identity gates one group", + level: constraint.GkrSkipLevel{Wires: []int{2, 3}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + for _, tc := range singleClaim { + t.Run(tc.name, func(t *testing.T) { + testSkipLevel(t, circuit, tc.level) + }) + } } var one, two, three, four, five, six fr.Element @@ -112,31 +238,20 @@ func init() { six.Double(&three) } -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances +func test(t *testing.T, circuit gkrcore.RawCircuit) { + testWithSchedule(t, circuit, nil) } -func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { - sCircuit := cache.Compile(t, circuit) - ins := circuit.Inputs() +func testWithSchedule(t *testing.T, circuit gkrcore.RawCircuit, schedule constraint.GkrProvingSchedule) { + gCircuit, sCircuit := cache.Compile(t, circuit) + if schedule == nil { + var err error + schedule, err = gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + } + ins := gCircuit.Inputs() insAssignment := make(WireAssignment, len(ins)) - maxSize := 1 << getLogMaxInstances(t) + maxSize := 1 << gkrtesting.GetLogMaxInstances(t) for i := range ins { insAssignment[i] = make([]fr.Element, maxSize) @@ -151,51 +266,33 @@ func test(t *testing.T, circuit gkrtypes.GadgetCircuit) { fullAssignment.Complete(sCircuit) - t.Log("Selected inputs for test") - - proof, err := Prove(sCircuit, fullAssignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(sCircuit, schedule, fullAssignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") - if proof.isEmpty() { // special case for TestNoGate: - continue // there's no way to make a trivial proof fail - } - - err = Verify(sCircuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(0, 1))) + err = Verify(sCircuit, schedule, fullAssignment, proof, newMessageCounter(0, 1)) assert.NotNil(t, err, "bad proof accepted") } - -} - -func (p Proof) isEmpty() bool { - for i := range p { - if len(p[i].finalEvalProof) != 0 { - return false - } - for j := range p[i].partialSumPolys { - if len(p[i].partialSumPolys[j]) != 0 { - return false - } - } - } - return true } func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := cache.Compile(t, gkrtesting.NoGateCircuit()) + _, c := cache.Compile(t, gkrtesting.NoGateCircuit()) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(t, err) assignment := WireAssignment{0: inputAssignments[0]} - proof, err := Prove(c, assignment, fiatshamir.WithHash(newMessageCounter(1, 1))) + proof, err := Prove(c, schedule, assignment, newMessageCounter(1, 1)) assert.NoError(t, err) // Even though a hash is called here, the proof is empty - err = Verify(c, assignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1))) + err = Verify(c, schedule, assignment, proof, newMessageCounter(1, 1)) assert.NoError(t, err, "proof rejected") } @@ -203,7 +300,7 @@ func generateTestProver(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + proof, err := Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) assert.NoError(t, err) assert.NoError(t, proofEquals(testCase.Proof, proof)) } @@ -213,17 +310,29 @@ func generateTestVerifier(path string) func(t *testing.T) { return func(t *testing.T) { testCase, err := newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, newMessageCounter(2, 0)) assert.NotNil(t, err, "bad proof accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0].Add(&testCase.InOutAssignment[len(testCase.InOutAssignment)-1][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered output accepted") + + testCase, err = newTestCase(path) + assert.NoError(t, err) + testCase.InOutAssignment[0][0].Add(&testCase.InOutAssignment[0][0], &one) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, testCase.Proof, testCase.Hash) + assert.NotNil(t, err, "tampered input accepted") } } func TestGkrVectors(t *testing.T) { - const testDirPath = "../test_vectors/" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) @@ -267,7 +376,10 @@ func proofEquals(expected Proof, seen Proof) error { func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") - c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + _, c := cache.Compile(b, gkrtesting.MiMCCircuit(mimcDepth)) + + schedule, err := gkrcore.DefaultProvingSchedule(c) + assert.NoError(b, err) in0 := make([]fr.Element, nbInstances) in1 := make([]fr.Element, nbInstances) @@ -283,12 +395,30 @@ func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { //b.ResetTimer() fmt.Println("constructing proof") start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + _, err = Prove(c, schedule, assignment, mimc.NewMiMC()) proved := time.Now().UnixMicro() - start fmt.Println("proved in", proved, "μs") assert.NoError(b, err) } +// TestSingleMulGateExplicitSchedule tests a single mul gate with an explicit single-step schedule, +// equivalent to the default but constructed manually to exercise the schedule path. +func TestSingleMulGateExplicitSchedule(t *testing.T) { + circuit := gkrtesting.SingleMulGateCircuit() + _, sCircuit := cache.Compile(t, circuit) + + // Wire 2 is the mul gate output (inputs: 0, 1). + // Explicit schedule: one GkrProvingLevel for wire 2. + // GkrClaimSource{Level:1} is the initial-challenge sentinel (len(schedule)=1). + schedule := constraint.GkrProvingSchedule{ + constraint.GkrSumcheckLevel{ + {Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 1}}}, + }, + } + testWithSchedule(t, circuit, schedule) + _ = sCircuit +} + func BenchmarkGkrMimc19(b *testing.B) { benchmarkGkrMiMC(b, 1<<19, 91) } @@ -327,11 +457,12 @@ func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { } type TestCase struct { - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit Hash hash.Hash Proof Proof FullAssignment WireAssignment InOutAssignment WireAssignment + Schedule constraint.GkrProvingSchedule } var testCases = make(map[string]*TestCase) @@ -362,6 +493,20 @@ func newTestCase(path string) (*TestCase, error) { if proof, err = unmarshalProof(info.Proof); err != nil { return nil, err } + var schedule constraint.GkrProvingSchedule + if schedule, err = info.Schedule.ToProvingSchedule(); err != nil { + return nil, err + } + if schedule == nil { + if schedule, err = gkrcore.DefaultProvingSchedule(circuit); err != nil { + return nil, err + } + } + + outputSet := make(map[int]bool, len(circuit)) + for _, o := range circuit.Outputs() { + outputSet[o] = true + } fullAssignment := make(WireAssignment, len(circuit)) inOutAssignment := make(WireAssignment, len(circuit)) @@ -375,7 +520,7 @@ func newTestCase(path string) (*TestCase, error) { } assignmentRaw = info.Input[inI] inI++ - } else if circuit[i].IsOutput() { + } else if outputSet[i] { if outI == len(info.Output) { return nil, fmt.Errorf("fewer output in vector than in circuit") } @@ -396,7 +541,7 @@ func newTestCase(path string) (*TestCase, error) { fullAssignment.Complete(circuit) for i := range circuit { - if circuit[i].IsOutput() { + if outputSet[i] { if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil { return nil, fmt.Errorf("assignment mismatch: %v", err) } @@ -409,6 +554,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, + Schedule: schedule, } testCases[path] = tCase diff --git a/internal/gkr/bw6-761/sumcheck.go b/internal/gkr/bw6-761/sumcheck.go index f083550198..87e6d425ae 100644 --- a/internal/gkr/bw6-761/sumcheck.go +++ b/internal/gkr/bw6-761/sumcheck.go @@ -7,33 +7,62 @@ package gkr import ( "errors" - "strconv" + "hash" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" ) -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. +// This does not make use of parallelism and represents polynomials as lists of coefficients. + +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.Hash + bound bool // whether Bind was called since the last getChallenge +} + +// Bind writes field elements to the transcript as bindings for the next challenge. +func (t *transcript) Bind(elements ...fr.Element) { + if len(elements) == 0 { + return + } + for i := range elements { + bytes := elements[i].Bytes() + t.h.Write(bytes[:]) + } + t.bound = true +} + +// getChallenge binds optional elements, then squeezes a challenge from the current hash state. +// If no bindings were added since the last squeeze, a separator byte is written first +// to advance the state and prevent repeated values. +func (t *transcript) getChallenge(bindings ...fr.Element) fr.Element { + t.Bind(bindings...) + if !t.bound { + t.h.Write([]byte{0}) + } + t.bound = false + var res fr.Element + res.SetBytes(t.h.Sum(nil)) + return res +} // sumcheckClaims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type sumcheckClaims interface { - fold(a fr.Element) polynomial.Polynomial // fold into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + roundPolynomial() polynomial.Polynomial // compute gⱼ polynomial for current round + roundFold(r fr.Element) // fold inputs and eq at challenge r varsNum() int // number of variables - claimsNum() int // number of claims proveFinalEval(r []fr.Element) []fr.Element // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // sumcheckLazyClaims is the sumcheckClaims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(a fr.Element) fr.Element // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(r []fr.Element, foldingCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error + varsNum() int // varsNum = n + degree(i int) int // degree of the total claim in the i'th variable + verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } // sumcheckProof of a multi-statement. @@ -42,130 +71,46 @@ type sumcheckProof struct { finalEvalProof []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// sumcheckProve create a non-interactive proof -func sumcheckProve(claims sumcheckClaims, transcriptSettings fiatshamir.Settings) (sumcheckProof, error) { - - var proof sumcheckProof - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var foldingCoeff fr.Element - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - +// sumcheckProve creates a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (proveLevel). +// Pattern: roundPolynomial, [roundFold, roundPolynomial]*, proveFinalEval. +func sumcheckProve(claims sumcheckClaims, t *transcript) sumcheckProof { varsNum := claims.varsNum() - proof.partialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.partialSumPolys[0] = claims.fold(foldingCoeff) + proof := sumcheckProof{partialSumPolys: make([]polynomial.Polynomial, varsNum)} + proof.partialSumPolys[0] = claims.roundPolynomial() challenges := make([]fr.Element, varsNum) - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.partialSumPolys[j+1] = claims.next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.partialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err + for j := range varsNum - 1 { + challenges[j] = t.getChallenge(proof.partialSumPolys[j]...) + claims.roundFold(challenges[j]) + proof.partialSumPolys[j+1] = claims.roundPolynomial() } + challenges[varsNum-1] = t.getChallenge(proof.partialSumPolys[varsNum-1]...) proof.finalEvalProof = claims.proveFinalEval(challenges) - - return proof, nil + return proof } -func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff fr.Element - - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - +// sumcheckVerify verifies a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (verifyLevel). +// claimedSum is the expected sum; degree is the polynomial's degree in each variable. +func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum fr.Element, degree int, t *transcript) error { r := make([]fr.Element, claims.varsNum()) - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum for j := range claims.varsNum() { - if len(proof.partialSumPolys[j]) != claims.degree(j) { + if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) - //Prepare for the next iteration - if r[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.degree(j) + 1)]) + r[j] = t.getChallenge(proof.partialSumPolys[j]...) + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) gJR = gJCoeffs.Eval(&r[j]) } - return claims.verifyFinalEval(r, foldingCoeff, gJR, proof.finalEvalProof) + return claims.verifyFinalEval(r, gJR, proof.finalEvalProof) } diff --git a/internal/gkr/bw6-761/sumcheck_test.go b/internal/gkr/bw6-761/sumcheck_test.go index 375652e4ec..d5fd5d305a 100644 --- a/internal/gkr/bw6-761/sumcheck_test.go +++ b/internal/gkr/bw6-761/sumcheck_test.go @@ -10,7 +10,6 @@ import ( "hash" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/stretchr/testify/assert" "math/bits" @@ -28,11 +27,9 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } claim := singleMultilinClaim{g: poly.Clone()} + t := transcript{h: hashGenerator()} - proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } + proof := sumcheckProve(&claim, &t) var sb strings.Builder for _, p := range proof.partialSumPolys { @@ -48,13 +45,15 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + t = transcript{h: hashGenerator()} + if err := sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t); err != nil { return err } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + t = transcript{h: hashGenerator()} + if sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t) == nil { return fmt.Errorf("bad proof accepted") } return nil @@ -93,18 +92,14 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { +func (c *singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } -func (c singleMultilinClaim) varsNum() int { +func (c *singleMultilinClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -func (c singleMultilinClaim) claimsNum() int { - return 1 -} - func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { sum := g[len(g)/2] for i := len(g)/2 + 1; i < len(g); i++ { @@ -113,13 +108,12 @@ func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { return []fr.Element{sum} } -func (c singleMultilinClaim) fold(fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundPolynomial() polynomial.Polynomial { return sumForX1One(c.g) } -func (c *singleMultilinClaim) next(r fr.Element) polynomial.Polynomial { +func (c *singleMultilinClaim) roundFold(r fr.Element) { c.g.Fold(r) - return sumForX1One(c.g) } type singleMultilinLazyClaim struct { @@ -127,7 +121,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, purportedValue fr.Element, proof []fr.Element) error { +func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil @@ -135,15 +129,7 @@ func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, _ fr.Element, p return fmt.Errorf("mismatch") } -func (c singleMultilinLazyClaim) foldedSum(_ fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) claimsNum() int { +func (c singleMultilinLazyClaim) degree(int) int { return 1 } diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index b19cd2f478..9009ca71e5 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -3,18 +3,18 @@ package gkr import ( "errors" "fmt" - "strconv" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrtypes" - fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/gkrcore" + "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/polynomial" ) // Type aliases for gadget circuit types type ( - Wire = gkrtypes.GadgetWire - Circuit = gkrtypes.GadgetCircuit + Wire = gkrcore.GadgetWire + Circuit = gkrcore.GadgetCircuit ) // WireAssignment is an assignment of values to the same wire across many instances of the circuit @@ -41,283 +41,191 @@ func (a WireAssignment) NbVars() int { // A SNARK gadget capable of verifying a GKR proof // The goal is to prove/verify evaluations of many instances of the same circuit. -type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +type Proof []sumcheckProof // for each schedule level, a sumcheck proof + +// resources holds all shared state for gadget GKR verification. +type resources struct { + api frontend.API + t *transcript + circuit Circuit + schedule constraint.GkrProvingSchedule + assignment WireAssignment + outgoingEvalPoints [][][]frontend.Variable // [levelI][outgoingClaimI] → eval point + nbVars int + uniqueInputIndices [][]int // [wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from +} // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. type zeroCheckLazyClaims struct { - wireI int - evaluationPoints [][]frontend.Variable - claimedEvaluations []frontend.Variable - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is foldingCoeff ) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(api frontend.API, r []frontend.Variable, foldingCoeff, purportedValue frontend.Variable, uniqueInputEvaluations []frontend.Variable) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(api, e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation = api.Mul(evaluation, foldingCoeff) - eq := polynomial.EvalEq(api, e.evaluationPoints[i], r) - evaluation = api.Add(evaluation, eq) - } - - wire := e.getWire() - - // the g(...) term - var gateEvaluation frontend.Variable - if wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(api, r) - } else { - - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } - - inputEvaluations := make([]frontend.Variable, len(wire.Inputs)) - for i, uniqueI := range injectionLeftInv { // map from all to unique - inputEvaluations[i] = uniqueInputEvaluations[uniqueI] - } - - gateEvaluation = wire.Gate.Evaluate(FrontendAPIWrapper{api}, inputEvaluations...) - } - evaluation = api.Mul(evaluation, gateEvaluation) - - api.AssertIsEqual(evaluation, purportedValue) - return nil -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff frontend.Variable + r *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *zeroCheckLazyClaims) foldedSum(api frontend.API, a frontend.Variable) frontend.Variable { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(api, a) + return e.r.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return 1 + e.getWire().Gate.Degree -} - -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - circuit Circuit + return e.r.circuit.ZeroCheckDegree(e.r.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } -func newClaimsManager(circuit Circuit, assignment WireAssignment) (claims claimsManager) { - claims.assignment = assignment - claims.claims = make([]*zeroCheckLazyClaims, len(circuit)) - claims.circuit = circuit +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is asserted equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +func (e *zeroCheckLazyClaims) verifyFinalEval(api frontend.API, r []frontend.Variable, purportedValue frontend.Variable, uniqueInputEvaluations []frontend.Variable) error { + e.r.outgoingEvalPoints[e.levelI] = [][]frontend.Variable{r} + level := e.r.schedule[e.levelI] + perWireInputEvals := gkrcore.ReduplicateInputs(level, e.r.circuit, uniqueInputEvaluations) + + var terms []frontend.Variable + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.r.circuit[wI] + + var gateEval frontend.Variable + if wire.IsInput() { + gateEval = e.r.assignment[wI].Evaluate(api, r) + api.AssertIsEqual(perWireInputEvals[levelWireI][0], gateEval) + } else { + gateEval = wire.Gate.Evaluate(FrontendAPIWrapper{api}, perWireInputEvals[levelWireI]...) + } - for i := range circuit { - if circuit[i].IsInput() { - circuit[i].Gate.Degree = 1 - circuit[i].Gate.Evaluate = gkrtypes.Identity - } - claims.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]frontend.Variable, 0, circuit[i].NbClaims()), - claimedEvaluations: make(polynomial.Polynomial, circuit[i].NbClaims()), - manager: &claims, + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(api, e.r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + term := api.Mul(eq, gateEval) + terms = append(terms, term) + } + levelWireI++ } } - return -} -func (m *claimsManager) add(wire int, evaluationPoint []frontend.Variable, evaluation frontend.Variable) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] -} - -func (m *claimsManager) deleteClaim(wire int) { - m.claims[wire].manager = nil - m.claims[wire] = nil -} - -type settings struct { - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int + claimedEvals := polynomial.Polynomial(terms) + total := claimedEvals.Eval(api, e.foldingCoeff) + api.AssertIsEqual(total, purportedValue) + return nil } -type Option func(*settings) - -func setup(api frontend.API, c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NbVars() - nbInstances := assignment.NbInstances() - if 1< 1 { //fold the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + initialChallengeI := len(r.schedule) - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] + foldingCoeff := frontend.Variable(0) + if nbClaims >= 2 { + foldingCoeff = r.t.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + var claimedEvals []frontend.Variable + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + var claimedEval frontend.Variable + if src.Level == initialChallengeI { + claimedEval = r.assignment[wI].Evaluate(r.api, r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex]) + } else { + i := r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex) + claimedEval = proof[src.Level].FinalEvalProof[i] + } + claimedEvals = append(claimedEvals, claimedEval) + } } } - return challenges -} + claimedSum := polynomial.Polynomial(claimedEvals).Eval(r.api, foldingCoeff) -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + r: r, + levelI: levelI, } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) (challenges []frontend.Variable, err error) { - challenges = make([]frontend.Variable, len(names)) - for i, name := range names { - if challenges[i], err = transcript.ComputeChallenge(name); err != nil { - return - } + if err := verifySumcheck(r.api, lazyClaims, proof[levelI], claimedSum, + r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), r.t); err != nil { + return fmt.Errorf("sumcheck proof rejected at level %d: %v", levelI, err) } - return + return nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(api, c, assignment, transcriptSettings, options...) - if err != nil { - return err +// Verify the consistency of the claimed output with the claimed input. +func Verify(api frontend.API, c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, h hash.FieldHasher) error { + nbVars := assignment.NbVars() + if 1<= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(api, firstChallenge)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } + initialChallengeI := len(schedule) + firstChallenge := make([]frontend.Variable, nbVars) + for j := range nbVars { + firstChallenge[j] = r.t.getChallenge() + } + r.outgoingEvalPoints[initialChallengeI] = [][]frontend.Variable{firstChallenge} - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[i].Evaluate(api, claim.evaluationPoints[0]) - api.AssertIsEqual(claim.claimedEvaluations[0], evaluation) - } - } else if err = verifySumcheck( - api, claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = proofW.FinalEvalProof + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := schedule[levelI].(constraint.GkrSkipLevel); isSkip { + r.verifySkipLevel(levelI, proof) } else { - return err + if err := r.verifySumcheckLevel(levelI, proof); err != nil { + return err + } } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(r.t, proof[levelI].FinalEvalProof, + c.UniqueGateInputs(schedule[levelI]), c.IsInput, schedule[levelI]) } return nil } @@ -339,27 +247,36 @@ func (p Proof) Serialize() []frontend.Variable { res = append(res, p[i].FinalEvalProof...) } if len(res) != size { - panic("bug") // TODO: Remove + panic("bug") } return res } // ComputeLogNbInstances derives n such that the number of instances is 2ⁿ -// from the size of the proof and the circuit structure. -// The number actually computed is that of rounds in each ZeroCheck instance, which is equal -// to the desired result. -// It is used in proof deserialization. -func ComputeLogNbInstances(circuit Circuit, serializedProofLen int) int { - partialEvalElemsPerVar := 0 - for _, w := range circuit { - partialEvalElemsPerVar += w.ZeroCheckDegree() - serializedProofLen -= w.NbUniqueOutputs +// from the size of the proof and the circuit/schedule structure. +func ComputeLogNbInstances(circuit Circuit, schedule constraint.GkrProvingSchedule, serializedProofLen int) int { + perVar := 0 + for _, level := range schedule { + nbUniqueInputs := len(circuit.UniqueGateInputs(level)) + if _, isSkip := level.(constraint.GkrSkipLevel); isSkip { + serializedProofLen -= nbUniqueInputs * level.NbOutgoingEvalPoints() + } else { + perVar += circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)) + serializedProofLen -= nbUniqueInputs + } } - res := serializedProofLen / partialEvalElemsPerVar - if res*partialEvalElemsPerVar != serializedProofLen { - panic("cannot compute logNbInstances") + if perVar == 0 { + if serializedProofLen == 0 { + return -1 + } + } else { + res := serializedProofLen / perVar + if res*perVar == serializedProofLen { + return res + } } - return res + + panic("cannot compute logNbInstances") } type variablesReader []frontend.Variable @@ -374,19 +291,23 @@ func (r *variablesReader) hasNextN(n int) bool { return len(*r) >= n } -func DeserializeProof(circuit Circuit, serializedProof []frontend.Variable) (Proof, error) { - proof := make(Proof, len(circuit)) - logNbInstances := ComputeLogNbInstances(circuit, len(serializedProof)) +func DeserializeProof(circuit Circuit, schedule constraint.GkrProvingSchedule, serializedProof []frontend.Variable) (Proof, error) { + proof := make(Proof, len(schedule)) + logNbInstances := ComputeLogNbInstances(circuit, schedule, len(serializedProof)) reader := variablesReader(serializedProof) - for i, wI := range circuit { - if !wI.NoProof() { - proof[i].PartialSumPolys = make([]polynomial.Polynomial, logNbInstances) - for j := range proof[i].PartialSumPolys { - proof[i].PartialSumPolys[j] = reader.nextN(wI.ZeroCheckDegree()) + for levelI, level := range schedule { + nbUniqueInputs := len(circuit.UniqueGateInputs(level)) + if _, isSkip := level.(constraint.GkrSkipLevel); isSkip { + proof[levelI].FinalEvalProof = reader.nextN(nbUniqueInputs * level.NbOutgoingEvalPoints()) + } else { + degree := circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)) + proof[levelI].PartialSumPolys = make([]polynomial.Polynomial, logNbInstances) + for j := range proof[levelI].PartialSumPolys { + proof[levelI].PartialSumPolys[j] = reader.nextN(degree) } + proof[levelI].FinalEvalProof = reader.nextN(nbUniqueInputs) } - proof[i].FinalEvalProof = reader.nextN(wI.NbUniqueInputs()) } if reader.hasNextN(1) { return nil, fmt.Errorf("proof too long: expected %d encountered %d", len(serializedProof)-len(reader), len(serializedProof)) diff --git a/internal/gkr/gkr_test.go b/internal/gkr/gkr_test.go index 6b0ecd1b67..93fb91697c 100644 --- a/internal/gkr/gkr_test.go +++ b/internal/gkr/gkr_test.go @@ -10,9 +10,10 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/gkr/gkrtesting" - fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/polynomial" "github.com/consensys/gnark/test" @@ -114,7 +115,7 @@ func (c *GkrVerifierCircuit) Define(api frontend.API) error { return err } - if proof, err = DeserializeProof(testCase.Circuit, c.SerializedProof); err != nil { + if proof, err = DeserializeProof(testCase.Circuit, testCase.Schedule, c.SerializedProof); err != nil { return err } assignment := makeInOutAssignment(testCase.Circuit, c.Input, c.Output) @@ -128,20 +129,18 @@ func (c *GkrVerifierCircuit) Define(api frontend.API) error { } } - return Verify(api, testCase.Circuit, assignment, proof, fiatshamir.WithHash(hsh)) + return Verify(api, testCase.Circuit, testCase.Schedule, assignment, proof, hsh) } func makeInOutAssignment(c Circuit, inputValues [][]frontend.Variable, outputValues [][]frontend.Variable) WireAssignment { res := make(WireAssignment, len(c)) - inI, outI := 0, 0 - for i := range c { - if c[i].IsInput() { - res[i] = inputValues[inI] - inI++ - } else if c[i].IsOutput() { - res[i] = outputValues[outI] - outI++ - } + inputs := c.Inputs() + outputs := c.Outputs() + for i, wI := range inputs { + res[wI] = inputValues[i] + } + for i, wI := range outputs { + res[wI] = outputValues[i] } return res } @@ -153,12 +152,13 @@ func fillWithBlanks(slice [][]frontend.Variable, size int) { } type TestCase struct { - Circuit Circuit - Hash HashDescription - Proof Proof - Input [][]frontend.Variable - Output [][]frontend.Variable - Name string + Circuit Circuit + Schedule constraint.GkrProvingSchedule + Hash HashDescription + Proof Proof + Input [][]frontend.Variable + Output [][]frontend.Variable + Name string } type TestCaseInfo struct { Hash HashDescription `json:"hash"` @@ -188,7 +188,14 @@ func getTestCase(path string) (*TestCase, error) { return nil, err } - _, cse.Circuit = cache.GetCircuit(filepath.Join(dir, info.Circuit)) + serializableCircuit, gadgetCircuit := cache.GetCircuit(filepath.Join(dir, info.Circuit)) + cse.Circuit = gadgetCircuit + + schedule, schedErr := gkrcore.DefaultProvingSchedule(serializableCircuit) + if schedErr != nil { + return nil, schedErr + } + cse.Schedule = schedule cse.Proof = unmarshalProof(info.Proof) @@ -241,7 +248,7 @@ func TestLogNbInstances(t *testing.T) { testCase, err := getTestCase(path) assert.NoError(t, err) serializedProof := testCase.Proof.Serialize() - logNbInstances := ComputeLogNbInstances(testCase.Circuit, len(serializedProof)) + logNbInstances := ComputeLogNbInstances(testCase.Circuit, testCase.Schedule, len(serializedProof)) assert.Equal(t, 1, logNbInstances) } } diff --git a/internal/gkr/gkrtypes/gate.go b/internal/gkr/gkrcore/gate.go similarity index 98% rename from internal/gkr/gkrtypes/gate.go rename to internal/gkr/gkrcore/gate.go index 2b6100db9a..8e3be2a773 100644 --- a/internal/gkr/gkrtypes/gate.go +++ b/internal/gkr/gkrcore/gate.go @@ -1,4 +1,4 @@ -package gkrtypes +package gkrcore import ( "crypto/rand" @@ -43,6 +43,12 @@ type GateBytecode struct { Constants []*big.Int // constant values at indices [0, nbConsts) } +// IdentityBytecode returns the compiled form of the identity gate (x → x). +// A GateBytecode with no instructions returns its sole input directly. +func IdentityBytecode() GateBytecode { + return GateBytecode{} +} + // NbConstants returns the number of constants in the gate func (g *GateBytecode) NbConstants() int { return len(g.Constants) diff --git a/internal/gkr/gkrtypes/gate_test.go b/internal/gkr/gkrcore/gate_test.go similarity index 99% rename from internal/gkr/gkrtypes/gate_test.go rename to internal/gkr/gkrcore/gate_test.go index 3bc044959d..df11f6a44a 100644 --- a/internal/gkr/gkrtypes/gate_test.go +++ b/internal/gkr/gkrcore/gate_test.go @@ -1,4 +1,4 @@ -package gkrtypes +package gkrcore import ( "math/big" diff --git a/internal/gkr/gkrcore/schedule.go b/internal/gkr/gkrcore/schedule.go new file mode 100644 index 0000000000..7fecf28aab --- /dev/null +++ b/internal/gkr/gkrcore/schedule.go @@ -0,0 +1,356 @@ +package gkrcore + +import ( + "fmt" + "slices" + + "github.com/consensys/gnark/constraint" +) + +// InputMapping returns as uniqueInputs the deduplicated list of inputs to the level, +// and as inputIndices for every wire in the level the list of positions for each of its +// inputs in the uniqueInputs list. +// Input wires of the circuit are considered self-input, as a convenience for the sumcheck protocol. +func (c Circuit[G]) InputMapping(level constraint.GkrProvingLevel) (uniqueInputs []int, inputIndices [][]int) { + seen := make(map[int]int) // wire index → position in uniqueInputs + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := c[wI] + inputs := wire.Inputs + if wire.IsInput() { + inputs = []int{wI} + } + + indices := make([]int, len(inputs)) + for inWI, inW := range inputs { + pos, ok := seen[inW] + if !ok { + pos = len(uniqueInputs) + seen[inW] = pos + uniqueInputs = append(uniqueInputs, inW) + } + indices[inWI] = pos + } + inputIndices = append(inputIndices, indices) + } + } + return +} + +// UniqueGateInputs returns the unique gate input wire indices for all wires in the level, +// deduplicated in batch-then-wire-then-input order (first occurrence wins). +// For circuit input wires (no gate inputs), the wire itself is returned. +func (c Circuit[G]) UniqueGateInputs(level constraint.GkrProvingLevel) []int { + uniqueInputs, _ := c.InputMapping(level) + return uniqueInputs +} + +func (c Circuit[G]) ZeroCheckDegree(level constraint.GkrSumcheckLevel) int { + maxDeg := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + w := &c[wI] + curr := 1 + if !w.IsInput() { + curr = w.Gate.Degree + } + maxDeg = max(maxDeg, curr) + } + } + return maxDeg + 1 +} + +// ProofSize returns the total number of field elements in a GKR proof. +func (c Circuit[G]) ProofSize(schedule constraint.GkrProvingSchedule, logNbInstances int) int { + size := 0 + for _, level := range schedule { + // For every outgoing claim and unique input wire, there will be + // an outgoing evaluation claim included in finalEvalProof. + size += len(c.UniqueGateInputs(level)) * level.NbOutgoingEvalPoints() + if sc, ok := level.(constraint.GkrSumcheckLevel); ok { + // ZeroCheckDegree is the degree of each sumcheck polynomial. + // logNbInstances is the number of rounds in each sumcheck. + size += c.ZeroCheckDegree(sc) * logNbInstances + } + } + return size +} + +// ReduplicateInputs expands unique evaluations to per-wire gate input evaluation lists. +func ReduplicateInputs[F any, G any](level constraint.GkrProvingLevel, c Circuit[G], uniqueEvals []F) [][]F { + _, inputIndices := c.InputMapping(level) + result := make([][]F, len(inputIndices)) + for wireInLevel := range inputIndices { + wireInputs := make([]F, len(inputIndices[wireInLevel])) + for gateInputJ, uniqueI := range inputIndices[wireInLevel] { + wireInputs[gateInputJ] = uniqueEvals[uniqueI] + } + result[wireInLevel] = wireInputs + } + return result +} + +// scheduleBuilder accumulates topology and per-wire claim sources while a schedule is being built. +// Steps are appended in out-to-in (topological) order and reversed by finalize. +// Claim source level values are stored as their index in the levels slice, with -1 as the +// sentinel for the initial challenge. finalize will map each src.Level to its final absolute index via +// n-1-src.level, where n = len(levels), so -1 → n (initial challenge) and i → n-1-i (real levels). +type scheduleBuilder[G any] struct { + circuit Circuit[G] + wireOutputs [][]int // wireOutputs[i] indices of wires that wire i feeds into, in increasing order and deduplicated. + wireLevels []int // wireLevels[i] which level wire i has been put in + wireProcessed []bool + claimSourcesCache [][]constraint.GkrClaimSource // claimSourcesCache[i] is the result of claimSources(i), or nil if not yet computed. + firstUnprocessedWire int + levels constraint.GkrProvingSchedule +} + +// newScheduleBuilder initialises a builder for the given circuit. +// It computes the outputs inverse-adjacency list. +func newScheduleBuilder[G any](c Circuit[G]) scheduleBuilder[G] { + b := scheduleBuilder[G]{ + circuit: c, + wireOutputs: make([][]int, len(c)), + wireLevels: make([]int, len(c)), + wireProcessed: make([]bool, len(c)), + claimSourcesCache: make([][]constraint.GkrClaimSource, len(c)), + firstUnprocessedWire: len(c) - 1, + } + seen := make(map[int]bool, len(c)) + for i := range c { + for k := range seen { + delete(seen, k) + } + for _, in := range c[i].Inputs { + b.wireOutputs[in] = append(b.wireOutputs[in], i) + seen[in] = true + } + } + return b +} + +// addSumcheckLevel appends a GkrSumcheckLevel to the schedule. Each batch is a set of wire indices +// to be proven together in a single zerocheck; all wires in a batch must share the same claim sources. +// All wires across all batches must be ready. +func (b *scheduleBuilder[G]) addSumcheckLevel(batches ...[]int) error { + claimGroups, err := b.buildClaimGroups(batches) + if err != nil { + return err + } + b.levels = append(b.levels, constraint.GkrSumcheckLevel(claimGroups)) + return nil +} + +// addSkipLevel appends a GkrSkipLevel to the schedule for a single set of wire indices. +// All wires in the batch must share the same claim sources and must be ready. +func (b *scheduleBuilder[G]) addSkipLevel(wireIndices []int) error { + claimGroups, err := b.buildClaimGroups([][]int{wireIndices}) + if err != nil { + return err + } + b.levels = append(b.levels, constraint.GkrSkipLevel(claimGroups[0])) + return nil +} + +// buildClaimGroups processes a set of batches, validates claim source consistency within each +// batch, updates wireLevels and wireProcessed, and returns the resulting GkrClaimGroups. +// Every ClaimSources slice is sorted. The user may reorder it to optimize eq handling. +func (b *scheduleBuilder[G]) buildClaimGroups(batches [][]int) ([]constraint.GkrClaimGroup, error) { + levelIdx := len(b.levels) + claimGroups := make([]constraint.GkrClaimGroup, len(batches)) + for i, wireIndices := range batches { + var claimSources []constraint.GkrClaimSource + for j, wI := range wireIndices { + wireClaims, ok := b.claimSources(wI) + if !ok { + return nil, fmt.Errorf("wire %d is not ready", wI) + } + if j == 0 { + claimSources = wireClaims + } else if !slices.Equal(claimSources, wireClaims) { + return nil, fmt.Errorf("wires %d and %d in the same batch have different claim sources", wireIndices[0], wI) + } + b.wireLevels[wI] = levelIdx + b.wireProcessed[wI] = true + if wI == b.firstUnprocessedWire { + for b.firstUnprocessedWire--; b.firstUnprocessedWire >= 0 && b.wireProcessed[b.firstUnprocessedWire]; b.firstUnprocessedWire-- { + } + } + } + claimGroups[i] = constraint.GkrClaimGroup{Wires: slices.Clone(wireIndices), ClaimSources: claimSources} + } + return claimGroups, nil +} + +// nextReady returns the highest wire index in the contiguous ready suffix starting at +// firstUnprocessedWire, along with each wire's claim sources in wire-index order. +// Returns firstUnprocessedWire, nil if no wires are ready. +func (b *scheduleBuilder[G]) nextReady() (highestWireI int, sources [][]constraint.GkrClaimSource) { + for lowestWireI := b.firstUnprocessedWire; lowestWireI >= 0; lowestWireI-- { + if b.wireProcessed[lowestWireI] { + break + } + src, ok := b.claimSources(lowestWireI) + if !ok { + break + } + sources = append(sources, src) + } + slices.Reverse(sources) + return b.firstUnprocessedWire, sources +} + +// claimSources checks whether all consumers of wire wI have already been processed. +// If so, it returns the deduplicated claim sources for wI and true. +// If not, it returns nil and false. Results are cached. +// SkipLevels are proper claim targets: a wire feeding into a SkipLevel L with M inherited +// evaluation points gets M claim sources {L, 0}, {L, 1}, ..., {L, M-1}. +func (b *scheduleBuilder[G]) claimSources(wI int) ([]constraint.GkrClaimSource, bool) { + if b.claimSourcesCache[wI] != nil { + return b.claimSourcesCache[wI], true + } + var wireClaims []constraint.GkrClaimSource + if b.circuit[wI].Exported || len(b.wireOutputs[wI]) == 0 { + wireClaims = append(wireClaims, constraint.GkrClaimSource{Level: -1, OutgoingClaimIndex: 0}) + } + for _, consumerWI := range b.wireOutputs[wI] { + if !b.wireProcessed[consumerWI] { + return nil, false + } + consumerLevel := b.wireLevels[consumerWI] + if _, isSkip := b.levels[consumerLevel].(constraint.GkrSkipLevel); isSkip { + // SkipLevel inherits M evaluation points from its own claim sources. + M := b.levels[consumerLevel].NbOutgoingEvalPoints() + for k := range M { + wireClaims = append(wireClaims, constraint.GkrClaimSource{Level: consumerLevel, OutgoingClaimIndex: k}) + } + } else { + wireClaims = append(wireClaims, constraint.GkrClaimSource{Level: consumerLevel, OutgoingClaimIndex: 0}) + } + } + // Deduplicate while preserving order. + seen := make(map[constraint.GkrClaimSource]bool, len(wireClaims)) + out := wireClaims[:0] + for _, cs := range wireClaims { + if !seen[cs] { + seen[cs] = true + out = append(out, cs) + } + } + b.claimSourcesCache[wI] = out + return out, true +} + +// finalize reverses the schedule into in-to-out order and fixes up Level indices in all +// ClaimSources. It errors if any wire has not been processed. +func (b *scheduleBuilder[G]) finalize() (constraint.GkrProvingSchedule, error) { + for i, processed := range b.wireProcessed { + if !processed { + return nil, fmt.Errorf("wire %d has not been processed", i) + } + } + + n := len(b.levels) + slices.Reverse(b.levels) + // Fix up ClaimSources: pre-reversal Level index src maps to n-1-src, + // and the initial-challenge sentinel -1 maps to n. + for _, level := range b.levels { + for _, group := range level.ClaimGroups() { + mirrorClaimSources(group.ClaimSources, n) + } + } + + return b.levels, nil +} + +// mirrorClaimSources maps each pre-reversal Level index src.Level in-place to its post-reversal +// absolute index n-1-src.Level. The initial-challenge sentinel -1 maps to n. +func mirrorClaimSources(s []constraint.GkrClaimSource, n int) { + n-- + for j := range s { + s[j].Level = n - s[j].Level + } +} + +// DefaultProvingSchedule generates a schedule that greedily batches input wires with the same +// single claim source into the same GkrSkipLevel. Non-input wires, and input wires with multiple +// claim sources, each get their own GkrSumcheckLevel. +func DefaultProvingSchedule[G any](c Circuit[G]) (constraint.GkrProvingSchedule, error) { + b := newScheduleBuilder(c) + + for b.firstUnprocessedWire >= 0 { + highWI, claimSources := b.nextReady() + // try and make a homogenous (same degree, same claims) batch + w := c[highWI] + batchClaimSources := claimSources[0] + if w.IsInput() && len(batchClaimSources) == 1 { + if err := b.addSkipLevel([]int{highWI}); err != nil { + return nil, err + } + continue + } + + // there is an actual "gate" in question + batch := []int{highWI} + for len(batch) < len(claimSources) { + if w.Gate.Degree != c[highWI-len(batch)].Gate.Degree || !slices.Equal(claimSources[0], claimSources[len(batch)]) { + break + } + batch = append(batch, highWI-len(batch)) + } + if w.Gate.Degree == 1 && len(batchClaimSources) == 1 { // certain that skipping won't cause a claim blowup + if err := b.addSkipLevel(batch); err != nil { + return nil, err + } + } else { + if err := b.addSumcheckLevel(batch); err != nil { + return nil, err + } + } + } + return b.finalize() +} + +// UniqueInputIndices returns uniqueInputIndices[wI][claimI], the position of wire wI +// in the UniqueGateInputs list of the source level for its claimI-th claim source. +// The sentinel initial-challenge claim maps to 0 (unused at call sites). +func (c Circuit[G]) UniqueInputIndices(schedule constraint.GkrProvingSchedule) [][]int { + cache := make([]map[int]int, len(schedule)) // cache[levelI][wireI] is the unique input index of wireI in levelI. + res := make([][]int, len(c)) + + // This loop weaves the level's treatment both as a claim source and as the collection of input wires + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + level := schedule[levelI] + cache[levelI] = make(map[int]int) + + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + + for _, inputWI := range c[wI].Inputs { + if _, ok := cache[levelI][inputWI]; !ok { + cache[levelI][inputWI] = len(cache[levelI]) + } + } + + for _, claimSource := range group.ClaimSources { + if claimSource.Level == len(schedule) { // output + res[wI] = append(res[wI], 0) // zero by convention + } else { + res[wI] = append(res[wI], cache[claimSource.Level][wI]) + } + } + } + } + } + return res +} + +// CollectOutgoingEvalPoints sets the outgoing evaluation points of a skip level, equal to its incoming ones. +func CollectOutgoingEvalPoints[F any](level constraint.GkrSkipLevel, levelI int, outgoingEvalPoints [][][]F) [][]F { + outPoints := make([][]F, level.NbOutgoingEvalPoints()) + for k, src := range level.ClaimSources { + outPoints[k] = outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + } + outgoingEvalPoints[levelI] = outPoints + return outPoints +} diff --git a/internal/gkr/gkrcore/schedule_test.go b/internal/gkr/gkrcore/schedule_test.go new file mode 100644 index 0000000000..55241be576 --- /dev/null +++ b/internal/gkr/gkrcore/schedule_test.go @@ -0,0 +1,108 @@ +package gkrcore_test + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" + "github.com/consensys/gnark/internal/gkr/gkrtesting" + "github.com/stretchr/testify/require" +) + +var scheduleTestCache = gkrtesting.NewCache(ecc.BN254.ScalarField()) + +func TestDefaultProvingSchedule(t *testing.T) { + _, c := scheduleTestCache.Compile(t, gkrtesting.SingleMulGateCircuit()) + schedule, err := gkrcore.DefaultProvingSchedule(c) + require.NoError(t, err) + + // SingleMulGateCircuit: wires 0, 1 (inputs), 2 (mul gate with inputs 0, 1). + // UniqueGateInputs for level 2 (wire 2) = [0, 1]. + // 3 = len(schedule) = initial challenge sentinel. + require.Equal(t, constraint.GkrProvingSchedule{ + // Level 0: input wire 0, position 0 in mul gate's UniqueGateInputs [0, 1] + constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{{Level: 2}}}, + // Level 1: input wire 1, position 1 in mul gate's UniqueGateInputs [0, 1] + constraint.GkrSkipLevel{Wires: []int{1}, ClaimSources: []constraint.GkrClaimSource{{Level: 2}}}, + // Level 2: mul gate output, claimed by initial challenge (sentinel) + constraint.GkrSumcheckLevel{{Wires: []int{2}, ClaimSources: []constraint.GkrClaimSource{{Level: 3}}}}, + }, schedule) +} + +func TestDefaultProvingSchedulePoseidon2(t *testing.T) { + _, c := scheduleTestCache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(c) + require.NoError(t, err) + + // Wire layout for Poseidon2Circuit(4, 2) — 25 wires total: + // 0, 1 inputs + // 2–3 full-round 0 lin (lin0=2, lin1=3) + // 4–5 full-round 0 sBox (sBox0=4, sBox1=5) + // 6–7 full-round 1 lin (lin0=6, lin1=7) + // 8–9 full-round 1 sBox (sBox0=8, sBox1=9) + // 10–11 partial-round 0 lin (lin0=10, lin1=11) + // 12 partial-round 0 sBox0 + // 13–14 partial-round 1 lin (lin0=13, lin1=14) + // 15 partial-round 1 sBox0 + // 16–17 full-round 2 lin (lin0=16, lin1=17) + // 18–19 full-round 2 sBox (sBox0=18, sBox1=19) + // 20–21 full-round 3 lin (lin0=20, lin1=21) + // 22–23 full-round 3 sBox (sBox0=22, sBox1=23) + // 24 feed-forward output + // + // 17 = len(schedule) = initial challenge sentinel. + require.Equal(t, constraint.GkrProvingSchedule{ + // Level 0: input wire 0 — claimed by level 2 (full-round 0 lin1+lin0 skip). + constraint.GkrSkipLevel{Wires: []int{0}, ClaimSources: []constraint.GkrClaimSource{{Level: 2}}}, + + // Level 1: input wire 1 — claimed by level 2 and level 16 (feed-forward skip). + constraint.GkrSkipLevel{Wires: []int{1}, ClaimSources: []constraint.GkrClaimSource{{Level: 2}, {Level: 16}}}, + + // Level 2: full-round 0 lin1+lin0 (skip, inputs from wires 0 and 1). + constraint.GkrSkipLevel{Wires: []int{3, 2}, ClaimSources: []constraint.GkrClaimSource{{Level: 3}}}, + + // Level 3: full-round 0 sBox1+sBox0 (sumcheck). + constraint.GkrSumcheckLevel{{Wires: []int{5, 4}, ClaimSources: []constraint.GkrClaimSource{{Level: 4}}}}, + + // Level 4: full-round 1 lin1+lin0 (skip, inputs [4, 5]). + constraint.GkrSkipLevel{Wires: []int{7, 6}, ClaimSources: []constraint.GkrClaimSource{{Level: 5}}}, + + // Level 5: full-round 1 sBox1+sBox0 (sumcheck, inputs lin1=7 and lin0=6). + // Feeds into level 6 (partial-round 0 lin0, M=1) and level 7 (partial-round 0 lin1, M=2). + constraint.GkrSumcheckLevel{{Wires: []int{9, 8}, ClaimSources: []constraint.GkrClaimSource{{Level: 6}, {Level: 7}, {Level: 7, OutgoingClaimIndex: 1}}}}, + + // Level 6: partial-round 0 lin0 (skip, inputs [8, 9]). + constraint.GkrSkipLevel{Wires: []int{10}, ClaimSources: []constraint.GkrClaimSource{{Level: 8}}}, + + // Level 7: partial-round 0 lin1 (skip, inputs [8, 9]). M=2 (two claim sources). + constraint.GkrSkipLevel{Wires: []int{11}, ClaimSources: []constraint.GkrClaimSource{{Level: 9}, {Level: 10}}}, + + // Level 8: partial-round 0 sBox0 (sumcheck, input lin0=10). + constraint.GkrSumcheckLevel{{Wires: []int{12}, ClaimSources: []constraint.GkrClaimSource{{Level: 9}, {Level: 10}}}}, + + // Level 9: partial-round 1 lin0 (skip, inputs [12, 11]). + constraint.GkrSkipLevel{Wires: []int{13}, ClaimSources: []constraint.GkrClaimSource{{Level: 11}}}, + + // Level 10: partial-round 1 lin1 (skip, inputs [12, 11]). + constraint.GkrSkipLevel{Wires: []int{14}, ClaimSources: []constraint.GkrClaimSource{{Level: 12}}}, + + // Level 11: partial-round 1 sBox0 (sumcheck, input lin0=13). + constraint.GkrSumcheckLevel{{Wires: []int{15}, ClaimSources: []constraint.GkrClaimSource{{Level: 12}}}}, + + // Level 12: full-round 2 lin1+lin0 (skip, inputs [15, 14]). + constraint.GkrSkipLevel{Wires: []int{17, 16}, ClaimSources: []constraint.GkrClaimSource{{Level: 13}}}, + + // Level 13: full-round 2 sBox1+sBox0 (sumcheck, inputs lin1=17 and lin0=16). + constraint.GkrSumcheckLevel{{Wires: []int{19, 18}, ClaimSources: []constraint.GkrClaimSource{{Level: 14}}}}, + + // Level 14: full-round 3 lin1+lin0 (skip, inputs [18, 19]). + constraint.GkrSkipLevel{Wires: []int{21, 20}, ClaimSources: []constraint.GkrClaimSource{{Level: 15}}}, + + // Level 15: full-round 3 sBox1+sBox0 (sumcheck, inputs lin1=21 and lin0=20). + constraint.GkrSumcheckLevel{{Wires: []int{23, 22}, ClaimSources: []constraint.GkrClaimSource{{Level: 16}}}}, + + // Level 16: feed-forward output (skip, inputs [22, 23, 1]). Claimed by initial challenge (17). + constraint.GkrSkipLevel{Wires: []int{24}, ClaimSources: []constraint.GkrClaimSource{{Level: 17}}}, + }, schedule) +} diff --git a/internal/gkr/gkrcore/serialize.go b/internal/gkr/gkrcore/serialize.go new file mode 100644 index 0000000000..d62bbeaae0 --- /dev/null +++ b/internal/gkr/gkrcore/serialize.go @@ -0,0 +1,197 @@ +package gkrcore + +import ( + "encoding/binary" + "fmt" + "io" + "math/big" + + "github.com/consensys/gnark/constraint" +) + +// Common serialization utilities + +func writeUint8(w io.Writer, x int) error { + if x >= 256 || x < 0 { + return fmt.Errorf("%d out of range", x) + } + _, err := w.Write([]byte{byte(x)}) + return err +} + +func writeUint16[T int | uint16](w io.Writer, x T) error { + var buf [2]byte + if int(x) >= 65536 || x < 0 { + return fmt.Errorf("%d out of range", x) + } + binary.LittleEndian.PutUint16(buf[:], uint16(x)) + _, err := w.Write(buf[:]) + return err +} + +func writeBool(w io.Writer, b bool) error { + var v byte + if b { + v = 1 + } + _, err := w.Write([]byte{v}) + return err +} + +func writeUint16Slice[T int | uint16](w io.Writer, s []T) error { + if err := writeUint16(w, len(s)); err != nil { + return err + } + for _, v := range s { + if err := writeUint16(w, v); err != nil { + return err + } + } + return nil +} + +func writeBigInt(w io.Writer, x *big.Int) error { + bytes := x.Bytes() + if err := writeUint8(w, len(bytes)); err != nil { + return err + } + _, err := w.Write(bytes) + return err +} + +// SerializeCircuit writes a SerializableCircuit to w in deterministic binary format, +// primarily for hashing circuits to create unique identifiers. +// +// The encoding is compact (uint16 for counts/indices, uint8 for bigint byte lengths) and +// uses little-endian throughout. Gate metadata (NbIn, Degree, SolvableVar) is omitted +// since it can be recomputed from bytecode on deserialization. +// +// Format: +// +// Circuit: [wire_count:u16] [wire...] +// Wire: [input_count:u16] [input_indices:u16...] [exported:bool] [gate] +// Gate (non-input only): [const_count:u16] [constants...] [inst_count:u16] [instructions...] +// Constant: [byte_len:u8] [bytes...] +// Instruction: [op:u8] [input_count:u16] [input_indices:u16...] +func SerializeCircuit(w io.Writer, c SerializableCircuit) error { + if len(c) >= 1<<32 { + return fmt.Errorf("circuit length too large: %d", len(c)) + } + + // Write the number of wires + if err := writeUint16(w, len(c)); err != nil { + return err + } + + // Write each wire + for i := range c { + wire := &c[i] + + if err := writeUint16Slice(w, wire.Inputs); err != nil { + return err + } + + // Write the exported flag + if err := writeBool(w, wire.Exported); err != nil { + return err + } + + // If not an input wire, write gate information + if !wire.IsInput() { + gate := &wire.Gate + + // Write the bytecode + bytecode := &gate.Evaluate + + // Write the constants + if err := writeUint16(w, len(bytecode.Constants)); err != nil { + return err + } + for _, constant := range bytecode.Constants { + if err := writeBigInt(w, constant); err != nil { + return err + } + } + + // Write the instructions + if err := writeUint16(w, len(bytecode.Instructions)); err != nil { + return err + } + for _, inst := range bytecode.Instructions { + // Write the operation + if _, err := w.Write([]byte{byte(inst.Op)}); err != nil { + return err + } + + // Write the instruction inputs + if err := writeUint16Slice(w, inst.Inputs); err != nil { + return err + } + } + } + } + + return nil +} + +// SerializeSchedule writes a GkrProvingSchedule to w in deterministic binary format, +// primarily for hashing schedules to create unique identifiers. +// +// The encoding uses uint16 for counts/indices and little-endian throughout. +// +// Format: +// +// Schedule: [level_count:u16] [level...] +// Level: [skip_sumcheck:bool] [group_count:u16] [claim_group...] +// GkrClaimGroup: [wire_count:u16] [wire_indices:u16...] [source_count:u16] [claim_source...] +// GkrClaimSource: [level:u16] [outgoing_claim_index:u16] +func SerializeSchedule(w io.Writer, s constraint.GkrProvingSchedule) error { + if len(s) >= 65536 { + return fmt.Errorf("schedule length too large: %d", len(s)) + } + + writeClaimGroup := func(cg constraint.GkrClaimGroup) error { + if err := writeUint16Slice(w, cg.Wires); err != nil { + return err + } + + // Write claim source count; each source is two uint16s: Level and OutgoingClaimIndex + if err := writeUint16(w, len(cg.ClaimSources)); err != nil { + return err + } + for _, src := range cg.ClaimSources { + if err := writeUint16(w, src.Level); err != nil { + return err + } + if err := writeUint16(w, src.OutgoingClaimIndex); err != nil { + return err + } + } + + return nil + } + + // Write number of levels + if err := writeUint16(w, len(s)); err != nil { + return err + } + + // Write each level + for _, level := range s { + _, isSkip := level.(constraint.GkrSkipLevel) + if err := writeBool(w, isSkip); err != nil { + return err + } + groups := level.ClaimGroups() + if err := writeUint16(w, len(groups)); err != nil { + return err + } + for _, cg := range groups { + if err := writeClaimGroup(cg); err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrcore/types.go similarity index 53% rename from internal/gkr/gkrtypes/types.go rename to internal/gkr/gkrcore/types.go index 9089de41b3..83bf399059 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrcore/types.go @@ -1,9 +1,8 @@ -package gkrtypes +package gkrcore import ( "errors" "math/big" - "reflect" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" @@ -17,6 +16,17 @@ type ( InputInstance int } + // RawWire is a minimal wire representation with only inputs and gate function. + RawWire struct { + Gate gkr.GateFunction + Inputs []int + Exported bool + } + + // RawCircuit is a minimal circuit representation for API-level circuit construction. + // It contains only the essential topology (inputs) and gate functions. + RawCircuit []RawWire + // A Gate is a low-degree multivariate polynomial Gate[GateExecutable any] struct { Evaluate GateExecutable @@ -26,10 +36,9 @@ type ( } Wire[GateExecutable any] struct { - Gate Gate[GateExecutable] - Inputs []int - NbUniqueOutputs int - Exported bool + Gate Gate[GateExecutable] + Inputs []int + Exported bool } Circuit[GateExecutable any] []Wire[GateExecutable] @@ -54,59 +63,14 @@ func (w Wire[GateExecutable]) IsInput() bool { return len(w.Inputs) == 0 } -// IsOutput returns whether the wire is an output wire. A wire is an output wire -// if it is not input to any other wire. -func (w Wire[GateExecutable]) IsOutput() bool { - return w.NbUniqueOutputs == 0 || w.Exported -} - -// NbClaims returns the number of claims to be proven about this wire. The number -// of claims is the number of Wires it is input to, except for an output wire, which -// has an extra claim. -func (w Wire[GateExecutable]) NbClaims() int { - res := w.NbUniqueOutputs - if w.IsOutput() { - res++ - } - return res -} - -// NoProof returns whether no proof is needed for this wire. This corresponds -// to input wires without any claims to be made about them. -func (w Wire[GateExecutable]) NoProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -// NbUniqueInputs returns the number of unique input wires to this wire. -func (w Wire[GateExecutable]) NbUniqueInputs() int { - set := make(map[int]struct{}, len(w.Inputs)) - for _, in := range w.Inputs { - set[in] = struct{}{} - } - return len(set) -} - -// ZeroCheckDegree returns the degree in each variable of the zero-check polynomial -// associated with this gate, if any. If this wire is not subject to zero-check, it will return 0. -func (w Wire[GateExecutable]) ZeroCheckDegree() int { - if w.IsInput() { - switch w.NbClaims() { - case 0: - panic("should be unreachable") - case 1: - return 0 - default: - // Input gate with multiple claims treated as a degree 1 gate. - return 2 - } - } - return w.Gate.Degree + 1 +func (c Circuit[GateExecutable]) IsInput(wireIndex int) bool { + return c[wireIndex].IsInput() } // ClaimPropagationInfo returns sets of indices describing the pruning of claim propagation. // At the end of sumcheck for wire #wireIndex, we end up with sequences "uniqueEvaluations" and "evaluations", // the former a subsequence of the latter. -// injection are the indices of the unique evaluations in the original evaluation list. +// injection consists of the indices of the unique evaluations in the original evaluation list. // injectionRightInverse are the indices of the original evaluations in the unique evaluations list. // There are no guarantees on the non-unique choice of the semi-inverse map. func (c Circuit[GateExecutable]) ClaimPropagationInfo(wireIndex int) (injection, injectionLeftInverse []int) { @@ -150,30 +114,6 @@ func (c Circuit[GateExecutable]) MemoryRequirements(nbInstances int) []int { return res } -// OutputsList for each wire, returns the set of indexes of wires it is input to. -// It also sets the NbUniqueOutputs fields. -func (c Circuit[GateExecutable]) OutputsList() [][]int { - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].NbUniqueOutputs = 0 - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - res[in] = append(res[in], i) - if _, ok := ins[in]; !ok { - c[in].NbUniqueOutputs++ - ins[in] = struct{}{} - } - } - } - return res -} - // Inputs returns the list of input wire indices. func (c Circuit[GateExecutable]) Inputs() []int { res := make([]int, 0, len(c)) @@ -186,11 +126,16 @@ func (c Circuit[GateExecutable]) Inputs() []int { } // Outputs returns the list of output wire indices. -// It requires the NbUniqueOutput values to have been set. func (c Circuit[GateExecutable]) Outputs() []int { + isOutputTo := make([]bool, len(c)) + for i := range c { + for _, in := range c[i].Inputs { + isOutputTo[in] = true + } + } res := make([]int, 0, len(c)) for i := range c { - if c[i].IsOutput() { + if !isOutputTo[i] || c[i].Exported { res = append(res, i) } } @@ -206,17 +151,6 @@ func (c Circuit[GateExecutable]) MaxGateNbIn() int { return res } -// ProofSize computes how large the proof for a circuit would be. It needs NbUniqueOutputs to be set. -func (c Circuit[GateExecutable]) ProofSize(logNbInstances int) int { - nbUniqueInputs := 0 - nbPartialEvalPolys := 0 - for i := range c { - nbUniqueInputs += c[i].NbUniqueOutputs // each unique output is manifest in a finalEvalProof entry - nbPartialEvalPolys += c[i].ZeroCheckDegree() - } - return nbUniqueInputs + nbPartialEvalPolys*logNbInstances -} - // makeNeg1Slice returns a slice of size n with all elements set to -1. func makeNeg1Slice(n int) []int { res := make([]int, n) @@ -267,55 +201,40 @@ type Blueprints struct { GetAssignmentID constraint.BlueprintID } -// CompileCircuit converts a gadget circuit to a serializable circuit by compiling the gate functions. -// It also sets wire and gate metadata (Degree, SolvableVar, NbUniqueOutputs) for both the input and output circuits. -func CompileCircuit(c GadgetCircuit, mod *big.Int) (SerializableCircuit, error) { +// Compile compiles a raw circuit into both a gadget circuit and a serializable circuit. +// It computes all wire and gate metadata (Degree, SolvableVar). +func (c RawCircuit) Compile(mod *big.Int) (GadgetCircuit, SerializableCircuit, error) { + gadget := make(GadgetCircuit, len(c)) + serializable := make(SerializableCircuit, len(c)) for i := range c { - c[i].NbUniqueOutputs = 0 - } + gadget[i].Inputs = c[i].Inputs + gadget[i].Exported = c[i].Exported + serializable[i].Inputs = c[i].Inputs + serializable[i].Exported = c[i].Exported - // compile the gate and compute metadata - curWireIn := make([]bool, len(c)) // curWireIn[j] = true iff i takes j as input. - res := make(SerializableCircuit, len(c)) - var err error - for i := range c { - // Compute NbUniqueOutputs as we go. - for _, in := range c[i].Inputs { - if !curWireIn[in] { - c[in].NbUniqueOutputs++ - curWireIn[in] = true - } - } - // clear curWireIn - for _, in := range c[i].Inputs { - curWireIn[in] = false + if gadget[i].IsInput() { + continue } - if c[i].IsInput() { - if !reflect.DeepEqual(c[i].Gate, GadgetGate{}) { - return nil, errors.New("empty gate expected for input wire") - } - continue + if c[i].Gate == nil { + return nil, nil, errors.New("gate function required for non-input wire") } - c[i].Gate.NbIn = len(c[i].Inputs) - if res[i].Gate, err = CompileGateFunction(c[i].Gate.Evaluate, c[i].Gate.NbIn, mod); err != nil { - return nil, err + nbIn := len(c[i].Inputs) + compiledGate, err := CompileGateFunction(c[i].Gate, nbIn, mod) + if err != nil { + return nil, nil, err } - c[i].Gate.Degree = res[i].Gate.Degree - c[i].Gate.SolvableVar = res[i].Gate.SolvableVar - } - // copy metadata from c to res - for i := range c { - res[i].Inputs = c[i].Inputs - res[i].Exported = c[i].Exported - res[i].NbUniqueOutputs = c[i].NbUniqueOutputs - res[i].Gate.Degree = c[i].Gate.Degree - res[i].Gate.NbIn = c[i].Gate.NbIn - res[i].Gate.SolvableVar = c[i].Gate.SolvableVar + gadget[i].Gate = GadgetGate{ + Evaluate: c[i].Gate, + NbIn: nbIn, + Degree: compiledGate.Degree, + SolvableVar: compiledGate.SolvableVar, + } + serializable[i].Gate = compiledGate } - return res, nil + return gadget, serializable, nil } diff --git a/internal/gkr/gkrtesting/gkrtesting.go b/internal/gkr/gkrtesting/gkrtesting.go index 56c2c16bd1..329fb3475b 100644 --- a/internal/gkr/gkrtesting/gkrtesting.go +++ b/internal/gkr/gkrtesting/gkrtesting.go @@ -3,13 +3,17 @@ package gkrtesting import ( "encoding/json" "errors" + "fmt" "math/big" "os" "path/filepath" + "strconv" "sync" + "testing" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/stretchr/testify/require" ) @@ -25,8 +29,8 @@ type Cache struct { } type circuits struct { - serializable gkrtypes.SerializableCircuit - gadget gkrtypes.GadgetCircuit + serializable gkrcore.SerializableCircuit + gadget gkrcore.GadgetCircuit } func mimcGate(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { @@ -46,11 +50,11 @@ func selectInput3Gate(_ gkr.GateAPI, in ...frontend.Variable) frontend.Variable func NewCache(field *big.Int) *Cache { gates := make(map[string]gkr.GateFunction, 7) gates[""] = nil - gates["identity"] = gkrtypes.Identity - gates["add2"] = gkrtypes.Add2 - gates["sub2"] = gkrtypes.Sub2 - gates["neg"] = gkrtypes.Neg - gates["mul2"] = gkrtypes.Mul2 + gates["identity"] = gkrcore.Identity + gates["add2"] = gkrcore.Add2 + gates["sub2"] = gkrcore.Sub2 + gates["neg"] = gkrcore.Neg + gates["mul2"] = gkrcore.Mul2 gates["mimc"] = mimcGate gates["select-input-3"] = selectInput3Gate @@ -70,14 +74,14 @@ type JSONWire struct { // JSONCircuit is the JSON serialization format for circuits type JSONCircuit []JSONWire -// Compile compiles a programmatic GadgetCircuit into a SerializableCircuit. -func (c *Cache) Compile(t require.TestingT, circuit gkrtypes.GadgetCircuit) gkrtypes.SerializableCircuit { - res, err := gkrtypes.CompileCircuit(circuit, c.field) +// Compile compiles a RawCircuit into a SerializableCircuit. +func (c *Cache) Compile(t require.TestingT, circuit gkrcore.RawCircuit) (gkrcore.GadgetCircuit, gkrcore.SerializableCircuit) { + gadget, serializable, err := circuit.Compile(c.field) require.NoError(t, err) - return res + return gadget, serializable } -func (c *Cache) GetCircuit(path string) (gkrtypes.SerializableCircuit, gkrtypes.GadgetCircuit) { +func (c *Cache) GetCircuit(path string) (gkrcore.SerializableCircuit, gkrcore.GadgetCircuit) { c.lock.Lock() defer c.lock.Unlock() @@ -100,17 +104,15 @@ func (c *Cache) GetCircuit(path string) (gkrtypes.SerializableCircuit, gkrtypes. panic(err) } - // Convert JSON format to GadgetCircuit - gCircuit := make(gkrtypes.GadgetCircuit, len(jsonCircuit)) + // Convert JSON format to RawCircuit + rawCircuit := make(gkrcore.RawCircuit, len(jsonCircuit)) for i, wJSON := range jsonCircuit { - gate := c.GetGate(wJSON.Gate) - - gCircuit[i] = gkrtypes.GadgetWire{ - Gate: gkrtypes.Gate[gkr.GateFunction]{Evaluate: gate}, + rawCircuit[i] = gkrcore.RawWire{ + Gate: c.GetGate(wJSON.Gate), Inputs: wJSON.Inputs, } } - sCircuit, err := gkrtypes.CompileCircuit(gCircuit, c.field) + gCircuit, sCircuit, err := rawCircuit.Compile(c.field) if err != nil { panic(err) } @@ -137,12 +139,10 @@ func (c *Cache) GetGate(name string) gkr.GateFunction { panic("gate not found: " + name) } -func MiMCCircuit(numRounds int) gkrtypes.GadgetCircuit { - c := make(gkrtypes.GadgetCircuit, numRounds+2) - mimc := gkrtypes.GadgetGate{Evaluate: mimcGate} - +func MiMCCircuit(numRounds int) gkrcore.RawCircuit { + c := make(gkrcore.RawCircuit, numRounds+2) for i := 2; i < len(c); i++ { - c[i] = gkrtypes.GadgetWire{Gate: mimc, Inputs: []int{i - 1, 0}} + c[i] = gkrcore.RawWire{Gate: mimcGate, Inputs: []int{i - 1, 0}} } return c } @@ -155,12 +155,54 @@ type PrintableSumcheckProof struct { } type HashDescription map[string]interface{} + +// TestCaseInfo is the serializable form of a GKR test case, matching the JSON file format. +// Schedule is nil when absent from JSON, which means DefaultProvingSchedule. type TestCaseInfo struct { - Hash HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` + Hash HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` + Schedule ScheduleInfo `json:"schedule,omitempty"` +} + +// ScheduleStepInfo is the JSON representation of a single proving level. +// Type is "sumcheck" or "skip". +type ScheduleStepInfo struct { + Type string `json:"type"` + ClaimGroups []constraint.GkrClaimGroup `json:"claimGroups,omitempty"` // for "sumcheck" + ClaimGroup *constraint.GkrClaimGroup `json:"claimGroup,omitempty"` // for "skip" +} + +// ScheduleInfo is the JSON-serializable form of a ProvingSchedule. +type ScheduleInfo []ScheduleStepInfo + +// ToProvingSchedule converts a ScheduleInfo to a constraint.GkrProvingSchedule. +// A nil ScheduleInfo returns nil, which callers should interpret as DefaultProvingSchedule. +func (p ScheduleInfo) ToProvingSchedule() (constraint.GkrProvingSchedule, error) { + if p == nil { + return nil, nil + } + s := make(constraint.GkrProvingSchedule, len(p)) + for i, step := range p { + switch step.Type { + case "sumcheck": + groups := step.ClaimGroups + if groups == nil { + groups = []constraint.GkrClaimGroup{} + } + s[i] = constraint.GkrSumcheckLevel(groups) + case "skip": + if step.ClaimGroup == nil { + return nil, fmt.Errorf("level %d: type=skip but claimGroup is absent", i) + } + s[i] = constraint.GkrSkipLevel(*step.ClaimGroup) + default: + return nil, errors.New("unknown ProvingLevel type: " + step.Type) + } + } + return s, nil } func (c *Cache) ReadTestCaseInfo(filePath string) (info TestCaseInfo, err error) { @@ -175,60 +217,163 @@ func (c *Cache) ReadTestCaseInfo(filePath string) (info TestCaseInfo, err error) return } -func NoGateCircuit() gkrtypes.GadgetCircuit { - return gkrtypes.GadgetCircuit{ +func NoGateCircuit() gkrcore.RawCircuit { + return gkrcore.RawCircuit{ {}, } } -func SingleAddGateCircuit() gkrtypes.GadgetCircuit { - return gkrtypes.GadgetCircuit{ +func SingleAddGateCircuit() gkrcore.RawCircuit { + return gkrcore.RawCircuit{ {}, {}, - {Gate: gkrtypes.GadgetGate{Evaluate: gkrtypes.Add2}, Inputs: []int{0, 1}}, + {Gate: gkrcore.Add2, Inputs: []int{0, 1}}, } } -func SingleMulGateCircuit() gkrtypes.GadgetCircuit { - return gkrtypes.GadgetCircuit{ +func SingleMulGateCircuit() gkrcore.RawCircuit { + return gkrcore.RawCircuit{ {}, {}, - {Gate: gkrtypes.GadgetGate{Evaluate: gkrtypes.Mul2}, Inputs: []int{0, 1}}, + {Gate: gkrcore.Mul2, Inputs: []int{0, 1}}, } } -func SingleInputTwoIdentityGatesCircuit() gkrtypes.GadgetCircuit { - idGate := gkrtypes.GadgetGate{Evaluate: gkrtypes.Identity} - return gkrtypes.GadgetCircuit{ +func SingleInputTwoIdentityGatesCircuit() gkrcore.RawCircuit { + return gkrcore.RawCircuit{ {}, - {Gate: idGate, Inputs: []int{0}}, - {Gate: idGate, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, } } -func SingleInputTwoIdentityGatesComposedCircuit() gkrtypes.GadgetCircuit { - idGate := gkrtypes.GadgetGate{Evaluate: gkrtypes.Identity} - return gkrtypes.GadgetCircuit{ +func SingleInputTwoIdentityGatesComposedCircuit() gkrcore.RawCircuit { + return gkrcore.RawCircuit{ {}, - {Gate: idGate, Inputs: []int{0}}, - {Gate: idGate, Inputs: []int{1}}, + {Gate: gkrcore.Identity, Inputs: []int{0}}, + {Gate: gkrcore.Identity, Inputs: []int{1}}, } } -func APowNTimesBCircuit(n int) gkrtypes.GadgetCircuit { - c := make(gkrtypes.GadgetCircuit, n+2) - mulGate := gkrtypes.GadgetGate{Evaluate: gkrtypes.Mul2} - +func APowNTimesBCircuit(n int) gkrcore.RawCircuit { + c := make(gkrcore.RawCircuit, n+2) for i := 2; i < len(c); i++ { - c[i] = gkrtypes.GadgetWire{Gate: mulGate, Inputs: []int{i - 1, 0}} + c[i] = gkrcore.RawWire{Gate: gkrcore.Mul2, Inputs: []int{i - 1, 0}} } return c } -func SingleMimcCipherGateCircuit() gkrtypes.GadgetCircuit { - return gkrtypes.GadgetCircuit{ +func SingleMimcCipherGateCircuit() gkrcore.RawCircuit { + return gkrcore.RawCircuit{ {}, {}, - {Gate: gkrtypes.GadgetGate{Evaluate: mimcGate}, Inputs: []int{0, 1}}, + {Gate: mimcGate, Inputs: []int{0, 1}}, + } +} + +// poseidon2ExtLinear0 computes 2*x[0] + x[1] (external matrix, state[0] row). +func poseidon2ExtLinear0(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[0], x[1]) +} + +// poseidon2ExtLinear1 computes x[0] + 2*x[1] (external matrix, state[1] row). +func poseidon2ExtLinear1(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[1], x[1]) +} + +// poseidon2IntLinear1 computes x[0] + 3*x[1] (internal matrix, state[1] row; state[0] row = external). +func poseidon2IntLinear1(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[1], x[1], x[1]) +} + +// poseidon2SBox computes x[0]^2 (simplified s-box). +func poseidon2SBox(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Mul(x[0], x[0]) +} + +// poseidon2FeedForward computes 2*x[0] + x[1] + x[2] (external matrix row with feed-forward). +func poseidon2FeedForward(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[0], x[1], x[2]) +} + +// Poseidon2Circuit returns a 2-state Poseidon2-like GKR circuit with the given number of +// full and partial rounds, followed by a feed-forward output wire. +// Each full round applies the external linear layer (degree 1, skip) to both state elements +// followed by the s-box x^2 (degree 2, sumcheck) to both. +// Each partial round applies the external linear layer to state[0] and the internal linear +// layer to state[1] (both skip), then the s-box to state[0] only (sumcheck). +// The final output wire is 2*s0 + s1 + in1 (external matrix row with the second input fed forward). +// +// Wire layout per full round (wires s0, s1 are the current state): +// +// +0 = 2*s0 + s1 external linear, state[0] (skip) +// +1 = s0 + 2*s1 external linear, state[1] (skip) +// +2 = lin0^2 s-box, state[0] (sumcheck) +// +3 = lin1^2 s-box, state[1] (sumcheck) +// +// Wire layout per partial round: +// +// +0 = 2*s0 + s1 external linear, state[0] (skip) +// +1 = s0 + 3*s1 internal linear, state[1] (skip) +// +2 = lin0^2 s-box, state[0] only (sumcheck) +// +// Final wire: 2*s0 + s1 + in1 (feed-forward, sumcheck output) +func Poseidon2Circuit(nbFullRounds, nbPartialRounds int) gkrcore.RawCircuit { + // 2 inputs + 4 wires per full round + 3 wires per partial round + 1 feed-forward output + nbWires := 2 + 4*nbFullRounds + 3*nbPartialRounds + 1 + c := make(gkrcore.RawCircuit, nbWires) + // wires 0, 1 are inputs + s0, s1 := 0, 1 + + w := 2 + appendFullRound := func() { + c[w] = gkrcore.RawWire{Gate: poseidon2ExtLinear0, Inputs: []int{s0, s1}} + c[w+1] = gkrcore.RawWire{Gate: poseidon2ExtLinear1, Inputs: []int{s0, s1}} + c[w+2] = gkrcore.RawWire{Gate: poseidon2SBox, Inputs: []int{w}} + c[w+3] = gkrcore.RawWire{Gate: poseidon2SBox, Inputs: []int{w + 1}} + s0, s1 = w+2, w+3 + w += 4 + } + appendPartialRound := func() { + c[w] = gkrcore.RawWire{Gate: poseidon2ExtLinear0, Inputs: []int{s0, s1}} + c[w+1] = gkrcore.RawWire{Gate: poseidon2IntLinear1, Inputs: []int{s0, s1}} + c[w+2] = gkrcore.RawWire{Gate: poseidon2SBox, Inputs: []int{w}} + s0, s1 = w+2, w+1 + w += 3 + } + + for range nbFullRounds / 2 { + appendFullRound() + } + for range nbPartialRounds { + appendPartialRound() + } + for range nbFullRounds - nbFullRounds/2 { + appendFullRound() + } + + // feed-forward: 2*s0 + s1 + in1 + c[w] = gkrcore.RawWire{Gate: poseidon2FeedForward, Inputs: []int{s0, s1, 1}} + + return c +} + +var testManyInstancesLogMaxInstances = -1 + +func GetLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + } + return testManyInstancesLogMaxInstances } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index c21d26c21f..1e4882d272 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -8,655 +8,557 @@ package gkr import ( "errors" "fmt" + "hash" "iter" - "strconv" "sync" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" ) // Type aliases for bytecode-based GKR types type ( - Wire = gkrtypes.SerializableWire - Circuit = gkrtypes.SerializableCircuit + Wire = gkrcore.SerializableWire + Circuit = gkrcore.SerializableCircuit ) // The goal is to prove/verify evaluations of many instances of the same circuit -// WireAssignment is assignment of values to the same wire across many instances of the circuit +// WireAssignment is the assignment of values to the same wire across many instances of the circuit type WireAssignment []polynomial.MultiLin type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) // zeroCheckLazyClaims is a lazy claim for sumcheck (verifier side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the checking of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all wᵢ and evaluation point xᵢ in the level. +// Its purpose is to batch the checking of multiple wire evaluations at evaluation points. type zeroCheckLazyClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]small_rational.SmallRational // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ), allegedly - manager *claimsManager // WARNING: Circular references -} - -func (e *zeroCheckLazyClaims) getWire() Wire { - return e.manager.circuit[e.wireI] -} - -func (e *zeroCheckLazyClaims) claimsNum() int { - return len(e.evaluationPoints) + foldingCoeff small_rational.SmallRational // the coefficient used to fold claims, conventionally 0 if there is only one claim + resources *resources + levelI int } func (e *zeroCheckLazyClaims) varsNum() int { - return len(e.evaluationPoints[0]) -} - -// foldedSum returns ∑ᵢ aⁱ yᵢ -func (e *zeroCheckLazyClaims) foldedSum(a small_rational.SmallRational) small_rational.SmallRational { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) + return e.resources.nbVars } func (e *zeroCheckLazyClaims) degree(int) int { - return e.manager.circuit[e.wireI].ZeroCheckDegree() -} - -// verifyFinalEval finalizes the verification of w. -// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying -// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. (c is foldingCoeff) -// Both purportedValue and the vector r have been randomized during the sumcheck protocol. -// By taking the w term out of the sum we get the equivalent claim that -// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. -// If w is an input wire, the verifier can directly check its evaluation at r. -// Otherwise, the prover makes claims about the evaluation of w's input wires, -// wᵢ, at r, to be verified later. -// The claims are communicated through the proof parameter. -// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with -// the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *zeroCheckLazyClaims) verifyFinalEval(r []small_rational.SmallRational, foldingCoeff, purportedValue small_rational.SmallRational, uniqueInputEvaluations []small_rational.SmallRational) error { - // the eq terms ( E ) - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &foldingCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - wire := e.manager.circuit[e.wireI] - - // the w(...) term - var gateEvaluation small_rational.SmallRational - if wire.IsInput() { // just compute w(r) - gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool) - } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire - injection, injectionLeftInv := - e.manager.circuit.ClaimPropagationInfo(e.wireI) - - if len(injection) != len(uniqueInputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection)) - } - - for uniqueI, i := range injection { // map from unique to all - e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI]) - } + return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) +} + +// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. +// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying +// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all +// claims on each wire and c is foldingCoeff. +// Both purportedValue and the vector r have been randomized during sumcheck. +// +// For input wires, w(r) is computed directly from the assignment and the claimed +// evaluation in uniqueInputEvaluations is checked equal to it. +// For non-input wires, the prover claims evaluations of their gate inputs at r via +// uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. +// The verifier checks consistency by evaluating gateᵥ(inputEvals...) and confirming +// that the full sum matches purportedValue. +func (e *zeroCheckLazyClaims) verifyFinalEval(r []small_rational.SmallRational, purportedValue small_rational.SmallRational, uniqueInputEvaluations []small_rational.SmallRational) error { + e.resources.outgoingEvalPoints[e.levelI] = [][]small_rational.SmallRational{r} + level := e.resources.schedule[e.levelI] + gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) + + var claimedEvals polynomial.Polynomial + levelWireI := 0 + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + wire := e.resources.circuit[wI] + + var gateEval small_rational.SmallRational + if wire.IsInput() { + gateEval = e.resources.assignment[wI].Evaluate(r, &e.resources.memPool) + if !gateInputEvals[levelWireI][0].Equal(&gateEval) { + return errors.New("incompatible evaluations") + } + } else { + evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) + for _, v := range gateInputEvals[levelWireI] { + evaluator.pushInput(v) + } + gateEval.Set(evaluator.evaluate()) + } - evaluator := newGateEvaluator(wire.Gate.Evaluate, len(wire.Inputs)) - for _, uniqueI := range injectionLeftInv { // map from all to unique - evaluator.pushInput(uniqueInputEvaluations[uniqueI]) + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term small_rational.SmallRational + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } + levelWireI++ } - - gateEvaluation.Set(evaluator.evaluate()) } - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil + if total := claimedEvals.Eval(&e.foldingCoeff); !total.Equal(&purportedValue) { + return errors.New("incompatible evaluations") } - return errors.New("incompatible evaluations") + return nil } // zeroCheckClaims is a claim for sumcheck (prover side). -// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) w(-) sums up to the expected multilinear -// extension of the values of w across all instances. -// Its purpose is to batch the proving of multiple evaluations of the same wire. +// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value, +// where the sum runs over all (wire v, claim source s) pairs in the level. +// Each wire has its own eq table with the batching coefficients baked in. type zeroCheckClaims struct { - wireI int // the wire for which we are making the claim, with value w - evaluationPoints [][]small_rational.SmallRational // xᵢ: the points at which the prover has made claims about the evaluation of w - claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ) - manager *claimsManager - - input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) - - eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) - - gateEvaluatorPool *gateEvaluatorPool -} - -func (c *zeroCheckClaims) getWire() Wire { - return c.manager.circuit[c.wireI] -} - -// fold the multiple claims into one claim using a random combination (foldingCoeff or c). -// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim -// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and -// i iterates over the claims. -// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). -// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form -// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, -// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. -// The output of fold is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. -func (c *zeroCheckClaims) fold(foldingCoeff small_rational.SmallRational) polynomial.Polynomial { - varsNum := c.varsNum() - eqLength := 1 << varsNum - claimsNum := c.claimsNum() - // initialize the eq tables ( E ) - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - // E := eq(x₀, -) - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := foldingCoeff - - // E += cⁱ eq(xᵢ, -) - for k := 1; k < claimsNum; k++ { - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - if k+1 < claimsNum { - aI.Mul(&aI, &foldingCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e. -// m <- eq(q, -). -// e <- e + m -func (c *zeroCheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() + levelI int + resources *resources + input []polynomial.MultiLin // UniqueGateInputs order + inputIndices [][]int // [wireInLevel][gateInputJ] → index in input + eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points + gateEvaluatorPools []*gateEvaluatorPool } -// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). -// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). -// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) computeGJ() polynomial.Polynomial { - - wire := c.getWire() - degGJ := wire.ZeroCheckDegree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.input) - - // Both E and wᵢ (the input wires and the eq table) are multilinear, thus - // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. - // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. - ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube - ml[0] = c.eq - copy(ml[1:], c.input) - - sumSize := len(c.eq) / 2 // the range of h, over which we sum - - // Perf-TODO: Collate once at claim "folding" time and not again. then, even folding can be done in one operation every time "next" is called - - gJ := make([]small_rational.SmallRational, degGJ) +func (c *zeroCheckClaims) varsNum() int { + return c.resources.nbVars +} + +// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). +// By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + + // Both eqs and input are multilinear, thus linear in Xⱼ. + // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly + // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree. + // Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}] + ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs) + copy(ml, c.eqs) + copy(ml[nbWires:], c.input) + + sumSize := len(c.eqs[0]) / 2 + + p := make([]small_rational.SmallRational, degree) var mu sync.Mutex - computeAll := func(start, end int) { // compute method to allow parallelization across instances + computeAll := func(start, end int) { var step small_rational.SmallRational - evaluator := c.gateEvaluatorPool.get() - defer c.gateEvaluatorPool.put(evaluator) + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() - res := make([]small_rational.SmallRational, degGJ) + res := make([]small_rational.SmallRational, degree) // evaluations of ml, laid out as: // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), // ... - // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) - mlEvals := make([]small_rational.SmallRational, degGJ*len(ml)) - - for h := start; h < end; h++ { // h counts across instances + // ml[0](degree, h...), ml[1](degree, h...), ..., ml[len(ml)-1](degree, h...) + mlEvals := make([]small_rational.SmallRational, degree*len(ml)) + for h := start; h < end; h++ { evalAt1Index := sumSize + h for k := range ml { - // d = 0 - mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1, taken directly from the table step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) - for d := 1; d < degGJ; d++ { + for d := 1; d < degree; d++ { mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - eIndex := 0 // index for where the current eq term is + eIndex := 0 // start of the current row's eq evaluations nextEIndex := len(ml) - for d := range degGJ { - // Push gate inputs - for i := range nbGateIn { - evaluator.pushInput(mlEvals[eIndex+1+i]) + for d := range degree { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &mlEvals[eIndex+w]) + res[d].Add(&res[d], summand) // collect contributions into the sum from start to end } - summand := evaluator.evaluate() - summand.Mul(summand, &mlEvals[eIndex]) - res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := range gJ { - gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + for i := range p { + p[i].Add(&p[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if sumSize < minBlockSize { - // no parallelization computeAll(0, sumSize) } else { - c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } - return gJ + return p } -// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. -// Thus, j <- j+1 and rⱼ = challenge. -func (c *zeroCheckClaims) next(challenge small_rational.SmallRational) polynomial.Polynomial { +// roundFold folds all input and eq polynomials at the verifier challenge r. +// After this call, j ← j+1 and rⱼ = r. +func (c *zeroCheckClaims) roundFold(r small_rational.SmallRational) { const minBlockSize = 512 - n := len(c.eq) / 2 + n := len(c.eqs[0]) / 2 if n < minBlockSize { - // no parallelization for i := range c.input { - c.input[i].Fold(challenge) + c.input[i].Fold(r) + } + for i := range c.eqs { + c.eqs[i].Fold(r) } - c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.input)) + wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { - wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + wgs[i] = c.resources.workers.Submit(n, c.input[i].FoldParallel(r), minBlockSize) + } + for i := range c.eqs { + wgs[len(c.input)+i] = c.resources.workers.Submit(n, c.eqs[i].FoldParallel(r), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } } - - return c.computeGJ() -} - -func (c *zeroCheckClaims) varsNum() int { - return len(c.evaluationPoints[0]) } -func (c *zeroCheckClaims) claimsNum() int { - return len(c.claimedEvaluations) -} - -// proveFinalEval provides the values wᵢ(r₁, ..., rₙ) +// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). func (c *zeroCheckClaims) proveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { - //defer the proof, return list of claims - - injection, _ := c.manager.circuit.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers. - evaluations := make([]small_rational.SmallRational, len(injection)) - for i, gateInputI := range injection { - wI := c.input[gateInputI] - wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. - c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0]) - evaluations[i] = wI[0] + c.resources.outgoingEvalPoints[c.levelI] = [][]small_rational.SmallRational{r} + evaluations := make([]small_rational.SmallRational, len(c.input)) + for i := range c.input { + c.input[i].Fold(r[len(r)-1]) + evaluations[i] = c.input[i][0] + } + for i := range c.input { + c.resources.memPool.Dump(c.input[i]) + } + for i := range c.eqs { + c.resources.memPool.Dump(c.eqs[i]) + } + for _, pool := range c.gateEvaluatorPools { + pool.dumpAll() } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - c.gateEvaluatorPool.dumpAll() - return evaluations } -type claimsManager struct { - claims []*zeroCheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool - circuit Circuit -} +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- m[0] · eq(q, -). +// e <- e + m +func (r *resources) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { + n := len(q) -func newClaimsManager(circuit Circuit, assignment WireAssignment, o settings) (manager claimsManager) { - manager.assignment = assignment - manager.claims = make([]*zeroCheckLazyClaims, len(circuit)) - manager.memPool = o.pool - manager.workers = o.workers - manager.circuit = circuit + // At the end of each iteration, m(h₁, ..., hₙ) = m[0] · eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // 1-based in comments: q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - for i := range circuit { - manager.claims[i] = &zeroCheckLazyClaims{ - wireI: i, - evaluationPoints: make([][]small_rational.SmallRational, 0, circuit[i].NbClaims()), - claimedEvaluations: manager.memPool.Make(circuit[i].NbClaims()), - manager: &manager, + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + } else { + r.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // m(b₁,...,bᵢ,1) = m(b₁,...,bᵢ) · qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // m(b₁,...,bᵢ,0) = m(b₁,...,bᵢ) · (1 - qᵢ₊₁) + } + }, 1024).Wait() } } - return -} - -func (m *claimsManager) add(wire int, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { - claim := m.claims[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) + r.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() } -func (m *claimsManager) getLazyClaim(wire int) *zeroCheckLazyClaims { - return m.claims[wire] +type resources struct { + // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. + // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). + // SumcheckLevels produce one point (k=0). SkipLevels pass on all their evaluation points. + outgoingEvalPoints [][][]small_rational.SmallRational + nbVars int + assignment WireAssignment + memPool polynomial.Pool + workers *utils.WorkerPool + circuit Circuit + schedule constraint.GkrProvingSchedule + transcript transcript + uniqueInputIndices [][]int // uniqueInputIndices[wI][claimI]: w's unique-input index in the layer its claimI-th evaluation is coming from } -func (m *claimsManager) getClaim(wireI int) *zeroCheckClaims { - lazy := m.claims[wireI] - wire := m.circuit[wireI] - res := &zeroCheckClaims{ - wireI: wireI, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])} - } else { - res.input = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied +func newResources(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (resources, error) { + nbVars := assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1<= 2 { + foldingCoeff = r.transcript.getChallenge() } -} -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) + uniqueInputs, inputIndices := r.circuit.InputMapping(level) + input := make([]polynomial.MultiLin, len(uniqueInputs)) + for i, inW := range uniqueInputs { + input[i] = r.memPool.Clone(r.assignment[inW]) } - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { + newEq := polynomial.MultiLin(r.memPool.Make(eqLength)) + aI := alpha + for k := 1; k < nbSources; k++ { + aI.Mul(&aI, &foldingCoeff) + newEq[0].Set(&aI) + r.eqAcc(groupEq, newEq, r.outgoingEvalPoints[group.ClaimSources[k].Level][group.ClaimSources[k].OutgoingClaimIndex]) + } + r.memPool.Dump(newEq) + } -func ChallengeNames(c Circuit, logNbInstances int, prefix string) []string { + var stride small_rational.SmallRational + stride.Set(&foldingCoeff) + for range nbSources - 1 { + stride.Mul(&stride, &foldingCoeff) + } - // Pre-compute the size TODO: Consider not doing this and just grow the list by appending - size := logNbInstances // first challenge + eqs[levelWireI] = groupEq + levelWireI++ + alpha.Mul(&alpha, &stride) - for i := range c { - if c[i].NoProof() { // no proof, no challenge - continue - } - if c[i].NbClaims() > 1 { //fold the claims - size++ + for w := 1; w < len(group.Wires); w++ { + eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength)) + r.workers.Submit(eqLength, func(start, end int) { + for i := start; i < end; i++ { + eqs[levelWireI][i].Mul(&eqs[levelWireI-1][i], &stride) + } + }, 512).Wait() + levelWireI++ + alpha.Mul(&alpha, &stride) } - size += logNbInstances // full run of sumcheck on logNbInstances variables } - nums := make([]string, max(len(c), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) + claims := &zeroCheckClaims{ + levelI: levelI, + resources: r, + input: input, + inputIndices: inputIndices, + eqs: eqs, + gateEvaluatorPools: pools, } + return sumcheckProve(claims, &r.transcript) +} - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] +func (r *resources) verifySumcheckLevel(levelI int, proof Proof) error { + level := r.schedule[levelI] + nbClaims := level.NbClaims() + var foldingCoeff small_rational.SmallRational + if nbClaims >= 2 { + foldingCoeff = r.transcript.getChallenge() } - j := logNbInstances - for i := len(c) - 1; i >= 0; i-- { - if c[i].NoProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - if c[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "fold" - j++ - } + initialChallengeI := len(r.schedule) + claimedEvals := make(polynomial.Polynomial, 0, level.NbClaims()) - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ + for _, group := range level.ClaimGroups() { + for _, wI := range group.Wires { + for claimI, src := range group.ClaimSources { + if src.Level == initialChallengeI { + claimedEvals = append(claimedEvals, r.assignment[wI].Evaluate(r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], &r.memPool)) + } else { + claimedEvals = append(claimedEvals, proof[src.Level].finalEvalProof[r.schedule[src.Level].FinalEvalProofIndex(r.uniqueInputIndices[wI][claimI], src.OutgoingClaimIndex)]) + } + } } } - return challenges -} -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} + claimedSum := claimedEvals.Eval(&foldingCoeff) -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { - res := make([]small_rational.SmallRational, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err != nil { - return nil, err - } else if err = res[i].SetBytesCanonical(bytes); err != nil { - return nil, err - } + lazyClaims := &zeroCheckLazyClaims{ + foldingCoeff: foldingCoeff, + resources: r, + levelI: levelI, } - return res, nil + return sumcheckVerify(lazyClaims, proof[levelI], claimedSum, r.circuit.ZeroCheckDegree(level.(constraint.GkrSumcheckLevel)), &r.transcript) } // Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) +func Prove(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, hasher hash.Hash) (Proof, error) { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return nil, err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) + proof := make(Proof, len(schedule)) - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err + // Derive the initial challenge point + firstChallenge := make([]small_rational.SmallRational, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]small_rational.SmallRational{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(i) - if wire.NoProof() { // input wires with one claim only - proof[i] = sumcheckProof{ - partialSumPolys: []polynomial.Polynomial{}, - finalEvalProof: []small_rational.SmallRational{}, - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + proof[levelI] = r.proveSkipLevel(levelI) } else { - if proof[i], err = sumcheckProve( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - baseChallenge = make([][]byte, len(proof[i].finalEvalProof)) - for j := range proof[i].finalEvalProof { - baseChallenge[j] = proof[i].finalEvalProof[j].Marshal() - } + proof[levelI] = r.proveSumcheckLevel(levelI) } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return proof, nil } -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) +// Verify the consistency of the claimed output with the claimed input. +// Unlike in Prove, the assignment argument need not be complete. +func Verify(c Circuit, schedule constraint.GkrProvingSchedule, assignment WireAssignment, proof Proof, hasher hash.Hash) error { + r, err := newResources(c, schedule, assignment, hasher) if err != nil { return err } - defer o.workers.Stop() + defer r.workers.Stop() - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err + // Derive the initial challenge point + firstChallenge := make([]small_rational.SmallRational, r.nbVars) + for j := range r.nbVars { + firstChallenge[j] = r.transcript.getChallenge() } + r.outgoingEvalPoints[len(schedule)] = [][]small_rational.SmallRational{firstChallenge} - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := c[i] - - if wire.IsOutput() { - claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - claim := claims.getLazyClaim(i) - if wire.NoProof() { // input wires with one claim only - // make sure the proof is empty - if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - if len(claim.evaluationPoints) == 0 || len(claim.claimedEvaluations) == 0 { - return errors.New("missing input wire claim") - } - evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheckVerify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.finalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.finalEvalProof[j].Marshal() - } + for levelI := len(schedule) - 1; levelI >= 0; levelI-- { + if _, isSkip := r.schedule[levelI].(constraint.GkrSkipLevel); isSkip { + err = r.verifySkipLevel(levelI, proof) } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + err = r.verifySumcheckLevel(levelI, proof) + } + if err != nil { + return fmt.Errorf("level %d: %v", levelI, err) } - claims.deleteClaim(i) + constraint.BindGkrFinalEvalProof(&r.transcript, proof[levelI].finalEvalProof, c.UniqueGateInputs(r.schedule[levelI]), c.IsInput, r.schedule[levelI]) } return nil } @@ -734,14 +636,14 @@ func (p Proof) flatten() iter.Seq2[int, *small_rational.SmallRational] { // It manages the stack internally and handles input buffering, making it easy to // evaluate the same gate multiple times with different inputs. type gateEvaluator struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode vars []small_rational.SmallRational nbIn int // number of inputs expected } // newGateEvaluator creates an evaluator for the given compiled gate. // The stack is preloaded with constants and ready for evaluation. -func newGateEvaluator(gate gkrtypes.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { +func newGateEvaluator(gate gkrcore.GateBytecode, nbIn int, elementPool ...*polynomial.Pool) gateEvaluator { e := gateEvaluator{ gate: gate, nbIn: nbIn, @@ -785,28 +687,28 @@ func (e *gateEvaluator) evaluate(top ...small_rational.SmallRational) *small_rat // Use switch instead of function pointer for better inlining switch inst.Op { - case gkrtypes.OpAdd: + case gkrcore.OpAdd: dst.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Add(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpMul: + case gkrcore.OpMul: dst.Mul(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Mul(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpSub: + case gkrcore.OpSub: dst.Sub(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) for j := 2; j < len(inst.Inputs); j++ { dst.Sub(dst, &e.vars[inst.Inputs[j]]) } - case gkrtypes.OpNeg: + case gkrcore.OpNeg: dst.Neg(&e.vars[inst.Inputs[0]]) - case gkrtypes.OpMulAcc: + case gkrcore.OpMulAcc: var prod small_rational.SmallRational prod.Mul(&e.vars[inst.Inputs[1]], &e.vars[inst.Inputs[2]]) dst.Add(&e.vars[inst.Inputs[0]], &prod) - case gkrtypes.OpSumExp17: + case gkrcore.OpSumExp17: // result = (x[0] + x[1] + x[2])^17 var sum small_rational.SmallRational sum.Add(&e.vars[inst.Inputs[0]], &e.vars[inst.Inputs[1]]) @@ -832,14 +734,14 @@ func (e *gateEvaluator) evaluate(top ...small_rational.SmallRational) *small_rat // gateEvaluatorPool manages a pool of gate evaluators for a specific gate type // All evaluators share the same underlying polynomial.Pool for element slices type gateEvaluatorPool struct { - gate gkrtypes.GateBytecode + gate gkrcore.GateBytecode nbIn int lock sync.Mutex available map[*gateEvaluator]struct{} elementPool *polynomial.Pool } -func newGateEvaluatorPool(gate gkrtypes.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { +func newGateEvaluatorPool(gate gkrcore.GateBytecode, nbIn int, elementPool *polynomial.Pool) *gateEvaluatorPool { gep := &gateEvaluatorPool{ gate: gate, nbIn: nbIn, @@ -867,7 +769,7 @@ func (gep *gateEvaluatorPool) put(e *gateEvaluator) { gep.lock.Lock() defer gep.lock.Unlock() - // Return evaluator to pool (it keeps its vars slice from polynomial pool) + // Return evaluator to pool (it keeps its vars slice from the polynomial pool) gep.available[e] = struct{}{} } diff --git a/internal/gkr/small_rational/sumcheck.go b/internal/gkr/small_rational/sumcheck.go index 19d2c86c77..375c402639 100644 --- a/internal/gkr/small_rational/sumcheck.go +++ b/internal/gkr/small_rational/sumcheck.go @@ -7,33 +7,62 @@ package gkr import ( "errors" - "strconv" + "hash" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" ) -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. +// This does not make use of parallelism and represents polynomials as lists of coefficients. + +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.Hash + bound bool // whether Bind was called since the last getChallenge +} + +// Bind writes field elements to the transcript as bindings for the next challenge. +func (t *transcript) Bind(elements ...small_rational.SmallRational) { + if len(elements) == 0 { + return + } + for i := range elements { + bytes := elements[i].Bytes() + t.h.Write(bytes[:]) + } + t.bound = true +} + +// getChallenge binds optional elements, then squeezes a challenge from the current hash state. +// If no bindings were added since the last squeeze, a separator byte is written first +// to advance the state and prevent repeated values. +func (t *transcript) getChallenge(bindings ...small_rational.SmallRational) small_rational.SmallRational { + t.Bind(bindings...) + if !t.bound { + t.h.Write([]byte{0}) + } + t.bound = false + var res small_rational.SmallRational + res.SetBytes(t.h.Sum(nil)) + return res +} // sumcheckClaims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type sumcheckClaims interface { - fold(a small_rational.SmallRational) polynomial.Polynomial // fold into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + roundPolynomial() polynomial.Polynomial // compute gⱼ polynomial for current round + roundFold(r small_rational.SmallRational) // fold inputs and eq at challenge r varsNum() int // number of variables - claimsNum() int // number of claims proveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational // in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // sumcheckLazyClaims is the sumcheckClaims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(a small_rational.SmallRational) small_rational.SmallRational // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(r []small_rational.SmallRational, foldingCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error + varsNum() int // varsNum = n + degree(i int) int // degree of the total claim in the i'th variable + verifyFinalEval(r []small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error } // sumcheckProof of a multi-statement. @@ -42,130 +71,46 @@ type sumcheckProof struct { finalEvalProof []small_rational.SmallRational //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return small_rational.SmallRational{}, err - } - } - var res small_rational.SmallRational - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// sumcheckProve create a non-interactive proof -func sumcheckProve(claims sumcheckClaims, transcriptSettings fiatshamir.Settings) (sumcheckProof, error) { - - var proof sumcheckProof - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var foldingCoeff small_rational.SmallRational - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - +// sumcheckProve creates a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (proveLevel). +// Pattern: roundPolynomial, [roundFold, roundPolynomial]*, proveFinalEval. +func sumcheckProve(claims sumcheckClaims, t *transcript) sumcheckProof { varsNum := claims.varsNum() - proof.partialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.partialSumPolys[0] = claims.fold(foldingCoeff) + proof := sumcheckProof{partialSumPolys: make([]polynomial.Polynomial, varsNum)} + proof.partialSumPolys[0] = claims.roundPolynomial() challenges := make([]small_rational.SmallRational, varsNum) - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.partialSumPolys[j+1] = claims.next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.partialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err + for j := range varsNum - 1 { + challenges[j] = t.getChallenge(proof.partialSumPolys[j]...) + claims.roundFold(challenges[j]) + proof.partialSumPolys[j+1] = claims.roundPolynomial() } + challenges[varsNum-1] = t.getChallenge(proof.partialSumPolys[varsNum-1]...) proof.finalEvalProof = claims.proveFinalEval(challenges) - - return proof, nil + return proof } -func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff small_rational.SmallRational - - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return err - } - } - +// sumcheckVerify verifies a non-interactive sumcheck proof. +// The fold challenge is derived by the caller (verifyLevel). +// claimedSum is the expected sum; degree is the polynomial's degree in each variable. +func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum small_rational.SmallRational, degree int, t *transcript) error { r := make([]small_rational.SmallRational, claims.varsNum()) - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum for j := range claims.varsNum() { - if len(proof.partialSumPolys[j]) != claims.degree(j) { + if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) - //Prepare for the next iteration - if r[j], err = next(transcript, proof.partialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.degree(j) + 1)]) + r[j] = t.getChallenge(proof.partialSumPolys[j]...) + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) gJR = gJCoeffs.Eval(&r[j]) } - return claims.verifyFinalEval(r, foldingCoeff, gJR, proof.finalEvalProof) + return claims.verifyFinalEval(r, gJR, proof.finalEvalProof) } diff --git a/internal/gkr/small_rational/sumcheck_test.go b/internal/gkr/small_rational/sumcheck_test.go index 9cb6eabd96..b364291f7b 100644 --- a/internal/gkr/small_rational/sumcheck_test.go +++ b/internal/gkr/small_rational/sumcheck_test.go @@ -9,7 +9,6 @@ import ( "fmt" "hash" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/internal/small_rational/polynomial" "github.com/stretchr/testify/assert" @@ -24,11 +23,9 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } claim := singleMultilinClaim{g: poly.Clone()} + t := transcript{h: hashGenerator()} - proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } + proof := sumcheckProve(&claim, &t) var sb strings.Builder for _, p := range proof.partialSumPolys { @@ -44,13 +41,15 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + t = transcript{h: hashGenerator()} + if err := sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t); err != nil { return err } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + t = transcript{h: hashGenerator()} + if sumcheckVerify(lazyClaim, proof, lazyClaim.claimedSum, 1, &t) == nil { return fmt.Errorf("bad proof accepted") } return nil diff --git a/internal/gkr/small_rational/sumcheck_test_vector_gen.go b/internal/gkr/small_rational/sumcheck_test_vector_gen.go index ea969d171e..89ad304949 100644 --- a/internal/gkr/small_rational/sumcheck_test_vector_gen.go +++ b/internal/gkr/small_rational/sumcheck_test_vector_gen.go @@ -14,7 +14,6 @@ import ( "path/filepath" "runtime/pprof" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/internal/gkr/gkrtesting" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" @@ -38,11 +37,9 @@ func runMultilin(testCaseInfo *sumcheckTestCaseInfo) error { return err } - proof, err := sumcheckProve( - &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) - if err != nil { - return err - } + claim := singleMultilinClaim{poly} + t := transcript{h: hsh} + proof := sumcheckProve(&claim, &t) testCaseInfo.Proof = sumcheckToPrintableProof(proof) // Verification @@ -56,12 +53,20 @@ func runMultilin(testCaseInfo *sumcheckTestCaseInfo) error { return err } - if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { + if hsh, err = hashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + t = transcript{h: hsh} + if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, claimedSum, 1, &t); err != nil { return fmt.Errorf("proof rejected: %v", err) } proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) - if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { + if hsh, err = hashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + t = transcript{h: hsh} + if err = sumcheckVerify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, claimedSum, 1, &t); err == nil { return fmt.Errorf("bad proof accepted") } @@ -150,18 +155,14 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) proveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { +func (c *singleMultilinClaim) proveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { return nil // verifier can compute the final eval itself } -func (c singleMultilinClaim) varsNum() int { +func (c *singleMultilinClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } -func (c singleMultilinClaim) claimsNum() int { - return 1 -} - func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { sum := g[len(g)/2] for i := len(g)/2 + 1; i < len(g); i++ { @@ -170,13 +171,12 @@ func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { return []small_rational.SmallRational{sum} } -func (c singleMultilinClaim) fold(small_rational.SmallRational) polynomial.Polynomial { +func (c *singleMultilinClaim) roundPolynomial() polynomial.Polynomial { return sumForX1One(c.g) } -func (c *singleMultilinClaim) next(r small_rational.SmallRational) polynomial.Polynomial { +func (c *singleMultilinClaim) roundFold(r small_rational.SmallRational) { c.g.Fold(r) - return sumForX1One(c.g) } type singleMultilinLazyClaim struct { @@ -184,7 +184,7 @@ type singleMultilinLazyClaim struct { claimedSum small_rational.SmallRational } -func (c singleMultilinLazyClaim) verifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error { +func (c singleMultilinLazyClaim) verifyFinalEval(r []small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil @@ -192,15 +192,7 @@ func (c singleMultilinLazyClaim) verifyFinalEval(r []small_rational.SmallRationa return fmt.Errorf("mismatch") } -func (c singleMultilinLazyClaim) foldedSum(_ small_rational.SmallRational) small_rational.SmallRational { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) claimsNum() int { +func (c singleMultilinLazyClaim) degree(int) int { return 1 } diff --git a/internal/gkr/small_rational/test_vector_gen.go b/internal/gkr/small_rational/test_vector_gen.go index 40f4beaf0e..327a18ee98 100644 --- a/internal/gkr/small_rational/test_vector_gen.go +++ b/internal/gkr/small_rational/test_vector_gen.go @@ -15,9 +15,9 @@ import ( "github.com/consensys/bavard" "github.com/consensys/gnark-crypto/ecc" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" ) @@ -64,10 +64,8 @@ func run(absPath string) error { return err } - transcriptSetting := fiatshamir.WithHash(testCase.Hash) - var proof Proof - proof, err = Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + proof, err = Prove(testCase.Circuit, testCase.Schedule, testCase.FullAssignment, testCase.Hash) if err != nil { return err } @@ -89,7 +87,7 @@ func run(absPath string) error { return err } - err = Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, proof, testCase.Hash) if err != nil { return err } @@ -99,7 +97,7 @@ func run(absPath string) error { return err } - err = Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(newMessageCounter(2, 0))) + err = Verify(testCase.Circuit, testCase.Schedule, testCase.InOutAssignment, proof, newMessageCounter(2, 0)) if err == nil { return fmt.Errorf("bad proof accepted") } @@ -191,11 +189,12 @@ func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) { } type TestCase struct { - Circuit gkrtypes.SerializableCircuit + Circuit gkrcore.SerializableCircuit Hash hash.Hash Proof Proof FullAssignment WireAssignment InOutAssignment WireAssignment + Schedule constraint.GkrProvingSchedule Info gkrtesting.TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it } @@ -227,6 +226,20 @@ func newTestCase(path string) (*TestCase, error) { if proof, err = unmarshalProof(info.Proof); err != nil { return nil, err } + var schedule constraint.GkrProvingSchedule + if schedule, err = info.Schedule.ToProvingSchedule(); err != nil { + return nil, err + } + if schedule == nil { + if schedule, err = gkrcore.DefaultProvingSchedule(circuit); err != nil { + return nil, err + } + } + + outputSet := make(map[int]bool, len(circuit)) + for _, o := range circuit.Outputs() { + outputSet[o] = true + } fullAssignment := make(WireAssignment, len(circuit)) inOutAssignment := make(WireAssignment, len(circuit)) @@ -240,7 +253,7 @@ func newTestCase(path string) (*TestCase, error) { } assignmentRaw = info.Input[inI] inI++ - } else if circuit[i].IsOutput() { + } else if outputSet[i] { if outI == len(info.Output) { return nil, fmt.Errorf("fewer output in vector than in circuit") } @@ -261,7 +274,7 @@ func newTestCase(path string) (*TestCase, error) { fullAssignment.Complete(circuit) for i := range circuit { - if circuit[i].IsOutput() { + if outputSet[i] { if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil { return nil, fmt.Errorf("assignment mismatch: %v", err) } @@ -274,6 +287,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, + Schedule: schedule, Info: info, } diff --git a/internal/gkr/sumcheck.go b/internal/gkr/sumcheck.go index e963509686..a3deee14f7 100644 --- a/internal/gkr/sumcheck.go +++ b/internal/gkr/sumcheck.go @@ -2,10 +2,9 @@ package gkr import ( "errors" - "strconv" "github.com/consensys/gnark/frontend" - fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/polynomial" ) @@ -13,11 +12,9 @@ import ( // sumcheckLazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. type sumcheckLazyClaims interface { - claimsNum() int // claimsNum = m - varsNum() int // varsNum = n - foldedSum(api frontend.API, a frontend.Variable) frontend.Variable // foldedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - degree(i int) int // degree of the total claim in the i'th variable - verifyFinalEval(api frontend.API, r []frontend.Variable, foldingCoeff, purportedValue frontend.Variable, proof []frontend.Variable) error + varsNum() int + degree(i int) int + verifyFinalEval(api frontend.API, r []frontend.Variable, purportedValue frontend.Variable, proof []frontend.Variable) error } // sumcheckProof of a multi-sumcheck statement. @@ -26,83 +23,49 @@ type sumcheckProof struct { FinalEvalProof []frontend.Variable } -func setupTranscript(api frontend.API, claimsNum int, varsNum int, settings *fiatshamir.Settings) ([]string, error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames := make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "fold" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - settings.Transcript = fiatshamir.NewTranscript(api, settings.Hash, challengeNames) - } - - return challengeNames, settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges) +// transcript is a Fiat-Shamir transcript backed by a running hash. +// Field elements are written via Bind; challenges are derived via getChallenge. +// The hash is never reset — all previous data is implicitly part of future challenges. +type transcript struct { + h hash.FieldHasher + bound bool } -func next(transcript *fiatshamir.Transcript, bindings []frontend.Variable, remainingChallengeNames *[]string) (frontend.Variable, error) { - challengeName := (*remainingChallengeNames)[0] - if err := transcript.Bind(challengeName, bindings); err != nil { - return nil, err +func (t *transcript) Bind(elements ...frontend.Variable) { + if len(elements) == 0 { + return } - - res, err := transcript.ComputeChallenge(challengeName) - *remainingChallengeNames = (*remainingChallengeNames)[1:] - return res, err + t.h.Write(elements...) + t.bound = true } -func verifySumcheck(api frontend.API, claims sumcheckLazyClaims, proof sumcheckProof, transcriptSettings fiatshamir.Settings) error { - - remainingChallengeNames, err := setupTranscript(api, claims.claimsNum(), claims.varsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var foldingCoeff frontend.Variable - - if claims.claimsNum() >= 2 { - if foldingCoeff, err = next(transcript, []frontend.Variable{}, &remainingChallengeNames); err != nil { - return err - } +func (t *transcript) getChallenge(bindings ...frontend.Variable) frontend.Variable { + t.Bind(bindings...) + if !t.bound { + t.h.Write(0) // separator to prevent repeated values } + t.bound = false + return t.h.Sum() +} +func verifySumcheck(api frontend.API, claims sumcheckLazyClaims, proof sumcheckProof, claimedSum frontend.Variable, degree int, t *transcript) error { r := make([]frontend.Variable, claims.varsNum()) - // Just so that there is enough room for gJ to be reused - maxDegree := claims.degree(0) - for j := 1; j < claims.varsNum(); j++ { - if d := claims.degree(j); d > maxDegree { - maxDegree = d - } - } - - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.varsNum() - gJR := claims.foldedSum(api, foldingCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + gJ := make(polynomial.Polynomial, degree+1) + gJR := claimedSum - for j := 0; j < claims.varsNum(); j++ { - partialSumPoly := proof.PartialSumPolys[j] //proof.PartialSumPolys(j) - if len(partialSumPoly) != claims.degree(j) { - return errors.New("malformed proof") //Malformed proof + for j := range claims.varsNum() { + partialSumPoly := proof.PartialSumPolys[j] + if len(partialSumPoly) != degree { + return errors.New("malformed proof") } copy(gJ[1:], partialSumPoly) gJ[0] = api.Sub(gJR, partialSumPoly[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } + r[j] = t.getChallenge(proof.PartialSumPolys[j]...) - gJR = polynomial.InterpolateLDE(api, r[j], gJ[:(claims.degree(j)+1)]) + gJR = polynomial.InterpolateLDE(api, r[j], gJ[:(degree+1)]) } - return claims.verifyFinalEval(api, r, foldingCoeff, gJR, proof.FinalEvalProof) - + return claims.verifyFinalEval(api, r, gJR, proof.FinalEvalProof) } diff --git a/internal/gkr/test_vectors/generate.go b/internal/gkr/test_vectors/generate.go index 5aeabe5d53..b2f1d57ee9 100644 --- a/internal/gkr/test_vectors/generate.go +++ b/internal/gkr/test_vectors/generate.go @@ -16,13 +16,15 @@ import ( ) func main() { - var wg sync.WaitGroup - wg.Add(3) - for _, f := range []func() error{ + tasks := []func() error{ gkr.GenerateSumcheckVectors, gkr.GenerateVectors, generateSerializationTestData, - } { + } + + var wg sync.WaitGroup + wg.Add(len(tasks)) + for _, f := range tasks { go func() { assertNoError(f()) wg.Done() diff --git a/internal/gkr/test_vectors/single_identity_gate_two_instances.json b/internal/gkr/test_vectors/single_identity_gate_two_instances.json index ba28e35961..97d6fdd62c 100644 --- a/internal/gkr/test_vectors/single_identity_gate_two_instances.json +++ b/internal/gkr/test_vectors/single_identity_gate_two_instances.json @@ -18,19 +18,16 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + 5 + ], "partialSumPolys": [] }, { "finalEvalProof": [ 5 ], - "partialSumPolys": [ - [ - -3, - -8 - ] - ] + "partialSumPolys": [] } ] } \ No newline at end of file diff --git a/internal/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/internal/gkr/test_vectors/single_input_two_identity_gates_two_instances.json index 1451b332c2..052e08d08f 100644 --- a/internal/gkr/test_vectors/single_input_two_identity_gates_two_instances.json +++ b/internal/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -21,36 +21,17 @@ ] ], "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [ - [ - 0, - 0 - ] - ] - }, { "finalEvalProof": [ 1 ], - "partialSumPolys": [ - [ - -3, - -16 - ] - ] + "partialSumPolys": [] }, { "finalEvalProof": [ 1 ], - "partialSumPolys": [ - [ - -3, - -16 - ] - ] + "partialSumPolys": [] } ] } \ No newline at end of file diff --git a/internal/gkr/test_vectors/single_input_two_outs_two_instances.json b/internal/gkr/test_vectors/single_input_two_outs_two_instances.json index 897aea7ee5..9f39dc5cdd 100644 --- a/internal/gkr/test_vectors/single_input_two_outs_two_instances.json +++ b/internal/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -22,7 +22,9 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + 0 + ], "partialSumPolys": [ [ 0, @@ -46,12 +48,7 @@ "finalEvalProof": [ 0 ], - "partialSumPolys": [ - [ - -2, - -12 - ] - ] + "partialSumPolys": [] } ] } \ No newline at end of file diff --git a/internal/gkr/test_vectors/single_mimc_gate_four_instances.json b/internal/gkr/test_vectors/single_mimc_gate_four_instances.json index a724ba5a7b..3907e92c92 100644 --- a/internal/gkr/test_vectors/single_mimc_gate_four_instances.json +++ b/internal/gkr/test_vectors/single_mimc_gate_four_instances.json @@ -28,11 +28,15 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + -1 + ], "partialSumPolys": [] }, { - "finalEvalProof": [], + "finalEvalProof": [ + -3 + ], "partialSumPolys": [] }, { diff --git a/internal/gkr/test_vectors/single_mimc_gate_two_instances.json b/internal/gkr/test_vectors/single_mimc_gate_two_instances.json index 901db48692..80a585dced 100644 --- a/internal/gkr/test_vectors/single_mimc_gate_two_instances.json +++ b/internal/gkr/test_vectors/single_mimc_gate_two_instances.json @@ -22,11 +22,15 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + 1 + ], "partialSumPolys": [] }, { - "finalEvalProof": [], + "finalEvalProof": [ + 0 + ], "partialSumPolys": [] }, { diff --git a/internal/gkr/test_vectors/single_mul_gate_two_instances.json b/internal/gkr/test_vectors/single_mul_gate_two_instances.json index b85a6df42c..390f3ef9ba 100644 --- a/internal/gkr/test_vectors/single_mul_gate_two_instances.json +++ b/internal/gkr/test_vectors/single_mul_gate_two_instances.json @@ -22,11 +22,15 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + 5 + ], "partialSumPolys": [] }, { - "finalEvalProof": [], + "finalEvalProof": [ + 1 + ], "partialSumPolys": [] }, { diff --git a/internal/gkr/test_vectors/testdata/gkr_circuit_bn254.scs b/internal/gkr/test_vectors/testdata/gkr_circuit_bn254.scs index 39f6c3ecd3..f1cb52784c 100644 Binary files a/internal/gkr/test_vectors/testdata/gkr_circuit_bn254.scs and b/internal/gkr/test_vectors/testdata/gkr_circuit_bn254.scs differ diff --git a/internal/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/internal/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json index 69a2038a75..d4c277e246 100644 --- a/internal/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json +++ b/internal/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -18,30 +18,22 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + 3 + ], "partialSumPolys": [] }, { "finalEvalProof": [ 3 ], - "partialSumPolys": [ - [ - -1, - 0 - ] - ] + "partialSumPolys": [] }, { "finalEvalProof": [ 3 ], - "partialSumPolys": [ - [ - -1, - 0 - ] - ] + "partialSumPolys": [] } ] } \ No newline at end of file diff --git a/internal/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/internal/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json index 2dca0746a2..731b3c6f35 100644 --- a/internal/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json +++ b/internal/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -22,11 +22,15 @@ ], "proof": [ { - "finalEvalProof": [], + "finalEvalProof": [ + -1 + ], "partialSumPolys": [] }, { - "finalEvalProof": [], + "finalEvalProof": [ + 1 + ], "partialSumPolys": [] }, { @@ -34,12 +38,7 @@ -1, 1 ], - "partialSumPolys": [ - [ - -3, - -16 - ] - ] + "partialSumPolys": [] } ] } \ No newline at end of file diff --git a/internal/utils/slices.go b/internal/utils/slices.go index bdd86119fa..f6fcd943d6 100644 --- a/internal/utils/slices.go +++ b/internal/utils/slices.go @@ -8,15 +8,6 @@ func AppendRefs[T any](s []any, v []T) []any { return s } -// References returns a slice of references to the elements of v. -func References[T any](v []T) []*T { - res := make([]*T, len(v)) - for i := range v { - res[i] = &v[i] - } - return res -} - // ExtendRepeatLast extends a non-empty slice s by repeating the last element until it reaches the length n. func ExtendRepeatLast[T any](s []T, n int) []T { if n <= len(s) { diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 68c481e343..c202be7862 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -4,14 +4,14 @@ import ( "github.com/consensys/gnark/constraint/solver/gkrgates" // nolint SA1019 "github.com/consensys/gnark/frontend" gadget "github.com/consensys/gnark/internal/gkr" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi/gkr" ) type ( API struct { - circuit gkrtypes.GadgetCircuit + circuit gkrcore.RawCircuit assignments gadget.WireAssignment parentApi frontend.API } @@ -23,8 +23,8 @@ func frontendVarToInt(a gkr.Variable) int { // Gate adds the given gate with the given inputs and returns its output wire. func (api *API) Gate(gate gkr.GateFunction, inputs ...gkr.Variable) gkr.Variable { - api.circuit = append(api.circuit, gkrtypes.GadgetWire{ - Gate: gkrtypes.GadgetGate{Evaluate: gate}, + api.circuit = append(api.circuit, gkrcore.RawWire{ + Gate: gate, Inputs: utils.Map(inputs, frontendVarToInt), }) api.assignments = append(api.assignments, nil) @@ -49,19 +49,19 @@ func (api *API) gate2PlusIn(gate gkr.GateFunction, in1, in2 gkr.Variable, in ... } func (api *API) Add(i1, i2 gkr.Variable) gkr.Variable { - return api.gate2PlusIn(gkrtypes.Add2, i1, i2) + return api.gate2PlusIn(gkrcore.Add2, i1, i2) } func (api *API) Neg(i1 gkr.Variable) gkr.Variable { - return api.Gate(gkrtypes.Neg, i1) + return api.Gate(gkrcore.Neg, i1) } func (api *API) Sub(i1, i2 gkr.Variable) gkr.Variable { - return api.gate2PlusIn(gkrtypes.Sub2, i1, i2) + return api.gate2PlusIn(gkrcore.Sub2, i1, i2) } func (api *API) Mul(i1, i2 gkr.Variable) gkr.Variable { - return api.gate2PlusIn(gkrtypes.Mul2, i1, i2) + return api.gate2PlusIn(gkrcore.Mul2, i1, i2) } // Export explicitly designates a wire as output. diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 5f839f6f6a..0b83fd81e0 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -778,3 +778,97 @@ func TestMulti(t *testing.T) { test.NewAssert(t).CheckCircuit(new(testMultiCircuit), test.WithValidAssignment(&assignment)) } + +// Poseidon2 gate functions — match gkrtesting.Poseidon2Circuit(4, 2) exactly. + +func p2ExtLinear0(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[0], x[1]) // 2*x[0] + x[1] +} + +func p2ExtLinear1(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[1], x[1]) // x[0] + 2*x[1] +} + +func p2IntLinear1(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[1], x[1], x[1]) // x[0] + 3*x[1] +} + +func p2SBox(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Mul(x[0], x[0]) // x^2 +} + +func p2FeedForward(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[0], x[1], x[2]) // 2*x[0] + x[1] + x[2] +} + +type poseidon2GadgetCircuit struct { + In0 []frontend.Variable + In1 []frontend.Variable + hashName string +} + +func (c *poseidon2GadgetCircuit) Define(api frontend.API) error { + gkrApi, err := New(api) + if err != nil { + return err + } + in0 := gkrApi.NewInput() + in1 := gkrApi.NewInput() + s0, s1 := in0, in1 + + appendFullRound := func() { + lin0 := gkrApi.Gate(p2ExtLinear0, s0, s1) + lin1 := gkrApi.Gate(p2ExtLinear1, s0, s1) + s0 = gkrApi.Gate(p2SBox, lin0) + s1 = gkrApi.Gate(p2SBox, lin1) + } + appendPartialRound := func() { + lin0 := gkrApi.Gate(p2ExtLinear0, s0, s1) + lin1 := gkrApi.Gate(p2IntLinear1, s0, s1) + s0 = gkrApi.Gate(p2SBox, lin0) + s1 = lin1 + } + + // 4 full rounds, 2 partial rounds: 2 + 2 structure + appendFullRound() + appendFullRound() + appendPartialRound() + appendPartialRound() + appendFullRound() + appendFullRound() + + _ = gkrApi.Gate(p2FeedForward, s0, s1, in1) + + gkrCircuit, err := gkrApi.Compile(c.hashName) + if err != nil { + return err + } + + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.In0 { + instanceIn[in0] = c.In0[i] + instanceIn[in1] = c.In1[i] + if _, err := gkrCircuit.AddInstance(instanceIn); err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + } + return nil +} + +func TestPoseidon2Gadget(t *testing.T) { + assert := test.NewAssert(t) + nbInstances := 2 + In0 := make([]frontend.Variable, nbInstances) + In1 := make([]frontend.Variable, nbInstances) + for i := range In0 { + In0[i] = i + 1 + In1[i] = 2 * (i + 1) + } + assignment := &poseidon2GadgetCircuit{In0: In0, In1: In1} + circuit := &poseidon2GadgetCircuit{ + In0: make([]frontend.Variable, nbInstances), + In1: make([]frontend.Variable, nbInstances), + hashName: "-1", + } + assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254)) +} diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 3132493baa..9a43e35a61 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -1,21 +1,23 @@ package gkrapi import ( + "crypto/sha256" "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" gadget "github.com/consensys/gnark/internal/gkr" gkrbls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" gkrbls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" gkrbn254 "github.com/consensys/gnark/internal/gkr/bn254" gkrbw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" - "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/gkr/gkrcore" "github.com/consensys/gnark/internal/utils" - fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" "github.com/consensys/gnark/std/multicommit" ) @@ -25,16 +27,20 @@ type InitialChallengeGetter func() []frontend.Variable // Circuit represents a GKR circuit. type Circuit struct { - circuit gkrtypes.GadgetCircuit - gates []gkrtypes.GateBytecode - assignments gadget.WireAssignment - getInitialChallenges InitialChallengeGetter // optional getter for the initial Fiat-Shamir challenge - ins []gkr.Variable - outs []gkr.Variable - api frontend.API // the parent API + circuit gkrcore.GadgetCircuit + schedule constraint.GkrProvingSchedule + gates []gkrcore.GateBytecode + assignments gadget.WireAssignment + ins []gkr.Variable + outs []gkr.Variable + api frontend.API // the parent API + + // Fiat-Shamir bootstrapping + getInitialChallenges InitialChallengeGetter // optional getter for the I/O related portion of the initial Fiat-Shamir challenge + statementHash []byte // hash of the circuit and schedule // Blueprint-based fields - blueprints gkrtypes.Blueprints + blueprints gkrcore.Blueprints // Metadata hashName string @@ -51,15 +57,15 @@ func New(api frontend.API) (*API, error) { // NewInput creates a new input variable. func (api *API) NewInput() gkr.Variable { i := len(api.circuit) - api.circuit = append(api.circuit, gkrtypes.GadgetWire{}) + api.circuit = append(api.circuit, gkrcore.RawWire{}) api.assignments = append(api.assignments, nil) return gkr.Variable(i) } type CompileOption func(*Circuit) -// WithInitialChallenge provides a getter for the initial Fiat-Shamir challenge. -// If not provided, the initial challenge will be a commitment to all the input and output values of the circuit. +// WithInitialChallenge provides a getter for the I/O portion of the initial Fiat-Shamir challenge. +// If not provided, the I/O initial challenge will be a commitment to all the input and output values of the circuit. func WithInitialChallenge(getInitialChallenge InitialChallengeGetter) CompileOption { return func(c *Circuit) { c.getInitialChallenges = getInitialChallenge @@ -70,31 +76,47 @@ func WithInitialChallenge(getInitialChallenge InitialChallengeGetter) CompileOpt // From this point on, the circuit cannot be modified, // but instances can be added to it. func (api *API) Compile(fiatshamirHashName string, options ...CompileOption) (*Circuit, error) { - res := Circuit{ - circuit: api.circuit, - assignments: make(gadget.WireAssignment, len(api.circuit)), - api: api.parentApi, - hashName: fiatshamirHashName, - } - - // Dispatch to curve-specific factory + // Dispatch to a curve-specific factory compiler := api.parentApi.Compiler() field := compiler.Field() curveID := utils.FieldToCurve(field) - serializableCircuit, err := gkrtypes.CompileCircuit(api.circuit, field) + + gadgetCircuit, serializableCircuit, err := api.circuit.Compile(field) if err != nil { return nil, err } + schedule, err := gkrcore.DefaultProvingSchedule(serializableCircuit) + if err != nil { + return nil, fmt.Errorf("failed to compute proving schedule: %w", err) + } + + hsh := sha256.New() + if err = gkrcore.SerializeCircuit(hsh, serializableCircuit); err != nil { + return nil, fmt.Errorf("failed to serialize circuit: %w", err) + } + if err = gkrcore.SerializeSchedule(hsh, schedule); err != nil { + return nil, fmt.Errorf("failed to serialize schedule: %w", err) + } + + res := Circuit{ + circuit: gadgetCircuit, + schedule: schedule, + assignments: make(gadget.WireAssignment, len(api.circuit)), + api: api.parentApi, + hashName: fiatshamirHashName, + statementHash: hsh.Sum(nil), + } + switch curveID { case ecc.BN254: - res.blueprints = gkrbn254.NewBlueprints(serializableCircuit, fiatshamirHashName, compiler) + res.blueprints = gkrbn254.NewBlueprints(serializableCircuit, schedule, fiatshamirHashName, compiler) case ecc.BLS12_377: - res.blueprints = gkrbls12377.NewBlueprints(serializableCircuit, fiatshamirHashName, compiler) + res.blueprints = gkrbls12377.NewBlueprints(serializableCircuit, schedule, fiatshamirHashName, compiler) case ecc.BLS12_381: - res.blueprints = gkrbls12381.NewBlueprints(serializableCircuit, fiatshamirHashName, compiler) + res.blueprints = gkrbls12381.NewBlueprints(serializableCircuit, schedule, fiatshamirHashName, compiler) case ecc.BW6_761: - res.blueprints = gkrbw6761.NewBlueprints(serializableCircuit, fiatshamirHashName, compiler) + res.blueprints = gkrbw6761.NewBlueprints(serializableCircuit, schedule, fiatshamirHashName, compiler) default: return nil, fmt.Errorf("unsupported curve: %s", curveID) } @@ -203,6 +225,10 @@ func (c *Circuit) finalize(api frontend.API) error { // if the circuit consists of only one instance, directly solve the circuit if len(c.assignments[c.ins[0]]) == 1 { + outputSet := make(map[int]bool, len(c.outs)) + for _, wI := range c.outs { + outputSet[int(wI)] = true + } gateIn := make([]frontend.Variable, c.circuit.MaxGateNbIn()) for wI, w := range c.circuit { if w.IsInput() { @@ -212,7 +238,7 @@ func (c *Circuit) finalize(api frontend.API) error { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } res := w.Gate.Evaluate(gadget.FrontendAPIWrapper{API: api}, gateIn[:len(w.Inputs)]...) - if w.IsOutput() { + if outputSet[wI] { api.AssertIsEqual(res, c.assignments[wI][0]) } else { c.assignments[wI] = append(c.assignments[wI], res) @@ -222,26 +248,27 @@ func (c *Circuit) finalize(api frontend.API) error { } if c.getInitialChallenges != nil { - return c.verify(api, c.circuit, c.getInitialChallenges()) + return c.verify(api, c.circuit, append([]frontend.Variable{c.statementHash}, c.getInitialChallenges()...)) } - // default initial challenge is a commitment to all input and output values - insOuts := make([]frontend.Variable, 0, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])) + // The default initial challenge is a commitment to the circuit, solving schedule, and all input and output values. + challenges := make([]frontend.Variable, 1, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])+1) + challenges[0] = c.statementHash for _, in := range c.ins { - insOuts = append(insOuts, c.assignments[in]...) + challenges = append(challenges, c.assignments[in]...) } for _, out := range c.outs { - insOuts = append(insOuts, c.assignments[out]...) + challenges = append(challenges, c.assignments[out]...) } multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { return c.verify(api, c.circuit, []frontend.Variable{commitment}) - }, insOuts...) + }, challenges...) return nil } -func (c *Circuit) verify(api frontend.API, circuit gkrtypes.GadgetCircuit, initialChallenges []frontend.Variable) error { +func (c *Circuit) verify(api frontend.API, circuit gkrcore.GadgetCircuit, initialChallenges []frontend.Variable) error { compiler := api.Compiler() @@ -272,7 +299,7 @@ func (c *Circuit) verify(api frontend.API, circuit gkrtypes.GadgetCircuit, initi err error ) - if proof, err = gadget.DeserializeProof(circuit, proofSerialized); err != nil { + if proof, err = gadget.DeserializeProof(circuit, c.schedule, proofSerialized); err != nil { return err } @@ -281,23 +308,24 @@ func (c *Circuit) verify(api frontend.API, circuit gkrtypes.GadgetCircuit, initi return err } - return gadget.Verify(api, circuit, c.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...)) + hsh.Write(initialChallenges...) + return gadget.Verify(api, circuit, c.schedule, c.assignments, proof, hsh) } // GetValue is a debugging utility returning the value of variable v at instance i. // While v can be an input or output variable, GetValue is most useful for querying intermediate values in the circuit. func (c *Circuit) GetValue(v gkr.Variable, i int) frontend.Variable { - // Create an instruction that will retrieve the assignment at solve time + // Create an instruction that will retrieve the assignment at solving time compiler := c.api.Compiler() // Build calldata: [0]=totalSize, [1]=wireI, [2]=instanceI, [3...]=dependency_wire_as_linear_expression - // The dependency ensures this instruction runs after the solve instruction for instance i + // The dependency ensures this instruction runs after the solving instruction for instance i calldata := make([]uint32, 3, 6) // pre-allocate: size + wireI + instanceI + dependency linear expression (typically 3) calldata[1] = uint32(v) calldata[2] = uint32(i) // Use the first output variable from instance i as a dependency - // This ensures the solve instruction for this instance has completed + // This ensures the solving instruction for this instance has completed if len(c.outs) == 0 || i >= len(c.assignments[c.outs[0]]) { panic("GetValue called with invalid instance or before instance was added") } diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index bbbef1f87c..f21222d92b 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -5,7 +5,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" - _ "github.com/consensys/gnark/std/hash/all" gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" ) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 14270b8135..679d618ad7 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -15,7 +15,6 @@ import ( "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark/std/hash" - _ "github.com/consensys/gnark/std/hash/all" ) // compressor implements a compression function by applying