Skip to content

Commit 916d2fc

Browse files
authored
Refactor (VM, WorkerState) to WorkerState (#1070)
1 parent b43ecf7 commit 916d2fc

File tree

1 file changed

+69
-63
lines changed

1 file changed

+69
-63
lines changed

lib/Echidna/Campaign.hs

+69-63
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
1111
import Control.Monad.Random.Strict (MonadRandom, RandT, evalRandT)
1212
import Control.Monad.Reader (MonadReader, asks, liftIO, ask)
1313
import Control.Monad.State.Strict
14-
(MonadState(..), StateT(..), evalStateT, execStateT, gets, MonadIO, modify')
14+
(MonadState(..), StateT(..), evalStateT, gets, MonadIO, modify')
1515
import Control.Monad.Trans (lift)
1616
import Control.Monad.Trans.Random.Strict (liftCatch)
1717
import Data.Binary.Get (runGetOrFail)
@@ -203,62 +203,61 @@ callseq vm txSeq = do
203203
execFunc =
204204
if coverageEnabled
205205
then execTxOptC
206-
else \tx -> do
207-
(v, ca) <- get
208-
(r, vm') <- runStateT (execTx tx) v
209-
put (vm', ca)
210-
pure r
211-
-- Then, we get the current campaign state
212-
campaign <- get
206+
else \vm' tx -> runStateT (execTx tx) vm'
213207

214208
-- Run each call sequentially. This gives us the result of each call
215209
-- and the new state
216-
(res, (vm', campaign')) <- runStateT (evalSeq vm execFunc txSeq) (vm, campaign)
217-
218-
let
219-
-- compute the addresses not present in the old VM via set difference
220-
newAddrs = Map.keys $ vm'.env.contracts \\ vm.env.contracts
221-
-- and construct a set to union to the constants table
222-
diffs = Map.fromList [(AbiAddressType, Set.fromList $ AbiAddress <$> newAddrs)]
223-
-- Now we try to parse the return values as solidity constants, and add them to 'GenDict'
224-
results = returnValues (map (\(t, (vr, _)) -> (t, vr)) res) campaign'.genDict.rTypes
225-
-- union the return results with the new addresses
226-
additions = Map.unionWith Set.union diffs results
227-
-- append to the constants dictionary
228-
updatedDict = campaign'.genDict
229-
{ constants = Map.unionWith Set.union additions campaign'.genDict.constants
230-
, dictValues = Set.union (mkDictValues $ Set.unions $ Map.elems additions)
231-
campaign'.genDict.dictValues
232-
}
210+
(results, vm') <- evalSeq vm execFunc txSeq
233211

234212
-- If there is new coverage, add the transaction list to the corpus
235-
when campaign'.newCoverage $ do
213+
newCoverage <- gets (.newCoverage)
214+
when newCoverage $ do
215+
ncallseqs <- gets (.ncallseqs)
236216
-- Even if this takes a bit of time, this is okay as finding new coverage
237217
-- is expected to be infrequent in the long term
238218
newSize <- liftIO $ atomicModifyIORef' env.corpusRef $ \corp ->
239219
-- Corpus is a bit too lazy, force the evaluation to reduce the memory usage
240-
let !corp' = force $ addToCorpus (campaign'.ncallseqs + 1) res corp
220+
let !corp' = force $ addToCorpus (ncallseqs + 1) results corp
241221
in (corp', corpusSize corp')
242222

243223
cov <- liftIO . readIORef =<< asks (.coverageRef)
244224
points <- liftIO $ scoveragePoints cov
245225
pushEvent (NewCoverage points (length cov) newSize)
246226

247-
-- Update the campaign state
248-
put campaign'
249-
{ genDict = updatedDict
250-
-- Update the gas estimation
251-
, gasInfo =
252-
if conf.estimateGas
253-
then updateGasInfo res [] campaign'.gasInfo
254-
else campaign'.gasInfo
255-
-- Reset the new coverage flag
256-
, newCoverage = False
257-
-- Keep track of the number of calls to `callseq`
258-
, ncallseqs = campaign'.ncallseqs + 1
259-
}
227+
modify' $ \workerState ->
228+
229+
let
230+
-- compute the addresses not present in the old VM via set difference
231+
newAddrs = Map.keys $ vm'.env.contracts \\ vm.env.contracts
232+
-- and construct a set to union to the constants table
233+
diffs = Map.fromList [(AbiAddressType, Set.fromList $ AbiAddress <$> newAddrs)]
234+
-- Now we try to parse the return values as solidity constants, and add them to 'GenDict'
235+
resultMap = returnValues (map (\(t, (vr, _)) -> (t, vr)) results) workerState.genDict.rTypes
236+
-- union the return results with the new addresses
237+
additions = Map.unionWith Set.union diffs resultMap
238+
-- append to the constants dictionary
239+
updatedDict = workerState.genDict
240+
{ constants = Map.unionWith Set.union additions workerState.genDict.constants
241+
, dictValues = Set.union (mkDictValues $ Set.unions $ Map.elems additions)
242+
workerState.genDict.dictValues
243+
}
244+
245+
-- Update the worker state
246+
in workerState
247+
{ genDict = updatedDict
248+
-- Update the gas estimation
249+
, gasInfo =
250+
if conf.estimateGas
251+
then updateGasInfo results [] workerState.gasInfo
252+
else workerState.gasInfo
253+
-- Reset the new coverage flag
254+
, newCoverage = False
255+
-- Keep track of the number of calls to `callseq`
256+
, ncallseqs = workerState.ncallseqs + 1
257+
}
260258

261259
pure vm'
260+
262261
where
263262
-- Given a list of transactions and a return typing rule, checks whether we
264263
-- know the return type for each function called. If yes, tries to parse the
@@ -291,19 +290,19 @@ callseq vm txSeq = do
291290
-- | Execute a transaction, capturing the PC and codehash of each instruction
292291
-- executed, saving the transaction if it finds new coverage.
293292
execTxOptC
294-
:: (MonadIO m, MonadReader Env m, MonadState (VM, WorkerState) m, MonadThrow m)
295-
=> Tx
296-
-> m (VMResult, Gas)
297-
execTxOptC tx = do
298-
(vm, camp) <- get
293+
:: (MonadIO m, MonadReader Env m, MonadState WorkerState m, MonadThrow m)
294+
=> VM -> Tx
295+
-> m ((VMResult, Gas), VM)
296+
execTxOptC vm tx = do
299297
((res, grew), vm') <- runStateT (execTxWithCov tx) vm
300-
put (vm', camp)
301298
when grew $ do
302-
let dict' = case tx.call of
303-
SolCall c -> gaddCalls (Set.singleton c) camp.genDict
304-
_ -> camp.genDict
305-
modify' $ \(_vm, c) -> (_vm, c { newCoverage = True, genDict = dict' })
306-
pure res
299+
modify' $ \workerState ->
300+
let
301+
dict' = case tx.call of
302+
SolCall c -> gaddCalls (Set.singleton c) workerState.genDict
303+
_ -> workerState.genDict
304+
in workerState { newCoverage = True, genDict = dict' }
305+
pure (res, vm')
307306

308307
-- | Given current `gasInfo` and a sequence of executed transactions, updates
309308
-- information on highest gas usage for each call
@@ -328,19 +327,26 @@ updateGasInfo ((t, _):ts) tseq gi = updateGasInfo ts (t:tseq) gi
328327
-- of transactions, constantly checking if we've solved any tests or can shrink
329328
-- known solves.
330329
evalSeq
331-
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m, MonadState (VM, WorkerState) m)
332-
=> VM
333-
-> (Tx -> m a)
330+
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
331+
=> VM -- ^ Initial VM
332+
-> (VM -> Tx -> m (result, VM))
334333
-> [Tx]
335-
-> m [(Tx, a)]
336-
evalSeq vmForShrink e = go [] where
337-
go r xs = do
338-
(v', camp) <- get
339-
camp' <- execStateT (runUpdate (updateTest vmForShrink (v', reverse r))) camp
340-
put (v', camp' { ncalls = camp'.ncalls + 1 })
341-
case xs of
342-
[] -> pure []
343-
(y:ys) -> e y >>= \a -> ((y, a) :) <$> go (y:r) ys
334+
-> m ([(Tx, result)], VM)
335+
evalSeq vm0 execFunc = go vm0 [] where
336+
go vm executedSoFar toExecute = do
337+
-- NOTE: we do reverse here because we build up this list by prepending,
338+
-- see the last line of this function.
339+
runUpdate (updateTest vm0 (vm, reverse executedSoFar))
340+
modify' $ \workerState -> workerState { ncalls = workerState.ncalls + 1 }
341+
case toExecute of
342+
[] -> pure ([], vm)
343+
(tx:remainingTxs) -> do
344+
(result, vm') <- execFunc vm tx
345+
-- NOTE: we don't use the intermediate VMs, just the last one. If any of
346+
-- the intermediate VMs are needed, they can be put next to the result
347+
-- of each transaction - `m ([(Tx, result, VM)])`
348+
(remaining, _vm) <- go vm' (tx:executedSoFar) remainingTxs
349+
pure ((tx, result) : remaining, vm')
344350

345351
-- | Given a rule for updating a particular test's state, apply it to each test
346352
-- in a 'Campaign'.

0 commit comments

Comments
 (0)