@@ -13,6 +13,7 @@ use polysub::{Coef, Mod, PhaseOptPolynom, VarIndex};
1313use rustc_hash:: { FxHashMap , FxHashSet } ;
1414use smallvec:: SmallVec ;
1515use std:: collections:: VecDeque ;
16+ use std:: fmt:: { Display , Formatter } ;
1617
1718/// Used for backwards substitution
1819type PolyOpt = PhaseOptPolynom < DefaultCoef > ;
@@ -33,14 +34,15 @@ pub fn backwards_sub(
3334 // check to see if there are any half adders we can identify
3435 let mut has = find_half_adders ( ctx, gate_level_expr) ;
3536 has. sort_by_key ( |ha| std:: cmp:: min ( usize:: from ( ha. sum ) , usize:: from ( ha. carry ) ) ) ;
36- let ha_id = FxHashMap :: from_iter ( has. iter ( ) . enumerate ( ) . map ( |( ii, ha) | ( * ha, ii) ) ) ;
3737 println ! ( "{} HAs: {has:?}" , has. len( ) ) ;
3838 let fas = find_full_adders ( ctx, gate_level_expr) ;
3939 println ! ( "{} FAs: {fas:?}" , fas. len( ) ) ;
4040 let output_to_ha = FxHashMap :: from_iter (
41- has. into_iter ( )
42- . flat_map ( |ha| vec ! [ ( ha. sum, ha) , ( ha. carry, ha) ] ) ,
41+ has. iter ( )
42+ . enumerate ( )
43+ . flat_map ( |( ii, ha) | vec ! [ ( ha. sum, ii) , ( ha. carry, ii) ] ) ,
4344 ) ;
45+ let ha_uses = analyze_ha_uses ( ctx, gate_level_expr, & has, & output_to_ha) ;
4446
4547 let mut spec: PolyOpt = spec. into ( ) ;
4648 let mut var_roots: Vec < _ > = todo. iter ( ) . map ( |( v, _) | * v) . collect ( ) ;
@@ -55,7 +57,7 @@ pub fn backwards_sub(
5557 let mut iter_count = 0 ;
5658 while !todo. is_empty ( ) {
5759 iter_count += 1 ;
58- let gate = pick_gate ( & todo, & root_uses, & output_to_ha, & replaced) ;
60+ let gate = pick_gate ( & todo, & root_uses, & output_to_ha, & has , & replaced, & ha_uses ) ;
5961 let output_var = todo. remove ( & gate) . unwrap ( ) ;
6062
6163 debug_assert ! ( !replaced. contains( & output_var) ) ;
@@ -66,7 +68,8 @@ pub fn backwards_sub(
6668 . map ( |ii| format ! ( "{}" , root_vars[ ii] ) )
6769 . collect :: < Vec < _ > > ( )
6870 . join ( ", " ) ;
69- let ha_str = half_adder_info ( gate, & output_to_ha, & replaced, & ha_id) ;
71+ let ha_str = half_adder_info ( gate, & has, & output_to_ha, & replaced) ;
72+
7073 let before_size = spec. size ( ) ;
7174
7275 let add_children = replace_gate ( ctx, input_vars, & mut spec, output_var, gate) ;
@@ -129,12 +132,12 @@ pub fn backwards_sub(
129132
130133fn half_adder_info (
131134 gate : ExprRef ,
132- output_to_ha : & FxHashMap < ExprRef , HalfAdder > ,
135+ has : & [ HalfAdder ] ,
136+ output_to_ha : & FxHashMap < ExprRef , usize > ,
133137 replaced : & FxHashSet < VarIndex > ,
134- ha_id : & FxHashMap < HalfAdder , usize > ,
135138) -> String {
136- if let Some ( ha ) = output_to_ha. get ( & gate) {
137- let id = ha_id [ ha ] ;
139+ if let Some ( & id ) = output_to_ha. get ( & gate) {
140+ let ha = & has [ id ] ;
138141 if ha. sum == gate {
139142 if replaced. contains ( & expr_to_var ( ha. carry ) ) {
140143 format ! ( " (S{id} -> DONE)" )
@@ -238,13 +241,28 @@ fn print_gate_stats(ctx: &Context, input_vars: &FxHashSet<VarIndex>, root: ExprR
238241fn pick_gate (
239242 todo : & FxHashMap < ExprRef , VarIndex > ,
240243 root_uses : & impl ExprMap < BitSet > ,
241- output_to_ha : & FxHashMap < ExprRef , HalfAdder > ,
244+ output_to_ha : & FxHashMap < ExprRef , usize > ,
245+ has : & [ HalfAdder ] ,
242246 replaced : & FxHashSet < VarIndex > ,
247+ ha_uses : & impl ExprMap < BitSet > ,
243248) -> ExprRef {
249+ let ha_uses_str: Vec < _ > = todo
250+ . keys ( )
251+ . map ( |& e| {
252+ let items: Vec < _ > = ha_uses[ e]
253+ . iter ( )
254+ . map ( |i| format ! ( "{}" , HalfAdderOut :: from( i) ) )
255+ . collect ( ) ;
256+ format ! ( "{{{}}}" , items. join( ", " ) )
257+ } )
258+ . collect ( ) ;
259+ println ! ( "HA: [{}]" , ha_uses_str. join( ", " ) ) ;
260+
244261 // check to see if there is a gate that would finish a half adder
245262 let mut other_output_available = vec ! [ ] ;
246263 for & gate in todo. keys ( ) {
247- if let Some ( ha) = output_to_ha. get ( & gate) {
264+ if let Some ( & ha_index) = output_to_ha. get ( & gate) {
265+ let ha = & has[ ha_index] ;
248266 let other = ha. get_other_output ( gate) . unwrap ( ) ;
249267 if replaced. contains ( & expr_to_var ( other) ) {
250268 return gate;
@@ -317,6 +335,86 @@ fn try_exhaustive(
317335 min_with_lowest_use
318336}
319337
338+ #[ derive( Debug , Copy , Clone ) ]
339+ enum HalfAdderOut {
340+ Sum ( usize ) ,
341+ Carry ( usize ) ,
342+ }
343+
344+ impl Display for HalfAdderOut {
345+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
346+ match self {
347+ HalfAdderOut :: Sum ( i) => write ! ( f, "S{i}" ) ,
348+ HalfAdderOut :: Carry ( i) => write ! ( f, "C{i}" ) ,
349+ }
350+ }
351+ }
352+
353+ impl HalfAdderOut {
354+ pub fn from_index_and_expr ( has : & [ HalfAdder ] , index : usize , e : ExprRef ) -> Option < Self > {
355+ let ha = & has[ index] ;
356+ if ha. sum == e {
357+ Some ( Self :: Sum ( index) )
358+ } else if ha. carry == e {
359+ Some ( Self :: Carry ( index) )
360+ } else {
361+ None
362+ }
363+ }
364+ }
365+
366+ impl From < HalfAdderOut > for usize {
367+ fn from ( value : HalfAdderOut ) -> Self {
368+ match value {
369+ HalfAdderOut :: Sum ( i) => i * 2 ,
370+ HalfAdderOut :: Carry ( i) => i * 2 + 1 ,
371+ }
372+ }
373+ }
374+
375+ impl From < usize > for HalfAdderOut {
376+ fn from ( value : usize ) -> Self {
377+ if value & 1 == 1 {
378+ Self :: Carry ( value / 2 )
379+ } else {
380+ Self :: Sum ( value / 2 )
381+ }
382+ }
383+ }
384+
385+ fn analyze_ha_uses (
386+ ctx : & Context ,
387+ root : ExprRef ,
388+ has : & [ HalfAdder ] ,
389+ output_to_ha : & FxHashMap < ExprRef , usize > ,
390+ ) -> impl ExprMap < BitSet > {
391+ let mut out = DenseExprMetaData :: < BitSet > :: default ( ) ;
392+
393+ traversal:: bottom_up ( ctx, root, |_, e, c : & [ BitSet ] | {
394+ let set = if let Some ( & index) = output_to_ha. get ( & e) {
395+ let out = HalfAdderOut :: from_index_and_expr ( has, index, e) . unwrap ( ) ;
396+ let mut set = BitSet :: new ( ) ;
397+ set. insert ( out. into ( ) ) ;
398+ set
399+ } else {
400+ if c. is_empty ( ) {
401+ BitSet :: new ( )
402+ } else if c. len ( ) == 1 {
403+ c[ 0 ] . clone ( )
404+ } else {
405+ assert_eq ! ( c. len( ) , 2 ) ;
406+ let mut r = c[ 0 ] . clone ( ) ;
407+ r. union_with ( & c[ 1 ] ) ;
408+ r
409+ }
410+ } ;
411+ out[ e] = set. clone ( ) ;
412+ set
413+ } ) ;
414+
415+ out
416+ }
417+
320418/// Calculates for each expression which root depends on it.
321419fn analyze_uses ( ctx : & Context , roots : & [ ExprRef ] ) -> impl ExprMap < BitSet > {
322420 let mut out = DenseExprMetaData :: < BitSet > :: default ( ) ;
0 commit comments