Skip to content

Commit 4272695

Browse files
committed
tighten the belt
1 parent a47ed81 commit 4272695

File tree

19 files changed

+371
-308
lines changed

19 files changed

+371
-308
lines changed

Diff for: arithmetic-circuits.cabal

+8-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ common extensions
3636
RecordWildCards
3737

3838
common warnings
39-
ghc-options: -Wall -Wredundant-constraints
39+
ghc-options: -Wall -Wredundant-constraints -Werror
4040

4141

4242
common deps
@@ -87,7 +87,6 @@ library language
8787
arithmetic-circuits
8888
, array
8989
, hashable
90-
, intern
9190
, microlens
9291
, mtl
9392
, unordered-containers
@@ -154,6 +153,7 @@ test-suite language-tests
154153
, crypton
155154
, bytestring
156155
, memory
156+
, unordered-containers
157157

158158

159159
build-tool-depends: tasty-discover:tasty-discover >=4.2
@@ -169,9 +169,11 @@ executable factors
169169
, arithmetic-circuits
170170
, arithmetic-circuits:language
171171
, arithmetic-circuits:circom-compat
172+
, vector
172173

173174
hs-source-dirs: circom-compat/app
174175
default-language: GHC2021
176+
ghc-options: -freverse-errors -O2 -Wall -main-is Main
175177

176178
test-suite circuit-tests
177179
import: deps, extensions
@@ -235,14 +237,16 @@ benchmark circuit-benchmarks
235237
type: exitcode-stdio-1.0
236238
main-is: Main.hs
237239
other-modules:
238-
Circuit
240+
Circuit.Bench
239241
Paths_arithmetic_circuits
240242

241243
hs-source-dirs: bench
242244

243245
ghc-options: -freverse-errors -O2 -Wall -main-is Main
244246
build-depends:
245247
arithmetic-circuits
246-
, criterion >=1.6
248+
, arithmetic-circuits:language
249+
, criterion
250+
, vector
247251

248252
default-language: GHC2021

Diff for: bench/Circuit.hs

-25
This file was deleted.

Diff for: bench/Circuit/Bench.hs

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{-# LANGUAGE DataKinds #-}
2+
3+
module Circuit.Bench where
4+
5+
import Circuit
6+
import Circuit.Language
7+
import Criterion
8+
import Data.Field.Galois (Prime)
9+
import Data.IntMap qualified as IntMap
10+
import Data.Map qualified as Map
11+
import Data.Vector (generateM)
12+
import GHC.TypeNats (Natural, SNat, withKnownNat, withSomeSNat)
13+
import Protolude
14+
15+
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
16+
17+
benchmarks :: Benchmark
18+
benchmarks =
19+
bgroup
20+
"largeMult"
21+
[ bench "1_000" $ whnf largeMult 1000,
22+
bench "10_000" $ whnf largeMult 10000,
23+
bench "100_000" $ whnf largeMult 100_000,
24+
bench "1_000_000" $ whnf largeMult 1_000_000
25+
]
26+
27+
largeMult :: Natural -> Fr
28+
largeMult n =
29+
withSomeSNat n $ \(sn :: SNat n) ->
30+
withKnownNat sn $
31+
let BuilderState {bsVars, bsCircuit} = snd $ runCircuitBuilder (program (Proxy @n))
32+
inputs =
33+
assignInputs bsVars $
34+
Map.fromList $
35+
map (\i -> (nameInput i, fromIntegral i + 1)) [0 .. n - 1]
36+
w = altSolve bsCircuit inputs
37+
in fromMaybe (panic "output not found") $ lookupVar bsVars "out" w
38+
39+
nameInput :: (Integral a) => a -> Text
40+
nameInput i = "x" <> show (toInteger i)
41+
42+
program :: forall n. (KnownNat n) => Proxy n -> ExprM Fr (Var Wire Fr Fr)
43+
program p = do
44+
xs <- generateM (fromIntegral $ natVal p) $ \i ->
45+
var_ <$> fieldInput Public (nameInput i)
46+
fieldOutput "out" $ product xs
47+
48+
altSolve :: ArithCircuit Fr -> IntMap Fr -> IntMap Fr
49+
altSolve p inputs =
50+
evalArithCircuit
51+
(\w m -> IntMap.lookup (wireName w) m)
52+
(\w m -> IntMap.insert (wireName w) m)
53+
p
54+
inputs

Diff for: bench/Main.hs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
module Main where
66

7-
import Circuit qualified
7+
import Circuit.Bench qualified
88
import Criterion.Main
99
import Protolude
1010

1111
main :: IO ()
1212
main =
1313
defaultMain
14-
[ bgroup "Circuit to QAP translation" Circuit.benchmarks
14+
[ Circuit.Bench.benchmarks
1515
]

Diff for: cabal.project

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
tests: True
2+
benchmarks: True
23

34
packages: .
45

6+
57
source-repository-package
68
type: git
79
location: https://github.com/chessai/semirings.git
@@ -19,9 +21,3 @@ source-repository-package
1921
location: https://github.com/martyall/propagators.git
2022
tag: 6c1778171ee7c37ca29cf1ec4efaf79c5ec8af62
2123
--sha256: 6MUrOeUvp/iRFCTdvAyAGT2ymm5yKRqfN4GCkZ4eDyo=
22-
23-
source-repository-package
24-
type: git
25-
location: https://github.com/l-adic/intern.git
26-
tag: 5d0df17db13978d16533530370c2f7fd51e37cc8
27-
--sha256: 6MUrOeUvp/iRFCTdvAyAGT2ymm5yKRqfN4GCkZ4eDyo=

Diff for: circom-compat/app/Main.hs

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ import Circuit
77
import Circuit.Language
88
import Data.Binary (encodeFile)
99
import Data.Field.Galois (Prime)
10-
import Data.Map qualified as Map
11-
import Data.Set qualified as Set
10+
import Data.IntMap qualified as IntMap
11+
import Data.IntSet qualified as IntSet
1212
import Protolude
1313
import R1CS (Inputs (..), calculateWitness, isValidWitness)
1414
import R1CS.Circom (r1csToCircomR1CS, witnessToCircomWitness)
1515

1616
main :: IO ()
1717
main = do
1818
let BuilderState {..} = snd $ runCircuitBuilder program
19-
publicInputs = Map.fromList $ zip (Set.toAscList $ cvPublicInputs bsVars) [6]
20-
privateInputs = Map.fromList $ zip (Set.toAscList $ cvPrivateInputs bsVars) [2, 3]
19+
publicInputs = IntMap.fromList $ zip (IntSet.toAscList $ cvPublicInputs bsVars) [6]
20+
privateInputs = IntMap.fromList $ zip (IntSet.toAscList $ cvPrivateInputs bsVars) [2, 3]
2121
inputs = publicInputs <> privateInputs
2222
(r1cs, wtns) = calculateWitness bsVars bsCircuit (Inputs inputs)
2323
unless (isValidWitness wtns r1cs) $ panic "Invalid witness"
@@ -28,7 +28,7 @@ type Fr = Prime 2188824287183927522224640574525727508854836440041603434369820418
2828

2929
program :: ExprM Fr (Var Wire Fr Bool)
3030
program = do
31-
n <- deref <$> fieldInput Public "n"
32-
a <- deref <$> fieldInput Private "a"
33-
b <- deref <$> fieldInput Private "b"
31+
n <- var_ <$> fieldInput Public "n"
32+
a <- var_ <$> fieldInput Private "a"
33+
b <- var_ <$> fieldInput Private "b"
3434
boolOutput "out" $ eq_ n (a * b)

Diff for: circom-compat/src/Circuit/Solver/Circom.hs

+9-8
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ where
2121
import Circuit
2222
import Data.Field.Galois (GaloisField, PrimeField (fromP), char)
2323
import Data.IORef (IORef, readIORef, writeIORef)
24+
import Data.IntMap qualified as IntMap
25+
import Data.IntSet qualified as IntSet
2426
import Data.Map qualified as Map
25-
import Data.Set qualified as Set
2627
import Data.Vector qualified as V
2728
import Data.Vector.Mutable (IOVector)
2829
import Data.Vector.Mutable qualified as MV
@@ -52,8 +53,8 @@ mkProgramEnv vars circ =
5253
{ peFieldSize = FieldSize 32,
5354
peRawPrime = toInteger $ char (1 :: f),
5455
peVersion = 2,
55-
peInputsSize = Set.size $ cvPrivateInputs vars <> cvPublicInputs vars,
56-
peWitnessSize = Set.size $ Set.insert oneVar $ cvVars vars,
56+
peInputsSize = IntSet.size $ cvPrivateInputs vars <> cvPublicInputs vars,
57+
peWitnessSize = IntSet.size $ IntSet.insert oneVar $ cvVars vars,
5758
peCircuit = circ,
5859
peCircuitVars = relabel hashText vars
5960
}
@@ -125,16 +126,16 @@ _setInputSignal env@(ProgramEnv {peCircuit, peInputsSize, peCircuitVars}) stRef
125126
st <- readIORef stRef
126127
let Inputs inputs = psInputs st
127128
let h = mkFNV msb lsb
128-
i = fromMaybe (panic $ "Hash not found: " <> show h) $ Map.lookup h (cvInputsLabels peCircuitVars)
129+
i = fromMaybe (panic $ "Hash not found: " <> show h) $ Map.lookup h (labelToVar $ cvInputsLabels peCircuitVars)
129130
newInput <- fromInteger <$> readBuffer env stRef
130-
let newInputs = Map.insert i newInput inputs
131+
let newInputs = IntMap.insert i newInput inputs
131132
writeIORef stRef $
132-
if Map.size newInputs == peInputsSize
133+
if IntMap.size newInputs == peInputsSize
133134
then
134135
let wtns = solve peCircuitVars peCircuit newInputs
135136
in st
136137
{ psInputs = Inputs newInputs,
137-
psWitness = Witness $ Map.insert oneVar 1 wtns
138+
psWitness = Witness $ IntMap.insert oneVar 1 wtns
138139
}
139140
else st {psInputs = Inputs newInputs}
140141

@@ -149,7 +150,7 @@ _getWitness ::
149150
IO ()
150151
_getWitness env stRef i = do
151152
ProgramState {psWitness = Witness wtns} <- readIORef stRef
152-
let wtn = maybe (panic $ "missing witness " <> show i) fromP $ Map.lookup i wtns
153+
let wtn = maybe (panic $ "missing witness " <> show i) fromP $ IntMap.lookup i wtns
153154
in writeBuffer env stRef wtn
154155

155156
--------------------------------------------------------------------------------

Diff for: circom-compat/src/R1CS/Circom.hs

+8-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import Data.Binary.Get (getInt32le, getInt64le, getWord32le, getWord64le, lookAh
2121
import Data.Binary.Put (putInt32le, putLazyByteString, putWord32le, putWord64le, runPut)
2222
import Data.ByteString.Lazy qualified as LBS
2323
import Data.Field.Galois (GaloisField (char), PrimeField, fromP)
24-
import Data.Map qualified as Map
24+
import Data.IntMap qualified as IntMap
2525
import Data.Vector (Vector)
2626
import Data.Vector qualified as V
2727
import Protolude
@@ -261,14 +261,14 @@ getPoly fieldSize = do
261261
LinearCombination factors <- getLinearCombination fieldSize
262262
pure $
263263
LinearPoly $
264-
foldl (\acc (Factor {wireId, value}) -> Map.insert (fromIntegral wireId) value acc) mempty factors
264+
foldl (\acc (Factor {wireId, value}) -> IntMap.insert (fromIntegral wireId) value acc) mempty factors
265265

266266
putPoly :: (PrimeField k) => FieldSize -> LinearPoly k -> Put
267267
putPoly fieldSize (LinearPoly p) =
268268
putLinearCombination fieldSize $
269269
LinearCombination
270270
[ Factor {wireId = fromIntegral var, value}
271-
| (var, value) <- Map.toAscList p,
271+
| (var, value) <- IntMap.toAscList p,
272272
value /= 0
273273
]
274274

@@ -325,14 +325,14 @@ witnessToCircomWitness (Witness m) =
325325
WitnessHeader
326326
{ whFieldSize = FieldSize 32,
327327
whPrime = fromIntegral $ char (1 :: f),
328-
whWitnessSize = fromIntegral $ Map.size m
328+
whWitnessSize = fromIntegral $ IntMap.size m
329329
},
330-
wtnsValues = snd <$> Map.toAscList m
330+
wtnsValues = snd <$> IntMap.toAscList m
331331
}
332332

333333
witnessFromCircomWitness :: CircomWitness f -> Witness f
334334
witnessFromCircomWitness (CircomWitness {wtnsValues}) =
335-
Witness $ Map.fromList $ zip [0 ..] wtnsValues
335+
Witness $ IntMap.fromList $ zip [0 ..] wtnsValues
336336

337337
instance (PrimeField k) => Binary (CircomWitness k) where
338338
get = do
@@ -425,6 +425,7 @@ putWitnessValues fieldSize values = do
425425
integerFromLittleEndian :: Vector Word32 -> Integer
426426
integerFromLittleEndian bytes =
427427
foldl' (\acc (i, byte) -> acc .|. (fromIntegral byte `shiftL` (i * 32))) 0 (V.zip (V.fromList [0 ..]) bytes)
428+
{-# INLINE integerFromLittleEndian #-}
428429

429430
integerToLittleEndian :: FieldSize -> Integer -> Vector Word32
430431
integerToLittleEndian fieldSize n =
@@ -434,3 +435,4 @@ integerToLittleEndian fieldSize n =
434435
where
435436
go 0 = mempty
436437
go x = fromIntegral (x .&. 0xffffffff) `V.cons` go (x `shiftR` 32)
438+
{-# INLINE integerToLittleEndian #-}

Diff for: circuit/src/Circuit/Affine.hs

+4-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ data AffineCircuit f i
2828
| ScalarMul f (AffineCircuit f i)
2929
| ConstGate f
3030
| Var i
31-
| Nil
3231
deriving (Read, Eq, Ord, Show, Generic, NFData)
3332

3433
instance (FromJSON i, FromJSON f) => FromJSON (AffineCircuit f i)
@@ -47,16 +46,13 @@ instance Bifunctor AffineCircuit where
4746
ScalarMul s x -> ScalarMul (f s) (bimap f g x)
4847
ConstGate c -> ConstGate (f c)
4948
Var i -> Var (g i)
50-
Nil -> Nil
5149

5250
instance (Pretty i, Pretty f) => Pretty (AffineCircuit f i) where
5351
pretty = prettyPrec 0
5452
where
5553
prettyPrec :: Int -> AffineCircuit f i -> Doc
5654
prettyPrec p e =
5755
case e of
58-
Nil ->
59-
text "nil"
6056
Var v ->
6157
pretty v
6258
ConstGate f -> pretty f
@@ -71,11 +67,13 @@ instance (Pretty i, Pretty f) => Pretty (AffineCircuit f i) where
7167
parensPrec :: Int -> Int -> Doc -> Doc
7268
parensPrec opPrec p = if p > opPrec then parens else identity
7369

70+
{-# SCC evalAffineCircuit #-}
71+
7472
-- | Evaluate the arithmetic circuit without mul-gates on the given
7573
-- input. Variable map is assumed to have all the variables referred
7674
-- to in the circuit. Failed lookups are currently treated as 0.
7775
evalAffineCircuit ::
78-
(Num f) =>
76+
(Num f, Show i) =>
7977
-- | lookup function for variable mapping
8078
(i -> vars -> Maybe f) ->
8179
-- | variables
@@ -84,9 +82,8 @@ evalAffineCircuit ::
8482
AffineCircuit f i ->
8583
f
8684
evalAffineCircuit lookupVar vars = \case
87-
Nil -> 0
8885
ConstGate f -> f
89-
Var i -> fromMaybe 0 $ lookupVar i vars
86+
Var i -> fromMaybe (panic $ "missing variable assignment: " <> show i) $ lookupVar i vars
9087
Add l r -> evalAffineCircuit lookupVar vars l + evalAffineCircuit lookupVar vars r
9188
ScalarMul scalar expr -> evalAffineCircuit lookupVar vars expr * scalar
9289

@@ -99,7 +96,6 @@ affineCircuitToAffineMap ::
9996
-- | constant part and non-constant part
10097
(f, Map i f)
10198
affineCircuitToAffineMap = \case
102-
Nil -> (0, mempty)
10399
Var i -> (0, Map.singleton i 1)
104100
Add l r -> (constLeft + constRight, Map.unionWith (+) vecLeft vecRight)
105101
where

0 commit comments

Comments
 (0)