Skip to content

Commit 1d0e124

Browse files
committed
clean up
1 parent 684854c commit 1d0e124

File tree

2 files changed

+82
-87
lines changed

2 files changed

+82
-87
lines changed

Diff for: src/Circuit/Expr.hs

+76-58
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE DataKinds #-}
22
{-# LANGUAGE TypeFamilies #-}
33
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
4+
45
{-# HLINT ignore "Use zipWithM" #-}
56

67
module Circuit.Expr
@@ -29,25 +30,25 @@ module Circuit.Expr
2930
evalExpr,
3031
rawWire,
3132
exprToArithCircuit,
33+
compileWithWire,
3234
)
3335
where
3436

3537
import Circuit.Affine
3638
import Circuit.Arithmetic
3739
import Data.Field.Galois (GaloisField, PrimeField (fromP))
3840
import Data.Finite (Finite)
41+
import Data.List.NonEmpty qualified as NE
3942
import Data.Map qualified as Map
43+
import Data.Semigroup qualified as NE
4044
import Data.Semiring (Ring (..), Semiring (..))
4145
import Data.Set qualified as Set
4246
import Data.Vector.Sized (Vector)
4347
import Data.Vector.Sized qualified as V
48+
import Lens.Micro (ix, (.~))
4449
import Protolude hiding (Semiring)
4550
import Text.PrettyPrint.Leijen.Text hiding ((<$>))
46-
import Data.List.NonEmpty qualified as NE
4751
import Prelude (foldl1)
48-
import Lens.Micro ( (.~), ix )
49-
import qualified Data.Semigroup as NE
50-
5152

5253
data UnOp f a where
5354
UNeg :: UnOp f f
@@ -99,28 +100,27 @@ rawWire (VarBool i) = i
99100

100101
type family NBits a :: Nat
101102

102-
103103
-- | This constring prevents us from building up nested vectors inside the expression type
104-
class GaloisField f => Ground f ty
104+
class (GaloisField f) => Ground f ty
105105

106-
instance GaloisField f => Ground f f
106+
instance (GaloisField f) => Ground f f
107107

108-
instance GaloisField f => Ground f Bool
108+
instance (GaloisField f) => Ground f Bool
109109

110110
-- | Expression data type of (arithmetic) expressions over a field @f@
111111
-- with variable names/indices coming from @i@.
112112
data Expr i f ty where
113113
EVal :: Val f ty -> Expr i f ty
114114
EVar :: Var i f ty -> Expr i f ty
115115
EUnOp :: UnOp f ty -> Expr i f ty -> Expr i f ty
116-
EBinOp :: BinOp f ty -> Expr i f ty -> Expr i f ty -> Expr i f ty
117-
EIf :: Expr i f Bool -> Expr i f ty -> Expr i f ty -> Expr i f ty
116+
EBinOp :: BinOp f ty -> Expr i f ty -> Expr i f ty -> Expr i f ty
117+
EIf :: Expr i f Bool -> Expr i f ty -> Expr i f ty -> Expr i f ty
118118
EEq :: Expr i f f -> Expr i f f -> Expr i f Bool
119119
ESplit :: (KnownNat (NBits f)) => Expr i f f -> Expr i f (Vector (NBits f) Bool)
120120
EJoin :: (KnownNat n) => Expr i f (Vector n Bool) -> Expr i f f
121121
EAtIndex :: (KnownNat n, Ground f ty) => Expr i f (Vector n ty) -> Finite n -> Expr i f ty
122-
EUpdateIndex :: (KnownNat n, Ground f ty) => Finite n -> (Expr i f ty) -> Expr i f (Vector n ty) -> Expr i f (Vector n ty)
123-
EBundle :: Ground f ty => Vector n (Expr i f ty) -> Expr i f (Vector n ty)
122+
EUpdateIndex :: (KnownNat n, Ground f ty) => Finite n -> (Expr i f ty) -> Expr i f (Vector n ty) -> Expr i f (Vector n ty)
123+
EBundle :: (Ground f ty) => Vector n (Expr i f ty) -> Expr i f (Vector n ty)
124124

125125
deriving instance (Show f) => Show (BinOp f a)
126126

@@ -215,7 +215,7 @@ evalExpr' expr = case expr of
215215
ValField f -> f
216216
EVar var -> do
217217
m <- get
218-
pure $ case var of
218+
pure $ case var of
219219
VarField i -> do
220220
case Map.lookup i m of
221221
Just v -> v
@@ -229,8 +229,8 @@ evalExpr' expr = case expr of
229229
EUnOp UNot e1 ->
230230
not <$> evalExpr' e1
231231
EBinOp op e1 e2 -> do
232-
e1' <- evalExpr' e1
233-
e2' <- evalExpr' e2
232+
e1' <- evalExpr' e1
233+
e2' <- evalExpr' e2
234234
pure $ apply e1' e2'
235235
where
236236
apply = case op of
@@ -257,7 +257,7 @@ evalExpr' expr = case expr of
257257
EJoin i -> do
258258
bits <- evalExpr' i
259259
pure $
260-
V.ifoldl (\acc _ix b -> acc + if b then fromInteger (2 ^ fromIntegral @_ @Integer _ix) else 0) 0 bits
260+
V.ifoldl (\acc _ix b -> acc + if b then fromInteger (2 ^ fromIntegral @_ @Integer _ix) else 0) 0 bits
261261
EAtIndex v i -> do
262262
_v <- evalExpr' v
263263
pure $ _v `V.index` i
@@ -266,7 +266,6 @@ evalExpr' expr = case expr of
266266
_b <- evalExpr' b
267267
pure $ _v & V.ix p .~ _b
268268

269-
270269
-- pure $ Vec.fromList $ map (testBit i) [0 .. Nat.toInt (Vec.length i) - 1]
271270

272271
-------------------------------------------------------------------------------
@@ -393,7 +392,21 @@ addWire x = case x of
393392
emit $ Mul (ConstGate 1) c mulOut
394393
pure mulOut
395394

396-
assertSingle :: NonEmpty a -> a
395+
compileWithWire :: (Num f) => ExprM f (Var Wire f ty) -> Expr Wire f ty -> ExprM f Wire
396+
compileWithWire freshWire expr = do
397+
-- the use of assertSingle here is justified because Var constraints what ty can be to either Bool or f
398+
compileOut <- assertSingle <$> compile expr
399+
case compileOut of
400+
Left wire -> do
401+
wire' <- rawWire <$> freshWire
402+
emit $ Mul (ConstGate 1) (Var wire') wire
403+
pure wire
404+
Right circ -> do
405+
wire <- rawWire <$> freshWire
406+
emit $ Mul (ConstGate 1) circ wire
407+
pure wire
408+
409+
assertSingle :: NonEmpty a -> a
397410
assertSingle xs = case xs of
398411
x NE.:| [] -> x
399412
_ -> panic "Expected single wire"
@@ -404,21 +417,23 @@ compile ::
404417
Expr Wire f ty ->
405418
ExprM f (NonEmpty (Either Wire (AffineCircuit f Wire)))
406419
compile expr = case expr of
407-
EVal v -> NE.singleton <$> case v of
408-
ValField f -> pure . Right $ ConstGate f
409-
ValBool b -> pure . Right $ ConstGate b
410-
EVar var -> NE.singleton <$> case var of
411-
VarField i -> pure . Left $ i
412-
VarBool i ->do
413-
squared <- mulToImm (Left i) (Left i)
414-
emit $ Mul (Var squared) (ConstGate 1) i
415-
pure . Left $ i
420+
EVal v ->
421+
NE.singleton <$> case v of
422+
ValField f -> pure . Right $ ConstGate f
423+
ValBool b -> pure . Right $ ConstGate b
424+
EVar var ->
425+
NE.singleton <$> case var of
426+
VarField i -> pure . Left $ i
427+
VarBool i -> do
428+
squared <- mulToImm (Left i) (Left i)
429+
emit $ Mul (Var squared) (ConstGate 1) i
430+
pure . Left $ i
416431
EUnOp op e1 -> do
417432
e1Outs <- compile e1
418433
for e1Outs $ \e1Out ->
419434
case op of
420-
UNeg -> pure . Right $ ScalarMul (-1) (addVar e1Out)
421-
UNot -> pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (addVar e1Out))
435+
UNeg -> pure . Right $ ScalarMul (-1) (addVar e1Out)
436+
UNot -> pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (addVar e1Out))
422437
EBinOp op e1 e2 -> do
423438
e1Outs <- fmap addVar <$> compile e1
424439
e2Outs <- fmap addVar <$> compile e2
@@ -436,15 +451,15 @@ compile expr = case expr of
436451
emit $ Mul e1Out (Var _recip) out
437452
pure $ Left out
438453
-- SUB(x, y) = x + (-y)
439-
BSub -> pure . Right $ Add e1Out (ScalarMul (-1) e2Out)
454+
BSub -> pure . Right $ Add e1Out (ScalarMul (-1) e2Out)
440455
BAnd -> do
441456
tmp1 <- mulToImm (Right e1Out) (Right e2Out)
442-
pure . Left $ tmp1
457+
pure . Left $ tmp1
443458
BOr -> do
444459
-- OR(input1, input2) = (input1 + input2) - (input1 * input2)
445460
tmp1 <- imm
446461
emit $ Mul e1Out e2Out tmp1
447-
pure . Right $ Add (Add e1Out e2Out) (ScalarMul (-1) (Var tmp1))
462+
pure . Right $ Add (Add e1Out e2Out) (ScalarMul (-1) (Var tmp1))
448463
BXor -> do
449464
-- XOR(input1, input2) = (input1 + input2) - 2 * (input1 * input2)
450465
tmp1 <- imm
@@ -455,24 +470,25 @@ compile expr = case expr of
455470
-- assertSingle is justified as the cond must be of type bool
456471
condOut <- addVar . assertSingle <$> compile cond
457472
trueOuts <- fmap addVar <$> compile true
458-
falseOuts <- fmap addVar <$> compile false
473+
falseOuts <- fmap addVar <$> compile false
459474
tmp1 <- imm
460475
for_ trueOuts $ \trueOut -> emit $ Mul condOut trueOut tmp1
461476
tmp2 <- imm
462-
for_ falseOuts $ \falseOut ->
477+
for_ falseOuts $ \falseOut ->
463478
emit $ Mul (Add (ConstGate 1) (ScalarMul (-1) condOut)) falseOut tmp2
464479
pure . NE.singleton . Right $ Add (Var tmp1) (Var tmp2)
465480
-- EQ(lhs, rhs) = (lhs - rhs == 1) only allowed for field comparison
466-
EEq lhs rhs -> NE.singleton <$> do
467-
-- assertSingle is justified as the lhs and rhs must be of type f
468-
lhsSubRhs <- assertSingle <$> compile (EBinOp BSub lhs rhs)
469-
eqInWire <- addWire lhsSubRhs
470-
eqFreeWire <- imm
471-
eqOutWire <- imm
472-
emit $ Equal eqInWire eqFreeWire eqOutWire
473-
-- eqOutWire == 0 if lhs == rhs, so we need to return 1 -
474-
-- neqOutWire instead.
475-
pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (Var eqOutWire))
481+
EEq lhs rhs ->
482+
NE.singleton <$> do
483+
-- assertSingle is justified as the lhs and rhs must be of type f
484+
lhsSubRhs <- assertSingle <$> compile (EBinOp BSub lhs rhs)
485+
eqInWire <- addWire lhsSubRhs
486+
eqFreeWire <- imm
487+
eqOutWire <- imm
488+
emit $ Equal eqInWire eqFreeWire eqOutWire
489+
-- eqOutWire == 0 if lhs == rhs, so we need to return 1 -
490+
-- neqOutWire instead.
491+
pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (Var eqOutWire))
476492
ESplit input -> do
477493
-- assertSingle is justified as the input must be of type f
478494
i <- compile input >>= addWire . assertSingle
@@ -484,16 +500,18 @@ compile expr = case expr of
484500
squared <- mulToImm (Left w) (Left w)
485501
emit $ Mul (Var squared) (ConstGate 1) w
486502
pure w
487-
EBundle as -> do
488-
as' <- traverse compile as
489-
pure $ Prelude.foldl1 (<>) (toList as')
490-
EJoin bits -> NE.singleton <$> do
491-
bs <- toList <$> compile bits
492-
ws <- traverse addWire bs
493-
pure . Right $ unsplit ws
494-
EAtIndex v _ix -> NE.singleton <$> do
495-
v' <- compile v
496-
pure $ v' NE.!! (fromIntegral _ix)
503+
EBundle as -> do
504+
as' <- traverse compile as
505+
pure $ Prelude.foldl1 (<>) (toList as')
506+
EJoin bits ->
507+
NE.singleton <$> do
508+
bs <- toList <$> compile bits
509+
ws <- traverse addWire bs
510+
pure . Right $ unsplit ws
511+
EAtIndex v _ix ->
512+
NE.singleton <$> do
513+
v' <- compile v
514+
pure $ v' NE.!! (fromIntegral _ix)
497515
EUpdateIndex p b v -> do
498516
v' <- compile v
499517
b' <- assertSingle <$> compile b
@@ -510,16 +528,16 @@ exprToArithCircuit expr output = do
510528
exprOut <- assertSingle <$> compile expr
511529
emit $ Mul (ConstGate 1) (addVar exprOut) output
512530

513-
instance (GaloisField f) => Semiring (Expr Wire f f) where
531+
instance (GaloisField f) => Semiring (Expr Wire f f) where
514532
plus = EBinOp BAdd
515533
zero = EVal $ ValField 0
516534
times = EBinOp BMul
517535
one = EVal $ ValField 1
518536

519-
instance (GaloisField f) => Ring (Expr Wire f f) where
537+
instance (GaloisField f) => Ring (Expr Wire f f) where
520538
negate = EUnOp UNeg
521539

522-
instance (GaloisField f) => Num (Expr Wire f f) where
540+
instance (GaloisField f) => Num (Expr Wire f f) where
523541
(+) = plus
524542
(*) = times
525543
(-) = EBinOp BSub
@@ -529,4 +547,4 @@ instance (GaloisField f) => Num (Expr Wire f f) where
529547
fromInteger = EVal . ValField . fromInteger
530548

531549
universe :: (KnownNat n) => Vector n (Finite n)
532-
universe = V.enumFromN 0
550+
universe = V.enumFromN 0

Diff for: src/Circuit/Lang.hs

+6-29
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ module Circuit.Lang
2727
updateIndex_,
2828
bundle,
2929
unBundle,
30+
31+
-- * Monoids
3032
Any_ (..),
3133
And_ (..),
3234
elem_,
@@ -35,9 +37,7 @@ module Circuit.Lang
3537
)
3638
where
3739

38-
import qualified Data.List.NonEmpty as NE
39-
import Circuit.Affine (AffineCircuit (..))
40-
import Circuit.Arithmetic (Gate (..), InputType (Private, Public), Wire (..))
40+
import Circuit.Arithmetic (InputType (Private, Public), Wire (..))
4141
import Circuit.Expr
4242
import Data.Field.Galois (GaloisField)
4343
import Data.Finite (Finite)
@@ -46,7 +46,6 @@ import Data.Vector.Sized qualified as V
4646
import Protolude
4747

4848
--------------------------------------------------------------------------------
49-
5049
type Signal f a = Expr Wire f a
5150

5251
type Bundle f n a = Expr Wire f (Vector n a)
@@ -103,44 +102,22 @@ splitBits = ESplit
103102
joinBits :: (KnownNat n) => Bundle f n Bool -> Signal f f
104103
joinBits = EJoin
105104

106-
107105
deref :: Var Wire f ty -> Signal f ty
108106
deref = EVar
109107

110-
compileWithWire :: (Num f) => ExprM f (Var Wire f ty) -> Signal f ty -> ExprM f (NonEmpty Wire)
111-
compileWithWire freshWire expr = do
112-
compileOuts <- compile expr
113-
for compileOuts $ \case
114-
Left wire -> do
115-
wire' <- rawWire <$> freshWire
116-
emit $ Mul (ConstGate 1) (Var wire') wire
117-
pure wire
118-
Right circ -> do
119-
wire <- rawWire <$> freshWire
120-
emit $ Mul (ConstGate 1) circ wire
121-
pure wire
122-
123108
retBool :: (Num f) => Text -> Signal f Bool -> ExprM f Wire
124-
retBool label sig = do
125-
as <- compileWithWire (boolInput Public label) sig
126-
case as of
127-
a NE.:| [] -> pure a
128-
_ -> panic "retBool: expected single wire"
109+
retBool label sig = compileWithWire (boolInput Public label) sig
129110

130111
retField :: (Num f) => Text -> Signal f f -> ExprM f Wire
131-
retField label sig = do
132-
as <- compileWithWire (fieldInput Public label) sig
133-
case as of
134-
a NE.:| [] -> pure a
135-
_ -> panic "retField: expected single wire"
112+
retField label sig = compileWithWire (fieldInput Public label) sig
136113

137114
atIndex :: (KnownNat n, Ground f ty) => Bundle f n ty -> Finite n -> Signal f ty
138115
atIndex = EAtIndex
139116

140117
updateIndex_ :: (KnownNat n, Ground f ty) => Finite n -> Signal f ty -> Bundle f n ty -> Bundle f n ty
141118
updateIndex_ p = EUpdateIndex p
142119

143-
bundle :: Ground f ty => Vector n (Signal f ty) -> Bundle f n ty
120+
bundle :: (Ground f ty) => Vector n (Signal f ty) -> Bundle f n ty
144121
bundle = EBundle
145122

146123
unBundle :: (KnownNat n, Ground f ty) => Bundle f n ty -> Vector n (Signal f ty)

0 commit comments

Comments
 (0)