Skip to content

Commit b5f560a

Browse files
committed
newtype hash
1 parent c373fb8 commit b5f560a

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

Diff for: language/src/Circuit/Language/Compile.hs

+8-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import Circuit.Language.Expr
2323
UnOp (..),
2424
getAnnotation,
2525
hashCons,
26-
unType,
26+
unType, Hash (Hash),
2727
)
2828
import Circuit.Language.TExpr qualified as TExpr
2929
import Data.Field.Galois (GaloisField)
@@ -41,7 +41,7 @@ data BuilderState f = BuilderState
4141
{ bsCircuit :: ArithCircuit f,
4242
bsNextVar :: Int,
4343
bsVars :: CircuitVars Text,
44-
bsSharedMap :: Map Int (V.Vector (SignalSource f))
44+
bsSharedMap :: Map Hash (V.Vector (SignalSource f))
4545
}
4646

4747
defaultBuilderState :: BuilderState f
@@ -206,7 +206,7 @@ compileWithWires ::
206206
m (V.Vector Wire)
207207
compileWithWires ws expr = do
208208
let e = hashCons $ unType expr
209-
compileOut <- memoizedCompile $ e
209+
compileOut <- memoizedCompile e
210210
for (V.zip compileOut ws) $ \(o, freshWire) -> do
211211
case o of
212212
WireSource wire -> do
@@ -238,7 +238,7 @@ assertSameSourceSize l r =
238238

239239
withCompilerCache ::
240240
(MonadState (BuilderState f) m) =>
241-
Int ->
241+
Hash ->
242242
m (V.Vector (SignalSource f)) ->
243243
m (V.Vector (SignalSource f))
244244
withCompilerCache i m = do
@@ -251,7 +251,7 @@ _compile ::
251251
(Hashable f, GaloisField f) =>
252252
(MonadState (BuilderState f) m) =>
253253
(MonadError (CircuitBuilderError f) m) =>
254-
Expr Int Wire f ->
254+
Expr Hash Wire f ->
255255
m (V.Vector (SignalSource f))
256256
_compile expr = withCompilerCache (getAnnotation expr) $ case expr of
257257
EVal _ f -> pure $ V.singleton $ AffineSource $ ConstGate f
@@ -310,7 +310,7 @@ _compile expr = withCompilerCache (getAnnotation expr) $ case expr of
310310
-- EQ(lhs, rhs) = (lhs - rhs == 1) only allowed for field comparison
311311
EEq _ lhs rhs -> do
312312
-- assertSingle is justified as the lhs and rhs must be of type f
313-
let e = EBinOp (hash (BSub, getAnnotation lhs, getAnnotation rhs)) BSub lhs rhs
313+
let e = EBinOp (Hash $ hash (BSub, getAnnotation lhs, getAnnotation rhs)) BSub lhs rhs
314314
eqInWire <- do
315315
eOut <- withCompilerCache (getAnnotation e) (memoizedCompile e)
316316
assertSingleSource eOut >>= addWire
@@ -330,7 +330,7 @@ _compile expr = withCompilerCache (getAnnotation expr) $ case expr of
330330
outputs
331331
( \o ->
332332
let v = UVar o
333-
e = EVar (hash v) v
333+
e = EVar (Hash $ hash v) v
334334
in memoizedCompile e
335335
)
336336
where
@@ -358,7 +358,7 @@ memoizedCompile ::
358358
(Hashable f, GaloisField f) =>
359359
(MonadState (BuilderState f) m) =>
360360
(MonadError (CircuitBuilderError f) m) =>
361-
Expr Int Wire f ->
361+
Expr Hash Wire f ->
362362
m (V.Vector (SignalSource f))
363363
memoizedCompile expr = do
364364
m <- gets bsSharedMap

Diff for: language/src/Circuit/Language/Expr.hs

+26-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
module Circuit.Language.Expr where
1+
module Circuit.Language.Expr
2+
( Expr(..),
3+
UVar(..),
4+
BinOp(..),
5+
UnOp(..),
6+
unType,
7+
Hash(..),
8+
hashCons,
9+
getAnnotation,
10+
) where
211

312
import Circuit.Language.TExpr qualified as TExpr
413
import Data.Vector qualified as V
@@ -77,52 +86,56 @@ getAnnotation = \case
7786
EUpdateIndex a _ _ _ -> a
7887
EBundle a _ -> a
7988

80-
hashCons :: (Hashable i, Hashable f) => Expr () i f -> Expr Int i f
89+
newtype Hash = Hash Int
90+
deriving (Show, Eq, Ord)
91+
deriving (Hashable) via Int
92+
93+
hashCons :: (Hashable i, Hashable f) => Expr () i f -> Expr Hash i f
8194
hashCons = \case
8295
EVal _ f ->
83-
let i = hash (hash @Text "EVal", f)
96+
let i = Hash $ hash (hash @Text "EVal", f)
8497
in EVal i f
8598
EVar _ v ->
86-
let i = hash (hash @Text "EVar", v)
99+
let i = Hash $ hash (hash @Text "EVar", v)
87100
in EVar i v
88101
EUnOp _ op e ->
89102
let e' = hashCons e
90-
i = hash (hash @Text "EUnOp", op, getAnnotation e')
103+
i = Hash $ hash (hash @Text "EUnOp", op, getAnnotation e')
91104
in EUnOp i op e'
92105
EBinOp _ op e1 e2 ->
93106
let e1' = hashCons e1
94107
e2' = hashCons e2
95-
i = hash (hash @Text "EBinOp", op, getAnnotation e1', getAnnotation e2')
108+
i = Hash $ hash (hash @Text "EBinOp", op, getAnnotation e1', getAnnotation e2')
96109
in EBinOp i op e1' e2'
97110
EIf _ b t e ->
98111
let b' = hashCons b
99112
t' = hashCons t
100113
e' = hashCons e
101-
i = hash (hash @Text "EIf", getAnnotation b', getAnnotation t', getAnnotation e')
114+
i = Hash $ hash (hash @Text "EIf", getAnnotation b', getAnnotation t', getAnnotation e')
102115
in EIf i b' t' e'
103116
EEq _ l r ->
104117
let l' = hashCons l
105118
r' = hashCons r
106-
i = hash (hash @Text "EEq", getAnnotation l', getAnnotation r')
119+
i = Hash $ hash (hash @Text "EEq", getAnnotation l', getAnnotation r')
107120
in EEq i l' r'
108121
ESplit _ n e ->
109122
let e' = hashCons e
110-
i = hash (hash @Text "ESplit", n, getAnnotation e')
123+
i = Hash $ hash (hash @Text "ESplit", n, getAnnotation e')
111124
in ESplit i n e'
112125
EJoin _ e ->
113126
let e' = hashCons e
114-
i = hash (hash @Text "EJoin", getAnnotation e')
127+
i = Hash $ hash (hash @Text "EJoin", getAnnotation e')
115128
in EJoin i e'
116129
EAtIndex _ v ix ->
117130
let v' = hashCons v
118-
i = hash (hash @Text "AtIndex", getAnnotation v', ix)
131+
i = Hash $ hash (hash @Text "AtIndex", getAnnotation v', ix)
119132
in EAtIndex i v' ix
120133
EUpdateIndex _ p b v ->
121134
let b' = hashCons b
122135
v' = hashCons v
123-
i = hash (hash @Text "UpdateIndex", p, getAnnotation b', getAnnotation v')
136+
i = Hash $ hash (hash @Text "UpdateIndex", p, getAnnotation b', getAnnotation v')
124137
in EUpdateIndex i p b' v'
125138
EBundle _ b ->
126139
let b' = V.map hashCons b
127-
i = hash (hash @Text "Bundle", toList $ fmap getAnnotation b')
140+
i = Hash $ hash (hash @Text "Bundle", toList $ fmap getAnnotation b')
128141
in EBundle i b'

0 commit comments

Comments
 (0)