Skip to content

Commit ac022c6

Browse files
committed
handle vector operations with combinators
1 parent b2f20d8 commit ac022c6

File tree

10 files changed

+101
-884
lines changed

10 files changed

+101
-884
lines changed

bundleUnbundle.dot

Lines changed: 0 additions & 771 deletions
This file was deleted.

cabal.project

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ source-repository-package
1212
source-repository-package
1313
type: git
1414
location: https://github.com/l-adic/galois-fields.git
15-
tag: fc82039e811ba68c10527cf871796b7ac8514926
15+
tag: b0867ffdebda5043c80315a51b15e82ed25acba6
1616
--sha256: j/zGFd2aeowzJfgCCBmJYmG8mDsfF0irqj/cPOw9ulE=
1717

1818
source-repository-package

circuit/src/Circuit/Affine.hs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,9 @@ instance Bifunctor AffineCircuit where
4747
ScalarMul s x -> ScalarMul (f s) (bimap f g x)
4848
ConstGate c -> ConstGate (f c)
4949
Var i -> Var (g i)
50-
5150
Nil -> Nil
5251

5352
instance (Pretty i, Pretty f) => Pretty (AffineCircuit f i) where
54-
pretty :: AffineCircuit f i -> Doc
5553
pretty = prettyPrec 0
5654
where
5755
prettyPrec :: Int -> AffineCircuit f i -> Doc
@@ -61,8 +59,7 @@ instance (Pretty i, Pretty f) => Pretty (AffineCircuit f i) where
6159
text "nil"
6260
Var v ->
6361
pretty v
64-
ConstGate f ->
65-
pretty f
62+
ConstGate f -> pretty f
6663
ScalarMul f e1 ->
6764
pretty f <+> text "*" <+> parensPrec 7 p (prettyPrec p e1)
6865
Add e1 e2 ->

circuit/src/Circuit/Arithmetic.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module Circuit.Arithmetic
1616
CircuitVars (..),
1717
relabel,
1818
collectCircuitVars,
19-
booleanWires
19+
booleanWires,
2020
)
2121
where
2222

@@ -342,4 +342,4 @@ booleanWires :: ArithCircuit f -> Set Wire
342342
booleanWires (ArithCircuit gates) = foldMap f gates
343343
where
344344
f (Boolean i) = Set.singleton i
345-
f _ = mempty
345+
f _ = mempty

circuit/src/Circuit/Dot.hs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ where
77

88
import Circuit.Affine ()
99
import Circuit.Arithmetic (ArithCircuit (..), Gate (..), Wire (..), booleanWires)
10+
import Data.Field.Galois (PrimeField)
1011
import Data.Text qualified as Text
1112
import Protolude
12-
import Data.Field.Galois (PrimeField)
1313
import System.FilePath (replaceExtension)
1414
import System.Process (readProcessWithExitCode)
1515
import Text.PrettyPrint.Leijen.Text (Pretty (..))
@@ -30,7 +30,7 @@ arithCircuitToDot c@(ArithCircuit gates) =
3030
dotArrow s t = s <> " -> " <> t
3131

3232
dotArrowLabel :: Text -> Text -> Text -> Text -> Text
33-
dotArrowLabel s t lbl _c =
33+
dotArrowLabel s t lbl _c =
3434
dotArrow s t <> " [label=\"" <> lbl <> "\", color=\"" <> _c <> "\"]"
3535

3636
labelNode lblId lbl = lblId <> " [label=\"" <> lbl <> "\"]"
@@ -53,8 +53,10 @@ arithCircuitToDot c@(ArithCircuit gates) =
5353
gateLabel = dotWire output
5454
inputs circuit tgt =
5555
map
56-
( \a -> (\src -> dotArrowLabel src tgt (show $ pretty src) (color a))
57-
. dotWire $ a
56+
( \a ->
57+
(\src -> dotArrowLabel src tgt (show $ pretty src) (color a))
58+
. dotWire
59+
$ a
5860
)
5961
$ toList circuit
6062
graphGate (Equal i m output) =
@@ -73,8 +75,8 @@ arithCircuitToDot c@(ArithCircuit gates) =
7375
where
7476
gateLabel = Text.concat . fmap dotWire $ outputs
7577
graphGate (Boolean _) = []
76-
77-
-- gateLabel = dotWire i
78+
79+
-- gateLabel = dotWire i
7880

7981
callDot :: Text -> IO Text
8082
callDot g = do

language/src/Circuit/Language/Compile.hs

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module Circuit.Language.Compile
1212
compileWithWire,
1313
compileWithWires,
1414
exprToArithCircuit,
15-
unBundle
15+
_unBundle,
1616
)
1717
where
1818

@@ -21,21 +21,21 @@ import Circuit.Arithmetic
2121
import Circuit.Language.Expr
2222
( BinOp (..),
2323
Expr (..),
24+
Hash (Hash),
2425
UVar (..),
2526
UnOp (..),
2627
getAnnotation,
2728
hashCons,
28-
Hash (Hash),
2929
)
3030
import Circuit.Language.TExpr qualified as TExpr
3131
import Data.Field.Galois (GaloisField)
3232
import Data.Map qualified as Map
33+
import Data.Maybe (fromJust)
3334
import Data.Set qualified as Set
3435
import Data.Vector qualified as V
36+
import Data.Vector.Sized qualified as SV
3537
import Protolude hiding (Semiring)
3638
import Text.PrettyPrint.Leijen.Text hiding ((<$>))
37-
import Data.Maybe (fromJust)
38-
import Data.Vector.Sized qualified as SV
3939
import Unsafe.Coerce (unsafeCoerce)
4040

4141
-------------------------------------------------------------------------------
@@ -204,13 +204,13 @@ compileWithWire freshWire e = do
204204
<$> compileWithWires (V.singleton $ fmap TExpr.coerceGroundType freshWire) e
205205

206206
compileWithWires ::
207-
(Hashable f) =>
207+
(Hashable f) =>
208208
(GaloisField f) =>
209209
(MonadState (BuilderState f) m) =>
210210
(MonadError (CircuitBuilderError f) m) =>
211211
V.Vector (m (TExpr.Var Wire f f)) ->
212212
TExpr.Expr Wire f ty ->
213-
m (V.Vector (TExpr.Var Wire f f) )
213+
m (V.Vector (TExpr.Var Wire f f))
214214
compileWithWires ws expr = do
215215
e <- hashCons <$> unType expr
216216
compileOut <- memoizedCompile e
@@ -331,7 +331,7 @@ _compile expr = withCompilerCache (getAnnotation expr) $ case expr of
331331
ESplit _ n input -> do
332332
-- assertSingle is justified as the input must be of type f
333333
i <- memoizedCompile input >>= assertSingleSource >>= addWire
334-
outputs <- V.generateM n $ \_ -> do
334+
outputs <- V.generateM n $ \_ -> do
335335
w <- imm
336336
emit $ Boolean w
337337
pure w
@@ -351,14 +351,6 @@ _compile expr = withCompilerCache (getAnnotation expr) $ case expr of
351351
bs <- toList <$> memoizedCompile bits
352352
ws <- traverse addWire bs
353353
pure . V.singleton . AffineSource $ unsplit ws
354-
EAtIndex _ v _ix -> do
355-
v' <- memoizedCompile v
356-
pure . V.singleton $ v' V.! (fromIntegral _ix)
357-
EUpdateIndex _ p b v -> do
358-
v' <- memoizedCompile v
359-
b' <- memoizedCompile b >>= assertSingleSource
360-
let p' = fromIntegral p
361-
pure $ V.imap (\_ix w -> if _ix == p' then b' else w) v'
362354

363355
memoizedCompile ::
364356
forall f m.
@@ -385,32 +377,22 @@ exprToArithCircuit expr output = do
385377
compileOut <- memoizedCompile e >>= assertSingleSource
386378
emit $ Mul (ConstGate 1) (addVar compileOut) output
387379

388-
fieldToBool
389-
:: (Hashable f, GaloisField f) =>
390-
TExpr.Expr Wire f f ->
391-
ExprM f (TExpr.Expr Wire f Bool)
380+
fieldToBool ::
381+
(Hashable f, GaloisField f) =>
382+
TExpr.Expr Wire f f ->
383+
ExprM f (TExpr.Expr Wire f Bool)
392384
fieldToBool e = do
393385
eOut <- hashCons <$> unType e
394386
a <- memoizedCompile eOut >>= assertSingleSource >>= addWire
395387
emit $ Boolean a
396388
pure $ unsafeCoerce e
397389

398-
unBundle ::
399-
forall n f ty.
400-
(KnownNat n, GaloisField f, Hashable f) =>
401-
TExpr.Expr Wire f (SV.Vector n ty) ->
402-
ExprM f (SV.Vector n (TExpr.Expr Wire f f))
403-
unBundle b = do
404-
bis <- memoizedCompile . hashCons =<< unType b
405-
ws <- traverse addWire bis
406-
pure $ fromJust $ SV.toSized (TExpr.EVar . TExpr.VarField <$> ws)
407-
408-
unType :: forall f ty m. MonadState (BuilderState f) m => TExpr.Expr Wire f ty -> m (Expr () Wire f)
390+
unType :: forall f ty m. (MonadState (BuilderState f) m) => TExpr.Expr Wire f ty -> m (Expr () Wire f)
409391
unType = \case
410392
TExpr.EVal v -> pure $ case v of
411393
TExpr.ValBool b -> EVal () b
412394
TExpr.ValField f -> EVal () f
413-
TExpr.EVar v -> case v of
395+
TExpr.EVar v -> case v of
414396
TExpr.VarField w -> pure $ EVar () (UVar w)
415397
TExpr.VarBool b -> do
416398
emit $ Boolean b
@@ -421,8 +403,6 @@ unType = \case
421403
TExpr.EEq l r -> EEq () <$> unType l <*> unType r
422404
TExpr.ESplit i -> ESplit () (fromIntegral $ natVal (Proxy @(TExpr.NBits f))) <$> unType i
423405
TExpr.EJoin i -> EJoin () <$> unType i
424-
TExpr.EAtIndex v ix -> EAtIndex () <$> unType v <*> pure (fromIntegral ix)
425-
TExpr.EUpdateIndex p b v -> EUpdateIndex () (fromIntegral p) <$> unType b <*> unType v
426406
TExpr.EBundle b -> EBundle () <$> traverse unType (SV.fromSized b)
427407
where
428408
untypeBinOp :: TExpr.BinOp f a -> BinOp
@@ -440,3 +420,14 @@ unType = \case
440420
TExpr.UNeg -> UNeg
441421
TExpr.UNot -> UNot
442422

423+
_unBundle ::
424+
forall n f ty.
425+
(KnownNat n) =>
426+
(GaloisField f) =>
427+
(Hashable f) =>
428+
TExpr.Expr Wire f (SV.Vector n ty) ->
429+
ExprM f (SV.Vector n (TExpr.Expr Wire f f))
430+
_unBundle b = do
431+
bis <- memoizedCompile . hashCons =<< unType b
432+
ws <- traverse addWire bis
433+
pure $ fromJust $ SV.toSized (TExpr.EVar . TExpr.VarField <$> ws)

language/src/Circuit/Language/DSL.hs

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
{-# LANGUAGE TypeFamilies #-}
1+
{-# LANGUAGE TypeFamilyDependencies #-}
22

33
-- | Surface language
44
module Circuit.Language.DSL
55
( Signal,
6-
Bundle,
6+
Bundled (..),
7+
type Unbundled,
78
cField,
89
cBool,
910
add,
@@ -25,7 +26,6 @@ module Circuit.Language.DSL
2526
joinBits,
2627
atIndex,
2728
updateIndex_,
28-
bundle,
2929

3030
-- * Monoids
3131
Any_ (..),
@@ -42,14 +42,14 @@ import Circuit.Language.Compile
4242
import Circuit.Language.TExpr
4343
import Data.Field.Galois (GaloisField, PrimeField)
4444
import Data.Finite (Finite)
45-
import Data.Vector.Sized (Vector)
45+
import Data.Vector.Sized (Vector, ix)
46+
import Lens.Micro ((.~), (^.))
4647
import Protolude
48+
import Unsafe.Coerce (unsafeCoerce)
4749

4850
--------------------------------------------------------------------------------
4951
type Signal f = Expr Wire f
5052

51-
type Bundle f n a = Expr Wire f (Vector n a)
52-
5353
-- | Convert constant to expression
5454
cField :: f -> Signal f f
5555
cField = EVal . ValField
@@ -96,10 +96,10 @@ cond = EIf
9696
splitBits ::
9797
(KnownNat (NBits f)) =>
9898
Signal f f ->
99-
Bundle f (NBits f) Bool
99+
Signal f (Vector (NBits f) Bool)
100100
splitBits = ESplit
101101

102-
joinBits :: (KnownNat n) => Bundle f n Bool -> Signal f f
102+
joinBits :: (KnownNat n) => Signal f (Vector n Bool) -> Signal f f
103103
joinBits = EJoin
104104

105105
deref :: Var Wire f ty -> Signal f ty
@@ -111,14 +111,25 @@ retBool label sig = compileWithWire (boolInput Public label) sig
111111
retField :: (PrimeField f, Hashable f) => Text -> Signal f f -> ExprM f (Var Wire f f)
112112
retField label sig = compileWithWire (fieldInput Public label) sig
113113

114-
atIndex :: (KnownNat n) => Bundle f n ty -> Finite n -> Signal f ty
115-
atIndex = EAtIndex
116-
117-
updateIndex_ :: (KnownNat n) => Finite n -> Signal f ty -> Bundle f n ty -> Bundle f n ty
118-
updateIndex_ p = EUpdateIndex p
119-
120-
bundle :: Vector n (Signal f ty) -> Bundle f n ty
121-
bundle = EBundle
114+
atIndex ::
115+
(Bundled f (Vector n ty)) =>
116+
Finite n ->
117+
Signal f (Vector n ty) ->
118+
ExprM f (Signal f ty)
119+
atIndex i b = do
120+
bs <- unbundle b
121+
return $ bs ^. ix i
122+
123+
updateIndex_ ::
124+
(Bundled f (Vector n ty)) =>
125+
Finite n ->
126+
Signal f ty ->
127+
Signal f (Vector n ty) ->
128+
ExprM f (Signal f (Vector n ty))
129+
updateIndex_ p s v = do
130+
bs <- unbundle v
131+
let bs' = bs & ix p .~ s
132+
return $ bundle bs'
122133

123134
--------------------------------------------------------------------------------
124135

@@ -140,10 +151,10 @@ instance (Num f) => Monoid (Any_ f) where
140151

141152
newtype Add_ f = Add_ {unAdd_ :: Signal f f}
142153

143-
instance GaloisField f => Semigroup (Add_ f) where
144-
Add_ a <> Add_ b = Add_ $ add a b
154+
instance (GaloisField f) => Semigroup (Add_ f) where
155+
Add_ a <> Add_ b = Add_ $ add a b
145156

146-
instance GaloisField f => Monoid (Add_ f) where
157+
instance (GaloisField f) => Monoid (Add_ f) where
147158
mempty = Add_ $ cField 0
148159

149160
--------------------------------------------------------------------------------
@@ -170,3 +181,20 @@ any_ ::
170181
t a ->
171182
Signal f Bool
172183
any_ f = unAny_ . foldMap (Any_ . f)
184+
185+
--------------------------------------------------------------------------------
186+
187+
type family Unbundled f a = res | res -> a where
188+
Unbundled f (Vector n ty) = Vector n (Signal f ty)
189+
190+
class Bundled f a where
191+
bundle :: Unbundled f a -> Signal f a
192+
unbundle :: Signal f a -> ExprM f (Unbundled f a)
193+
194+
instance (Hashable f, GaloisField f, KnownNat n) => Bundled f (Vector n f) where
195+
bundle = EBundle
196+
unbundle = _unBundle
197+
198+
instance (Hashable f, GaloisField f, KnownNat n) => Bundled f (Vector n Bool) where
199+
bundle = EBundle
200+
unbundle = fmap unsafeCoerce . _unBundle

0 commit comments

Comments
 (0)