@@ -6,14 +6,14 @@ import Circuit.Arithmetic (CircuitVars (..), VarType (..), InputBindings (labelT
6
6
import Circuit.Dataflow qualified as DataFlow
7
7
import Circuit.Dot (arithCircuitToDot )
8
8
import Circuit.Language.Compile (BuilderState (.. ), ExprM , runCircuitBuilder )
9
- import Data.Aeson (decodeFileStrict )
10
9
import Data.Aeson qualified as A
10
+ import Data.Aeson.Types qualified as A
11
11
import Data.Binary (decodeFile , encodeFile )
12
12
import Data.Field.Galois (Prime , PrimeField (fromP ))
13
13
import Data.IntSet qualified as IntSet
14
14
import Data.Text qualified as Text
15
15
import GHC.TypeNats (SNat , withKnownNat , withSomeSNat )
16
- import Options.Applicative (CommandFields , Mod , Parser , ParserInfo , command , execParser , fullDesc , header , help , helper , hsubparser , info , long , progDesc , showDefault , strOption , switch , value )
16
+ import Options.Applicative (CommandFields , Mod , Parser , ParserInfo , command , execParser , fullDesc , header , help , helper , hsubparser , info , long , progDesc , showDefault , strOption , switch , value , option , eitherReader , showDefaultWith )
17
17
import Protolude
18
18
import R1CS (R1CS , Witness (Witness ), isValidWitness , toR1CS )
19
19
import Data.Text.Read (decimal , hexadecimal )
@@ -22,9 +22,12 @@ import Data.Map qualified as Map
22
22
import Protolude.Unsafe (unsafeHead )
23
23
import Data.Maybe (fromJust )
24
24
import qualified Data.IntMap as IntMap
25
+ import Numeric (showHex )
26
+ import Control.Error (hoistEither )
25
27
26
28
data GlobalOpts = GlobalOpts
27
29
{ cmd :: Command
30
+ , encoding :: Encoding
28
31
}
29
32
30
33
optsParser :: Text -> ParserInfo GlobalOpts
@@ -40,6 +43,25 @@ optsParser progName =
40
43
globalOptsParser =
41
44
GlobalOpts
42
45
<$> hsubparser (compileCommand <> solveCommand <> verifyCommand)
46
+ <*> encodingParser
47
+
48
+ encodingParser :: Parser Encoding
49
+ encodingParser =
50
+ let readEncoding = eitherReader $ \ case
51
+ " hex" -> pure HexString
52
+ " decimal-string" -> pure DecString
53
+ " decimal" -> pure Dec
54
+ _ -> throwError $ " Invalid encoding, expected one of: hex, decimal-string, decimal"
55
+ in option readEncoding
56
+ ( long " encoding"
57
+ <> help " encoding for inputs and outputs"
58
+ <> showDefaultWith (\ case
59
+ HexString -> " hex"
60
+ DecString -> " decimal-string"
61
+ Dec -> " decimal"
62
+ )
63
+ <> value Dec
64
+ )
43
65
44
66
compileCommand :: Mod CommandFields Command
45
67
compileCommand =
@@ -186,9 +208,9 @@ defaultMain progName program = do
186
208
let binFilePath = coCircuitBinFile compilerOpts
187
209
encodeFile binFilePath prog
188
210
when (coGenInputsTemplate compilerOpts) $ do
189
- let inputsTemplate = mkInputsTemplate $ cpVars prog
211
+ let inputsTemplate = mkInputsTemplate (encoding opts) ( cpVars prog)
190
212
inputsTemplateFilePath = Text. unpack progName <> " -inputs-template.json"
191
- A. encodeFile inputsTemplateFilePath inputsTemplate
213
+ writeIOVars inputsTemplateFilePath inputsTemplate
192
214
when (coIncludeJson compilerOpts) $ do
193
215
A. encodeFile (r1csFilePath <> " .json" ) (map fromP r1cs)
194
216
A. encodeFile (binFilePath <> " .json" ) (map fromP prog)
@@ -197,8 +219,8 @@ defaultMain progName program = do
197
219
writeFile dotFilePath $ arithCircuitToDot (cpCircuit prog)
198
220
Solve solveOpts -> do
199
221
inputs <- do
200
- mInputs <- decodeFileStrict (soInputsFile solveOpts)
201
- maybe (panic " Failed to decode inputs " ) ( pure . map (map (fromInteger @ f . unFieldElem))) mInputs
222
+ IOVars _ is <- readIOVars (encoding opts) (soInputsFile solveOpts)
223
+ pure $ map (map (fromInteger @ f . unFieldElem)) is
202
224
let binFilePath = soCircuitBinFile solveOpts
203
225
circuit <- decodeFile binFilePath
204
226
let wtns = nativeGenWitness circuit inputs
@@ -207,8 +229,8 @@ defaultMain progName program = do
207
229
when (soIncludeJson solveOpts) $ do
208
230
A. encodeFile (wtnsFilePath <> " .json" ) (map fromP wtns)
209
231
when (soShowOutputs solveOpts) $ do
210
- let outputs = mkOutputs (cpVars circuit) (witnessFromCircomWitness wtns)
211
- print $ A. encode ( map ( map fromP) outputs)
232
+ let outputs = mkOutputs (encoding opts) ( cpVars circuit) (witnessFromCircomWitness wtns)
233
+ print $ A. encode $ encodeIOVars outputs
212
234
Verify verifyOpts -> do
213
235
let r1csFilePath = voR1CSFile verifyOpts
214
236
cr1cs <- decodeR1CSHeaderFromFile r1csFilePath
@@ -246,26 +268,62 @@ optimize opts =
246
268
else mempty
247
269
248
270
--------------------------------------------------------------------------------
271
+ -- Programs expecting to interact with Circom via the file system and solver API can
272
+ -- be incredibly stupid w.r.t. to accepting / demanding inputs be encoded as strings (either hex or dec)
273
+ -- or as numbers.
274
+
275
+ data Encoding = HexString | DecString | Dec deriving (Eq , Show )
249
276
250
277
newtype FieldElem = FieldElem { unFieldElem :: Integer } deriving newtype (Eq , Ord , Enum , Num , Real , Integral )
251
278
252
- instance A. FromJSON FieldElem where
253
- parseJSON v = case v of
254
- A. String s ->
255
- case hexadecimal s <> decimal s of
256
- Left e -> fail e
257
- Right (a, rest) ->
258
- if Text. null rest
259
- then pure a
260
- else fail $ " FieldElem parser failed to consume all input: " <> Text. unpack rest
261
- _ -> FieldElem <$> A. parseJSON v
262
- instance A. ToJSON FieldElem where
263
- toJSON (FieldElem a) = A. toJSON a
264
-
265
- newtype Inputs = Inputs (Map Text (VarType FieldElem )) deriving newtype (A.FromJSON , A.ToJSON )
266
-
267
- mkInputsTemplate :: CircuitVars Text -> Inputs
268
- mkInputsTemplate vars =
279
+ encodeFieldElem :: Encoding -> FieldElem -> A. Value
280
+ encodeFieldElem enc (FieldElem a) = case enc of
281
+ HexString -> A. toJSON $ " 0x" <> (Text. pack $ showHex a " " )
282
+ DecString -> A. toJSON $ Text. pack $ show a
283
+ Dec -> A. toJSON a
284
+
285
+ decodeFieldElem :: Encoding -> A. Value -> A. Parser FieldElem
286
+ decodeFieldElem enc _v = case enc of
287
+ Dec -> FieldElem <$> A. parseJSON _v
288
+ DecString -> do
289
+ s <- A. parseJSON _v
290
+ FieldElem <$> parseDec s
291
+ where
292
+ parseDec str = case decimal str of
293
+ Right (n, " " ) -> pure n
294
+ _ -> fail " FieldElem: expected a decimal string"
295
+ HexString -> do
296
+ s <- A. parseJSON _v
297
+ FieldElem <$> parseHex s
298
+ where
299
+ parseHex str = case hexadecimal str of
300
+ Right (n, " " ) -> pure n
301
+ _ -> fail " FieldElem: expected a hexadecimal string"
302
+
303
+ encodeVarType :: Encoding -> VarType FieldElem -> A. Value
304
+ encodeVarType enc = \ case
305
+ Simple a -> encodeFieldElem enc a
306
+ Array as -> A. toJSON $ map (encodeFieldElem enc) as
307
+
308
+ decodeVarType :: Encoding -> A. Value -> A. Parser (VarType FieldElem )
309
+ decodeVarType enc v = do
310
+ vs <- A. parseJSON v
311
+ case vs of
312
+ A. Array as -> Array <$> traverse (decodeFieldElem enc) (toList as)
313
+ _ -> Simple <$> decodeFieldElem enc v
314
+
315
+ data IOVars = IOVars Encoding (Map Text (VarType FieldElem ))
316
+
317
+ encodeIOVars :: IOVars -> A. Value
318
+ encodeIOVars (IOVars enc vs) = A. toJSON $ map (encodeVarType enc) vs
319
+
320
+ decodeIOVars :: Encoding -> A. Value -> A. Parser IOVars
321
+ decodeIOVars enc v = do
322
+ kvs <- A. parseJSON v
323
+ IOVars enc <$> traverse (decodeVarType enc) kvs
324
+
325
+ mkInputsTemplate :: Encoding -> CircuitVars Text -> IOVars
326
+ mkInputsTemplate enc vars =
269
327
let inputsOnly = cvInputsLabels $ restrictVars vars (cvPrivateInputs vars `IntSet.union` cvPublicInputs vars)
270
328
vs =
271
329
map (\ a -> (fst $ unsafeHead a, length a)) $
@@ -276,10 +334,10 @@ mkInputsTemplate vars =
276
334
if len > 1
277
335
then (label, Array (replicate len 0 ))
278
336
else (label, Simple 0 )
279
- in Inputs $ Map. fromList $ map f vs
337
+ in IOVars enc $ Map. fromList $ map f vs
280
338
281
- mkOutputs :: CircuitVars Text -> Witness f -> Map Text ( VarType f )
282
- mkOutputs vars (Witness w) =
339
+ mkOutputs :: PrimeField f => Encoding -> CircuitVars Text -> Witness f -> IOVars
340
+ mkOutputs enc vars (Witness w) =
283
341
let vs :: [[((Text ,Int ), Int )]]
284
342
vs = groupBy (\ a b -> fst (fst a) == fst (fst b)) $
285
343
Map. toList $
@@ -289,12 +347,20 @@ mkOutputs vars (Witness w) =
289
347
f = \ case
290
348
[((label, _), v)] ->
291
349
let val = fromJust $ IntMap. lookup v w
292
- in (label, Simple val)
350
+ in (label, Simple . FieldElem . fromP $ val)
293
351
as@ ( ((l, _), _) : _ ) ->
294
352
( l
295
353
, Array $ fromJust $ for as $ \ (_, i) ->
296
- IntMap. lookup i w
354
+ FieldElem . fromP <$> IntMap. lookup i w
297
355
298
356
)
299
357
_ -> panic " impossible: groupBy lists are non empty"
300
- in Map. fromList $ map f vs
358
+ in IOVars enc (Map. fromList $ map f vs)
359
+
360
+ writeIOVars :: FilePath -> IOVars -> IO ()
361
+ writeIOVars fp (IOVars enc vs) = A. encodeFile fp (encodeIOVars (IOVars enc vs))
362
+
363
+ readIOVars :: Encoding -> FilePath -> IO IOVars
364
+ readIOVars enc fp = map (either (panic . Text. pack) identity) $ runExceptT $ do
365
+ contents <- ExceptT $ A. eitherDecodeFileStrict fp
366
+ hoistEither $ A. parseEither (decodeIOVars enc) contents
0 commit comments