@@ -247,10 +247,10 @@ getAnnotation = \case
247
247
UEBundle a _ -> a
248
248
249
249
hashCons :: (Hashable i , Hashable f ) => UntypedExpr () i f -> UntypedExpr Int i f
250
- hashCons = \ case
250
+ hashCons = \ case
251
251
UEVal _ f -> UEVal (hash (hash @ Text " Val" , f)) f
252
252
UEVar _ v -> UEVar (hash (hash @ Text " Var" , v)) v
253
- UEUnOp _ op e ->
253
+ UEUnOp _ op e ->
254
254
let e' = hashCons e
255
255
in UEUnOp (hash (op, getAnnotation e')) op e'
256
256
UEBinOp _ op e1 e2 ->
@@ -473,7 +473,7 @@ data BuilderState f = BuilderState
473
473
{ bsCircuit :: ArithCircuit f ,
474
474
bsNextVar :: Int ,
475
475
bsVars :: CircuitVars Text ,
476
- bsSharedMap :: Map Id (UntypedExpr Int Wire f , V. Vector (SignalSource f ))
476
+ bsSharedMap :: Map Id (V. Vector (SignalSource f ))
477
477
}
478
478
479
479
defaultBuilderState :: BuilderState f
@@ -645,7 +645,7 @@ compileWithWires ::
645
645
m (V. Vector Wire )
646
646
compileWithWires ws expr = do
647
647
let e = hashCons $ unType expr
648
- compileOut <- memoizedCompile $ e
648
+ compileOut <- memoizedCompile $ e
649
649
for (V. zip compileOut ws) $ \ (o, freshWire) -> do
650
650
case o of
651
651
WireSource wire -> do
@@ -675,37 +675,37 @@ assertSameSourceSize l r =
675
675
throwError $
676
676
MismatchedWireTypes l r
677
677
678
+ withCompilerCache ::
679
+ (MonadState (BuilderState f ) m ) =>
680
+ Id ->
681
+ m (V. Vector (SignalSource f )) ->
682
+ m (V. Vector (SignalSource f ))
683
+ withCompilerCache i m = do
684
+ res <- m
685
+ modify $ \ s -> s {bsSharedMap = Map. insert i res (bsSharedMap s)}
686
+ pure res
687
+
678
688
_compile ::
679
689
forall f m .
680
690
(Hashable f , GaloisField f ) =>
681
691
(MonadState (BuilderState f ) m ) =>
682
692
(MonadError (CircuitBuilderError f ) m ) =>
683
693
UntypedExpr Int Wire f ->
684
694
m (V. Vector (SignalSource f ))
685
- _compile expr = case expr of
686
- UEVal i v ->
687
- case v of
688
- f -> do
689
- let res = V. singleton $ AffineSource $ ConstGate f
690
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
691
- pure res
692
- UEVar i (UVar var) -> do
693
- let res = V. singleton $ WireSource var
694
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
695
- pure res
696
- UEUnOp i op e1 -> do
695
+ _compile expr = withCompilerCache (getAnnotation expr) $ case expr of
696
+ UEVal _ f -> pure $ V. singleton $ AffineSource $ ConstGate f
697
+ UEVar _ (UVar var) -> pure . V. singleton $ WireSource var
698
+ UEUnOp _ op e1 -> do
697
699
e1Outs <- memoizedCompile e1
698
- res <- for e1Outs $ \ e1Out ->
700
+ for e1Outs $ \ e1Out ->
699
701
case op of
700
702
UUNeg -> pure . AffineSource $ ScalarMul (- 1 ) (addVar e1Out)
701
703
UUNot -> pure . AffineSource $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (addVar e1Out))
702
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
703
- pure res
704
- UEBinOp i op e1 e2 -> do
704
+ UEBinOp _ op e1 e2 -> do
705
705
e1Outs <- memoizedCompile e1
706
706
e2Outs <- memoizedCompile e2
707
707
assertSameSourceSize e1Outs e2Outs
708
- res <- for (V. zip (addVar <$> e1Outs) (addVar <$> e2Outs)) $ \ (e1Out, e2Out) ->
708
+ for (V. zip (addVar <$> e1Outs) (addVar <$> e2Outs)) $ \ (e1Out, e2Out) ->
709
709
case op of
710
710
UAdd -> pure . AffineSource $ Add e1Out e2Out
711
711
UMul -> do
@@ -733,11 +733,8 @@ _compile expr = case expr of
733
733
tmp1 <- imm
734
734
emit $ Mul e1Out e2Out tmp1
735
735
pure . AffineSource $ Add (Add e1Out e2Out) (ScalarMul (- 2 ) (Var tmp1))
736
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
737
- pure res
738
736
-- IF(cond, true, false) = (cond*true) + ((!cond) * false)
739
- UEIf i cond true false -> do
740
- res <- V. singleton <$> do
737
+ UEIf _ cond true false -> do
741
738
condOut <- addVar <$> (memoizedCompile cond >>= assertSingleSource)
742
739
trueOuts <- memoizedCompile true
743
740
falseOuts <- memoizedCompile false
@@ -748,69 +745,49 @@ _compile expr = case expr of
748
745
tmp2 <- imm
749
746
for_ (addVar <$> falseOuts) $ \ falseOut ->
750
747
emit $ Mul (Add (ConstGate 1 ) (ScalarMul (- 1 ) condOut)) falseOut tmp2
751
- pure . AffineSource $ Add (Var tmp1) (Var tmp2)
752
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
753
- pure res
748
+ pure . V. singleton . AffineSource $ Add (Var tmp1) (Var tmp2)
754
749
-- EQ(lhs, rhs) = (lhs - rhs == 1) only allowed for field comparison
755
- UEEq i lhs rhs -> do
756
- res <- V. singleton <$> do
750
+ UEEq _ lhs rhs -> do
757
751
-- assertSingle is justified as the lhs and rhs must be of type f
758
752
let e = UEBinOp (hash (USub , getAnnotation lhs, getAnnotation rhs)) USub lhs rhs
759
- r <- memoizedCompile e >>= assertSingleSource
760
- modify $ \ s -> s {bsSharedMap = Map. insert (getAnnotation e) (e, V. singleton r) (bsSharedMap s)}
761
- eqInWire <- addWire r
753
+ eqInWire <- do
754
+ eOut <- withCompilerCache (getAnnotation e) (memoizedCompile e)
755
+ assertSingleSource eOut >>= addWire
762
756
eqFreeWire <- imm
763
757
eqOutWire <- imm
764
758
emit $ Equal eqInWire eqFreeWire eqOutWire
765
759
-- eqOutWire == 0 if lhs == rhs, so we need to return 1 -
766
760
-- neqOutWire instead.
767
- pure . AffineSource $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (Var eqOutWire))
768
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
769
- pure res
770
- UESplit _i n input -> do
771
- res <- do
761
+ pure . V. singleton . AffineSource $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (Var eqOutWire))
762
+ UESplit _ n input -> do
772
763
-- assertSingle is justified as the input must be of type f
773
764
i <- memoizedCompile input >>= assertSingleSource >>= addWire
774
765
outputs <- V. generateM n (const $ mkBoolVar =<< imm)
775
766
emit $ Split i (V. toList outputs)
776
- fold <$> ( for outputs $ \ o ->
767
+ fold <$> for outputs ( \ o ->
777
768
let v = UVar o
778
769
e = UEVar (hash v) v
779
770
in memoizedCompile e)
780
- modify $ \ s -> s {bsSharedMap = Map. insert _i (expr, res) (bsSharedMap s)}
781
- pure res
782
771
where
783
772
mkBoolVar w = do
784
773
emit $ Mul (Var w) (Var w) w
785
774
pure w
786
- UEBundle i as -> do
787
- res <- do
775
+ UEBundle _ as -> do
788
776
as' <- traverse memoizedCompile as
789
777
pure $ fold as'
790
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
791
- pure res
792
- UEJoin i bits -> do
793
- res <- V. singleton <$> do
778
+ UEJoin _ bits -> do
794
779
bs <- toList <$> memoizedCompile bits
795
780
ws <- traverse addWire bs
796
- pure . AffineSource $ unsplit ws
797
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
798
- pure res
799
- UEAtIndex i v _ix -> do
800
- res <- V. singleton <$> do
781
+ pure . V. singleton . AffineSource $ unsplit ws
782
+ UEAtIndex _ v _ix -> do
801
783
v' <- memoizedCompile v
802
- pure $ v' V. ! (fromIntegral _ix)
803
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
804
- pure res
805
- UEUpdateIndex i p b v -> do
806
- res <- do
784
+ pure . V. singleton $ v' V. ! (fromIntegral _ix)
785
+ UEUpdateIndex _ p b v -> do
807
786
v' <- memoizedCompile v
808
787
b' <- memoizedCompile b >>= assertSingleSource
809
788
let p' = fromIntegral p
810
789
pure $ V. imap (\ _ix w -> if _ix == p' then b' else w) v'
811
- modify $ \ s -> s {bsSharedMap = Map. insert i (expr, res) (bsSharedMap s)}
812
- pure res
813
-
790
+
814
791
815
792
memoizedCompile ::
816
793
forall f m .
@@ -819,18 +796,10 @@ memoizedCompile ::
819
796
(MonadError (CircuitBuilderError f ) m ) =>
820
797
UntypedExpr Int Wire f ->
821
798
m (V. Vector (SignalSource f ))
822
- memoizedCompile expr = do
799
+ memoizedCompile expr = do
823
800
m <- gets bsSharedMap
824
801
case Map. lookup (getAnnotation expr) m of
825
- Just (e, ws) -> pure ws
826
- -- if (expr /= e)
827
- -- then do
828
- -- traceM $ "COLLISION"
829
- -- traceM $ show expr
830
- -- traceM "with"
831
- -- traceM $ show e
832
- -- panic "Cache is fucked"
833
- -- else pure ws
802
+ Just ws -> pure ws
834
803
Nothing -> _compile expr
835
804
836
805
exprToArithCircuit ::
0 commit comments