Skip to content

Commit 1e16a0c

Browse files
authored
Merge pull request #7 from l-adic/middle-out
Middle out
2 parents e9bf4ac + 4272695 commit 1e16a0c

File tree

23 files changed

+1721
-882
lines changed

23 files changed

+1721
-882
lines changed

Diff for: arithmetic-circuits.cabal

+20-5
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,21 @@ library language
7979
exposed-modules:
8080
Circuit.Language.Compile
8181
Circuit.Language.Expr
82-
Circuit.Language.TExpr
8382
Circuit.Language.DSL
8483
Circuit.Language
8584

8685
hs-source-dirs: language/src
8786
build-depends:
8887
arithmetic-circuits
88+
, array
8989
, hashable
9090
, microlens
91+
, mtl
92+
, unordered-containers
9193
, vector
9294
, vector-sized
9395

94-
ghc-options: -freverse-errors -O2 -Wall
96+
ghc-options: -freverse-errors -O2 -Wall
9597

9698
default-language: GHC2021
9799

@@ -125,16 +127,18 @@ test-suite language-tests
125127
Test.Circuit.Expr
126128
Test.Circuit.Sudoku
127129
Test.Circuit.Lang
130+
Test.Circuit.SHA3
128131

129132
hs-source-dirs: language/test
130133

131-
ghc-options: -freverse-errors -O2 -Wall -main-is Main
134+
ghc-options: -freverse-errors -O2 -Wall -main-is Main
132135
build-depends:
133136
arithmetic-circuits
134137
, arithmetic-circuits:language
135138
, array
136139
, distributive
137140
, fin
141+
, microlens
138142
, quickcheck-instances >=0.3
139143
, QuickCheck
140144
, random
@@ -144,6 +148,13 @@ test-suite language-tests
144148
, tasty-hspec
145149
, tasty-quickcheck >=0.10
146150
, vec
151+
, vector
152+
, vector-sized
153+
, crypton
154+
, bytestring
155+
, memory
156+
, unordered-containers
157+
147158

148159
build-tool-depends: tasty-discover:tasty-discover >=4.2
149160

@@ -158,9 +169,11 @@ executable factors
158169
, arithmetic-circuits
159170
, arithmetic-circuits:language
160171
, arithmetic-circuits:circom-compat
172+
, vector
161173

162174
hs-source-dirs: circom-compat/app
163175
default-language: GHC2021
176+
ghc-options: -freverse-errors -O2 -Wall -main-is Main
164177

165178
test-suite circuit-tests
166179
import: deps, extensions
@@ -224,14 +237,16 @@ benchmark circuit-benchmarks
224237
type: exitcode-stdio-1.0
225238
main-is: Main.hs
226239
other-modules:
227-
Circuit
240+
Circuit.Bench
228241
Paths_arithmetic_circuits
229242

230243
hs-source-dirs: bench
231244

232245
ghc-options: -freverse-errors -O2 -Wall -main-is Main
233246
build-depends:
234247
arithmetic-circuits
235-
, criterion >=1.6
248+
, arithmetic-circuits:language
249+
, criterion
250+
, vector
236251

237252
default-language: GHC2021

Diff for: bench/Circuit.hs

-25
This file was deleted.

Diff for: bench/Circuit/Bench.hs

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{-# LANGUAGE DataKinds #-}
2+
3+
module Circuit.Bench where
4+
5+
import Circuit
6+
import Circuit.Language
7+
import Criterion
8+
import Data.Field.Galois (Prime)
9+
import Data.IntMap qualified as IntMap
10+
import Data.Map qualified as Map
11+
import Data.Vector (generateM)
12+
import GHC.TypeNats (Natural, SNat, withKnownNat, withSomeSNat)
13+
import Protolude
14+
15+
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
16+
17+
benchmarks :: Benchmark
18+
benchmarks =
19+
bgroup
20+
"largeMult"
21+
[ bench "1_000" $ whnf largeMult 1000,
22+
bench "10_000" $ whnf largeMult 10000,
23+
bench "100_000" $ whnf largeMult 100_000,
24+
bench "1_000_000" $ whnf largeMult 1_000_000
25+
]
26+
27+
largeMult :: Natural -> Fr
28+
largeMult n =
29+
withSomeSNat n $ \(sn :: SNat n) ->
30+
withKnownNat sn $
31+
let BuilderState {bsVars, bsCircuit} = snd $ runCircuitBuilder (program (Proxy @n))
32+
inputs =
33+
assignInputs bsVars $
34+
Map.fromList $
35+
map (\i -> (nameInput i, fromIntegral i + 1)) [0 .. n - 1]
36+
w = altSolve bsCircuit inputs
37+
in fromMaybe (panic "output not found") $ lookupVar bsVars "out" w
38+
39+
nameInput :: (Integral a) => a -> Text
40+
nameInput i = "x" <> show (toInteger i)
41+
42+
program :: forall n. (KnownNat n) => Proxy n -> ExprM Fr (Var Wire Fr Fr)
43+
program p = do
44+
xs <- generateM (fromIntegral $ natVal p) $ \i ->
45+
var_ <$> fieldInput Public (nameInput i)
46+
fieldOutput "out" $ product xs
47+
48+
altSolve :: ArithCircuit Fr -> IntMap Fr -> IntMap Fr
49+
altSolve p inputs =
50+
evalArithCircuit
51+
(\w m -> IntMap.lookup (wireName w) m)
52+
(\w m -> IntMap.insert (wireName w) m)
53+
p
54+
inputs

Diff for: bench/Main.hs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
module Main where
66

7-
import Circuit qualified
7+
import Circuit.Bench qualified
88
import Criterion.Main
99
import Protolude
1010

1111
main :: IO ()
1212
main =
1313
defaultMain
14-
[ bgroup "Circuit to QAP translation" Circuit.benchmarks
14+
[ Circuit.Bench.benchmarks
1515
]

Diff for: cabal.project

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ benchmarks: True
33

44
packages: .
55

6+
67
source-repository-package
78
type: git
89
location: https://github.com/chessai/semirings.git
@@ -12,7 +13,7 @@ source-repository-package
1213
source-repository-package
1314
type: git
1415
location: https://github.com/l-adic/galois-fields.git
15-
tag: fc82039e811ba68c10527cf871796b7ac8514926
16+
tag: b0867ffdebda5043c80315a51b15e82ed25acba6
1617
--sha256: j/zGFd2aeowzJfgCCBmJYmG8mDsfF0irqj/cPOw9ulE=
1718

1819
source-repository-package

Diff for: circom-compat/app/Main.hs

+9-9
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ import Circuit
77
import Circuit.Language
88
import Data.Binary (encodeFile)
99
import Data.Field.Galois (Prime)
10-
import Data.Map qualified as Map
11-
import Data.Set qualified as Set
10+
import Data.IntMap qualified as IntMap
11+
import Data.IntSet qualified as IntSet
1212
import Protolude
1313
import R1CS (Inputs (..), calculateWitness, isValidWitness)
1414
import R1CS.Circom (r1csToCircomR1CS, witnessToCircomWitness)
1515

1616
main :: IO ()
1717
main = do
1818
let BuilderState {..} = snd $ runCircuitBuilder program
19-
publicInputs = Map.fromList $ zip (Set.toAscList $ cvPublicInputs bsVars) [6]
20-
privateInputs = Map.fromList $ zip (Set.toAscList $ cvPrivateInputs bsVars) [2, 3]
19+
publicInputs = IntMap.fromList $ zip (IntSet.toAscList $ cvPublicInputs bsVars) [6]
20+
privateInputs = IntMap.fromList $ zip (IntSet.toAscList $ cvPrivateInputs bsVars) [2, 3]
2121
inputs = publicInputs <> privateInputs
2222
(r1cs, wtns) = calculateWitness bsVars bsCircuit (Inputs inputs)
2323
unless (isValidWitness wtns r1cs) $ panic "Invalid witness"
@@ -26,9 +26,9 @@ main = do
2626

2727
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
2828

29-
program :: ExprM Fr Wire
29+
program :: ExprM Fr (Var Wire Fr Bool)
3030
program = do
31-
n <- deref <$> fieldInput Public "n"
32-
a <- deref <$> fieldInput Private "a"
33-
b <- deref <$> fieldInput Private "b"
34-
retBool "out" $ eq n (a * b)
31+
n <- var_ <$> fieldInput Public "n"
32+
a <- var_ <$> fieldInput Private "a"
33+
b <- var_ <$> fieldInput Private "b"
34+
boolOutput "out" $ eq_ n (a * b)

Diff for: circom-compat/src/Circuit/Solver/Circom.hs

+9-8
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ where
2121
import Circuit
2222
import Data.Field.Galois (GaloisField, PrimeField (fromP), char)
2323
import Data.IORef (IORef, readIORef, writeIORef)
24+
import Data.IntMap qualified as IntMap
25+
import Data.IntSet qualified as IntSet
2426
import Data.Map qualified as Map
25-
import Data.Set qualified as Set
2627
import Data.Vector qualified as V
2728
import Data.Vector.Mutable (IOVector)
2829
import Data.Vector.Mutable qualified as MV
@@ -52,8 +53,8 @@ mkProgramEnv vars circ =
5253
{ peFieldSize = FieldSize 32,
5354
peRawPrime = toInteger $ char (1 :: f),
5455
peVersion = 2,
55-
peInputsSize = Set.size $ cvPrivateInputs vars <> cvPublicInputs vars,
56-
peWitnessSize = Set.size $ Set.insert oneVar $ cvVars vars,
56+
peInputsSize = IntSet.size $ cvPrivateInputs vars <> cvPublicInputs vars,
57+
peWitnessSize = IntSet.size $ IntSet.insert oneVar $ cvVars vars,
5758
peCircuit = circ,
5859
peCircuitVars = relabel hashText vars
5960
}
@@ -125,16 +126,16 @@ _setInputSignal env@(ProgramEnv {peCircuit, peInputsSize, peCircuitVars}) stRef
125126
st <- readIORef stRef
126127
let Inputs inputs = psInputs st
127128
let h = mkFNV msb lsb
128-
i = fromMaybe (panic $ "Hash not found: " <> show h) $ Map.lookup h (cvInputsLabels peCircuitVars)
129+
i = fromMaybe (panic $ "Hash not found: " <> show h) $ Map.lookup h (labelToVar $ cvInputsLabels peCircuitVars)
129130
newInput <- fromInteger <$> readBuffer env stRef
130-
let newInputs = Map.insert i newInput inputs
131+
let newInputs = IntMap.insert i newInput inputs
131132
writeIORef stRef $
132-
if Map.size newInputs == peInputsSize
133+
if IntMap.size newInputs == peInputsSize
133134
then
134135
let wtns = solve peCircuitVars peCircuit newInputs
135136
in st
136137
{ psInputs = Inputs newInputs,
137-
psWitness = Witness $ Map.insert oneVar 1 wtns
138+
psWitness = Witness $ IntMap.insert oneVar 1 wtns
138139
}
139140
else st {psInputs = Inputs newInputs}
140141

@@ -149,7 +150,7 @@ _getWitness ::
149150
IO ()
150151
_getWitness env stRef i = do
151152
ProgramState {psWitness = Witness wtns} <- readIORef stRef
152-
let wtn = maybe (panic $ "missing witness " <> show i) fromP $ Map.lookup i wtns
153+
let wtn = maybe (panic $ "missing witness " <> show i) fromP $ IntMap.lookup i wtns
153154
in writeBuffer env stRef wtn
154155

155156
--------------------------------------------------------------------------------

Diff for: circom-compat/src/R1CS/Circom.hs

+8-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import Data.Binary.Get (getInt32le, getInt64le, getWord32le, getWord64le, lookAh
2121
import Data.Binary.Put (putInt32le, putLazyByteString, putWord32le, putWord64le, runPut)
2222
import Data.ByteString.Lazy qualified as LBS
2323
import Data.Field.Galois (GaloisField (char), PrimeField, fromP)
24-
import Data.Map qualified as Map
24+
import Data.IntMap qualified as IntMap
2525
import Data.Vector (Vector)
2626
import Data.Vector qualified as V
2727
import Protolude
@@ -261,14 +261,14 @@ getPoly fieldSize = do
261261
LinearCombination factors <- getLinearCombination fieldSize
262262
pure $
263263
LinearPoly $
264-
foldl (\acc (Factor {wireId, value}) -> Map.insert (fromIntegral wireId) value acc) mempty factors
264+
foldl (\acc (Factor {wireId, value}) -> IntMap.insert (fromIntegral wireId) value acc) mempty factors
265265

266266
putPoly :: (PrimeField k) => FieldSize -> LinearPoly k -> Put
267267
putPoly fieldSize (LinearPoly p) =
268268
putLinearCombination fieldSize $
269269
LinearCombination
270270
[ Factor {wireId = fromIntegral var, value}
271-
| (var, value) <- Map.toAscList p,
271+
| (var, value) <- IntMap.toAscList p,
272272
value /= 0
273273
]
274274

@@ -325,14 +325,14 @@ witnessToCircomWitness (Witness m) =
325325
WitnessHeader
326326
{ whFieldSize = FieldSize 32,
327327
whPrime = fromIntegral $ char (1 :: f),
328-
whWitnessSize = fromIntegral $ Map.size m
328+
whWitnessSize = fromIntegral $ IntMap.size m
329329
},
330-
wtnsValues = snd <$> Map.toAscList m
330+
wtnsValues = snd <$> IntMap.toAscList m
331331
}
332332

333333
witnessFromCircomWitness :: CircomWitness f -> Witness f
334334
witnessFromCircomWitness (CircomWitness {wtnsValues}) =
335-
Witness $ Map.fromList $ zip [0 ..] wtnsValues
335+
Witness $ IntMap.fromList $ zip [0 ..] wtnsValues
336336

337337
instance (PrimeField k) => Binary (CircomWitness k) where
338338
get = do
@@ -425,6 +425,7 @@ putWitnessValues fieldSize values = do
425425
integerFromLittleEndian :: Vector Word32 -> Integer
426426
integerFromLittleEndian bytes =
427427
foldl' (\acc (i, byte) -> acc .|. (fromIntegral byte `shiftL` (i * 32))) 0 (V.zip (V.fromList [0 ..]) bytes)
428+
{-# INLINE integerFromLittleEndian #-}
428429

429430
integerToLittleEndian :: FieldSize -> Integer -> Vector Word32
430431
integerToLittleEndian fieldSize n =
@@ -434,3 +435,4 @@ integerToLittleEndian fieldSize n =
434435
where
435436
go 0 = mempty
436437
go x = fromIntegral (x .&. 0xffffffff) `V.cons` go (x `shiftR` 32)
438+
{-# INLINE integerToLittleEndian #-}

0 commit comments

Comments
 (0)