Skip to content

Commit 71ef9cb

Browse files
committed
vector operations for field elems
1 parent 7785d46 commit 71ef9cb

File tree

4 files changed

+181
-55
lines changed

4 files changed

+181
-55
lines changed

Diff for: language/src/Circuit/Language/DSL.hs

+42-13
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,22 @@ module Circuit.Language.DSL
77
Bundle (..),
88
cField,
99
cBool,
10-
add,
11-
sub,
12-
mul,
10+
add_,
11+
adds_,
12+
sub_,
13+
subs_,
14+
mul_,
15+
muls_,
16+
div_,
17+
divs_,
1318
and_,
1419
ands_,
1520
or_,
1621
ors_,
1722
xor_,
1823
xors_,
24+
neg_,
25+
negs_,
1926
not_,
2027
nots_,
2128
eq_,
@@ -71,10 +78,24 @@ cBool b = val_ . ValBool $ if b then 1 else 0
7178
{-# INLINE cBool #-}
7279

7380
-- | Binary arithmetic operations on expressions
74-
add, sub, mul :: (Hashable f, Num f) => Signal f 'TField -> Signal f 'TField -> Signal f 'TField
75-
add = (+)
76-
sub = (-)
77-
mul = (*)
81+
add_, sub_, mul_, div_ :: (Hashable f) => Signal f 'TField -> Signal f 'TField -> Signal f 'TField
82+
add_ = binOp_ BAdd
83+
sub_ = binOp_ BSub
84+
mul_ = binOp_ BMul
85+
div_ = binOp_ BDiv
86+
87+
adds_,
88+
subs_,
89+
muls_,
90+
divs_ ::
91+
(Hashable f) =>
92+
Signal f ('TVec n 'TField) ->
93+
Signal f ('TVec n 'TField) ->
94+
Signal f ('TVec n 'TField)
95+
adds_ = binOp_ BAdds
96+
subs_ = binOp_ BSubs
97+
muls_ = binOp_ BMuls
98+
divs_ = binOp_ BDivs
7899

79100
-- | Binary logic operations on expressions
80101
-- Have to use underscore or similar to avoid shadowing @and@ and @or@
@@ -84,11 +105,13 @@ and_ = binOp_ BAnd
84105
or_ = binOp_ BOr
85106
xor_ = binOp_ BXor
86107

87-
ands_, ors_, xors_ ::
88-
(Hashable f) =>
89-
Signal f ('TVec n 'TBool) ->
90-
Signal f ('TVec n 'TBool) ->
91-
Signal f ('TVec n 'TBool)
108+
ands_,
109+
ors_,
110+
xors_ ::
111+
(Hashable f) =>
112+
Signal f ('TVec n 'TBool) ->
113+
Signal f ('TVec n 'TBool) ->
114+
Signal f ('TVec n 'TBool)
92115
ands_ = binOp_ BAnds
93116
ors_ = binOp_ BOrs
94117
xors_ = binOp_ BXors
@@ -100,6 +123,12 @@ not_ = unOp_ UNot
100123
nots_ :: (Hashable f) => Signal f ('TVec n 'TBool) -> Signal f ('TVec n 'TBool)
101124
nots_ = unOp_ UNots
102125

126+
neg_ :: (Hashable f) => Signal f 'TField -> Signal f 'TField
127+
neg_ = unOp_ UNeg
128+
129+
negs_ :: (Hashable f) => Signal f ('TVec n 'TField) -> Signal f ('TVec n 'TField)
130+
negs_ = unOp_ UNegs
131+
103132
fieldInput :: InputType -> Text -> ExprM f (Var Wire f 'TField)
104133
fieldInput it label =
105134
case it of
@@ -188,7 +217,7 @@ instance (Eq f, Num f, Hashable f) => Monoid (Any_ f) where
188217
newtype Add_ f = Add_ {unAdd_ :: Signal f 'TField}
189218

190219
instance (Hashable f, Num f) => Semigroup (Add_ f) where
191-
Add_ a <> Add_ b = Add_ $ add a b
220+
Add_ a <> Add_ b = Add_ $ add_ a b
192221

193222
instance (Hashable f, Num f) => Monoid (Add_ f) where
194223
mempty = Add_ $ cField 0

Diff for: language/src/Circuit/Language/Expr.hs

+43-16
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Circuit.Language.Expr
3535
)
3636
where
3737

38-
import Data.Field.Galois (Prime, PrimeField (fromP), GaloisField)
38+
import Data.Field.Galois (GaloisField, Prime, PrimeField (fromP))
3939
import Data.Semiring (Ring (..), Semiring (..))
4040
import Data.Sequence ((|>))
4141
import Data.Set qualified as Set
@@ -90,6 +90,7 @@ rawWire (VarBool i) = i
9090

9191
data UnOp f (ty :: Ty) where
9292
UNeg :: UnOp f 'TField
93+
UNegs :: UnOp f ('TVec n 'TField)
9394
UNot :: UnOp f 'TBool
9495
UNots :: UnOp f ('TVec n 'TBool)
9596

@@ -100,14 +101,19 @@ deriving instance Eq (UnOp f a)
100101
instance Pretty (UnOp f a) where
101102
pretty op = case op of
102103
UNeg -> text "neg"
104+
UNegs -> text "negs"
103105
UNot -> text "!"
104106
UNots -> text "nots"
105107

106108
data BinOp f (a :: Ty) where
107109
BAdd :: BinOp f 'TField
110+
BAdds :: BinOp f (TVec n 'TField)
108111
BSub :: BinOp f 'TField
112+
BSubs :: BinOp f (TVec n 'TField)
109113
BMul :: BinOp f 'TField
114+
BMuls :: BinOp f (TVec n 'TField)
110115
BDiv :: BinOp f 'TField
116+
BDivs :: BinOp f (TVec n 'TField)
111117
BAnd :: BinOp f 'TBool
112118
BAnds :: BinOp f (TVec n 'TBool)
113119
BOr :: BinOp f 'TBool
@@ -122,9 +128,13 @@ deriving instance Eq (BinOp f a)
122128
instance Pretty (BinOp f a) where
123129
pretty op = case op of
124130
BAdd -> text "+"
131+
BAdds -> text ".+."
125132
BSub -> text "-"
133+
BSubs -> text ".-."
126134
BMul -> text "*"
135+
BMuls -> text ".*."
127136
BDiv -> text "/"
137+
BDivs -> text "./."
128138
BAnd -> text "&&"
129139
BAnds -> text ".&&."
130140
BOr -> text "||"
@@ -140,9 +150,13 @@ opPrecedence BXors = 5
140150
opPrecedence BAnd = 5
141151
opPrecedence BAnds = 5
142152
opPrecedence BSub = 6
153+
opPrecedence BSubs = 6
143154
opPrecedence BAdd = 6
155+
opPrecedence BAdds = 6
144156
opPrecedence BMul = 7
157+
opPrecedence BMuls = 7
145158
opPrecedence BDiv = 8
159+
opPrecedence BDivs = 8
146160

147161
type family NBits a :: Nat where
148162
NBits (Prime p) = (Log2 p) + 1
@@ -253,22 +267,29 @@ evalExpr lookupVar vars expr = case expr of
253267
case lookupVar i vars of
254268
Just v -> v == 1
255269
Nothing -> panic $ "TODO: incorrect bool var lookup: " <> Protolude.show i
256-
EUnOp _ UNeg e1 ->
257-
Protolude.negate $ evalExpr lookupVar vars e1
258-
EUnOp _ UNot e1 ->
259-
not $ evalExpr lookupVar vars e1
260-
EUnOp _ UNots e1 ->
261-
SV.map not $ evalExpr lookupVar vars e1
270+
EUnOp _ op e1 ->
271+
let e1' = evalExpr lookupVar vars e1
272+
in apply e1'
273+
where
274+
apply = case op of
275+
UNeg -> Protolude.negate
276+
UNegs -> map Protolude.negate
277+
UNot -> not
278+
UNots -> map not
262279
EBinOp _ op e1 e2 ->
263280
let e1' = evalExpr lookupVar vars e1
264281
e2' = evalExpr lookupVar vars e2
265282
in apply e1' e2'
266283
where
267284
apply = case op of
268285
BAdd -> (+)
286+
BAdds -> SV.zipWith (+)
269287
BSub -> (-)
288+
BSubs -> SV.zipWith (-)
270289
BMul -> (*)
290+
BMuls -> SV.zipWith (*)
271291
BDiv -> (/)
292+
BDivs -> SV.zipWith (/)
272293
BAnd -> (&&)
273294
BAnds -> SV.zipWith (&&)
274295
BOr -> (||)
@@ -420,7 +441,7 @@ bundle_ b =
420441
class BoolToField b f where
421442
boolToField :: b -> f
422443

423-
instance GaloisField f => BoolToField Bool f where
444+
instance (GaloisField f) => BoolToField Bool f where
424445
boolToField b = fromInteger $ if b then 1 else 0
425446

426447
instance BoolToField (Val f 'TBool) (Val f 'TField) where
@@ -437,14 +458,15 @@ instance BoolToField (Expr i f ('TVec n 'TBool)) (Expr i f ('TVec n 'TField)) wh
437458

438459
-------------------------------------------------------------------------------
439460

440-
data UBinOp =
441-
UBAdd |
442-
UBSub |
443-
UBMul |
444-
UBDiv |
445-
UBAnd |
446-
UBOr |
447-
UBXor deriving (Show, Eq, Generic)
461+
data UBinOp
462+
= UBAdd
463+
| UBSub
464+
| UBMul
465+
| UBDiv
466+
| UBAnd
467+
| UBOr
468+
| UBXor
469+
deriving (Show, Eq, Generic)
448470

449471
instance Hashable UBinOp
450472

@@ -506,15 +528,20 @@ instance (Hashable i, Hashable f) => Hashable (Node i f) where
506528

507529
untypeUnOp :: UnOp f a -> UUnOp
508530
untypeUnOp UNeg = UUNeg
531+
untypeUnOp UNegs = UUNeg
509532
untypeUnOp UNot = UUNot
510533
untypeUnOp UNots = UUNot
511534
{-# INLINE untypeUnOp #-}
512535

513536
untypeBinOp :: BinOp f a -> UBinOp
514537
untypeBinOp BAdd = UBAdd
538+
untypeBinOp BAdds = UBAdd
515539
untypeBinOp BSub = UBSub
540+
untypeBinOp BSubs = UBSub
516541
untypeBinOp BMul = UBMul
542+
untypeBinOp BMuls = UBMul
517543
untypeBinOp BDiv = UBDiv
544+
untypeBinOp BDivs = UBDiv
518545
untypeBinOp BAnd = UBAnd
519546
untypeBinOp BAnds = UBAnd
520547
untypeBinOp BOr = UBOr

0 commit comments

Comments
 (0)