@@ -27,20 +27,18 @@ module Circuit.Expr
27
27
evalExpr ,
28
28
rawWire ,
29
29
exprToArithCircuit ,
30
- type Nat. FromGHC ,
31
30
)
32
31
where
33
32
34
33
import Circuit.Affine
35
34
import Circuit.Arithmetic
36
35
import Data.Field.Galois (GaloisField , PrimeField (fromP ))
37
- import Data.Fin ( Fin )
36
+ import Data.Finite ( Finite )
38
37
import Data.Map qualified as Map
39
38
import Data.Semiring (Ring (.. ), Semiring (.. ))
40
39
import Data.Set qualified as Set
41
- import Data.Type.Nat qualified as Nat
42
- import Data.Vec.Lazy (Vec , universe )
43
- import Data.Vec.Lazy qualified as Vec
40
+ import Data.Vector.Sized (Vector )
41
+ import Data.Vector.Sized qualified as V
44
42
import Protolude hiding (Semiring )
45
43
import Text.PrettyPrint.Leijen.Text hiding ((<$>) )
46
44
@@ -92,7 +90,7 @@ rawWire :: Var i f ty -> i
92
90
rawWire (VarField i) = i
93
91
rawWire (VarBool i) = i
94
92
95
- type family NBits a :: Nat. Nat
93
+ type family NBits a :: Nat
96
94
97
95
-- | Expression data type of (arithmetic) expressions over a field @f@
98
96
-- with variable names/indices coming from @i@.
@@ -103,9 +101,9 @@ data Expr i f t ty where
103
101
EBinOp :: BinOp f ty -> Expr i f Identity ty -> Expr i f Identity ty -> Expr i f Identity ty
104
102
EIf :: Expr i f Identity Bool -> Expr i f Identity ty -> Expr i f Identity ty -> Expr i f Identity ty
105
103
EEq :: Expr i f Identity f -> Expr i f Identity f -> Expr i f Identity Bool
106
- ESplit :: (Nat. SNatI (NBits f )) => Expr i f Identity f -> Expr i f (Vec (NBits f )) Bool
107
- EJoin :: (Num f , Nat. SNatI n ) => Expr i f (Vec n ) Bool -> Expr i f Identity f
108
- EAtIndex :: (Nat. SNatI n ) => Expr i f (Vec n ) ty -> Fin n -> Expr i f Identity ty
104
+ ESplit :: (KnownNat (NBits f )) => Expr i f Identity f -> Expr i f (Vector (NBits f )) Bool
105
+ EJoin :: (Num f , KnownNat n ) => Expr i f (Vector n ) Bool -> Expr i f Identity f
106
+ EAtIndex :: (KnownNat n ) => Expr i f (Vector n ) ty -> Finite n -> Expr i f Identity ty
109
107
110
108
deriving instance (Show f ) => Show (BinOp f a )
111
109
@@ -235,15 +233,15 @@ evalExpr' expr = case expr of
235
233
pure $ Identity $ lhs' == rhs'
236
234
ESplit i -> do
237
235
x <- runIdentity <$> evalExpr' i
238
- pure $ Vec. tabulate $ \ ix -> testBit (fromP x) (fromIntegral ix)
236
+ pure $ V. generate $ \ ix -> testBit (fromP x) (fromIntegral ix)
239
237
EJoin i -> do
240
238
bits <- evalExpr' i
241
239
pure $
242
240
Identity $
243
- Vec. ifoldMap (\ ix b -> if b then fromInteger (2 ^ fromIntegral @ _ @ Integer ix) else 0 ) bits
241
+ V. ifoldl (\ acc ix b -> acc + if b then fromInteger (2 ^ fromIntegral @ _ @ Integer ix) else 0 ) 0 bits
244
242
EAtIndex v i -> do
245
243
_v <- evalExpr' v
246
- pure $ Identity $ _v Vec. ! i
244
+ pure $ Identity $ _v `V.index` i
247
245
248
246
-- pure $ Vec.fromList $ map (testBit i) [0 .. Nat.toInt (Vec.length i) - 1]
249
247
@@ -442,7 +440,7 @@ compile expr = case expr of
442
440
ESplit input -> do
443
441
i <- compile input >>= addWire . runIdentity
444
442
outputs <- traverse (\ _ -> mkBoolVar =<< imm) $ universe @ (NBits f )
445
- emit $ Split i (Vec . toList outputs)
443
+ emit $ Split i (V . toList outputs)
446
444
traverse (fmap runIdentity . compile . EVar . VarBool ) outputs
447
445
where
448
446
mkBoolVar w = do
@@ -455,7 +453,7 @@ compile expr = case expr of
455
453
pure . Identity . Right $ unsplit ws
456
454
EAtIndex v ix -> do
457
455
v' <- compile v
458
- pure . Identity $ v' Vec. ! ix
456
+ pure . Identity $ v' `V.index` ix
459
457
460
458
exprToArithCircuit ::
461
459
(Num f , Foldable t ) =>
@@ -486,3 +484,6 @@ instance (GaloisField f) => Num (Expr Wire f Identity f) where
486
484
abs = identity
487
485
signum = const 1
488
486
fromInteger = EVal . ValField . fromInteger
487
+
488
+ universe :: (KnownNat n ) => Vector n (Finite n )
489
+ universe = V. enumFromN 0
0 commit comments