Skip to content

Commit 1e54029

Browse files
committed
ok really seems to work, time to clean up
1 parent 8ddd296 commit 1e54029

File tree

2 files changed

+45
-74
lines changed

2 files changed

+45
-74
lines changed

Diff for: src/Circuit/Expr.hs

+39-70
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,10 @@ getAnnotation = \case
247247
UEBundle a _ -> a
248248

249249
hashCons :: (Hashable i, Hashable f) => UntypedExpr () i f -> UntypedExpr Int i f
250-
hashCons = \case
250+
hashCons = \case
251251
UEVal _ f -> UEVal (hash (hash @Text "Val", f)) f
252252
UEVar _ v -> UEVar (hash (hash @Text "Var", v)) v
253-
UEUnOp _ op e ->
253+
UEUnOp _ op e ->
254254
let e' = hashCons e
255255
in UEUnOp (hash (op, getAnnotation e')) op e'
256256
UEBinOp _ op e1 e2 ->
@@ -473,7 +473,7 @@ data BuilderState f = BuilderState
473473
{ bsCircuit :: ArithCircuit f,
474474
bsNextVar :: Int,
475475
bsVars :: CircuitVars Text,
476-
bsSharedMap :: Map Id (UntypedExpr Int Wire f, V.Vector (SignalSource f))
476+
bsSharedMap :: Map Id (V.Vector (SignalSource f))
477477
}
478478

479479
defaultBuilderState :: BuilderState f
@@ -645,7 +645,7 @@ compileWithWires ::
645645
m (V.Vector Wire)
646646
compileWithWires ws expr = do
647647
let e = hashCons $ unType expr
648-
compileOut <- memoizedCompile $ e
648+
compileOut <- memoizedCompile $ e
649649
for (V.zip compileOut ws) $ \(o, freshWire) -> do
650650
case o of
651651
WireSource wire -> do
@@ -675,37 +675,37 @@ assertSameSourceSize l r =
675675
throwError $
676676
MismatchedWireTypes l r
677677

678+
withCompilerCache ::
679+
(MonadState (BuilderState f) m) =>
680+
Id ->
681+
m (V.Vector (SignalSource f)) ->
682+
m (V.Vector (SignalSource f))
683+
withCompilerCache i m = do
684+
res <- m
685+
modify $ \s -> s {bsSharedMap = Map.insert i res (bsSharedMap s)}
686+
pure res
687+
678688
_compile ::
679689
forall f m.
680690
(Hashable f, GaloisField f) =>
681691
(MonadState (BuilderState f) m) =>
682692
(MonadError (CircuitBuilderError f) m) =>
683693
UntypedExpr Int Wire f ->
684694
m (V.Vector (SignalSource f))
685-
_compile expr = case expr of
686-
UEVal i v ->
687-
case v of
688-
f -> do
689-
let res = V.singleton $ AffineSource $ ConstGate f
690-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
691-
pure res
692-
UEVar i (UVar var) -> do
693-
let res = V.singleton $ WireSource var
694-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
695-
pure res
696-
UEUnOp i op e1 -> do
695+
_compile expr = withCompilerCache (getAnnotation expr) $ case expr of
696+
UEVal _ f -> pure $ V.singleton $ AffineSource $ ConstGate f
697+
UEVar _ (UVar var) -> pure . V.singleton $ WireSource var
698+
UEUnOp _ op e1 -> do
697699
e1Outs <- memoizedCompile e1
698-
res <- for e1Outs $ \e1Out ->
700+
for e1Outs $ \e1Out ->
699701
case op of
700702
UUNeg -> pure . AffineSource $ ScalarMul (-1) (addVar e1Out)
701703
UUNot -> pure . AffineSource $ Add (ConstGate 1) (ScalarMul (-1) (addVar e1Out))
702-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
703-
pure res
704-
UEBinOp i op e1 e2 -> do
704+
UEBinOp _ op e1 e2 -> do
705705
e1Outs <- memoizedCompile e1
706706
e2Outs <- memoizedCompile e2
707707
assertSameSourceSize e1Outs e2Outs
708-
res <- for (V.zip (addVar <$> e1Outs) (addVar <$> e2Outs)) $ \(e1Out, e2Out) ->
708+
for (V.zip (addVar <$> e1Outs) (addVar <$> e2Outs)) $ \(e1Out, e2Out) ->
709709
case op of
710710
UAdd -> pure . AffineSource $ Add e1Out e2Out
711711
UMul -> do
@@ -733,11 +733,8 @@ _compile expr = case expr of
733733
tmp1 <- imm
734734
emit $ Mul e1Out e2Out tmp1
735735
pure . AffineSource $ Add (Add e1Out e2Out) (ScalarMul (-2) (Var tmp1))
736-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
737-
pure res
738736
-- IF(cond, true, false) = (cond*true) + ((!cond) * false)
739-
UEIf i cond true false -> do
740-
res <- V.singleton <$> do
737+
UEIf _ cond true false -> do
741738
condOut <- addVar <$> (memoizedCompile cond >>= assertSingleSource)
742739
trueOuts <- memoizedCompile true
743740
falseOuts <- memoizedCompile false
@@ -748,69 +745,49 @@ _compile expr = case expr of
748745
tmp2 <- imm
749746
for_ (addVar <$> falseOuts) $ \falseOut ->
750747
emit $ Mul (Add (ConstGate 1) (ScalarMul (-1) condOut)) falseOut tmp2
751-
pure . AffineSource $ Add (Var tmp1) (Var tmp2)
752-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
753-
pure res
748+
pure . V.singleton . AffineSource $ Add (Var tmp1) (Var tmp2)
754749
-- EQ(lhs, rhs) = (lhs - rhs == 1) only allowed for field comparison
755-
UEEq i lhs rhs -> do
756-
res <- V.singleton <$> do
750+
UEEq _ lhs rhs -> do
757751
-- assertSingle is justified as the lhs and rhs must be of type f
758752
let e = UEBinOp (hash (USub, getAnnotation lhs, getAnnotation rhs)) USub lhs rhs
759-
r <- memoizedCompile e >>= assertSingleSource
760-
modify $ \s -> s {bsSharedMap = Map.insert (getAnnotation e) (e, V.singleton r) (bsSharedMap s)}
761-
eqInWire <- addWire r
753+
eqInWire <- do
754+
eOut <- withCompilerCache (getAnnotation e) (memoizedCompile e)
755+
assertSingleSource eOut >>= addWire
762756
eqFreeWire <- imm
763757
eqOutWire <- imm
764758
emit $ Equal eqInWire eqFreeWire eqOutWire
765759
-- eqOutWire == 0 if lhs == rhs, so we need to return 1 -
766760
-- neqOutWire instead.
767-
pure . AffineSource $ Add (ConstGate 1) (ScalarMul (-1) (Var eqOutWire))
768-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
769-
pure res
770-
UESplit _i n input -> do
771-
res <- do
761+
pure . V.singleton . AffineSource $ Add (ConstGate 1) (ScalarMul (-1) (Var eqOutWire))
762+
UESplit _ n input -> do
772763
-- assertSingle is justified as the input must be of type f
773764
i <- memoizedCompile input >>= assertSingleSource >>= addWire
774765
outputs <- V.generateM n (const $ mkBoolVar =<< imm)
775766
emit $ Split i (V.toList outputs)
776-
fold <$> (for outputs $ \o ->
767+
fold <$> for outputs (\o ->
777768
let v = UVar o
778769
e = UEVar (hash v) v
779770
in memoizedCompile e)
780-
modify $ \s -> s {bsSharedMap = Map.insert _i (expr, res) (bsSharedMap s)}
781-
pure res
782771
where
783772
mkBoolVar w = do
784773
emit $ Mul (Var w) (Var w) w
785774
pure w
786-
UEBundle i as -> do
787-
res <- do
775+
UEBundle _ as -> do
788776
as' <- traverse memoizedCompile as
789777
pure $ fold as'
790-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
791-
pure res
792-
UEJoin i bits -> do
793-
res <- V.singleton <$> do
778+
UEJoin _ bits -> do
794779
bs <- toList <$> memoizedCompile bits
795780
ws <- traverse addWire bs
796-
pure . AffineSource $ unsplit ws
797-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
798-
pure res
799-
UEAtIndex i v _ix -> do
800-
res <- V.singleton <$> do
781+
pure . V.singleton . AffineSource $ unsplit ws
782+
UEAtIndex _ v _ix -> do
801783
v' <- memoizedCompile v
802-
pure $ v' V.! (fromIntegral _ix)
803-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
804-
pure res
805-
UEUpdateIndex i p b v -> do
806-
res <- do
784+
pure . V.singleton $ v' V.! (fromIntegral _ix)
785+
UEUpdateIndex _ p b v -> do
807786
v' <- memoizedCompile v
808787
b' <- memoizedCompile b >>= assertSingleSource
809788
let p' = fromIntegral p
810789
pure $ V.imap (\_ix w -> if _ix == p' then b' else w) v'
811-
modify $ \s -> s {bsSharedMap = Map.insert i (expr, res) (bsSharedMap s)}
812-
pure res
813-
790+
814791

815792
memoizedCompile ::
816793
forall f m.
@@ -819,18 +796,10 @@ memoizedCompile ::
819796
(MonadError (CircuitBuilderError f) m) =>
820797
UntypedExpr Int Wire f ->
821798
m (V.Vector (SignalSource f))
822-
memoizedCompile expr = do
799+
memoizedCompile expr = do
823800
m <- gets bsSharedMap
824801
case Map.lookup (getAnnotation expr) m of
825-
Just (e, ws) -> pure ws
826-
--if (expr /= e)
827-
-- then do
828-
-- traceM $ "COLLISION"
829-
-- traceM $ show expr
830-
-- traceM "with"
831-
-- traceM $ show e
832-
-- panic "Cache is fucked"
833-
-- else pure ws
802+
Just ws -> pure ws
834803
Nothing -> _compile expr
835804

836805
exprToArithCircuit ::

Diff for: test/Test/Circuit/Sudoku.hs

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
module Test.Circuit.Sudoku where
44

55
import Circuit
6+
import Data.Set qualified as Set
67
import Data.Array.IO (IOArray, getElems, newArray, readArray, writeArray)
78
import Data.Distributive (Distributive (distribute))
89
import Data.Field.Galois (GaloisField, Prime, PrimeField)
@@ -42,8 +43,8 @@ isPermutation ::
4243
isPermutation as bs =
4344
let f (a, i) =
4445
let isPresent = elem_ a bs
45-
-- isUnique = not_ $ elem_ a (take i as)
46-
in isPresent -- `and_` isUnique
46+
isUnique = not_ $ elem_ a (take i as)
47+
in isPresent `and_` isUnique
4748
in all_ f (zip as [0 ..])
4849

4950
validateBoxes ::
@@ -79,8 +80,8 @@ validate = do
7980
b <- mkBoard >>= initializeBoard
8081
let rowsValid = all_ (isPermutation $ Vec.toList sudokuSet) (Vec.toList <$> b)
8182
colsValid = all_ (isPermutation $ Vec.toList sudokuSet) (Vec.toList <$> distribute b)
82-
-- boxesValid = validateBoxes sudokuSet (mkBoxes b)
83-
retBool "out" $ rowsValid `and_` colsValid -- `and_` boxesValid
83+
boxesValid = validateBoxes sudokuSet (mkBoxes b)
84+
retBool "out" $ rowsValid `and_` colsValid `and_` boxesValid
8485

8586
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
8687

@@ -98,6 +99,7 @@ spec_sudokuSolver = do
9899
map (first (\a -> "private_cell_" <> show a)) $
99100
filter (\(_, v) -> v /= 0) sol
100101
BuilderState {bsVars, bsCircuit} <- snd <$> runCircuitBuilder (validate @Fr)
102+
print (Set.size $ cvVars bsVars)
101103
let pubInputs =
102104
Map.fromList $
103105
[ (var, fromIntegral value)

0 commit comments

Comments
 (0)