Skip to content

Commit 6324a1c

Browse files
committed
further optimizations -- specializing and inlining
1 parent d243f4c commit 6324a1c

File tree

12 files changed

+377
-353
lines changed

12 files changed

+377
-353
lines changed

Diff for: arithmetic-circuits.cabal

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ benchmark circuit-benchmarks
242242

243243
hs-source-dirs: bench
244244

245-
ghc-options: -freverse-errors -O2 -Wall -main-is Main
245+
ghc-options: -freverse-errors -O4 -Wall -main-is Main
246246
build-depends:
247247
arithmetic-circuits
248248
, arithmetic-circuits:language

Diff for: bench/Circuit/Bench.hs

+10-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import Data.Field.Galois (Prime)
99
import Data.IntMap qualified as IntMap
1010
import Data.Map qualified as Map
1111
import Data.Vector (generateM)
12-
import GHC.TypeNats (Natural, SNat, withKnownNat, withSomeSNat)
1312
import Protolude
1413

1514
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
@@ -24,24 +23,22 @@ benchmarks =
2423
bench "1_000_000" $ whnf largeMult 1_000_000
2524
]
2625

27-
largeMult :: Natural -> Fr
26+
largeMult :: Int -> Fr
2827
largeMult n =
29-
withSomeSNat n $ \(sn :: SNat n) ->
30-
withKnownNat sn $
31-
let BuilderState {bsVars, bsCircuit} = snd $ runCircuitBuilder (program (Proxy @n))
32-
inputs =
33-
assignInputs bsVars $
34-
Map.fromList $
35-
map (\i -> (nameInput i, fromIntegral i + 1)) [0 .. n - 1]
36-
w = altSolve bsCircuit inputs
37-
in fromMaybe (panic "output not found") $ lookupVar bsVars "out" w
28+
let BuilderState {bsVars, bsCircuit} = snd $ runCircuitBuilder (program n)
29+
inputs =
30+
assignInputs bsVars $
31+
Map.fromList $
32+
map (\i -> (nameInput i, fromIntegral i + 1)) [0 .. n - 1]
33+
w = altSolve bsCircuit inputs
34+
in fromMaybe (panic "output not found") $ lookupVar bsVars "out" w
3835

3936
nameInput :: (Integral a) => a -> Text
4037
nameInput i = "x" <> show (toInteger i)
4138

42-
program :: forall n. (KnownNat n) => Proxy n -> ExprM Fr (Var Wire Fr Fr)
39+
program :: Int -> ExprM Fr (Var Wire Fr 'TField)
4340
program p = do
44-
xs <- generateM (fromIntegral $ natVal p) $ \i ->
41+
xs <- generateM p $ \i ->
4542
var_ <$> fieldInput Public (nameInput i)
4643
fieldOutput "out" $ product xs
4744

Diff for: circom-compat/app/Main.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ main = do
2626

2727
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
2828

29-
program :: ExprM Fr (Var Wire Fr Bool)
29+
program :: ExprM Fr (Var Wire Fr 'TBool)
3030
program = do
3131
n <- var_ <$> fieldInput Public "n"
3232
a <- var_ <$> fieldInput Private "a"

Diff for: circom-compat/src/R1CS/Circom.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ module R1CS.Circom
1717
)
1818
where
1919

20+
import Circuit (CircuitVars (..))
2021
import Data.Binary (Binary (..), Get, Put)
2122
import Data.Binary.Get (getInt32le, getInt64le, getWord32le, getWord64le, lookAhead, skip)
2223
import Data.Binary.Put (putInt32le, putLazyByteString, putWord32le, putWord64le, runPut)
@@ -27,7 +28,6 @@ import Data.IntSet qualified as IntSet
2728
import Data.Vector (Vector)
2829
import Data.Vector qualified as V
2930
import Protolude
30-
import Circuit (CircuitVars(..))
3131
import R1CS (LinearPoly (..), R1C (..), R1CS (..), Witness (..))
3232
import Prelude (fail)
3333

Diff for: circuit/src/Circuit/Arithmetic.hs

+13-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ module Circuit.Arithmetic
2222
InputBidings (..),
2323
insertInputBinding,
2424
Reindexable (..),
25+
restrictVars,
2526
)
2627
where
2728

@@ -316,6 +317,7 @@ instance (Ord label) => Monoid (CircuitVars label) where
316317
cvOutputs = mempty,
317318
cvInputsLabels = mempty
318319
}
320+
319321
instance Reindexable (CircuitVars label) where
320322
reindex f CircuitVars {..} =
321323
CircuitVars
@@ -355,6 +357,16 @@ collectCircuitVars (ArithCircuit gates) =
355357
cvInputsLabels = inputBindingsFromList ls
356358
}
357359

360+
restrictVars :: CircuitVars label -> IntSet -> CircuitVars label
361+
restrictVars CircuitVars {..} vars =
362+
CircuitVars
363+
{ cvVars = IntSet.intersection cvVars vars,
364+
cvPrivateInputs = IntSet.intersection cvPrivateInputs vars,
365+
cvPublicInputs = IntSet.intersection cvPublicInputs vars,
366+
cvOutputs = IntSet.intersection cvOutputs vars,
367+
cvInputsLabels = cvInputsLabels
368+
}
369+
358370
assignInputs :: (Ord label) => CircuitVars label -> Map label f -> IntMap f
359371
assignInputs CircuitVars {..} inputs =
360372
IntMap.mapMaybe (\label -> Map.lookup label inputs) (varToLabel cvInputsLabels)
@@ -394,11 +406,10 @@ instance Reindexable (InputBidings label) where
394406
{ labelToVar = Map.mapMaybe (flip IntMap.lookup f) labelToVar,
395407
varToLabel = IntMap.compose varToLabel (reverseMap f)
396408
}
397-
where
409+
where
398410
reverseMap :: IntMap Int -> IntMap Int
399411
reverseMap = IntMap.foldlWithKey' (\acc k v -> IntMap.insert v k acc) mempty
400412

401-
402413
instance (Ord label) => Semigroup (InputBidings label) where
403414
a <> b =
404415
InputBidings

Diff for: language/src/Circuit/Language/Compile.hs

+20-24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE DataKinds #-}
12
{-# LANGUAGE PatternSynonyms #-}
23

34
module Circuit.Language.Compile
@@ -174,17 +175,20 @@ mulToImm l r = do
174175
o <- imm
175176
emit $ Mul (addVar l) (addVar r) o
176177
pure o
178+
{-# INLINE mulToImm #-}
177179

178180
-- | Add a Mul and its output to the ArithCircuit
179181
emit :: (MonadState (BuilderState f) m) => Gate f Wire -> m ()
180182
emit c = modify $ \s@(BuilderState {bsCircuit = ArithCircuit cs}) ->
181183
s {bsCircuit = ArithCircuit (c : cs)}
184+
{-# INLINE emit #-}
182185

183186
-- | Turn a wire into an affine circuit, or leave it be
184187
addVar :: SignalSource f -> AffineCircuit f Wire
185188
addVar s = case s of
186189
WireSource w -> Var w
187190
AffineSource c -> c
191+
{-# INLINE addVar #-}
188192

189193
-- | Turn an affine circuit into a wire, or leave it be
190194
addWire :: (MonadState (BuilderState f) m, Num f) => SignalSource f -> m Wire
@@ -194,29 +198,26 @@ addWire x = case x of
194198
mulOut <- imm
195199
emit $ Mul (ConstGate 1) c mulOut
196200
pure mulOut
201+
{-# INLINE addWire #-}
197202

198203
--------------------------------------------------------------------------------
199204

200205
compileWithWire ::
201206
(Hashable f) =>
202207
(GaloisField f) =>
203-
(MonadState (BuilderState f) m) =>
204-
(MonadError (CircuitBuilderError f) m) =>
205-
Var Wire f f ->
206-
Expr Wire f f ->
207-
m (Var Wire f f)
208+
Var Wire f 'TField ->
209+
Expr Wire f 'TField ->
210+
ExprM f (Var Wire f 'TField)
208211
compileWithWire freshWire e = do
209212
res <- compileWithWires (V.singleton freshWire) e
210213
pure . V.head $ res
211214

212215
compileWithWires ::
213216
(Hashable f) =>
214217
(GaloisField f) =>
215-
(MonadState (BuilderState f) m) =>
216-
(MonadError (CircuitBuilderError f) m) =>
217-
V.Vector (Var Wire f f) ->
218+
V.Vector (Var Wire f 'TField) ->
218219
Expr Wire f ty ->
219-
m (V.Vector (Var Wire f f))
220+
ExprM f (V.Vector (Var Wire f 'TField))
220221
compileWithWires ws expr = do
221222
compileOut <- compile expr
222223
for (V.zip compileOut ws) $ \(o, freshWire) -> do
@@ -233,25 +234,20 @@ compileWithWires ws expr = do
233234
{-# SCC compile #-}
234235
compile ::
235236
(Hashable f, GaloisField f) =>
236-
(MonadState (BuilderState f) m) =>
237-
(MonadError (CircuitBuilderError f) m) =>
238237
Expr Wire f ty ->
239-
m (V.Vector (SignalSource f))
238+
ExprM f (V.Vector (SignalSource f))
240239
compile e = do
241-
let g = reifyGraph e
242-
res <- traverse _compile g
243-
case res of
244-
(_ :|> x) -> pure x
240+
case reifyGraph e of
241+
(xs :|> x) ->
242+
traverse_ _compile xs >> _compile x
245243
_ -> panic "empty graph"
246244

247245
{-# SCC _compile #-}
248246
_compile ::
249-
forall f m.
247+
forall f.
250248
(Hashable f, GaloisField f) =>
251-
(MonadState (BuilderState f) m) =>
252-
(MonadError (CircuitBuilderError f) m) =>
253249
(Hash, Node Wire f) ->
254-
m (V.Vector (SignalSource f))
250+
ExprM f (V.Vector (SignalSource f))
255251
_compile (h, expr) = case expr of
256252
NVal f -> do
257253
let source = V.singleton $ AffineSource $ ConstGate f
@@ -398,8 +394,8 @@ exprToArithCircuit expr output = do
398394

399395
fieldToBool ::
400396
(Hashable f, GaloisField f) =>
401-
Expr Wire f f ->
402-
ExprM f (Expr Wire f Bool)
397+
Expr Wire f 'TField ->
398+
ExprM f (Expr Wire f 'TBool)
403399
fieldToBool e = do
404400
-- let eOut = unType e
405401
a <- compile e >>= assertSingleSource >>= addWire
@@ -411,8 +407,8 @@ _unBundle ::
411407
(KnownNat n) =>
412408
(GaloisField f) =>
413409
(Hashable f) =>
414-
Expr Wire f (SV.Vector n ty) ->
415-
ExprM f (SV.Vector n (Expr Wire f f))
410+
Expr Wire f (TVec n ty) ->
411+
ExprM f (SV.Vector n (Expr Wire f 'TField))
416412
_unBundle b = do
417413
bis <- compile b
418414
ws <- traverse addWire bis

0 commit comments

Comments
 (0)