Skip to content

Commit 7785d46

Browse files
committed
added vector ops for bitvec operations
1 parent c2c1943 commit 7785d46

File tree

4 files changed

+184
-47
lines changed

4 files changed

+184
-47
lines changed

Diff for: language/src/Circuit/Language/DSL.hs

+19-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ module Circuit.Language.DSL
1111
sub,
1212
mul,
1313
and_,
14+
ands_,
1415
or_,
16+
ors_,
1517
xor_,
18+
xors_,
1619
not_,
20+
nots_,
1721
eq_,
1822
fieldInput,
1923
boolInput,
@@ -80,10 +84,22 @@ and_ = binOp_ BAnd
8084
or_ = binOp_ BOr
8185
xor_ = binOp_ BXor
8286

87+
ands_, ors_, xors_ ::
88+
(Hashable f) =>
89+
Signal f ('TVec n 'TBool) ->
90+
Signal f ('TVec n 'TBool) ->
91+
Signal f ('TVec n 'TBool)
92+
ands_ = binOp_ BAnds
93+
ors_ = binOp_ BOrs
94+
xors_ = binOp_ BXors
95+
8396
-- | Negate expression
8497
not_ :: (Hashable f) => Signal f 'TBool -> Signal f 'TBool
8598
not_ = unOp_ UNot
8699

100+
nots_ :: (Hashable f) => Signal f ('TVec n 'TBool) -> Signal f ('TVec n 'TBool)
101+
nots_ = unOp_ UNots
102+
87103
fieldInput :: InputType -> Text -> ExprM f (Var Wire f 'TField)
88104
fieldInput it label =
89105
case it of
@@ -106,10 +122,10 @@ fieldOutput label s = do
106122
fieldsOutput :: (KnownNat n, Hashable f, GaloisField f) => Vector n (Var Wire f 'TField) -> Signal f ('TVec n 'TField) -> ExprM f (Vector n (Var Wire f 'TField))
107123
fieldsOutput vs s = fromJust . SV.toSized <$> compileWithWires (SV.fromSized vs) s
108124

109-
boolOutput :: (Hashable f, GaloisField f) => Text -> Signal f 'TBool -> ExprM f (Var Wire f 'TBool)
125+
boolOutput :: forall f. (Hashable f, GaloisField f) => Text -> Signal f 'TBool -> ExprM f (Var Wire f 'TBool)
110126
boolOutput label s = do
111127
out <- VarBool <$> freshOutput label
112-
unsafeCoerce <$> compileWithWire (boolToField out) (boolToField s)
128+
unsafeCoerce <$> compileWithWire (boolToField @(Var Wire f 'TBool) out) (boolToField s)
113129
{-# INLINE boolOutput #-}
114130

115131
boolsOutput :: (KnownNat n, Hashable f, GaloisField f) => Vector n (Var Wire f 'TBool) -> Signal f ('TVec n 'TBool) -> ExprM f (Vector n (Var Wire f 'TBool))
@@ -213,7 +229,7 @@ any_ f = unAny_ . foldMap (Any_ . f)
213229
--------------------------------------------------------------------------------
214230

215231
class Bundle f a where
216-
type Unbundled f a
232+
type Unbundled f a = r | r -> a
217233
bundle :: Unbundled f a -> Signal f a
218234
unbundle :: Signal f a -> ExprM f (Unbundled f a)
219235

Diff for: language/src/Circuit/Language/Expr.hs

+33-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Circuit.Language.Expr
3535
)
3636
where
3737

38-
import Data.Field.Galois (Prime, PrimeField (fromP))
38+
import Data.Field.Galois (Prime, PrimeField (fromP), GaloisField)
3939
import Data.Semiring (Ring (..), Semiring (..))
4040
import Data.Sequence ((|>))
4141
import Data.Set qualified as Set
@@ -91,6 +91,7 @@ rawWire (VarBool i) = i
9191
data UnOp f (ty :: Ty) where
9292
UNeg :: UnOp f 'TField
9393
UNot :: UnOp f 'TBool
94+
UNots :: UnOp f ('TVec n 'TBool)
9495

9596
deriving instance (Show f) => Show (UnOp f a)
9697

@@ -100,15 +101,19 @@ instance Pretty (UnOp f a) where
100101
pretty op = case op of
101102
UNeg -> text "neg"
102103
UNot -> text "!"
104+
UNots -> text "nots"
103105

104106
data BinOp f (a :: Ty) where
105107
BAdd :: BinOp f 'TField
106108
BSub :: BinOp f 'TField
107109
BMul :: BinOp f 'TField
108110
BDiv :: BinOp f 'TField
109111
BAnd :: BinOp f 'TBool
112+
BAnds :: BinOp f (TVec n 'TBool)
110113
BOr :: BinOp f 'TBool
114+
BOrs :: BinOp f (TVec n 'TBool)
111115
BXor :: BinOp f 'TBool
116+
BXors :: BinOp f (TVec n 'TBool)
112117

113118
deriving instance (Show f) => Show (BinOp f a)
114119

@@ -121,13 +126,19 @@ instance Pretty (BinOp f a) where
121126
BMul -> text "*"
122127
BDiv -> text "/"
123128
BAnd -> text "&&"
129+
BAnds -> text ".&&."
124130
BOr -> text "||"
131+
BOrs -> text ".||."
125132
BXor -> text "xor"
133+
BXors -> text "xors"
126134

127135
opPrecedence :: BinOp f a -> Int
128136
opPrecedence BOr = 5
137+
opPrecedence BOrs = 5
129138
opPrecedence BXor = 5
139+
opPrecedence BXors = 5
130140
opPrecedence BAnd = 5
141+
opPrecedence BAnds = 5
131142
opPrecedence BSub = 6
132143
opPrecedence BAdd = 6
133144
opPrecedence BMul = 7
@@ -246,6 +257,8 @@ evalExpr lookupVar vars expr = case expr of
246257
Protolude.negate $ evalExpr lookupVar vars e1
247258
EUnOp _ UNot e1 ->
248259
not $ evalExpr lookupVar vars e1
260+
EUnOp _ UNots e1 ->
261+
SV.map not $ evalExpr lookupVar vars e1
249262
EBinOp _ op e1 e2 ->
250263
let e1' = evalExpr lookupVar vars e1
251264
e2' = evalExpr lookupVar vars e2
@@ -257,8 +270,11 @@ evalExpr lookupVar vars expr = case expr of
257270
BMul -> (*)
258271
BDiv -> (/)
259272
BAnd -> (&&)
273+
BAnds -> SV.zipWith (&&)
260274
BOr -> (||)
275+
BOrs -> SV.zipWith (||)
261276
BXor -> \x y -> (x || y) && not (x && y)
277+
BXors -> SV.zipWith (\x y -> (x || y) && not (x && y))
262278
EIf _ b true false ->
263279
let cond = evalExpr lookupVar vars b
264280
in if cond
@@ -401,9 +417,12 @@ bundle_ b =
401417
in EBundle h b
402418
{-# INLINE bundle_ #-}
403419

404-
class BoolToField b f | b -> f where
420+
class BoolToField b f where
405421
boolToField :: b -> f
406422

423+
instance GaloisField f => BoolToField Bool f where
424+
boolToField b = fromInteger $ if b then 1 else 0
425+
407426
instance BoolToField (Val f 'TBool) (Val f 'TField) where
408427
boolToField (ValBool b) = ValField b
409428

@@ -418,7 +437,14 @@ instance BoolToField (Expr i f ('TVec n 'TBool)) (Expr i f ('TVec n 'TField)) wh
418437

419438
-------------------------------------------------------------------------------
420439

421-
data UBinOp = UBAdd | UBSub | UBMul | UBDiv | UBAnd | UBOr | UBXor deriving (Show, Eq, Generic)
440+
data UBinOp =
441+
UBAdd |
442+
UBSub |
443+
UBMul |
444+
UBDiv |
445+
UBAnd |
446+
UBOr |
447+
UBXor deriving (Show, Eq, Generic)
422448

423449
instance Hashable UBinOp
424450

@@ -481,6 +507,7 @@ instance (Hashable i, Hashable f) => Hashable (Node i f) where
481507
untypeUnOp :: UnOp f a -> UUnOp
482508
untypeUnOp UNeg = UUNeg
483509
untypeUnOp UNot = UUNot
510+
untypeUnOp UNots = UUNot
484511
{-# INLINE untypeUnOp #-}
485512

486513
untypeBinOp :: BinOp f a -> UBinOp
@@ -489,8 +516,11 @@ untypeBinOp BSub = UBSub
489516
untypeBinOp BMul = UBMul
490517
untypeBinOp BDiv = UBDiv
491518
untypeBinOp BAnd = UBAnd
519+
untypeBinOp BAnds = UBAnd
492520
untypeBinOp BOr = UBOr
521+
untypeBinOp BOrs = UBOr
493522
untypeBinOp BXor = UBXor
523+
untypeBinOp BXors = UBXor
494524
{-# INLINE untypeBinOp #-}
495525

496526
--------------------------------------------------------------------------------

Diff for: language/test/Test/Circuit/Lang.hs

+76-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ import Data.Finite (Finite)
1111
import Data.IntMap qualified as IntMap
1212
import Data.Map qualified as Map
1313
import Protolude
14-
import Test.QuickCheck (Property, (.&&.), (=/=), (===), (==>))
14+
import Data.Vector.Sized qualified as SV
15+
import Test.QuickCheck (Arbitrary (arbitrary), Property, forAll, vectorOf, (.&&.), (=/=), (===), (==>))
1516

1617
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
1718

@@ -142,4 +143,78 @@ prop_sharingProg x y =
142143
computed = evalExpr IntMap.lookup input (relabelExpr wireName prog)
143144
in res === Just expected .&&. computed === expected
144145

146+
binopsProg ::
147+
BinOp Fr ('TVec 50 'TBool) ->
148+
ExprM Fr (SV.Vector 50 (Var Wire Fr 'TBool))
149+
binopsProg op = do
150+
xs <- SV.generateM $ \i ->
151+
var_ <$> boolInput Public ("x" <> showFinite i)
152+
ys <- SV.generateM $ \i ->
153+
var_ <$> boolInput Public ("y" <> showFinite i)
154+
zs <- SV.generateM $ \i ->
155+
VarBool <$> freshOutput ("out" <> showFinite i)
156+
boolsOutput zs $ binOp_ op (bundle xs) (bundle ys)
157+
158+
propBinopsProg ::
159+
BinOp Fr ('TVec 50 'TBool) ->
160+
(Bool -> Bool -> Bool) ->
161+
Property
162+
propBinopsProg top op = forAll arbInputs $ \(bs, bs') ->
163+
let _xs = zip (map (\i -> "x" <> show @Int i) [0 ..]) bs
164+
_ys = zip (map (\i -> "y" <> show @Int i) [0 ..]) bs'
165+
(_, BuilderState {bsVars, bsCircuit}) = runCircuitBuilder (binopsProg top)
166+
inputs =
167+
assignInputs bsVars $
168+
fmap boolToField $
169+
Map.fromList (_xs <> _ys)
170+
w = solve bsVars bsCircuit inputs
171+
expected = map boolToField $ zipWith op bs bs'
172+
in all (\(i, b) -> lookupVar bsVars ("out" <> show @Int i) w == Just b) $
173+
zip [0 ..] expected
174+
where
175+
arbInputs = ((,) <$> vectorOf 50 arbitrary <*> vectorOf 50 arbitrary)
176+
177+
prop_andsProg :: Property
178+
prop_andsProg = propBinopsProg BAnds (&&)
179+
180+
prop_orsProg :: Property
181+
prop_orsProg = propBinopsProg BOrs (||)
182+
183+
prop_xorsProg :: Property
184+
prop_xorsProg = propBinopsProg BXors (/=)
185+
186+
unopsProg ::
187+
UnOp Fr ('TVec 50 'TBool) ->
188+
ExprM Fr (SV.Vector 50 (Var Wire Fr 'TBool))
189+
unopsProg op = do
190+
xs <- SV.generateM $ \i ->
191+
var_ <$> boolInput Public ("x" <> showFinite i)
192+
zs <- SV.generateM $ \i ->
193+
VarBool <$> freshOutput ("out" <> showFinite i)
194+
boolsOutput zs $ unOp_ op (bundle xs)
195+
196+
propUnopsProg ::
197+
UnOp Fr ('TVec 50 'TBool) ->
198+
(Bool -> Bool) ->
199+
Property
200+
propUnopsProg top op = forAll arbInputs $ \bs ->
201+
let _xs = zip (map (\i -> "x" <> show @Int i) [0 ..]) bs
202+
(_, BuilderState {bsVars, bsCircuit}) = runCircuitBuilder (unopsProg top)
203+
inputs =
204+
assignInputs bsVars $
205+
fmap boolToField $
206+
Map.fromList _xs
207+
w = solve bsVars bsCircuit inputs
208+
expected = map (boolToField . op) bs
209+
in all (\(i, b) -> lookupVar bsVars ("out" <> show @Int i) w == Just b) $
210+
zip [0 ..] expected
211+
where
212+
arbInputs = vectorOf 50 arbitrary
213+
214+
prop_notsProg :: Property
215+
prop_notsProg = propUnopsProg UNots not
216+
145217
--------------------------------------------------------------------------------
218+
219+
showFinite :: (KnownNat n) => Finite n -> Text
220+
showFinite = show . toInteger

0 commit comments

Comments
 (0)