@@ -35,7 +35,7 @@ module Circuit.Language.Expr
35
35
)
36
36
where
37
37
38
- import Data.Field.Galois (Prime , PrimeField (fromP ))
38
+ import Data.Field.Galois (GaloisField , Prime , PrimeField (fromP ))
39
39
import Data.Semiring (Ring (.. ), Semiring (.. ))
40
40
import Data.Sequence ((|>) )
41
41
import Data.Set qualified as Set
@@ -90,7 +90,9 @@ rawWire (VarBool i) = i
90
90
91
91
data UnOp f (ty :: Ty ) where
92
92
UNeg :: UnOp f 'TField
93
+ UNegs :: UnOp f ('TVec n 'TField)
93
94
UNot :: UnOp f 'TBool
95
+ UNots :: UnOp f ('TVec n 'TBool)
94
96
95
97
deriving instance (Show f ) => Show (UnOp f a )
96
98
@@ -99,16 +101,25 @@ deriving instance Eq (UnOp f a)
99
101
instance Pretty (UnOp f a ) where
100
102
pretty op = case op of
101
103
UNeg -> text " neg"
104
+ UNegs -> text " negs"
102
105
UNot -> text " !"
106
+ UNots -> text " nots"
103
107
104
108
data BinOp f (a :: Ty ) where
105
109
BAdd :: BinOp f 'TField
110
+ BAdds :: BinOp f (TVec n 'TField)
106
111
BSub :: BinOp f 'TField
112
+ BSubs :: BinOp f (TVec n 'TField)
107
113
BMul :: BinOp f 'TField
114
+ BMuls :: BinOp f (TVec n 'TField)
108
115
BDiv :: BinOp f 'TField
116
+ BDivs :: BinOp f (TVec n 'TField)
109
117
BAnd :: BinOp f 'TBool
118
+ BAnds :: BinOp f (TVec n 'TBool)
110
119
BOr :: BinOp f 'TBool
120
+ BOrs :: BinOp f (TVec n 'TBool)
111
121
BXor :: BinOp f 'TBool
122
+ BXors :: BinOp f (TVec n 'TBool)
112
123
113
124
deriving instance (Show f ) => Show (BinOp f a )
114
125
@@ -117,21 +128,35 @@ deriving instance Eq (BinOp f a)
117
128
instance Pretty (BinOp f a ) where
118
129
pretty op = case op of
119
130
BAdd -> text " +"
131
+ BAdds -> text " .+."
120
132
BSub -> text " -"
133
+ BSubs -> text " .-."
121
134
BMul -> text " *"
135
+ BMuls -> text " .*."
122
136
BDiv -> text " /"
137
+ BDivs -> text " ./."
123
138
BAnd -> text " &&"
139
+ BAnds -> text " .&&."
124
140
BOr -> text " ||"
141
+ BOrs -> text " .||."
125
142
BXor -> text " xor"
143
+ BXors -> text " xors"
126
144
127
145
opPrecedence :: BinOp f a -> Int
128
146
opPrecedence BOr = 5
147
+ opPrecedence BOrs = 5
129
148
opPrecedence BXor = 5
149
+ opPrecedence BXors = 5
130
150
opPrecedence BAnd = 5
151
+ opPrecedence BAnds = 5
131
152
opPrecedence BSub = 6
153
+ opPrecedence BSubs = 6
132
154
opPrecedence BAdd = 6
155
+ opPrecedence BAdds = 6
133
156
opPrecedence BMul = 7
157
+ opPrecedence BMuls = 7
134
158
opPrecedence BDiv = 8
159
+ opPrecedence BDivs = 8
135
160
136
161
type family NBits a :: Nat where
137
162
NBits (Prime p ) = (Log2 p ) + 1
@@ -242,23 +267,35 @@ evalExpr lookupVar vars expr = case expr of
242
267
case lookupVar i vars of
243
268
Just v -> v == 1
244
269
Nothing -> panic $ " TODO: incorrect bool var lookup: " <> Protolude. show i
245
- EUnOp _ UNeg e1 ->
246
- Protolude. negate $ evalExpr lookupVar vars e1
247
- EUnOp _ UNot e1 ->
248
- not $ evalExpr lookupVar vars e1
270
+ EUnOp _ op e1 ->
271
+ let e1' = evalExpr lookupVar vars e1
272
+ in apply e1'
273
+ where
274
+ apply = case op of
275
+ UNeg -> Protolude. negate
276
+ UNegs -> map Protolude. negate
277
+ UNot -> not
278
+ UNots -> map not
249
279
EBinOp _ op e1 e2 ->
250
280
let e1' = evalExpr lookupVar vars e1
251
281
e2' = evalExpr lookupVar vars e2
252
282
in apply e1' e2'
253
283
where
254
284
apply = case op of
255
285
BAdd -> (+)
286
+ BAdds -> SV. zipWith (+)
256
287
BSub -> (-)
288
+ BSubs -> SV. zipWith (-)
257
289
BMul -> (*)
290
+ BMuls -> SV. zipWith (*)
258
291
BDiv -> (/)
292
+ BDivs -> SV. zipWith (/)
259
293
BAnd -> (&&)
294
+ BAnds -> SV. zipWith (&&)
260
295
BOr -> (||)
296
+ BOrs -> SV. zipWith (||)
261
297
BXor -> \ x y -> (x || y) && not (x && y)
298
+ BXors -> SV. zipWith (\ x y -> (x || y) && not (x && y))
262
299
EIf _ b true false ->
263
300
let cond = evalExpr lookupVar vars b
264
301
in if cond
@@ -401,9 +438,12 @@ bundle_ b =
401
438
in EBundle h b
402
439
{-# INLINE bundle_ #-}
403
440
404
- class BoolToField b f | b -> f where
441
+ class BoolToField b f where
405
442
boolToField :: b -> f
406
443
444
+ instance (GaloisField f ) => BoolToField Bool f where
445
+ boolToField b = fromInteger $ if b then 1 else 0
446
+
407
447
instance BoolToField (Val f 'TBool) (Val f 'TField) where
408
448
boolToField (ValBool b) = ValField b
409
449
@@ -418,7 +458,15 @@ instance BoolToField (Expr i f ('TVec n 'TBool)) (Expr i f ('TVec n 'TField)) wh
418
458
419
459
-------------------------------------------------------------------------------
420
460
421
- data UBinOp = UBAdd | UBSub | UBMul | UBDiv | UBAnd | UBOr | UBXor deriving (Show , Eq , Generic )
461
+ data UBinOp
462
+ = UBAdd
463
+ | UBSub
464
+ | UBMul
465
+ | UBDiv
466
+ | UBAnd
467
+ | UBOr
468
+ | UBXor
469
+ deriving (Show , Eq , Generic )
422
470
423
471
instance Hashable UBinOp
424
472
@@ -480,17 +528,26 @@ instance (Hashable i, Hashable f) => Hashable (Node i f) where
480
528
481
529
untypeUnOp :: UnOp f a -> UUnOp
482
530
untypeUnOp UNeg = UUNeg
531
+ untypeUnOp UNegs = UUNeg
483
532
untypeUnOp UNot = UUNot
533
+ untypeUnOp UNots = UUNot
484
534
{-# INLINE untypeUnOp #-}
485
535
486
536
untypeBinOp :: BinOp f a -> UBinOp
487
537
untypeBinOp BAdd = UBAdd
538
+ untypeBinOp BAdds = UBAdd
488
539
untypeBinOp BSub = UBSub
540
+ untypeBinOp BSubs = UBSub
489
541
untypeBinOp BMul = UBMul
542
+ untypeBinOp BMuls = UBMul
490
543
untypeBinOp BDiv = UBDiv
544
+ untypeBinOp BDivs = UBDiv
491
545
untypeBinOp BAnd = UBAnd
546
+ untypeBinOp BAnds = UBAnd
492
547
untypeBinOp BOr = UBOr
548
+ untypeBinOp BOrs = UBOr
493
549
untypeBinOp BXor = UBXor
550
+ untypeBinOp BXors = UBXor
494
551
{-# INLINE untypeBinOp #-}
495
552
496
553
--------------------------------------------------------------------------------
0 commit comments