@@ -5,7 +5,7 @@ use ff_ext::ExtensionField;
55use itertools:: Itertools ;
66use p3:: field:: Field ;
77use serde:: { Deserialize , Serialize } ;
8- use std:: collections:: HashMap ;
8+ use std:: collections:: { BTreeMap , HashMap } ;
99
1010impl WitIn {
1111 pub fn assign < E : ExtensionField > ( & self , instance : & mut [ E :: BaseField ] , value : E :: BaseField ) {
@@ -585,7 +585,10 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
585585 }
586586 }
587587 }
588- c @ Expression :: Challenge ( ..) => {
588+ c @ Expression :: Challenge ( challenge_id, _power, scalar, offset) => {
589+ if * scalar == E :: ZERO && * offset == E :: ZERO {
590+ return None
591+ }
589592 let challenge_id = * challenges_dedup. entry ( c. clone ( ) ) . or_insert_with ( || {
590593 challenges. push ( c. clone ( ) ) ;
591594 ( challenges_offset + challenges. len ( ) - 1 ) as u32
@@ -603,16 +606,168 @@ fn expr_compression_to_dag_helper<E: ExtensionField>(
603606 }
604607}
605608
609+ // trie
610+ #[ derive( Default ) ]
611+ struct TrieNode {
612+ children : BTreeMap < u16 , TrieNode > , // Sorted keys: commutative grouping
613+ scalar_indices : Vec < usize > ,
614+ }
615+ pub fn build_factored_dag_commutative < E : ExtensionField > (
616+ terms : & [ Term < Expression < E > , Expression < E > > ] ,
617+ ) -> ( Vec < Node > , Vec < Expression < E > > , Option < u32 > , u32 ) {
618+ let mut root = TrieNode :: default ( ) ;
619+ let mut scalars: Vec < Expression < E > > = Vec :: new ( ) ;
620+
621+ // ---- Step 1: canonicalize products (commutative) ----
622+ for term in terms {
623+ let mut ids: Vec < u16 > = term
624+ . product
625+ . iter ( )
626+ . filter_map ( |e| match e {
627+ Expression :: WitIn ( id) => Some ( * id) ,
628+ e => unimplemented ! ( "unknown expression {e}" ) ,
629+ } )
630+ . collect ( ) ;
631+ ids. sort ( ) ; // ensure a*b == b*a
632+ // we assume witiness being shared will be made with larger id
633+ // so we build the prefix tree with larger id go first
634+ ids. reverse ( ) ;
635+
636+ let mut cur = & mut root;
637+ for wid in ids {
638+ cur = cur. children . entry ( wid) . or_default ( ) ;
639+ }
640+
641+ let idx = scalars. len ( ) ;
642+ scalars. push ( term. scalar . clone ( ) ) ;
643+ cur. scalar_indices . push ( idx) ;
644+ }
645+
646+ // ---- Step 2: emit DAG (stack semantics) ----
647+ let mut dag = Vec :: new ( ) ;
648+ let mut stack_top: u32 = 0 ;
649+ let mut max_stack_depth: u32 = 0 ;
650+
651+ fn push ( stack_top : & mut u32 , max_depth : & mut u32 ) -> u32 {
652+ let out = * stack_top;
653+ * stack_top += 1 ;
654+ * max_depth = ( * max_depth) . max ( * stack_top) ;
655+ out
656+ }
657+
658+ fn pop2_push1 ( stack_top : & mut u32 ) -> ( u32 , u32 , u32 ) {
659+ let left = * stack_top - 2 ;
660+ let right = * stack_top - 1 ;
661+ let out = left;
662+ * stack_top -= 1 ;
663+ ( left, right, out)
664+ }
665+
666+ fn emit < E : ExtensionField > (
667+ node : & TrieNode ,
668+ dag : & mut Vec < Node > ,
669+ stack_top : & mut u32 ,
670+ max_depth : & mut u32 ,
671+ ) -> Option < u32 > {
672+ let mut acc_child: Option < u32 > = None ;
673+
674+ // Recurse into children (witness factors)
675+ for ( & wid, child) in & node. children {
676+ let child_out = emit :: < E > ( child, dag, stack_top, max_depth) ;
677+
678+ // LOAD_WIT: push
679+ let out = push ( stack_top, max_depth) ;
680+ dag. push ( Node {
681+ op : DagLoadWit as u32 ,
682+ left_id : wid as u32 ,
683+ right_id : 0 ,
684+ out,
685+ } ) ;
686+
687+ // If child exists, multiply with it
688+ if let Some ( rhs) = child_out {
689+ let ( left, right, out) = pop2_push1 ( stack_top) ;
690+ dag. push ( Node {
691+ op : DagMul as u32 ,
692+ left_id : left,
693+ right_id : right,
694+ out,
695+ } ) ;
696+ acc_child = Some ( match acc_child {
697+ None => out,
698+ Some ( _) => {
699+ let ( l, r, out) = pop2_push1 ( stack_top) ;
700+ dag. push ( Node {
701+ op : DagAdd as u32 ,
702+ left_id : l,
703+ right_id : r,
704+ out,
705+ } ) ;
706+ out
707+ }
708+ } ) ;
709+ } else {
710+ acc_child = Some ( out) ;
711+ }
712+ }
713+
714+ // Handle scalar accumulation at leaf
715+ let mut acc_scalar: Option < u32 > = None ;
716+ for & idx in & node. scalar_indices {
717+ let out = push ( stack_top, max_depth) ;
718+ dag. push ( Node {
719+ op : DagLoadScalar as u32 ,
720+ left_id : idx as u32 ,
721+ right_id : 0 ,
722+ out,
723+ } ) ;
606724
725+ acc_scalar = Some ( match acc_scalar {
726+ None => out,
727+ Some ( _) => {
728+ let ( l, r, out) = pop2_push1 ( stack_top) ;
729+ dag. push ( Node {
730+ op : DagAdd as u32 ,
731+ left_id : l,
732+ right_id : r,
733+ out,
734+ } ) ;
735+ out
736+ }
737+ } ) ;
738+ }
739+
740+ // Merge both child and scalar accumulations
741+ match ( acc_scalar, acc_child) {
742+ ( Some ( _) , Some ( _) ) => {
743+ let ( l, r, out) = pop2_push1 ( stack_top) ;
744+ dag. push ( Node {
745+ op : DagAdd as u32 ,
746+ left_id : l,
747+ right_id : r,
748+ out,
749+ } ) ;
750+ Some ( out)
751+ }
752+ ( Some ( s) , None ) => Some ( s) ,
753+ ( None , Some ( c) ) => Some ( c) ,
754+ ( None , None ) => None ,
755+ }
756+ }
757+
758+ let final_out = emit :: < E > ( & root, & mut dag, & mut stack_top, & mut max_stack_depth) ;
759+ ( dag, scalars, final_out, max_stack_depth)
760+ }
607761#[ cfg( test) ]
608762mod tests {
763+ use std:: ops:: Neg ;
609764 use either:: Either ;
610765 use itertools:: Itertools ;
611766 use ff_ext:: { BabyBearExt4 , ExtensionField } ;
612767 use p3:: babybear:: BabyBear ;
613768 use p3:: field:: FieldAlgebra ;
614769 use crate :: { power_sequence, Expression , Instance , ToExpr } ;
615- use crate :: utils:: { expr_compression_to_dag, Node } ;
770+ use crate :: utils:: { build_factored_dag_commutative , expr_compression_to_dag, Node } ;
616771
617772 type E = BabyBearExt4 ;
618773 type B = BabyBear ;
@@ -710,4 +865,38 @@ mod tests {
710865 assert_eq ! ( max_degree, 1 ) ;
711866
712867 }
868+
869+ #[ test]
870+ fn test_build_factored_dag_commutative ( ) {
871+ // w1 * (c2 * (2 + w0*c1 -1))
872+ let w0 = Expression :: < E > :: WitIn ( 0 ) ;
873+ let w1 = Expression :: < E > :: WitIn ( 1 ) ;
874+ let c1 = Expression :: < E > :: Challenge ( 0 , 1 , E :: ONE , E :: ZERO ) ;
875+ let c2 = Expression :: < E > :: Challenge ( 2 , 1 , E :: ONE , E :: ZERO ) ;
876+ let constant_2 = Expression :: < E > :: Constant ( Either :: Left ( B :: from_canonical_u32 ( 2 ) ) ) ;
877+ let constant_negative_1 = Expression :: < E > :: Constant ( Either :: Left ( B :: from_canonical_u32 ( 1 ) . neg ( ) ) ) ;
878+
879+ let e: Expression < E > = w1. expr ( ) * ( c2. expr ( ) * ( constant_2. expr ( ) + w0. expr ( ) * c1. expr ( ) - constant_negative_1. expr ( ) ) ) ;
880+ let e_monomials = e. get_monomial_terms ( ) ;
881+ let ( dag, coeffs, final_out, _) = build_factored_dag_commutative ( & e_monomials) ;
882+
883+ let mut num_add = 0 ;
884+ let mut num_mul = 0 ;
885+
886+ for node in & dag {
887+ match node. op {
888+ 0 => ( ) , // skip wit index
889+ 1 => ( ) , // skip scalar index
890+ 2 => {
891+ num_add += 1 ;
892+ }
893+ 3 => {
894+ num_mul += 1 ;
895+ }
896+ op => panic ! ( "unknown op {op}" ) ,
897+ }
898+ }
899+ assert_eq ! ( num_add, 1 ) ;
900+ assert_eq ! ( num_mul, 3 ) ;
901+ }
713902}
0 commit comments