1
1
{-# LANGUAGE DataKinds #-}
2
2
{-# LANGUAGE TypeFamilies #-}
3
3
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
4
+
4
5
{-# HLINT ignore "Use zipWithM" #-}
5
6
6
7
module Circuit.Expr
@@ -29,25 +30,25 @@ module Circuit.Expr
29
30
evalExpr ,
30
31
rawWire ,
31
32
exprToArithCircuit ,
33
+ compileWithWire ,
32
34
)
33
35
where
34
36
35
37
import Circuit.Affine
36
38
import Circuit.Arithmetic
37
39
import Data.Field.Galois (GaloisField , PrimeField (fromP ))
38
40
import Data.Finite (Finite )
41
+ import Data.List.NonEmpty qualified as NE
39
42
import Data.Map qualified as Map
43
+ import Data.Semigroup qualified as NE
40
44
import Data.Semiring (Ring (.. ), Semiring (.. ))
41
45
import Data.Set qualified as Set
42
46
import Data.Vector.Sized (Vector )
43
47
import Data.Vector.Sized qualified as V
48
+ import Lens.Micro (ix , (.~) )
44
49
import Protolude hiding (Semiring )
45
50
import Text.PrettyPrint.Leijen.Text hiding ((<$>) )
46
- import Data.List.NonEmpty qualified as NE
47
51
import Prelude (foldl1 )
48
- import Lens.Micro ( (.~) , ix )
49
- import qualified Data.Semigroup as NE
50
-
51
52
52
53
data UnOp f a where
53
54
UNeg :: UnOp f f
@@ -99,28 +100,27 @@ rawWire (VarBool i) = i
99
100
100
101
type family NBits a :: Nat
101
102
102
-
103
103
-- | This constring prevents us from building up nested vectors inside the expression type
104
- class GaloisField f => Ground f ty
104
+ class ( GaloisField f ) => Ground f ty
105
105
106
- instance GaloisField f => Ground f f
106
+ instance ( GaloisField f ) => Ground f f
107
107
108
- instance GaloisField f => Ground f Bool
108
+ instance ( GaloisField f ) => Ground f Bool
109
109
110
110
-- | Expression data type of (arithmetic) expressions over a field @f@
111
111
-- with variable names/indices coming from @i@.
112
112
data Expr i f ty where
113
113
EVal :: Val f ty -> Expr i f ty
114
114
EVar :: Var i f ty -> Expr i f ty
115
115
EUnOp :: UnOp f ty -> Expr i f ty -> Expr i f ty
116
- EBinOp :: BinOp f ty -> Expr i f ty -> Expr i f ty -> Expr i f ty
117
- EIf :: Expr i f Bool -> Expr i f ty -> Expr i f ty -> Expr i f ty
116
+ EBinOp :: BinOp f ty -> Expr i f ty -> Expr i f ty -> Expr i f ty
117
+ EIf :: Expr i f Bool -> Expr i f ty -> Expr i f ty -> Expr i f ty
118
118
EEq :: Expr i f f -> Expr i f f -> Expr i f Bool
119
119
ESplit :: (KnownNat (NBits f )) => Expr i f f -> Expr i f (Vector (NBits f ) Bool )
120
120
EJoin :: (KnownNat n ) => Expr i f (Vector n Bool ) -> Expr i f f
121
121
EAtIndex :: (KnownNat n , Ground f ty ) => Expr i f (Vector n ty ) -> Finite n -> Expr i f ty
122
- EUpdateIndex :: (KnownNat n , Ground f ty ) => Finite n -> (Expr i f ty ) -> Expr i f (Vector n ty ) -> Expr i f (Vector n ty )
123
- EBundle :: Ground f ty => Vector n (Expr i f ty ) -> Expr i f (Vector n ty )
122
+ EUpdateIndex :: (KnownNat n , Ground f ty ) => Finite n -> (Expr i f ty ) -> Expr i f (Vector n ty ) -> Expr i f (Vector n ty )
123
+ EBundle :: ( Ground f ty ) => Vector n (Expr i f ty ) -> Expr i f (Vector n ty )
124
124
125
125
deriving instance (Show f ) => Show (BinOp f a )
126
126
@@ -215,7 +215,7 @@ evalExpr' expr = case expr of
215
215
ValField f -> f
216
216
EVar var -> do
217
217
m <- get
218
- pure $ case var of
218
+ pure $ case var of
219
219
VarField i -> do
220
220
case Map. lookup i m of
221
221
Just v -> v
@@ -229,8 +229,8 @@ evalExpr' expr = case expr of
229
229
EUnOp UNot e1 ->
230
230
not <$> evalExpr' e1
231
231
EBinOp op e1 e2 -> do
232
- e1' <- evalExpr' e1
233
- e2' <- evalExpr' e2
232
+ e1' <- evalExpr' e1
233
+ e2' <- evalExpr' e2
234
234
pure $ apply e1' e2'
235
235
where
236
236
apply = case op of
@@ -257,7 +257,7 @@ evalExpr' expr = case expr of
257
257
EJoin i -> do
258
258
bits <- evalExpr' i
259
259
pure $
260
- V. ifoldl (\ acc _ix b -> acc + if b then fromInteger (2 ^ fromIntegral @ _ @ Integer _ix) else 0 ) 0 bits
260
+ V. ifoldl (\ acc _ix b -> acc + if b then fromInteger (2 ^ fromIntegral @ _ @ Integer _ix) else 0 ) 0 bits
261
261
EAtIndex v i -> do
262
262
_v <- evalExpr' v
263
263
pure $ _v `V.index` i
@@ -266,7 +266,6 @@ evalExpr' expr = case expr of
266
266
_b <- evalExpr' b
267
267
pure $ _v & V. ix p .~ _b
268
268
269
-
270
269
-- pure $ Vec.fromList $ map (testBit i) [0 .. Nat.toInt (Vec.length i) - 1]
271
270
272
271
-------------------------------------------------------------------------------
@@ -393,7 +392,21 @@ addWire x = case x of
393
392
emit $ Mul (ConstGate 1 ) c mulOut
394
393
pure mulOut
395
394
396
- assertSingle :: NonEmpty a -> a
395
+ compileWithWire :: (Num f ) => ExprM f (Var Wire f ty ) -> Expr Wire f ty -> ExprM f Wire
396
+ compileWithWire freshWire expr = do
397
+ -- the use of assertSingle here is justified because Var constraints what ty can be to either Bool or f
398
+ compileOut <- assertSingle <$> compile expr
399
+ case compileOut of
400
+ Left wire -> do
401
+ wire' <- rawWire <$> freshWire
402
+ emit $ Mul (ConstGate 1 ) (Var wire') wire
403
+ pure wire
404
+ Right circ -> do
405
+ wire <- rawWire <$> freshWire
406
+ emit $ Mul (ConstGate 1 ) circ wire
407
+ pure wire
408
+
409
+ assertSingle :: NonEmpty a -> a
397
410
assertSingle xs = case xs of
398
411
x NE. :| [] -> x
399
412
_ -> panic " Expected single wire"
@@ -404,21 +417,23 @@ compile ::
404
417
Expr Wire f ty ->
405
418
ExprM f (NonEmpty (Either Wire (AffineCircuit f Wire )))
406
419
compile expr = case expr of
407
- EVal v -> NE. singleton <$> case v of
408
- ValField f -> pure . Right $ ConstGate f
409
- ValBool b -> pure . Right $ ConstGate b
410
- EVar var -> NE. singleton <$> case var of
411
- VarField i -> pure . Left $ i
412
- VarBool i -> do
413
- squared <- mulToImm (Left i) (Left i)
414
- emit $ Mul (Var squared) (ConstGate 1 ) i
415
- pure . Left $ i
420
+ EVal v ->
421
+ NE. singleton <$> case v of
422
+ ValField f -> pure . Right $ ConstGate f
423
+ ValBool b -> pure . Right $ ConstGate b
424
+ EVar var ->
425
+ NE. singleton <$> case var of
426
+ VarField i -> pure . Left $ i
427
+ VarBool i -> do
428
+ squared <- mulToImm (Left i) (Left i)
429
+ emit $ Mul (Var squared) (ConstGate 1 ) i
430
+ pure . Left $ i
416
431
EUnOp op e1 -> do
417
432
e1Outs <- compile e1
418
433
for e1Outs $ \ e1Out ->
419
434
case op of
420
- UNeg -> pure . Right $ ScalarMul (- 1 ) (addVar e1Out)
421
- UNot -> pure . Right $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (addVar e1Out))
435
+ UNeg -> pure . Right $ ScalarMul (- 1 ) (addVar e1Out)
436
+ UNot -> pure . Right $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (addVar e1Out))
422
437
EBinOp op e1 e2 -> do
423
438
e1Outs <- fmap addVar <$> compile e1
424
439
e2Outs <- fmap addVar <$> compile e2
@@ -436,15 +451,15 @@ compile expr = case expr of
436
451
emit $ Mul e1Out (Var _recip) out
437
452
pure $ Left out
438
453
-- SUB(x, y) = x + (-y)
439
- BSub -> pure . Right $ Add e1Out (ScalarMul (- 1 ) e2Out)
454
+ BSub -> pure . Right $ Add e1Out (ScalarMul (- 1 ) e2Out)
440
455
BAnd -> do
441
456
tmp1 <- mulToImm (Right e1Out) (Right e2Out)
442
- pure . Left $ tmp1
457
+ pure . Left $ tmp1
443
458
BOr -> do
444
459
-- OR(input1, input2) = (input1 + input2) - (input1 * input2)
445
460
tmp1 <- imm
446
461
emit $ Mul e1Out e2Out tmp1
447
- pure . Right $ Add (Add e1Out e2Out) (ScalarMul (- 1 ) (Var tmp1))
462
+ pure . Right $ Add (Add e1Out e2Out) (ScalarMul (- 1 ) (Var tmp1))
448
463
BXor -> do
449
464
-- XOR(input1, input2) = (input1 + input2) - 2 * (input1 * input2)
450
465
tmp1 <- imm
@@ -455,24 +470,25 @@ compile expr = case expr of
455
470
-- assertSingle is justified as the cond must be of type bool
456
471
condOut <- addVar . assertSingle <$> compile cond
457
472
trueOuts <- fmap addVar <$> compile true
458
- falseOuts <- fmap addVar <$> compile false
473
+ falseOuts <- fmap addVar <$> compile false
459
474
tmp1 <- imm
460
475
for_ trueOuts $ \ trueOut -> emit $ Mul condOut trueOut tmp1
461
476
tmp2 <- imm
462
- for_ falseOuts $ \ falseOut ->
477
+ for_ falseOuts $ \ falseOut ->
463
478
emit $ Mul (Add (ConstGate 1 ) (ScalarMul (- 1 ) condOut)) falseOut tmp2
464
479
pure . NE. singleton . Right $ Add (Var tmp1) (Var tmp2)
465
480
-- EQ(lhs, rhs) = (lhs - rhs == 1) only allowed for field comparison
466
- EEq lhs rhs -> NE. singleton <$> do
467
- -- assertSingle is justified as the lhs and rhs must be of type f
468
- lhsSubRhs <- assertSingle <$> compile (EBinOp BSub lhs rhs)
469
- eqInWire <- addWire lhsSubRhs
470
- eqFreeWire <- imm
471
- eqOutWire <- imm
472
- emit $ Equal eqInWire eqFreeWire eqOutWire
473
- -- eqOutWire == 0 if lhs == rhs, so we need to return 1 -
474
- -- neqOutWire instead.
475
- pure . Right $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (Var eqOutWire))
481
+ EEq lhs rhs ->
482
+ NE. singleton <$> do
483
+ -- assertSingle is justified as the lhs and rhs must be of type f
484
+ lhsSubRhs <- assertSingle <$> compile (EBinOp BSub lhs rhs)
485
+ eqInWire <- addWire lhsSubRhs
486
+ eqFreeWire <- imm
487
+ eqOutWire <- imm
488
+ emit $ Equal eqInWire eqFreeWire eqOutWire
489
+ -- eqOutWire == 0 if lhs == rhs, so we need to return 1 -
490
+ -- neqOutWire instead.
491
+ pure . Right $ Add (ConstGate 1 ) (ScalarMul (- 1 ) (Var eqOutWire))
476
492
ESplit input -> do
477
493
-- assertSingle is justified as the input must be of type f
478
494
i <- compile input >>= addWire . assertSingle
@@ -484,16 +500,18 @@ compile expr = case expr of
484
500
squared <- mulToImm (Left w) (Left w)
485
501
emit $ Mul (Var squared) (ConstGate 1 ) w
486
502
pure w
487
- EBundle as -> do
488
- as' <- traverse compile as
489
- pure $ Prelude. foldl1 (<>) (toList as')
490
- EJoin bits -> NE. singleton <$> do
491
- bs <- toList <$> compile bits
492
- ws <- traverse addWire bs
493
- pure . Right $ unsplit ws
494
- EAtIndex v _ix -> NE. singleton <$> do
495
- v' <- compile v
496
- pure $ v' NE. !! (fromIntegral _ix)
503
+ EBundle as -> do
504
+ as' <- traverse compile as
505
+ pure $ Prelude. foldl1 (<>) (toList as')
506
+ EJoin bits ->
507
+ NE. singleton <$> do
508
+ bs <- toList <$> compile bits
509
+ ws <- traverse addWire bs
510
+ pure . Right $ unsplit ws
511
+ EAtIndex v _ix ->
512
+ NE. singleton <$> do
513
+ v' <- compile v
514
+ pure $ v' NE. !! (fromIntegral _ix)
497
515
EUpdateIndex p b v -> do
498
516
v' <- compile v
499
517
b' <- assertSingle <$> compile b
@@ -510,16 +528,16 @@ exprToArithCircuit expr output = do
510
528
exprOut <- assertSingle <$> compile expr
511
529
emit $ Mul (ConstGate 1 ) (addVar exprOut) output
512
530
513
- instance (GaloisField f ) => Semiring (Expr Wire f f ) where
531
+ instance (GaloisField f ) => Semiring (Expr Wire f f ) where
514
532
plus = EBinOp BAdd
515
533
zero = EVal $ ValField 0
516
534
times = EBinOp BMul
517
535
one = EVal $ ValField 1
518
536
519
- instance (GaloisField f ) => Ring (Expr Wire f f ) where
537
+ instance (GaloisField f ) => Ring (Expr Wire f f ) where
520
538
negate = EUnOp UNeg
521
539
522
- instance (GaloisField f ) => Num (Expr Wire f f ) where
540
+ instance (GaloisField f ) => Num (Expr Wire f f ) where
523
541
(+) = plus
524
542
(*) = times
525
543
(-) = EBinOp BSub
@@ -529,4 +547,4 @@ instance (GaloisField f) => Num (Expr Wire f f) where
529
547
fromInteger = EVal . ValField . fromInteger
530
548
531
549
universe :: (KnownNat n ) => Vector n (Finite n )
532
- universe = V. enumFromN 0
550
+ universe = V. enumFromN 0
0 commit comments