Skip to content

Commit b49f89a

Browse files
authored
Merge pull request #10 from l-adic/bitvector-ops
added vector ops for bitvec operations
2 parents c2c1943 + 71ef9cb commit b49f89a

File tree

4 files changed

+322
-59
lines changed

4 files changed

+322
-59
lines changed

Diff for: language/src/Circuit/Language/DSL.hs

+56-11
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,24 @@ module Circuit.Language.DSL
77
Bundle (..),
88
cField,
99
cBool,
10-
add,
11-
sub,
12-
mul,
10+
add_,
11+
adds_,
12+
sub_,
13+
subs_,
14+
mul_,
15+
muls_,
16+
div_,
17+
divs_,
1318
and_,
19+
ands_,
1420
or_,
21+
ors_,
1522
xor_,
23+
xors_,
24+
neg_,
25+
negs_,
1626
not_,
27+
nots_,
1728
eq_,
1829
fieldInput,
1930
boolInput,
@@ -67,10 +78,24 @@ cBool b = val_ . ValBool $ if b then 1 else 0
6778
{-# INLINE cBool #-}
6879

6980
-- | Binary arithmetic operations on expressions
70-
add, sub, mul :: (Hashable f, Num f) => Signal f 'TField -> Signal f 'TField -> Signal f 'TField
71-
add = (+)
72-
sub = (-)
73-
mul = (*)
81+
add_, sub_, mul_, div_ :: (Hashable f) => Signal f 'TField -> Signal f 'TField -> Signal f 'TField
82+
add_ = binOp_ BAdd
83+
sub_ = binOp_ BSub
84+
mul_ = binOp_ BMul
85+
div_ = binOp_ BDiv
86+
87+
adds_,
88+
subs_,
89+
muls_,
90+
divs_ ::
91+
(Hashable f) =>
92+
Signal f ('TVec n 'TField) ->
93+
Signal f ('TVec n 'TField) ->
94+
Signal f ('TVec n 'TField)
95+
adds_ = binOp_ BAdds
96+
subs_ = binOp_ BSubs
97+
muls_ = binOp_ BMuls
98+
divs_ = binOp_ BDivs
7499

75100
-- | Binary logic operations on expressions
76101
-- Have to use underscore or similar to avoid shadowing @and@ and @or@
@@ -80,10 +105,30 @@ and_ = binOp_ BAnd
80105
or_ = binOp_ BOr
81106
xor_ = binOp_ BXor
82107

108+
ands_,
109+
ors_,
110+
xors_ ::
111+
(Hashable f) =>
112+
Signal f ('TVec n 'TBool) ->
113+
Signal f ('TVec n 'TBool) ->
114+
Signal f ('TVec n 'TBool)
115+
ands_ = binOp_ BAnds
116+
ors_ = binOp_ BOrs
117+
xors_ = binOp_ BXors
118+
83119
-- | Negate expression
84120
not_ :: (Hashable f) => Signal f 'TBool -> Signal f 'TBool
85121
not_ = unOp_ UNot
86122

123+
nots_ :: (Hashable f) => Signal f ('TVec n 'TBool) -> Signal f ('TVec n 'TBool)
124+
nots_ = unOp_ UNots
125+
126+
neg_ :: (Hashable f) => Signal f 'TField -> Signal f 'TField
127+
neg_ = unOp_ UNeg
128+
129+
negs_ :: (Hashable f) => Signal f ('TVec n 'TField) -> Signal f ('TVec n 'TField)
130+
negs_ = unOp_ UNegs
131+
87132
fieldInput :: InputType -> Text -> ExprM f (Var Wire f 'TField)
88133
fieldInput it label =
89134
case it of
@@ -106,10 +151,10 @@ fieldOutput label s = do
106151
fieldsOutput :: (KnownNat n, Hashable f, GaloisField f) => Vector n (Var Wire f 'TField) -> Signal f ('TVec n 'TField) -> ExprM f (Vector n (Var Wire f 'TField))
107152
fieldsOutput vs s = fromJust . SV.toSized <$> compileWithWires (SV.fromSized vs) s
108153

109-
boolOutput :: (Hashable f, GaloisField f) => Text -> Signal f 'TBool -> ExprM f (Var Wire f 'TBool)
154+
boolOutput :: forall f. (Hashable f, GaloisField f) => Text -> Signal f 'TBool -> ExprM f (Var Wire f 'TBool)
110155
boolOutput label s = do
111156
out <- VarBool <$> freshOutput label
112-
unsafeCoerce <$> compileWithWire (boolToField out) (boolToField s)
157+
unsafeCoerce <$> compileWithWire (boolToField @(Var Wire f 'TBool) out) (boolToField s)
113158
{-# INLINE boolOutput #-}
114159

115160
boolsOutput :: (KnownNat n, Hashable f, GaloisField f) => Vector n (Var Wire f 'TBool) -> Signal f ('TVec n 'TBool) -> ExprM f (Vector n (Var Wire f 'TBool))
@@ -172,7 +217,7 @@ instance (Eq f, Num f, Hashable f) => Monoid (Any_ f) where
172217
newtype Add_ f = Add_ {unAdd_ :: Signal f 'TField}
173218

174219
instance (Hashable f, Num f) => Semigroup (Add_ f) where
175-
Add_ a <> Add_ b = Add_ $ add a b
220+
Add_ a <> Add_ b = Add_ $ add_ a b
176221

177222
instance (Hashable f, Num f) => Monoid (Add_ f) where
178223
mempty = Add_ $ cField 0
@@ -213,7 +258,7 @@ any_ f = unAny_ . foldMap (Any_ . f)
213258
--------------------------------------------------------------------------------
214259

215260
class Bundle f a where
216-
type Unbundled f a
261+
type Unbundled f a = r | r -> a
217262
bundle :: Unbundled f a -> Signal f a
218263
unbundle :: Signal f a -> ExprM f (Unbundled f a)
219264

Diff for: language/src/Circuit/Language/Expr.hs

+64-7
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Circuit.Language.Expr
3535
)
3636
where
3737

38-
import Data.Field.Galois (Prime, PrimeField (fromP))
38+
import Data.Field.Galois (GaloisField, Prime, PrimeField (fromP))
3939
import Data.Semiring (Ring (..), Semiring (..))
4040
import Data.Sequence ((|>))
4141
import Data.Set qualified as Set
@@ -90,7 +90,9 @@ rawWire (VarBool i) = i
9090

9191
data UnOp f (ty :: Ty) where
9292
UNeg :: UnOp f 'TField
93+
UNegs :: UnOp f ('TVec n 'TField)
9394
UNot :: UnOp f 'TBool
95+
UNots :: UnOp f ('TVec n 'TBool)
9496

9597
deriving instance (Show f) => Show (UnOp f a)
9698

@@ -99,16 +101,25 @@ deriving instance Eq (UnOp f a)
99101
instance Pretty (UnOp f a) where
100102
pretty op = case op of
101103
UNeg -> text "neg"
104+
UNegs -> text "negs"
102105
UNot -> text "!"
106+
UNots -> text "nots"
103107

104108
data BinOp f (a :: Ty) where
105109
BAdd :: BinOp f 'TField
110+
BAdds :: BinOp f (TVec n 'TField)
106111
BSub :: BinOp f 'TField
112+
BSubs :: BinOp f (TVec n 'TField)
107113
BMul :: BinOp f 'TField
114+
BMuls :: BinOp f (TVec n 'TField)
108115
BDiv :: BinOp f 'TField
116+
BDivs :: BinOp f (TVec n 'TField)
109117
BAnd :: BinOp f 'TBool
118+
BAnds :: BinOp f (TVec n 'TBool)
110119
BOr :: BinOp f 'TBool
120+
BOrs :: BinOp f (TVec n 'TBool)
111121
BXor :: BinOp f 'TBool
122+
BXors :: BinOp f (TVec n 'TBool)
112123

113124
deriving instance (Show f) => Show (BinOp f a)
114125

@@ -117,21 +128,35 @@ deriving instance Eq (BinOp f a)
117128
instance Pretty (BinOp f a) where
118129
pretty op = case op of
119130
BAdd -> text "+"
131+
BAdds -> text ".+."
120132
BSub -> text "-"
133+
BSubs -> text ".-."
121134
BMul -> text "*"
135+
BMuls -> text ".*."
122136
BDiv -> text "/"
137+
BDivs -> text "./."
123138
BAnd -> text "&&"
139+
BAnds -> text ".&&."
124140
BOr -> text "||"
141+
BOrs -> text ".||."
125142
BXor -> text "xor"
143+
BXors -> text "xors"
126144

127145
opPrecedence :: BinOp f a -> Int
128146
opPrecedence BOr = 5
147+
opPrecedence BOrs = 5
129148
opPrecedence BXor = 5
149+
opPrecedence BXors = 5
130150
opPrecedence BAnd = 5
151+
opPrecedence BAnds = 5
131152
opPrecedence BSub = 6
153+
opPrecedence BSubs = 6
132154
opPrecedence BAdd = 6
155+
opPrecedence BAdds = 6
133156
opPrecedence BMul = 7
157+
opPrecedence BMuls = 7
134158
opPrecedence BDiv = 8
159+
opPrecedence BDivs = 8
135160

136161
type family NBits a :: Nat where
137162
NBits (Prime p) = (Log2 p) + 1
@@ -242,23 +267,35 @@ evalExpr lookupVar vars expr = case expr of
242267
case lookupVar i vars of
243268
Just v -> v == 1
244269
Nothing -> panic $ "TODO: incorrect bool var lookup: " <> Protolude.show i
245-
EUnOp _ UNeg e1 ->
246-
Protolude.negate $ evalExpr lookupVar vars e1
247-
EUnOp _ UNot e1 ->
248-
not $ evalExpr lookupVar vars e1
270+
EUnOp _ op e1 ->
271+
let e1' = evalExpr lookupVar vars e1
272+
in apply e1'
273+
where
274+
apply = case op of
275+
UNeg -> Protolude.negate
276+
UNegs -> map Protolude.negate
277+
UNot -> not
278+
UNots -> map not
249279
EBinOp _ op e1 e2 ->
250280
let e1' = evalExpr lookupVar vars e1
251281
e2' = evalExpr lookupVar vars e2
252282
in apply e1' e2'
253283
where
254284
apply = case op of
255285
BAdd -> (+)
286+
BAdds -> SV.zipWith (+)
256287
BSub -> (-)
288+
BSubs -> SV.zipWith (-)
257289
BMul -> (*)
290+
BMuls -> SV.zipWith (*)
258291
BDiv -> (/)
292+
BDivs -> SV.zipWith (/)
259293
BAnd -> (&&)
294+
BAnds -> SV.zipWith (&&)
260295
BOr -> (||)
296+
BOrs -> SV.zipWith (||)
261297
BXor -> \x y -> (x || y) && not (x && y)
298+
BXors -> SV.zipWith (\x y -> (x || y) && not (x && y))
262299
EIf _ b true false ->
263300
let cond = evalExpr lookupVar vars b
264301
in if cond
@@ -401,9 +438,12 @@ bundle_ b =
401438
in EBundle h b
402439
{-# INLINE bundle_ #-}
403440

404-
class BoolToField b f | b -> f where
441+
class BoolToField b f where
405442
boolToField :: b -> f
406443

444+
instance (GaloisField f) => BoolToField Bool f where
445+
boolToField b = fromInteger $ if b then 1 else 0
446+
407447
instance BoolToField (Val f 'TBool) (Val f 'TField) where
408448
boolToField (ValBool b) = ValField b
409449

@@ -418,7 +458,15 @@ instance BoolToField (Expr i f ('TVec n 'TBool)) (Expr i f ('TVec n 'TField)) wh
418458

419459
-------------------------------------------------------------------------------
420460

421-
data UBinOp = UBAdd | UBSub | UBMul | UBDiv | UBAnd | UBOr | UBXor deriving (Show, Eq, Generic)
461+
data UBinOp
462+
= UBAdd
463+
| UBSub
464+
| UBMul
465+
| UBDiv
466+
| UBAnd
467+
| UBOr
468+
| UBXor
469+
deriving (Show, Eq, Generic)
422470

423471
instance Hashable UBinOp
424472

@@ -480,17 +528,26 @@ instance (Hashable i, Hashable f) => Hashable (Node i f) where
480528

481529
untypeUnOp :: UnOp f a -> UUnOp
482530
untypeUnOp UNeg = UUNeg
531+
untypeUnOp UNegs = UUNeg
483532
untypeUnOp UNot = UUNot
533+
untypeUnOp UNots = UUNot
484534
{-# INLINE untypeUnOp #-}
485535

486536
untypeBinOp :: BinOp f a -> UBinOp
487537
untypeBinOp BAdd = UBAdd
538+
untypeBinOp BAdds = UBAdd
488539
untypeBinOp BSub = UBSub
540+
untypeBinOp BSubs = UBSub
489541
untypeBinOp BMul = UBMul
542+
untypeBinOp BMuls = UBMul
490543
untypeBinOp BDiv = UBDiv
544+
untypeBinOp BDivs = UBDiv
491545
untypeBinOp BAnd = UBAnd
546+
untypeBinOp BAnds = UBAnd
492547
untypeBinOp BOr = UBOr
548+
untypeBinOp BOrs = UBOr
493549
untypeBinOp BXor = UBXor
550+
untypeBinOp BXors = UBXor
494551
{-# INLINE untypeBinOp #-}
495552

496553
--------------------------------------------------------------------------------

0 commit comments

Comments
 (0)