Skip to content

Commit ee599ec

Browse files
authored
Merge pull request #2 from l-adic/sized-vector
use sized-vector instead of vec package
2 parents d90bfd0 + df51e5c commit ee599ec

File tree

4 files changed

+28
-27
lines changed

4 files changed

+28
-27
lines changed

Diff for: arithmetic-circuits.cabal

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ common deps
4545
, base >=4.10 && <5
4646
, containers >=0.6.0
4747
, filepath >=1.4.2
48+
, finite-typelits >=0.1.0
4849
, galois-field >=2.0.0
49-
, fin
5050
, process
5151
, propagators
5252
, protolude >=0.2
5353
, semirings >=0.7
5454
, text >=1.2.3
55-
, vec
55+
, vector-sized
5656
, wl-pprint-text >=1.2.0
5757

5858
library
@@ -128,6 +128,7 @@ test-suite circuit-tests
128128
arithmetic-circuits
129129
, array
130130
, distributive
131+
, fin
131132
, integer-logarithms
132133
, quickcheck-instances >=0.3
133134
, QuickCheck
@@ -137,6 +138,7 @@ test-suite circuit-tests
137138
, tasty-hunit >=0.10
138139
, tasty-hspec
139140
, tasty-quickcheck >=0.10
141+
, vec
140142

141143
build-tool-depends: tasty-discover:tasty-discover >=4.2
142144

Diff for: src/Circuit/Expr.hs

+15-14
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,18 @@ module Circuit.Expr
2727
evalExpr,
2828
rawWire,
2929
exprToArithCircuit,
30-
type Nat.FromGHC,
3130
)
3231
where
3332

3433
import Circuit.Affine
3534
import Circuit.Arithmetic
3635
import Data.Field.Galois (GaloisField, PrimeField (fromP))
37-
import Data.Fin (Fin)
36+
import Data.Finite (Finite)
3837
import Data.Map qualified as Map
3938
import Data.Semiring (Ring (..), Semiring (..))
4039
import Data.Set qualified as Set
41-
import Data.Type.Nat qualified as Nat
42-
import Data.Vec.Lazy (Vec, universe)
43-
import Data.Vec.Lazy qualified as Vec
40+
import Data.Vector.Sized (Vector)
41+
import Data.Vector.Sized qualified as V
4442
import Protolude hiding (Semiring)
4543
import Text.PrettyPrint.Leijen.Text hiding ((<$>))
4644

@@ -92,7 +90,7 @@ rawWire :: Var i f ty -> i
9290
rawWire (VarField i) = i
9391
rawWire (VarBool i) = i
9492

95-
type family NBits a :: Nat.Nat
93+
type family NBits a :: Nat
9694

9795
-- | Expression data type of (arithmetic) expressions over a field @f@
9896
-- with variable names/indices coming from @i@.
@@ -103,9 +101,9 @@ data Expr i f t ty where
103101
EBinOp :: BinOp f ty -> Expr i f Identity ty -> Expr i f Identity ty -> Expr i f Identity ty
104102
EIf :: Expr i f Identity Bool -> Expr i f Identity ty -> Expr i f Identity ty -> Expr i f Identity ty
105103
EEq :: Expr i f Identity f -> Expr i f Identity f -> Expr i f Identity Bool
106-
ESplit :: (Nat.SNatI (NBits f)) => Expr i f Identity f -> Expr i f (Vec (NBits f)) Bool
107-
EJoin :: (Num f, Nat.SNatI n) => Expr i f (Vec n) Bool -> Expr i f Identity f
108-
EAtIndex :: (Nat.SNatI n) => Expr i f (Vec n) ty -> Fin n -> Expr i f Identity ty
104+
ESplit :: (KnownNat (NBits f)) => Expr i f Identity f -> Expr i f (Vector (NBits f)) Bool
105+
EJoin :: (Num f, KnownNat n) => Expr i f (Vector n) Bool -> Expr i f Identity f
106+
EAtIndex :: (KnownNat n) => Expr i f (Vector n) ty -> Finite n -> Expr i f Identity ty
109107

110108
deriving instance (Show f) => Show (BinOp f a)
111109

@@ -235,15 +233,15 @@ evalExpr' expr = case expr of
235233
pure $ Identity $ lhs' == rhs'
236234
ESplit i -> do
237235
x <- runIdentity <$> evalExpr' i
238-
pure $ Vec.tabulate $ \ix -> testBit (fromP x) (fromIntegral ix)
236+
pure $ V.generate $ \ix -> testBit (fromP x) (fromIntegral ix)
239237
EJoin i -> do
240238
bits <- evalExpr' i
241239
pure $
242240
Identity $
243-
Vec.ifoldMap (\ix b -> if b then fromInteger (2 ^ fromIntegral @_ @Integer ix) else 0) bits
241+
V.ifoldl (\acc ix b -> acc + if b then fromInteger (2 ^ fromIntegral @_ @Integer ix) else 0) 0 bits
244242
EAtIndex v i -> do
245243
_v <- evalExpr' v
246-
pure $ Identity $ _v Vec.! i
244+
pure $ Identity $ _v `V.index` i
247245

248246
-- pure $ Vec.fromList $ map (testBit i) [0 .. Nat.toInt (Vec.length i) - 1]
249247

@@ -442,7 +440,7 @@ compile expr = case expr of
442440
ESplit input -> do
443441
i <- compile input >>= addWire . runIdentity
444442
outputs <- traverse (\_ -> mkBoolVar =<< imm) $ universe @(NBits f)
445-
emit $ Split i (Vec.toList outputs)
443+
emit $ Split i (V.toList outputs)
446444
traverse (fmap runIdentity . compile . EVar . VarBool) outputs
447445
where
448446
mkBoolVar w = do
@@ -455,7 +453,7 @@ compile expr = case expr of
455453
pure . Identity . Right $ unsplit ws
456454
EAtIndex v ix -> do
457455
v' <- compile v
458-
pure . Identity $ v' Vec.! ix
456+
pure . Identity $ v' `V.index` ix
459457

460458
exprToArithCircuit ::
461459
(Num f, Foldable t) =>
@@ -486,3 +484,6 @@ instance (GaloisField f) => Num (Expr Wire f Identity f) where
486484
abs = identity
487485
signum = const 1
488486
fromInteger = EVal . ValField . fromInteger
487+
488+
universe :: (KnownNat n) => Vector n (Finite n)
489+
universe = V.enumFromN 0

Diff for: src/Circuit/Lang.hs

+6-7
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,15 @@ import Circuit.Affine (AffineCircuit (..))
3434
import Circuit.Arithmetic (Gate (..), InputType (Private, Public), Wire (..))
3535
import Circuit.Expr
3636
import Data.Field.Galois (GaloisField)
37-
import Data.Fin (Fin)
38-
import Data.Type.Nat qualified as Nat
39-
import Data.Vec.Lazy (Vec)
37+
import Data.Finite (Finite)
38+
import Data.Vector.Sized (Vector)
4039
import Protolude
4140

4241
--------------------------------------------------------------------------------
4342

4443
type Signal f a = Expr Wire f Identity a
4544

46-
type Bundle f n a = Expr Wire f (Vec n) a
45+
type Bundle f n a = Expr Wire f (Vector n) a
4746

4847
-- | Convert constant to expression
4948
cField :: f -> Signal f f
@@ -89,12 +88,12 @@ cond :: Signal f Bool -> Signal f ty -> Signal f ty -> Signal f ty
8988
cond = EIf
9089

9190
splitBits ::
92-
(Nat.SNatI (NBits f)) =>
91+
(KnownNat (NBits f)) =>
9392
Signal f f ->
9493
Bundle f (NBits f) Bool
9594
splitBits = ESplit
9695

97-
joinBits :: (Num f, Nat.SNatI n) => Bundle f n Bool -> Signal f f
96+
joinBits :: (Num f, KnownNat n) => Bundle f n Bool -> Signal f f
9897
joinBits = EJoin
9998

10099
deref :: Var Wire f ty -> Signal f ty
@@ -119,7 +118,7 @@ retBool label = compileWithWire (boolInput Public label)
119118
retField :: (Num f) => Text -> Signal f f -> ExprM f Wire
120119
retField label = compileWithWire (fieldInput Public label)
121120

122-
atIndex :: (Nat.SNatI n) => Bundle f n ty -> Fin n -> Signal f ty
121+
atIndex :: (KnownNat n) => Bundle f n ty -> Finite n -> Signal f ty
123122
atIndex = EAtIndex
124123

125124
--------------------------------------------------------------------------------

Diff for: test/Test/Circuit/Lang.hs

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
{-# LANGUAGE DataKinds #-}
22
{-# LANGUAGE TypeFamilies #-}
3-
{-# OPTIONS_GHC -freduction-depth=0 #-}
43

54
module Test.Circuit.Lang where
65

76
import Circuit
87
import Data.Field.Galois (Prime, PrimeField (fromP))
9-
import Data.Fin (Fin)
8+
import Data.Finite (Finite)
109
import Data.Map qualified as Map
1110
import Data.Maybe (fromJust)
1211
import Protolude hiding (Show, show)
1312
import Test.QuickCheck (Property, (==>))
1413

1514
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
1615

17-
type instance NBits (Prime p) = FromGHC 256
16+
type instance NBits (Prime p) = 256
1817

1918
bitSplitJoin :: ExprM Fr Wire
2019
bitSplitJoin = do
@@ -59,7 +58,7 @@ prop_factorizationContra x y z =
5958
w = solve bsVars bsCircuit inputs
6059
in lookupVar bsVars "out" w == Just 0
6160

62-
bitIndex :: Fin (NBits Fr) -> ExprM Fr Wire
61+
bitIndex :: Finite (NBits Fr) -> ExprM Fr Wire
6362
bitIndex i = do
6463
x <- deref <$> fieldInput Public "x"
6564
let bits = splitBits x

0 commit comments

Comments
 (0)