Skip to content

Commit d94dd4d

Browse files
committed
build dag from monomial terms
1 parent 24b4570 commit d94dd4d

File tree

1 file changed

+192
-3
lines changed
  • crates/multilinear_extensions/src/expression

1 file changed

+192
-3
lines changed

crates/multilinear_extensions/src/expression/utils.rs

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ff_ext::ExtensionField;
55
use itertools::Itertools;
66
use p3::field::Field;
77
use serde::{Deserialize, Serialize};
8-
use std::collections::HashMap;
8+
use std::collections::{BTreeMap, HashMap};
99

1010
impl WitIn {
1111
pub fn assign<E: ExtensionField>(&self, instance: &mut [E::BaseField], value: E::BaseField) {
@@ -585,7 +585,10 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
585585
}
586586
}
587587
}
588-
c @ Expression::Challenge(..) => {
588+
c @ Expression::Challenge(challenge_id, _power, scalar, offset) => {
589+
if *scalar == E::ZERO && *offset == E::ZERO {
590+
return None
591+
}
589592
let challenge_id = *challenges_dedup.entry(c.clone()).or_insert_with(|| {
590593
challenges.push(c.clone());
591594
(challenges_offset + challenges.len() - 1) as u32
@@ -603,16 +606,168 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
603606
}
604607
}
605608

609+
// trie
610+
#[derive(Default)]
611+
struct TrieNode {
612+
children: BTreeMap<u16, TrieNode>, // Sorted keys: commutative grouping
613+
scalar_indices: Vec<usize>,
614+
}
615+
pub fn build_factored_dag_commutative<E: ExtensionField>(
616+
terms: &[Term<Expression<E>, Expression<E>>],
617+
) -> (Vec<Node>, Vec<Expression<E>>, Option<u32>, u32) {
618+
let mut root = TrieNode::default();
619+
let mut scalars: Vec<Expression<E>> = Vec::new();
620+
621+
// ---- Step 1: canonicalize products (commutative) ----
622+
for term in terms {
623+
let mut ids: Vec<u16> = term
624+
.product
625+
.iter()
626+
.filter_map(|e| match e {
627+
Expression::WitIn(id) => Some(*id),
628+
e => unimplemented!("unknown expression {e}"),
629+
})
630+
.collect();
631+
ids.sort(); // ensure a*b == b*a
632+
// we assume witiness being shared will be made with larger id
633+
// so we build the prefix tree with larger id go first
634+
ids.reverse();
635+
636+
let mut cur = &mut root;
637+
for wid in ids {
638+
cur = cur.children.entry(wid).or_default();
639+
}
640+
641+
let idx = scalars.len();
642+
scalars.push(term.scalar.clone());
643+
cur.scalar_indices.push(idx);
644+
}
645+
646+
// ---- Step 2: emit DAG (stack semantics) ----
647+
let mut dag = Vec::new();
648+
let mut stack_top: u32 = 0;
649+
let mut max_stack_depth: u32 = 0;
650+
651+
fn push(stack_top: &mut u32, max_depth: &mut u32) -> u32 {
652+
let out = *stack_top;
653+
*stack_top += 1;
654+
*max_depth = (*max_depth).max(*stack_top);
655+
out
656+
}
657+
658+
fn pop2_push1(stack_top: &mut u32) -> (u32, u32, u32) {
659+
let left = *stack_top - 2;
660+
let right = *stack_top - 1;
661+
let out = left;
662+
*stack_top -= 1;
663+
(left, right, out)
664+
}
665+
666+
fn emit<E: ExtensionField>(
667+
node: &TrieNode,
668+
dag: &mut Vec<Node>,
669+
stack_top: &mut u32,
670+
max_depth: &mut u32,
671+
) -> Option<u32> {
672+
let mut acc_child: Option<u32> = None;
673+
674+
// Recurse into children (witness factors)
675+
for (&wid, child) in &node.children {
676+
let child_out = emit::<E>(child, dag, stack_top, max_depth);
677+
678+
// LOAD_WIT: push
679+
let out = push(stack_top, max_depth);
680+
dag.push(Node {
681+
op: DagLoadWit as u32,
682+
left_id: wid as u32,
683+
right_id: 0,
684+
out,
685+
});
686+
687+
// If child exists, multiply with it
688+
if let Some(rhs) = child_out {
689+
let (left, right, out) = pop2_push1(stack_top);
690+
dag.push(Node {
691+
op: DagMul as u32,
692+
left_id: left,
693+
right_id: right,
694+
out,
695+
});
696+
acc_child = Some(match acc_child {
697+
None => out,
698+
Some(_) => {
699+
let (l, r, out) = pop2_push1(stack_top);
700+
dag.push(Node {
701+
op: DagAdd as u32,
702+
left_id: l,
703+
right_id: r,
704+
out,
705+
});
706+
out
707+
}
708+
});
709+
} else {
710+
acc_child = Some(out);
711+
}
712+
}
713+
714+
// Handle scalar accumulation at leaf
715+
let mut acc_scalar: Option<u32> = None;
716+
for &idx in &node.scalar_indices {
717+
let out = push(stack_top, max_depth);
718+
dag.push(Node {
719+
op: DagLoadScalar as u32,
720+
left_id: idx as u32,
721+
right_id: 0,
722+
out,
723+
});
606724

725+
acc_scalar = Some(match acc_scalar {
726+
None => out,
727+
Some(_) => {
728+
let (l, r, out) = pop2_push1(stack_top);
729+
dag.push(Node {
730+
op: DagAdd as u32,
731+
left_id: l,
732+
right_id: r,
733+
out,
734+
});
735+
out
736+
}
737+
});
738+
}
739+
740+
// Merge both child and scalar accumulations
741+
match (acc_scalar, acc_child) {
742+
(Some(_), Some(_)) => {
743+
let (l, r, out) = pop2_push1(stack_top);
744+
dag.push(Node {
745+
op: DagAdd as u32,
746+
left_id: l,
747+
right_id: r,
748+
out,
749+
});
750+
Some(out)
751+
}
752+
(Some(s), None) => Some(s),
753+
(None, Some(c)) => Some(c),
754+
(None, None) => None,
755+
}
756+
}
757+
758+
let final_out = emit::<E>(&root, &mut dag, &mut stack_top, &mut max_stack_depth);
759+
(dag, scalars, final_out, max_stack_depth)
760+
}
607761
#[cfg(test)]
608762
mod tests {
763+
use std::ops::Neg;
609764
use either::Either;
610765
use itertools::Itertools;
611766
use ff_ext::{BabyBearExt4, ExtensionField};
612767
use p3::babybear::BabyBear;
613768
use p3::field::FieldAlgebra;
614769
use crate::{power_sequence, Expression, Instance, ToExpr};
615-
use crate::utils::{expr_compression_to_dag, Node};
770+
use crate::utils::{build_factored_dag_commutative, expr_compression_to_dag, Node};
616771

617772
type E = BabyBearExt4;
618773
type B = BabyBear;
@@ -710,4 +865,38 @@ mod tests {
710865
assert_eq!(max_degree, 1);
711866

712867
}
868+
869+
#[test]
870+
fn test_build_factored_dag_commutative() {
871+
// w1 * (c2 * (2 + w0*c1 -1))
872+
let w0 = Expression::<E>::WitIn(0);
873+
let w1 = Expression::<E>::WitIn(1);
874+
let c1 = Expression::<E>::Challenge(0, 1, E::ONE, E::ZERO);
875+
let c2 = Expression::<E>::Challenge(2, 1, E::ONE, E::ZERO);
876+
let constant_2 = Expression::<E>::Constant(Either::Left(B::from_canonical_u32(2)));
877+
let constant_negative_1 = Expression::<E>::Constant(Either::Left(B::from_canonical_u32(1).neg()));
878+
879+
let e: Expression<E> = w1.expr() * (c2.expr() * (constant_2.expr() + w0.expr() * c1.expr() - constant_negative_1.expr()));
880+
let e_monomials = e.get_monomial_terms();
881+
let (dag, coeffs, final_out, _)= build_factored_dag_commutative(&e_monomials);
882+
883+
let mut num_add = 0;
884+
let mut num_mul = 0;
885+
886+
for node in &dag {
887+
match node.op {
888+
0 => (), // skip wit index
889+
1 => (), // skip scalar index
890+
2 => {
891+
num_add += 1;
892+
}
893+
3 => {
894+
num_mul += 1;
895+
}
896+
op => panic!("unknown op {op}"),
897+
}
898+
}
899+
assert_eq!(num_add, 1);
900+
assert_eq!(num_mul, 3);
901+
}
713902
}

0 commit comments

Comments
 (0)