Skip to content

Improved error handling with pkg/errors #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
21 changes: 11 additions & 10 deletions chain/cheat_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/pkg/errors"
"github.com/trailofbits/medusa/utils"
"math/big"
"os/exec"
Expand Down Expand Up @@ -42,43 +43,43 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract,
// Define some basic ABI argument types
typeAddress, err := abi.NewType("address", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeBytes, err := abi.NewType("bytes", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeBytes32, err := abi.NewType("bytes32", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeUint8, err := abi.NewType("uint8", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeUint64, err := abi.NewType("uint64", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeUint256, err := abi.NewType("uint256", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeInt256, err := abi.NewType("int256", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeStringSlice, err := abi.NewType("string[]", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeString, err := abi.NewType("string", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}
typeBool, err := abi.NewType("bool", "", nil)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}

// Warp: Sets VM timestamp
Expand Down
33 changes: 16 additions & 17 deletions chain/test_chain.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package chain

import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/pkg/errors"
"github.com/trailofbits/medusa/chain/config"
"golang.org/x/exp/maps"
"math/big"
Expand Down Expand Up @@ -207,7 +206,7 @@ func (t *TestChain) Clone(onCreateFunc func(chain *TestChain) error) (*TestChain
if onCreateFunc != nil {
err = onCreateFunc(targetChain)
if err != nil {
return nil, fmt.Errorf("could not clone chain due to error: %v", err)
return nil, errors.WithMessage(err, "could not clone chain due to error")
}
}

Expand Down Expand Up @@ -337,7 +336,7 @@ func (t *TestChain) fetchClosestInternalBlock(blockNumber uint64) (int, *chainTy
func (t *TestChain) BlockFromNumber(blockNumber uint64) (*chainTypes.Block, error) {
// If the block number is past our current head, return an error.
if blockNumber > t.HeadBlockNumber() {
return nil, fmt.Errorf("could not obtain block for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
return nil, errors.Errorf("could not obtain block for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
}

// We only commit blocks that were created by this chain. If block numbers are skipped, we simulate their existence
Expand Down Expand Up @@ -411,7 +410,7 @@ func getSpoofedBlockHashFromNumber(blockNumber uint64) common.Hash {
func (t *TestChain) BlockHashFromNumber(blockNumber uint64) (common.Hash, error) {
// If our block number references something too new, return an error
if blockNumber > t.HeadBlockNumber() {
return common.Hash{}, fmt.Errorf("could not obtain block hash for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
return common.Hash{}, errors.Errorf("could not obtain block hash for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
}

// Obtain our closest internally committed block
Expand Down Expand Up @@ -443,7 +442,7 @@ func (t *TestChain) StateFromRoot(root common.Hash) (*state.StateDB, error) {
func (t *TestChain) StateRootAfterBlockNumber(blockNumber uint64) (common.Hash, error) {
// If our block number references something too new, return an error
if blockNumber > t.HeadBlockNumber() {
return common.Hash{}, fmt.Errorf("could not obtain post-state for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
return common.Hash{}, errors.Errorf("could not obtain post-state for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
}

// Obtain our closest internally committed block
Expand All @@ -459,7 +458,7 @@ func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (*state.StateDB, e
// Obtain our block's post-execution state root hash
root, err := t.StateRootAfterBlockNumber(blockNumber)
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}

// Load our state from the database
Expand All @@ -471,14 +470,14 @@ func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (*state.StateDB, e
func (t *TestChain) RevertToBlockNumber(blockNumber uint64) error {
// If our block number references something too new, return an error
if blockNumber > t.HeadBlockNumber() {
return fmt.Errorf("could not revert to block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
return errors.Errorf("could not revert to block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber())
}

// Obtain our closest internally committed block, if it's not an exact match, it means we're trying to revert
// to a spoofed block, which we disallow for now.
closestBlockIndex, closestBlock := t.fetchClosestInternalBlock(blockNumber)
if closestBlock.Header.Number.Uint64() != blockNumber {
return fmt.Errorf("could not revert to block number %d because it does not refer to an internally committed block", blockNumber)
return errors.Errorf("could not revert to block number %d because it does not refer to an internally committed block", blockNumber)
}

// Slice off our blocks to be removed (to produce relevant events)
Expand Down Expand Up @@ -567,7 +566,7 @@ func (t *TestChain) CallContract(msg core.Message, state *state.StateDB, additio
// Revert to our state snapshot to undo any changes.
state.RevertToSnapshot(snapshot)

return res, err
return res, errors.WithStack(err)
}

// PendingBlock describes the current pending block which is being constructed and awaiting commitment to the chain.
Expand All @@ -594,7 +593,7 @@ func (t *TestChain) PendingBlockCreate() (*chainTypes.Block, error) {
func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTime uint64, blockGasLimit *uint64) (*chainTypes.Block, error) {
// If we already have a pending block, return an error.
if t.pendingBlock != nil {
return nil, fmt.Errorf("could not create a new pending block for chain, as a block is already pending")
return nil, errors.New("could not create a new pending block for chain, as a block is already pending")
}

// If our block gas limit is not specified, use the default defined by this chain.
Expand All @@ -605,7 +604,7 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi
// Validate our block number exceeds our previous head
currentHeadBlockNumber := t.Head().Header.Number.Uint64()
if blockNumber <= currentHeadBlockNumber {
return nil, fmt.Errorf("failed to create block with a block number of %d as does precedes the chain head block number of %d", blockNumber, currentHeadBlockNumber)
return nil, errors.Errorf("failed to create block with a block number of %d as does precedes the chain head block number of %d", blockNumber, currentHeadBlockNumber)
}

// Obtain our parent block hash to reference in our new block.
Expand All @@ -623,7 +622,7 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi
// block number for us to spoof the existence of those intermediate blocks, each with their own unique timestamp.
currentHeadTimeStamp := t.Head().Header.Time
if currentHeadTimeStamp >= blockTime || blockNumberDifference > blockTime-currentHeadTimeStamp {
return nil, fmt.Errorf("failed to create block as block number was advanced by %d while block timestamp was advanced by %d. timestamps must be unique per block", blockNumberDifference, blockTime-currentHeadTimeStamp)
return nil, errors.Errorf("failed to create block as block number was advanced by %d while block timestamp was advanced by %d. timestamps must be unique per block", blockNumberDifference, blockTime-currentHeadTimeStamp)
}

// Create a block header for this block:
Expand Down Expand Up @@ -705,7 +704,7 @@ func (t *TestChain) PendingBlockAddTx(message core.Message) error {
if err != nil {
// If we encountered an error, reset our state, as we couldn't add the tx.
t.state, _ = state.New(t.pendingBlock.Header.Root, t.stateDatabase, nil)
return fmt.Errorf("test chain state write error when adding tx to pending block: %v", err)
return errors.WithMessage(err, "test chain state write error when adding tx to pending block")
}

// Create our message result
Expand All @@ -725,12 +724,12 @@ func (t *TestChain) PendingBlockAddTx(message core.Message) error {
// safe to update the block header afterwards.
root, err := t.state.Commit(t.chainConfig.IsEIP158(t.pendingBlock.Header.Number))
if err != nil {
return fmt.Errorf("test chain state write error: %v", err)
return errors.Wrap(err, "test chain state write error")
}
if err := t.state.Database().TrieDB().Commit(root, false); err != nil {
// If we encountered an error, reset our state, as we couldn't add the tx.
t.state, _ = state.New(t.pendingBlock.Header.Root, t.stateDatabase, nil)
return fmt.Errorf("test chain trie write error: %v", err)
return errors.Wrap(err, "test chain trie write error")
}

// Update our gas used in the block header
Expand Down Expand Up @@ -773,7 +772,7 @@ func (t *TestChain) PendingBlockAddTx(message core.Message) error {
func (t *TestChain) PendingBlockCommit() error {
// If we have no pending block, we cannot commit it.
if t.pendingBlock == nil {
return fmt.Errorf("could not commit chain's pending block, as no pending block was created")
return errors.New("could not commit chain's pending block, as no pending block was created")
}

// Append our new block to our chain.
Expand Down
3 changes: 2 additions & 1 deletion chain/vendored/apply_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/params"
"github.com/pkg/errors"
"math/big"
)

Expand All @@ -43,7 +44,7 @@ func EVMApplyTransaction(msg Message, config *params.ChainConfig, author *common
// Apply the transaction to the current state (included in the env).
result, err := ApplyMessage(evm, msg, gp)
if err != nil {
return nil, nil, err
return nil, nil, errors.WithStack(err)
}

// Update the state with pending changes.
Expand Down
11 changes: 6 additions & 5 deletions cmd/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"github.com/pkg/errors"
"os"
"os/signal"
"path/filepath"
Expand All @@ -24,7 +25,7 @@ var fuzzCmd = &cobra.Command{
func cmdValidateFuzzArgs(cmd *cobra.Command, args []string) error {
// Make sure we have no positional args
if err := cobra.NoArgs(cmd, args); err != nil {
return fmt.Errorf("fuzz does not accept any positional arguments, only flags and their associated values")
return errors.New("fuzz does not accept any positional arguments, only flags and their associated values")
}
return nil
}
Expand Down Expand Up @@ -52,14 +53,14 @@ func cmdRunFuzz(cmd *cobra.Command, args []string) error {
configFlagUsed := cmd.Flags().Changed("config")
configPath, err := cmd.Flags().GetString("config")
if err != nil {
return err
return errors.WithStack(err)
}

// If --config was not used, look for `medusa.json` in the current work directory
if !configFlagUsed {
workingDirectory, err := os.Getwd()
if err != nil {
return err
return errors.WithStack(err)
}
configPath = filepath.Join(workingDirectory, DefaultProjectConfigFilename)
}
Expand All @@ -78,7 +79,7 @@ func cmdRunFuzz(cmd *cobra.Command, args []string) error {

// Possibility #2: If the --config flag was used, and we couldn't find the file, we'll throw an error
if configFlagUsed && existenceError != nil {
return existenceError
return errors.WithStack(existenceError)
}

// Possibility #3: --config flag was not used and medusa.json was not found, so use the default project config
Expand All @@ -104,7 +105,7 @@ func cmdRunFuzz(cmd *cobra.Command, args []string) error {
// be in the config directory when running this.
err = os.Chdir(filepath.Dir(configPath))
if err != nil {
return err
return errors.WithStack(err)
}

// Create our fuzzing
Expand Down
21 changes: 11 additions & 10 deletions cmd/fuzz_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"github.com/pkg/errors"

"github.com/spf13/cobra"
"github.com/trailofbits/medusa/fuzzing/config"
Expand Down Expand Up @@ -76,7 +77,7 @@ func updateProjectConfigWithFuzzFlags(cmd *cobra.Command, projectConfig *config.
// Get the new target
newTarget, err := cmd.Flags().GetString("target")
if err != nil {
return err
return errors.WithStack(err)
}

err = projectConfig.Compilation.SetTarget(newTarget)
Expand All @@ -89,71 +90,71 @@ func updateProjectConfigWithFuzzFlags(cmd *cobra.Command, projectConfig *config.
if cmd.Flags().Changed("workers") {
projectConfig.Fuzzing.Workers, err = cmd.Flags().GetInt("workers")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update timeout
if cmd.Flags().Changed("timeout") {
projectConfig.Fuzzing.Timeout, err = cmd.Flags().GetInt("timeout")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update test limit
if cmd.Flags().Changed("test-limit") {
projectConfig.Fuzzing.TestLimit, err = cmd.Flags().GetUint64("test-limit")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update sequence length
if cmd.Flags().Changed("seq-len") {
projectConfig.Fuzzing.CallSequenceLength, err = cmd.Flags().GetInt("seq-len")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update deployment order
if cmd.Flags().Changed("deployment-order") {
projectConfig.Fuzzing.DeploymentOrder, err = cmd.Flags().GetStringSlice("deployment-order")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update corpus directory
if cmd.Flags().Changed("corpus-dir") {
projectConfig.Fuzzing.CorpusDirectory, err = cmd.Flags().GetString("corpus-dir")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update senders
if cmd.Flags().Changed("senders") {
projectConfig.Fuzzing.SenderAddresses, err = cmd.Flags().GetStringSlice("senders")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update deployer address
if cmd.Flags().Changed("deployer") {
projectConfig.Fuzzing.DeployerAddress, err = cmd.Flags().GetString("deployer")
if err != nil {
return err
return errors.WithStack(err)
}
}

// Update assertion mode enablement
if cmd.Flags().Changed("assertion-mode") {
projectConfig.Fuzzing.Testing.AssertionTesting.Enabled, err = cmd.Flags().GetBool("assertion-mode")
if err != nil {
return err
return errors.WithStack(err)
}
}

Expand Down
Loading