Skip to content

Commit 191e890

Browse files
committed
added standard binary solver for circom
1 parent be5e58e commit 191e890

File tree

3 files changed

+77
-26
lines changed

3 files changed

+77
-26
lines changed

Diff for: .github/workflows/ormolu.yml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
name: Ormolu CI
2+
3+
on:
4+
push:
5+
branches: [main, master]
6+
pull_request:
7+
branches: [main, master]
8+
9+
jobs:
10+
ormolu:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- uses: haskell-actions/run-ormolu@v15
15+
with:
16+
version: "0.7.2.0"

Diff for: circom-compat/src/Circuit/Solver/Circom.hs

+44-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module Circuit.Solver.Circom
1515
_setInputSignal,
1616
_getWitnessSize,
1717
_getWitness,
18+
standardSolver,
1819
)
1920
where
2021

@@ -30,7 +31,10 @@ import Data.Vector.Mutable qualified as MV
3031
import FNV (FNVHash (..), hashText, mkFNV)
3132
import Protolude
3233
import R1CS (Inputs (..), Witness (..), oneVar)
33-
import R1CS.Circom (FieldSize (..), integerFromLittleEndian, integerToLittleEndian, n32)
34+
import R1CS.Circom (CircomWitness, FieldSize (..), integerFromLittleEndian, integerToLittleEndian, n32, witnessToCircomWitness)
35+
import Text.PrettyPrint.Leijen.Text (Pretty (pretty), (<+>))
36+
37+
-- WASM Solver
3438

3539
data ProgramEnv f = ProgramEnv
3640
{ peFieldSize :: FieldSize,
@@ -77,10 +81,8 @@ mkProgramState ProgramEnv {peFieldSize} = do
7781
psSharedRWMemory = sharedRWMemory
7882
}
7983

80-
-- | The arg is a bool representing 'sanityCheck'. We don't
81-
-- need this at the moment
82-
_init :: Int -> IO ()
83-
_init = mempty
84+
_init :: ProgramEnv f -> IORef (ProgramState f) -> Int -> IO ()
85+
_init env st _ = writeBuffer env st 0
8486

8587
_getNVars :: ProgramEnv f -> Int
8688
_getNVars = peWitnessSize
@@ -132,7 +134,12 @@ _setInputSignal env@(ProgramEnv {peCircuit, peInputsSize, peCircuitVars}) stRef
132134
writeIORef stRef $
133135
if IntMap.size newInputs == peInputsSize
134136
then
135-
let wtns = solve peCircuitVars peCircuit newInputs
137+
let wtns =
138+
evalArithCircuit
139+
(\w a -> IntMap.lookup (wireName w) a)
140+
(\w a -> safeAssign (wireName w) a)
141+
peCircuit
142+
newInputs
136143
in st
137144
{ psInputs = Inputs newInputs,
138145
psWitness = Witness $ IntMap.insert oneVar 1 wtns
@@ -154,13 +161,44 @@ _getWitness env stRef i = do
154161
in writeBuffer env stRef wtn
155162

156163
--------------------------------------------------------------------------------
164+
-- Standard Solver (to be used as native executable)
165+
166+
standardSolver ::
167+
forall f.
168+
(PrimeField f) =>
169+
CircuitVars Text ->
170+
ArithCircuit f ->
171+
Map Text f ->
172+
CircomWitness f
173+
standardSolver vars circ inputs =
174+
let initAssignments = assignInputs vars inputs
175+
wtns =
176+
evalArithCircuit
177+
(\w a -> IntMap.lookup (wireName w) a)
178+
(\w a -> safeAssign (wireName w) a)
179+
circ
180+
initAssignments
181+
in witnessToCircomWitness $ Witness $ IntMap.insert oneVar 1 wtns
182+
183+
--------------------------------------------------------------------------------
184+
185+
{-# INLINE safeAssign #-}
186+
safeAssign :: (Eq f) => (Pretty f) => Int -> f -> IntMap f -> IntMap f
187+
safeAssign =
188+
let f k new old =
189+
if new == old
190+
then new
191+
else panic $ show $ "Assignment contradiction for var" <+> pretty k <> ":" <> pretty new <+> " /= " <+> pretty old
192+
in IntMap.insertWithKey f
157193

194+
{-# INLINE writeBuffer #-}
158195
writeBuffer :: ProgramEnv f -> IORef (ProgramState f) -> Integer -> IO ()
159196
writeBuffer (ProgramEnv {peFieldSize}) stRef x = do
160197
let chunks = integerToLittleEndian peFieldSize x
161198
forM_ [0 .. n32 peFieldSize - 1] $ \j ->
162199
_writeSharedRWMemory stRef j (chunks V.! j)
163200

201+
{-# INLINE readBuffer #-}
164202
readBuffer :: ProgramEnv f -> IORef (ProgramState f) -> IO Integer
165203
readBuffer (ProgramEnv {peFieldSize}) stRef = do
166204
v <- V.generateM (n32 peFieldSize) $ \j ->

Diff for: language/src/Circuit/Language/Expr.hs

+17-20
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ untypeBinOp BXors = UBXor
635635

636636
--------------------------------------------------------------------------------
637637

638-
reifyGraph :: Expr i f ty -> Seq (Hash, Node i f)
638+
reifyGraph :: Expr i f ty -> Seq (Hash, Node i f)
639639
reifyGraph e =
640640
gbsEdges $ execState (buildGraph_ e) (GraphBuilderState mempty mempty)
641641

@@ -645,7 +645,7 @@ data GraphBuilderState i f = GraphBuilderState
645645
}
646646

647647
{-# SCC buildGraph_ #-}
648-
buildGraph_ :: forall i f ty. Expr i f ty -> State (GraphBuilderState i f) Hash
648+
buildGraph_ :: forall i f ty. Expr i f ty -> State (GraphBuilderState i f) Hash
649649
buildGraph_ expr =
650650
getId expr <$ case expr of
651651
EVal h v ->
@@ -656,7 +656,7 @@ buildGraph_ expr =
656656
unlessM (hasBeenVisited h) $ do
657657
let n = NVar (rawWire v)
658658
markVisited h n
659-
EUnOp h op e ->
659+
EUnOp h op e ->
660660
unlessM (hasBeenVisited h) $ do
661661
e' <- buildGraph_ e
662662
let n = NUnOp (untypeUnOp op) e'
@@ -674,7 +674,7 @@ buildGraph_ expr =
674674
f' <- buildGraph_ f
675675
let n = NIf b' t' f'
676676
markVisited h n
677-
EEq h l r ->
677+
EEq h l r ->
678678
unlessM (hasBeenVisited h) $ do
679679
l' <- buildGraph_ l
680680
r' <- buildGraph_ r
@@ -685,7 +685,7 @@ buildGraph_ expr =
685685
i' <- buildGraph_ i
686686
let n = NSplit i' (fromIntegral $ natVal (Proxy @(NBits f)))
687687
markVisited h n
688-
EJoin h i ->
688+
EJoin h i ->
689689
unlessM (hasBeenVisited h) $ do
690690
i' <- buildGraph_ i
691691
let n = NJoin i'
@@ -709,10 +709,10 @@ buildGraph_ expr =
709709
where
710710
hasBeenVisited h = gets $ Set.member h . gbsSharedNodes
711711
{-# INLINE hasBeenVisited #-}
712-
markVisited h n = modify $ \s ->
713-
s
714-
{ gbsSharedNodes = Set.insert h (gbsSharedNodes s)
715-
, gbsEdges = gbsEdges s |> (h, n)
712+
markVisited h n = modify $ \s ->
713+
s
714+
{ gbsSharedNodes = Set.insert h (gbsSharedNodes s),
715+
gbsEdges = gbsEdges s |> (h, n)
716716
}
717717
{-# INLINE markVisited #-}
718718

@@ -738,7 +738,7 @@ evalGraph ::
738738
EvalM i f (V.Vector f)
739739
evalGraph lookupVar vars graph = case graph of
740740
Empty -> panic "empty graph"
741-
ns :|> n -> traverse eval ns >> eval n
741+
ns :|> n -> traverse_ eval ns >> eval n
742742
where
743743
eval (h, n) = evalNode lookupVar vars h n
744744

@@ -759,15 +759,12 @@ evalNode lookupVar vars h node =
759759
Nothing -> throwError $ MissingVar i
760760
NUnOp op e -> do
761761
e' <- assertFromCache e
762-
res <- case op of
763-
UUNeg -> pure $ fmap Protolude.negate $ e'
764-
UUNot -> pure $ fmap (\x -> 1 - x) $ e'
765-
UURot n ->
766-
pure $ V.fromList . rotateList n $ V.toList e'
767-
UUShift n ->
768-
pure $ V.fromList . shiftList 0 n $ V.toList e'
769-
UUReverse ->
770-
pure $ V.reverse e'
762+
let res = case op of
763+
UUNeg -> fmap Protolude.negate $ e'
764+
UUNot -> fmap (\x -> 1 - x) $ e'
765+
UURot n -> V.fromList . rotateList n $ V.toList e'
766+
UUShift n -> V.fromList . shiftList 0 n $ V.toList e'
767+
UUReverse -> V.reverse e'
771768
cachResult h res
772769
NBinOp op e1 e2 -> do
773770
e1' <- assertFromCache e1
@@ -837,4 +834,4 @@ evalNode lookupVar vars h node =
837834
case Map.lookup i m of
838835
Just ws -> pure ws
839836
Nothing -> throwError $ MissingFromCache i
840-
{-# INLINE assertFromCache #-}
837+
{-# INLINE assertFromCache #-}

0 commit comments

Comments
 (0)