Skip to content

Commit c373fb8

Browse files
committed
separate out language module
1 parent f03a382 commit c373fb8

File tree

14 files changed

+190
-147
lines changed

14 files changed

+190
-147
lines changed

Diff for: arithmetic-circuits.cabal

+58-17
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,14 @@ common deps
4444
aeson >=1.4
4545
, base >=4.10 && <5
4646
, containers >=0.6.0
47-
, data-reify
4847
, filepath >=1.4.2
4948
, finite-typelits >=0.1.0
5049
, galois-field >=2.0.0
51-
, hashable
52-
, intern
53-
, microlens
5450
, process
5551
, propagators
5652
, protolude >=0.2
5753
, semirings >=0.7
5854
, text >=1.2.3
59-
, vector
60-
, vector-sized
6155
, wl-pprint-text >=1.2.0
6256

6357
library
@@ -66,13 +60,9 @@ library
6660
Circuit
6761
Circuit.Affine
6862
Circuit.Arithmetic
69-
Circuit.Compile
7063
Circuit.Dataflow
7164
Circuit.Solver
7265
Circuit.Dot
73-
Circuit.Expr
74-
Circuit.Lang
75-
Circuit.TExpr
7666
R1CS
7767
Fresh
7868

@@ -83,6 +73,29 @@ library
8373

8474
default-language: GHC2021
8575

76+
library language
77+
import: deps, extensions, warnings
78+
visibility: public
79+
exposed-modules:
80+
Circuit.Language.Compile
81+
Circuit.Language.Expr
82+
Circuit.Language.TExpr
83+
Circuit.Language.DSL
84+
Circuit.Language
85+
86+
hs-source-dirs: language/src
87+
build-depends:
88+
arithmetic-circuits
89+
, hashable
90+
, microlens
91+
, vector
92+
, vector-sized
93+
94+
ghc-options: -freverse-errors -O2 -Wall
95+
96+
default-language: GHC2021
97+
98+
8699
library circom-compat
87100
import: deps, extensions, warnings
88101
visibility: public
@@ -103,12 +116,47 @@ library circom-compat
103116

104117
default-language: GHC2021
105118

119+
test-suite language-tests
120+
import: deps, extensions
121+
type: exitcode-stdio-1.0
122+
main-is: Main.hs
123+
124+
other-modules:
125+
Test.Circuit.Expr
126+
Test.Circuit.Sudoku
127+
Test.Circuit.Lang
128+
129+
hs-source-dirs: language/test
130+
131+
ghc-options: -freverse-errors -O2 -Wall -main-is Main
132+
build-depends:
133+
arithmetic-circuits
134+
, arithmetic-circuits:language
135+
, array
136+
, distributive
137+
, fin
138+
, quickcheck-instances >=0.3
139+
, QuickCheck
140+
, random
141+
, hspec
142+
, tasty >=1.2
143+
, tasty-hunit >=0.10
144+
, tasty-hspec
145+
, tasty-quickcheck >=0.10
146+
, vec
147+
148+
build-tool-depends: tasty-discover:tasty-discover >=4.2
149+
150+
default-language: GHC2021
151+
152+
106153
executable factors
107154
import: warnings, extensions, deps
108155
main-is: Main.hs
109156
build-depends:
110157
binary
111158
, arithmetic-circuits
159+
, arithmetic-circuits:language
112160
, arithmetic-circuits:circom-compat
113161

114162
hs-source-dirs: circom-compat/app
@@ -123,20 +171,14 @@ test-suite circuit-tests
123171
Paths_arithmetic_circuits
124172
Test.Circuit.Affine
125173
Test.Circuit.Arithmetic
126-
Test.Circuit.Expr
127174
Test.Circuit.R1CS
128-
Test.Circuit.Sudoku
129-
Test.Circuit.Lang
130175

131176
hs-source-dirs: test
132177

133178
ghc-options: -freverse-errors -O2 -Wall -main-is Main
134179
build-depends:
135180
arithmetic-circuits
136181
, array
137-
, distributive
138-
, fin
139-
, integer-logarithms
140182
, quickcheck-instances >=0.3
141183
, QuickCheck
142184
, random
@@ -145,7 +187,6 @@ test-suite circuit-tests
145187
, tasty-hunit >=0.10
146188
, tasty-hspec
147189
, tasty-quickcheck >=0.10
148-
, vec
149190

150191
build-tool-depends: tasty-discover:tasty-discover >=4.2
151192

Diff for: circom-compat/app/Main.hs

+4-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
module Main where
55

66
import Circuit
7+
import Circuit.Language
78
import Data.Binary (encodeFile)
8-
import Data.Field.Galois (GaloisField, Prime)
9+
import Data.Field.Galois (Prime)
910
import Data.Map qualified as Map
1011
import Data.Set qualified as Set
1112
import Protolude
@@ -14,7 +15,7 @@ import R1CS.Circom (r1csToCircomR1CS, witnessToCircomWitness)
1415

1516
main :: IO ()
1617
main = do
17-
let BuilderState {..} = snd $ runCircuitBuilder $ program @Fr
18+
let BuilderState {..} = snd $ runCircuitBuilder program
1819
publicInputs = Map.fromList $ zip (Set.toAscList $ cvPublicInputs bsVars) [6]
1920
privateInputs = Map.fromList $ zip (Set.toAscList $ cvPrivateInputs bsVars) [2, 3]
2021
inputs = publicInputs <> privateInputs
@@ -25,7 +26,7 @@ main = do
2526

2627
type Fr = Prime 21888242871839275222246405745257275088548364400416034343698204186575808495617
2728

28-
program :: (GaloisField f) => ExprM f Wire
29+
program :: ExprM Fr Wire
2930
program = do
3031
n <- deref <$> fieldInput Public "n"
3132
a <- deref <$> fieldInput Private "a"

Diff for: language/src/Circuit/Language.hs

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module Circuit.Language
2+
( module Circuit.Language.Compile,
3+
module Circuit.Language.DSL,
4+
module Circuit.Language.TExpr,
5+
)
6+
where
7+
8+
import Circuit.Language.Compile
9+
import Circuit.Language.DSL
10+
import Circuit.Language.TExpr

Diff for: src/Circuit/Compile.hs renamed to language/src/Circuit/Language/Compile.hs

+21-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module Circuit.Compile
1+
module Circuit.Language.Compile
22
( ExprM,
33
BuilderState (..),
44
execCircuitBuilder,
@@ -16,7 +16,7 @@ where
1616

1717
import Circuit.Affine
1818
import Circuit.Arithmetic
19-
import Circuit.Expr
19+
import Circuit.Language.Expr
2020
( BinOp (..),
2121
Expr (..),
2222
UVar (..),
@@ -25,9 +25,8 @@ import Circuit.Expr
2525
hashCons,
2626
unType,
2727
)
28-
import Circuit.TExpr qualified as TExpr
28+
import Circuit.Language.TExpr qualified as TExpr
2929
import Data.Field.Galois (GaloisField)
30-
import Data.Interned
3130
import Data.Map qualified as Map
3231
import Data.Set qualified as Set
3332
import Data.Vector qualified as V
@@ -42,7 +41,7 @@ data BuilderState f = BuilderState
4241
{ bsCircuit :: ArithCircuit f,
4342
bsNextVar :: Int,
4443
bsVars :: CircuitVars Text,
45-
bsSharedMap :: Map Id (V.Vector (SignalSource f))
44+
bsSharedMap :: Map Int (V.Vector (SignalSource f))
4645
}
4746

4847
defaultBuilderState :: BuilderState f
@@ -64,32 +63,31 @@ instance (GaloisField f) => Pretty (CircuitBuilderError f) where
6463
ExpectedSingleWire wires -> "Expected a single wire, but got:" <+> pretty (toList wires)
6564
MismatchedWireTypes l r -> "Mismatched wire types:" <+> pretty (toList l) <+> pretty (toList r)
6665

67-
type ExprM f a = ExceptT (CircuitBuilderError f) (StateT (BuilderState f) IO) a
66+
type ExprM f a = ExceptT (CircuitBuilderError f) (State (BuilderState f)) a
6867

69-
runExprM :: (GaloisField f) => ExprM f a -> BuilderState f -> IO (a, BuilderState f)
68+
runExprM :: (GaloisField f) => ExprM f a -> BuilderState f -> (a, BuilderState f)
7069
runExprM m s = do
71-
res <- runStateT (runExceptT m) s
70+
let res = runState (runExceptT m) s
7271
case res of
7372
(Left e, _) -> panic $ Protolude.show $ pretty e
74-
(Right a, s') -> pure $ (a, s')
73+
(Right a, s') -> (a, s')
7574

76-
execCircuitBuilder :: (GaloisField f) => ExprM f a -> IO (ArithCircuit f)
77-
execCircuitBuilder m = reverseCircuit . bsCircuit . snd <$> runExprM m defaultBuilderState
75+
execCircuitBuilder :: (GaloisField f) => ExprM f a -> (ArithCircuit f)
76+
execCircuitBuilder m = reverseCircuit . bsCircuit . snd $ runExprM m defaultBuilderState
7877
where
7978
reverseCircuit = \(ArithCircuit cs) -> ArithCircuit $ reverse cs
8079

81-
evalCircuitBuilder :: (GaloisField f) => ExprM f a -> IO a
82-
evalCircuitBuilder e = fst <$> runCircuitBuilder e
80+
evalCircuitBuilder :: (GaloisField f) => ExprM f a -> a
81+
evalCircuitBuilder e = fst $ runCircuitBuilder e
8382

84-
runCircuitBuilder :: (GaloisField f) => ExprM f a -> IO (a, BuilderState f)
83+
runCircuitBuilder :: (GaloisField f) => ExprM f a -> (a, BuilderState f)
8584
runCircuitBuilder m = do
86-
(a, s) <- runExprM m defaultBuilderState
87-
pure
88-
( a,
89-
s
90-
{ bsCircuit = reverseCircuit $ bsCircuit s
91-
}
92-
)
85+
let (a, s) = runExprM m defaultBuilderState
86+
in ( a,
87+
s
88+
{ bsCircuit = reverseCircuit $ bsCircuit s
89+
}
90+
)
9391
where
9492
reverseCircuit = \(ArithCircuit cs) -> ArithCircuit $ reverse cs
9593

@@ -190,7 +188,6 @@ addWire x = case x of
190188

191189
compileWithWire ::
192190
(Hashable f, GaloisField f) =>
193-
(MonadIO m) =>
194191
(MonadState (BuilderState f) m) =>
195192
(MonadError (CircuitBuilderError f) m) =>
196193
m (TExpr.Var Wire f ty) ->
@@ -202,7 +199,6 @@ compileWithWire freshWire e = do
202199

203200
compileWithWires ::
204201
(Hashable f, GaloisField f) =>
205-
(MonadIO m) =>
206202
(MonadState (BuilderState f) m) =>
207203
(MonadError (CircuitBuilderError f) m) =>
208204
V.Vector (m (TExpr.Var Wire f f)) ->
@@ -242,7 +238,7 @@ assertSameSourceSize l r =
242238

243239
withCompilerCache ::
244240
(MonadState (BuilderState f) m) =>
245-
Id ->
241+
Int ->
246242
m (V.Vector (SignalSource f)) ->
247243
m (V.Vector (SignalSource f))
248244
withCompilerCache i m = do
@@ -380,4 +376,4 @@ exprToArithCircuit ::
380376
exprToArithCircuit expr output = do
381377
let e = hashCons $ unType expr
382378
compileOut <- memoizedCompile e >>= assertSingleSource
383-
emit $ Mul (ConstGate 1) (addVar compileOut) output
379+
emit $ Mul (ConstGate 1) (addVar compileOut) output

Diff for: src/Circuit/Lang.hs renamed to language/src/Circuit/Language/DSL.hs

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{-# LANGUAGE TypeFamilies #-}
22

33
-- | Surface language
4-
module Circuit.Lang
4+
module Circuit.Language.DSL
55
( Signal,
66
Bundle,
77
cField,
@@ -39,16 +39,16 @@ module Circuit.Lang
3939
where
4040

4141
import Circuit.Arithmetic (InputType (Private, Public), Wire (..))
42-
import Circuit.TExpr
42+
import Circuit.Language.Compile
43+
import Circuit.Language.TExpr
4344
import Data.Field.Galois (GaloisField, PrimeField)
4445
import Data.Finite (Finite)
4546
import Data.Maybe (fromJust)
47+
import Data.Vector qualified as V
4648
import Data.Vector.Sized (Vector)
4749
import Data.Vector.Sized qualified as SV
48-
import Data.Vector qualified as V
4950
import Protolude
5051
import Unsafe.Coerce (unsafeCoerce)
51-
import Circuit.Compile
5252

5353
--------------------------------------------------------------------------------
5454
type Signal f = Expr Wire f
@@ -128,14 +128,14 @@ bundle = EBundle
128128
boolToField :: Signal f Bool -> Signal f f
129129
boolToField = unsafeCoerce
130130

131-
132-
unBundle :: forall n f ty.
133-
(KnownNat n, GaloisField f, Hashable f) =>
134-
Expr Wire f (Vector n ty) ->
131+
unBundle ::
132+
forall n f ty.
133+
(KnownNat n, GaloisField f, Hashable f) =>
134+
Expr Wire f (Vector n ty) ->
135135
ExprM f (Vector n (Expr Wire f f))
136136
unBundle b = do
137137
let freshWires = V.replicate (fromIntegral $ natVal $ Proxy @n) (VarField <$> imm)
138-
bis <- compileWithWires freshWires b
138+
bis <- compileWithWires freshWires b
139139
pure $ fromJust $ SV.toSized (EVar . VarField <$> bis)
140140

141141
--------------------------------------------------------------------------------

0 commit comments

Comments
 (0)