diff --git a/lib/Echidna.hs b/lib/Echidna.hs index 0795016bf..c4d624687 100644 --- a/lib/Echidna.hs +++ b/lib/Echidna.hs @@ -18,6 +18,8 @@ import EVM.Fetch qualified import EVM.Solidity (BuildOutput(..), Contracts(Contracts)) import EVM.Types hiding (Env) +import Bandit.EpsGreedy +import Bandit.Types import Echidna.ABI import Echidna.Onchain as Onchain import Echidna.Output.Corpus @@ -108,6 +110,7 @@ mkEnv cfg buildOutput tests world slitherInfo = do coverageRefInit <- newIORef mempty coverageRefRuntime <- newIORef mempty corpusRef <- newIORef mempty + banditRef <- newIORef (EpsGreedy 0 (FixedRate 0.1) undefined undefined) testRefs <- traverse newIORef tests (contractCache, slotCache) <- Onchain.loadRpcCache cfg fetchContractCache <- newIORef contractCache @@ -116,6 +119,6 @@ mkEnv cfg buildOutput tests world slitherInfo = do -- TODO put in real path let dapp = dappInfo "/" buildOutput pure $ Env { cfg, dapp, codehashMap, fetchContractCache, fetchSlotCache, contractNameCache - , chainId, eventQueue, coverageRefInit, coverageRefRuntime, corpusRef, testRefs, world + , chainId, eventQueue, coverageRefInit, coverageRefRuntime, corpusRef, banditRef, testRefs, world , slitherInfo } diff --git a/lib/Echidna/Campaign.hs b/lib/Echidna/Campaign.hs index 6abc97d35..f24c4a19e 100644 --- a/lib/Echidna/Campaign.hs +++ b/lib/Echidna/Campaign.hs @@ -7,10 +7,11 @@ import Control.Concurrent import Control.DeepSeq (force) import Control.Monad (replicateM, when, unless, void, forM_) import Control.Monad.Catch (MonadThrow(..)) -import Control.Monad.Random.Strict (MonadRandom, RandT, evalRandT) +import Control.Monad.Random.Strict (MonadRandom, RandT, evalRandT, runRand, runRandT) import Control.Monad.Reader (MonadReader, asks, liftIO, ask) +import Prelude hiding (init) import Control.Monad.State.Strict - (MonadState(..), StateT(..), gets, MonadIO, modify') + (MonadState(..), StateT(..), gets, MonadIO, modify', runState) import Control.Monad.ST (RealWorld) import Control.Monad.Trans (lift) import Data.Binary.Get (runGetOrFail) @@ -18,6 +19,7 @@ import Data.ByteString.Lazy qualified as LBS import Data.IORef (readIORef, atomicModifyIORef', writeIORef) import Data.Foldable (foldlM) import Data.List qualified as List +import Data.List.NonEmpty qualified as NE import Data.List.NonEmpty qualified as NEList import Data.Map qualified as Map import Data.Map (Map, (\\)) @@ -27,8 +29,12 @@ import Data.Set qualified as Set import Data.Text (Text, unpack) import Data.Time (LocalTime) import Data.Vector qualified as V -import System.Random (mkStdGen) +import System.Random (mkStdGen, getStdGen, setStdGen, StdGen) +import Bandit.Class +import Bandit.EpsGreedy +import Bandit.Types +import Bandit.UCB import EVM (cheatCode) import EVM.ABI (getAbi, AbiType(AbiAddressType, AbiTupleType), AbiValue(AbiAddress, AbiTuple), abiValueType) import EVM.Dapp (DappInfo(..)) @@ -207,7 +213,7 @@ runSymWorker callback vm dict workerId _ name = do txsToTxAndVmsSym _ [] = pure [(Nothing, vm, [])] txsToTxAndVmsSym False txs = do -- Separate the last tx, which should be the one increasing coverage - let (itxs, ltx) = (init txs, last txs) + let (itxs, ltx) = (List.init txs, last txs) ivm <- foldlM (\vm' tx -> snd <$> execTx vm' tx) vm itxs -- Split the sequence randomly and select any next transaction i <- if length txs == 1 then pure 0 else rElem $ NEList.fromList [1 .. length txs - 1] @@ -334,9 +340,11 @@ runFuzzWorker callback vm dict workerId initialCorpus testLimit = do let effectiveSeed = dict.defSeed + workerId effectiveGenDict = dict { defSeed = effectiveSeed } + (bandit, _, _) = init (mkStdGen effectiveSeed) (EpsGreedyHyper (FixedRate 0.1) (Arms (NE.fromList (zip [0..] (snd <$> initialCorpus))))) initialState = WorkerState { workerId , genDict = effectiveGenDict + , bandit = bandit , newCoverage = False , ncallseqs = 0 , ncalls = 0 @@ -444,7 +452,13 @@ randseq deployedContracts = do corpus <- liftIO $ readIORef env.corpusRef if null corpus then pure randTxs -- Use the generated random transactions - else mut seqLen corpus randTxs -- Apply the mutator + else do + banditState <- gets (.bandit) + g <- getStdGen + let ((selectedCorpus', newG), banditState') = runState (step g 0.0) banditState + setStdGen newG + modify' $ \ws -> ws { bandit = banditState' } + mut seqLen (Set.fromList [selectedCorpus']) randTxs -- Apply the mutator -- TODO callseq ideally shouldn't need to be MonadRandom @@ -486,6 +500,18 @@ callseq vm txSeq = do , corpusSize = newSize , transactions = fst <$> results } + ws <- get + g <- getStdGen + let ((_, newG), bandit') = runState (step g 1.0) ws.bandit + setStdGen newG + put $ ws { bandit = bandit' } + + unless newCoverage $ do + ws <- get + g <- getStdGen + let ((_, newG), bandit') = runState (step g 0.0) ws.bandit + setStdGen newG + put $ ws { bandit = bandit' } modify' $ \workerState -> diff --git a/lib/Echidna/Types/Campaign.hs b/lib/Echidna/Types/Campaign.hs index 132411169..1ed18a0f1 100644 --- a/lib/Echidna/Types/Campaign.hs +++ b/lib/Echidna/Types/Campaign.hs @@ -7,9 +7,12 @@ import GHC.Conc (numCapabilities) import EVM.Solvers (Solver(..)) +import Bandit.EpsGreedy +import Bandit.Types import Echidna.ABI (GenDict, emptyDict) import Echidna.Types import Echidna.Types.Coverage (CoverageFileType, CoverageMap) +import Echidna.Types.Tx (Tx) -- | Configuration for running an Echidna 'Campaign'. data CampaignConf = CampaignConf @@ -73,6 +76,8 @@ data WorkerState = WorkerState -- ^ Worker ID starting from 0 , genDict :: !GenDict -- ^ Generation dictionary + , bandit :: EpsGreedy (Int, [Tx]) FixedRate + -- ^ Multi-armed bandit state , newCoverage :: !Bool -- ^ Flag to indicate new coverage found , ncallseqs :: !Int @@ -90,6 +95,7 @@ initialWorkerState :: WorkerState initialWorkerState = WorkerState { workerId = 0 , genDict = emptyDict + , bandit = EpsGreedy 0 (FixedRate 0.1) undefined undefined , newCoverage = False , ncallseqs = 0 , ncalls = 0 diff --git a/lib/Echidna/Types/Config.hs b/lib/Echidna/Types/Config.hs index 637fd5959..f019e8047 100644 --- a/lib/Echidna/Types/Config.hs +++ b/lib/Echidna/Types/Config.hs @@ -11,6 +11,8 @@ import Data.Word (Word64) import EVM.Dapp (DappInfo) import EVM.Types (Addr, W256) +import Bandit.EpsGreedy +import Bandit.Types import Echidna.SourceAnalysis.Slither (SlitherInfo) import Echidna.SourceMapping (CodehashMap) import Echidna.Types.Campaign (CampaignConf) @@ -19,7 +21,7 @@ import Echidna.Types.Corpus (Corpus) import Echidna.Types.Coverage (CoverageMap) import Echidna.Types.Solidity (SolConf) import Echidna.Types.Test (TestConf, EchidnaTest) -import Echidna.Types.Tx (TxConf) +import Echidna.Types.Tx (TxConf, Tx) import Echidna.Types.Cache import Echidna.Types.World (World) @@ -76,6 +78,7 @@ data Env = Env , coverageRefInit :: IORef CoverageMap , coverageRefRuntime :: IORef CoverageMap , corpusRef :: IORef Corpus + , banditRef :: IORef (EpsGreedy (Int, [Tx]) FixedRate) , slitherInfo :: Maybe SlitherInfo , codehashMap :: CodehashMap diff --git a/package.yaml b/package.yaml index 99bf5e138..f660f8b57 100644 --- a/package.yaml +++ b/package.yaml @@ -16,6 +16,9 @@ ghc-options: - -fexpose-all-unfoldings - -Wunused-packages +packages: +- hbandit + dependencies: - aeson - base @@ -79,6 +82,7 @@ library: - word-wrap - xml-conduit - yaml + - hbandit executables: echidna: diff --git a/stack.yaml b/stack.yaml index 9c3e91909..f136a82e0 100644 --- a/stack.yaml +++ b/stack.yaml @@ -3,6 +3,7 @@ resolver: lts-23.24 packages: - '.' +- 'hbandit' extra-deps: - git: https://github.com/argotorg/hevm.git @@ -13,3 +14,4 @@ extra-deps: - spool-0.1@sha256:77780cbfc2c0be23ff2ea9e474062f3df97fcd9db946ee0b3508280a923b83e2,1461 - strip-ansi-escape-0.1.0.0@sha256:08f2ed93b16086a837ec46eab7ce8d27cf39d47783caaeb818878ea33c2ff75f,1628 - vty-windows-0.2.0.3@sha256:0c010b1086a725046a8bb08bb1e6bfdfdb3cfe1c72d6fa77c37306ef9ec774d8,2844 +- list-extras-0.4.1.6@sha256:2b8b7c2632f7a98ee94c74bcb836ba2df7c7089ac7ccb16b6e11aa9df4b20c21,2691