diff --git a/.gitignore b/.gitignore index e9f9d788..1f92dfb2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ docs/book # Build results result + +.vscode/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..e69de29b diff --git a/chain/cheat_code_tracer.go b/chain/cheat_code_tracer.go index 9ea79b11..b2cf8e1e 100644 --- a/chain/cheat_code_tracer.go +++ b/chain/cheat_code_tracer.go @@ -1,6 +1,7 @@ package chain import ( + "fmt" "math/big" "github.com/crytic/medusa-geth/common" @@ -9,6 +10,7 @@ import ( "github.com/crytic/medusa-geth/core/vm" "github.com/crytic/medusa-geth/eth/tracers" "github.com/crytic/medusa/chain/types" + "github.com/holiman/uint256" ) // cheatCodeTracer represents an EVM.Logger which tracks and patches EVM execution state to enable extended @@ -232,6 +234,43 @@ func (t *cheatCodeTracer) OnOpcode(pc uint64, op byte, gas, cost uint64, scope t if t.callDepth > 0 { t.callFrames[t.callDepth-1].onNextFrameEnterHooks.Execute(true, true) } + + // Support for expectRevert cheatcode (see standard_cheat_code_contrat.go) + // TODO : support dynamic value for expectRevert + if _, ok := currentCallFrame.extraData["expectRevert"]; ok { + // expectRevert does not affect the calls to the cheatcode VM (ex: prank) + // So if the next call is another call to the VM + // We forward the extraData to the next Frame + if scope.Address() == StandardCheatcodeContractAddress { + // TODO refactor the following to be an internal function used in both here and standard_cheat_code_contrat + delete(currentCallFrame.extraData, "expectRevert") + + cheatCodeCallerFrame := t.PreviousCallFrame() + cheatCodeCallerFrame.onNextFrameEnterHooks.Push(func() { + revertFrame := t.PreviousCallFrame() + // TODO : support dynamic value for expectRevert instead of a bool + revertFrame.extraData["expectRevert"] = true + }) + } else { + delete(currentCallFrame.extraData, "expectRevert") + + stack := scope.StackData() + index := len(stack) - 1 + return_value := stack[index] + if return_value.Eq(uint256.NewInt(0)) { + stack[index] = *uint256.NewInt(1) + } else { + + if !return_value.Eq(uint256.NewInt(1)) { + // TODO: find a better error handling + panic(fmt.Sprintf("expected revert but got return value %v", return_value)) + } + + stack[index] = *uint256.NewInt(0) + } + + } + } } // CaptureTxEndSetAdditionalResults can be used to set additional results captured from execution tracing. If this diff --git a/chain/standard_cheat_code_contract.go b/chain/standard_cheat_code_contract.go index dde693e6..8c0d2698 100644 --- a/chain/standard_cheat_code_contract.go +++ b/chain/standard_cheat_code_contract.go @@ -10,13 +10,12 @@ import ( "strconv" "strings" - "github.com/crytic/medusa/chain/types" - "github.com/crytic/medusa-geth/accounts/abi" "github.com/crytic/medusa-geth/common" "github.com/crytic/medusa-geth/core/tracing" "github.com/crytic/medusa-geth/core/vm" "github.com/crytic/medusa-geth/crypto" + "github.com/crytic/medusa/chain/types" "github.com/crytic/medusa/utils" "github.com/holiman/uint256" ) @@ -59,6 +58,10 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, if err != nil { return nil, err } + typeBytes4, err := abi.NewType("bytes4", "", nil) + if err != nil { + return nil, err + } typeBytes32, err := abi.NewType("bytes32", "", nil) if err != nil { return nil, err @@ -810,10 +813,234 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, }, ) + // getCode: Retrieves the runtime bytecode for a contract + contract.addMethod("getDeployedCode", abi.Arguments{{Type: typeString}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + + contractPath := inputs[0].(string) + + _, contractName, err := parseContractPath(contractPath) + if err != nil { + fmt.Println("getDeployedCode error: invalid path format: %v", err) + return nil, cheatCodeRevertData([]byte(fmt.Sprintf("getCode error: invalid path format: %v", err))) + } + + compiledContract, exists := tracer.chain.CompiledContracts[contractName] + if !exists { + // TODO: this is probably not the best to print given it will be printed in loop + // But it should be shown once to the user to avoid mistakes + // Same is true for getCode + fmt.Println("getDeployedCode error: contract not found: %s (did you forget to deploy the contract with predeployedContracts?)", contractName) + return nil, cheatCodeRevertData([]byte(fmt.Sprintf("getCode error: contract not found: %s", contractName))) + } + + bytecode := compiledContract.RuntimeBytecode + if len(bytecode) == 0 { + fmt.Println("getDeployedCode error: contract bytecode is empty: %s", contractName) + return nil, cheatCodeRevertData([]byte(fmt.Sprintf("getCode error: contract bytecode is empty: %s", contractName))) + } + + fmt.Println("getDeployedCode found") + // Return the bytecode + return []any{bytecode}, nil + }, + ) + + // assertTrue + contract.addMethod("assertTrue", abi.Arguments{{Type: typeBool}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + bool_result := inputs[0].(bool) + + if !bool_result { + return nil, cheatCodeRevertData([]byte("assertFalse failed")) + } + + // Return nothing + return nil, nil + }, + ) + + // assertTrue with reason + contract.addMethod("assertTrue", abi.Arguments{{Type: typeBool}, {Type: typeString}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + bool_result := inputs[0].(bool) + + if !bool_result { + return nil, cheatCodeRevertData([]byte(inputs[1].(string))) + } + + // Return nothing + return nil, nil + }, + ) + + // assertFalse + contract.addMethod("assertFalse", abi.Arguments{{Type: typeBool}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + bool_result := inputs[0].(bool) + + if bool_result { + return nil, cheatCodeRevertData([]byte("assertFalse failed")) + } + + // Return nothing + return nil, nil + }, + ) + + // assertFalse with reason + contract.addMethod("assertFalse", abi.Arguments{{Type: typeBool}, {Type: typeString}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + bool_result := inputs[0].(bool) + + if bool_result { + return nil, cheatCodeRevertData([]byte(inputs[1].(string))) + } + + // Return nothing + return nil, nil + }, + ) + + // assume: Revert if the condition if false + contract.addMethod("assume", abi.Arguments{{Type: typeBool}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + bool_result := inputs[0].(bool) + + if !bool_result { + return nil, cheatCodeRevertData([]byte("assume failed")) + } + + // Return nothing + return nil, nil + }, + ) + + // expectEmit: NOOP for now (TODO) + contract.addMethod("expectEmit", abi.Arguments{}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + // Return nothing + return nil, nil + }, + ) + + // expectRevert: Follow the expect revert logic + // TODO: merge the different expectRevert to reduce the code dupplicate and handle properly the reason check + contract.addMethod("expectRevert", abi.Arguments{}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + + // To implement expectRevert we follow this logic: + // We add a hook on the next call that happen after the call to the cheatcode's VM + // Which is the "next" frame of the "previous frame" + // The previous frame being the caller of the cheatcode, and its next frame the next call + // The hook just set expectRevert to true + // The actual update is done in the OnOpcode hook (cheat_code_tracer.go) + + cheatCodeCallerFrame := tracer.PreviousCallFrame() + cheatCodeCallerFrame.onNextFrameEnterHooks.Push(func() { + revertFrame := tracer.PreviousCallFrame() + // TODO : support dynamic value for expectRevert instead of a bool + revertFrame.extraData["expectRevert"] = true + }) + return nil, nil + }, + ) + + // expectRevert: Follow the expect revert logic + contract.addMethod("expectRevert", abi.Arguments{{Type: typeBytes4}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + + // To implement expectRevert we follow this logic: + // We add a hook on the next call that happen after the call to the cheatcode's VM + // Which is the "next" frame of the "previous frame" + // The previous frame being the caller of the cheatcode, and its next frame the next call + // The hook just set expectRevert to true + // The actual update is done in the OnOpcode hook (cheat_code_tracer.go) + + cheatCodeCallerFrame := tracer.PreviousCallFrame() + cheatCodeCallerFrame.onNextFrameEnterHooks.Push(func() { + revertFrame := tracer.PreviousCallFrame() + // TODO : support dynamic value for expectRevert instead of a bool + revertFrame.extraData["expectRevert"] = true + }) + return nil, nil + }, + ) + + // expectRevert: Follow the expect revert logic + contract.addMethod("expectRevert", abi.Arguments{{Type: typeBytes}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + + // To implement expectRevert we follow this logic: + // We add a hook on the next call that happen after the call to the cheatcode's VM + // Which is the "next" frame of the "previous frame" + // The previous frame being the caller of the cheatcode, and its next frame the next call + // The hook just set expectRevert to true + // The actual update is done in the OnOpcode hook (cheat_code_tracer.go) + + cheatCodeCallerFrame := tracer.PreviousCallFrame() + cheatCodeCallerFrame.onNextFrameEnterHooks.Push(func() { + revertFrame := tracer.PreviousCallFrame() + // TODO : support dynamic value for expectRevert instead of a bool + revertFrame.extraData["expectRevert"] = true + }) + return nil, nil + }, + ) + + // assertEq: Register assertEq for all supported types + assertEqTypes := []abi.Type{typeAddress, typeBytes, typeBytes4, typeBytes32, typeUint8, typeUint64, typeUint256, typeInt256, typeString, typeBool} + for _, t := range assertEqTypes { + currentType := t // capture range variable + contract.addMethod("assertEq", abi.Arguments{{Type: currentType}, {Type: currentType}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + if !assertEqGenerator(inputs, currentType) { + return nil, cheatCodeRevertData([]byte("assertEq failed")) + } + return nil, nil + }, + ) + + contract.addMethod("assertEq", abi.Arguments{{Type: currentType}, {Type: currentType}, {Type: typeString}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + if !assertEqGenerator(inputs, currentType) { + reason := inputs[2].(string) + return nil, cheatCodeRevertData([]byte(reason)) + } + return nil, nil + }, + ) + } + // Return our precompile contract information. return contract, nil } +func assertEqGenerator(inputs []any, t abi.Type) bool { + l := inputs[0] + r := inputs[1] + + // Use type-specific comparisons based on the ABI type + switch t.T { + case abi.AddressTy: + return l.(common.Address) == r.(common.Address) + case abi.BoolTy: + return l.(bool) == r.(bool) + case abi.IntTy, abi.UintTy: + return l.(*big.Int).Cmp(r.(*big.Int)) == 0 + case abi.StringTy: + return l.(string) == r.(string) + case abi.BytesTy: + return string(l.([]byte)) == string(r.([]byte)) + case abi.FixedBytesTy: + lBytes := l.([32]byte) + rBytes := r.([32]byte) + return string(lBytes[:]) == string(rBytes[:]) + default: + return l == r + } +} + // parseContractPath parses a contract path in the following formats: // - "MyContract.sol:MyContract" // - "MyContract" diff --git a/chain/test_chain.go b/chain/test_chain.go index 9d244d3d..5a7013c6 100644 --- a/chain/test_chain.go +++ b/chain/test_chain.go @@ -3,9 +3,10 @@ package chain import ( "errors" "fmt" - compilationTypes "github.com/crytic/medusa/compilation/types" "math/big" + compilationTypes "github.com/crytic/medusa/compilation/types" + "github.com/crytic/medusa/chain/state" "golang.org/x/net/context" diff --git a/docs/src/project_configuration/fuzzing_config.md b/docs/src/project_configuration/fuzzing_config.md index 13974545..d25463be 100644 --- a/docs/src/project_configuration/fuzzing_config.md +++ b/docs/src/project_configuration/fuzzing_config.md @@ -103,6 +103,13 @@ The fuzzing configuration defines the parameters for the fuzzing campaign. then `A` will have a starting balance of `1,234 wei`, `B` will have `4,660 wei (0x1234 in decimal)`, and `C` will have `1.2 ETH (1.2 × 10^18 wei)`. - **Default**: `[]` + +### `targetContractsInitFunctions` + +- **Type**: [String] (e.g. `["setUp", "initialize", ""]`) +- **Description**: Specifies post-deployment initialization functions to call for each contract in `targetContracts`. This array has a one-to-one mapping with `targetContracts`, where each element corresponds to the initialization function for the contract at the same index. Empty strings indicate no initialization for that contract. +- **Default**: `[]` + ### `constructorArgs` - **Type**: `{"contractName": {"variableName": _value}}` @@ -110,6 +117,20 @@ The fuzzing configuration defines the parameters for the fuzzing campaign. An example can be found [here](#using-constructorargs). - **Default**: `{}` +### `initializationArgs` + +- **Type**: `{"contractName": {"parameterName": _value}}` +- **Description**: Specifies arguments to pass to initialization functions defined in `targetContractsInitFunctions`. The keys in this map must match the contract names exactly, and the parameter names must match the parameter names in the function signature. + For example, if contract `MyContract` has an initialization function `initialize(uint256 _value, address _owner)`, then you would configure: + ```json + { + "MyContract": { + "_value": "100", + "_owner": "0x1234..." + } + } + ``` + ### `deployerAddress` - **Type**: Address diff --git a/docs/src/static/function_level_testing_medusa.json b/docs/src/static/function_level_testing_medusa.json index 11934058..e036c562 100644 --- a/docs/src/static/function_level_testing_medusa.json +++ b/docs/src/static/function_level_testing_medusa.json @@ -9,7 +9,9 @@ "coverageEnabled": true, "targetContracts": ["TestDepositContract"], "targetContractsBalances": ["21267647932558653966460912964485513215"], + "TargetContractsInitFunctions": [], "constructorArgs": {}, + "initializationArgs": {}, "deployerAddress": "0x30000", "senderAddresses": ["0x10000", "0x20000", "0x30000"], "blockNumberDelayMax": 60480, diff --git a/docs/src/static/medusa.json b/docs/src/static/medusa.json index d8d4d3a9..7b33624c 100644 --- a/docs/src/static/medusa.json +++ b/docs/src/static/medusa.json @@ -13,7 +13,9 @@ "targetContracts": [], "predeployedContracts": {}, "targetContractsBalances": [], + "TargetContractsInitFunctions": [], "constructorArgs": {}, + "initializationArgs": {}, "deployerAddress": "0x30000", "senderAddresses": ["0x10000", "0x20000", "0x30000"], "blockNumberDelayMax": 60480, diff --git a/fuzzing/config/config.go b/fuzzing/config/config.go index 1f7aa0f3..c4018759 100644 --- a/fuzzing/config/config.go +++ b/fuzzing/config/config.go @@ -79,10 +79,17 @@ type FuzzingConfig struct { // TargetContracts TargetContractsBalances []*ContractBalance `json:"targetContractsBalances"` + // TargetContractsInitFunctions is the list of functions to users to specify an "init function" (with setUp() as the default) + TargetContractsInitFunctions []string `json:"targetContractsInitFunctions"` + // ConstructorArgs holds the constructor arguments for TargetContracts deployments. It is available via the project // configuration ConstructorArgs map[string]map[string]any `json:"constructorArgs"` + // InitializationArgs holds the arguments for TargetContractsInitFunctions deployments. It is available via the project + // configuration + InitializationArgs map[string]map[string]any `json:"initializationArgs"` + // DeployerAddress describe the account address to be used to deploy contracts. DeployerAddress string `json:"deployerAddress"` diff --git a/fuzzing/config/config_defaults.go b/fuzzing/config/config_defaults.go index 0e04e201..58deeff7 100644 --- a/fuzzing/config/config_defaults.go +++ b/fuzzing/config/config_defaults.go @@ -39,19 +39,21 @@ func GetDefaultProjectConfig(platform string) (*ProjectConfig, error) { // Create a project configuration projectConfig := &ProjectConfig{ Fuzzing: FuzzingConfig{ - Workers: 10, - WorkerResetLimit: 50, - Timeout: 0, - TestLimit: 0, - ShrinkLimit: 5_000, - CallSequenceLength: 100, - TargetContracts: []string{}, - TargetContractsBalances: []*ContractBalance{}, - PredeployedContracts: map[string]string{}, - ConstructorArgs: map[string]map[string]any{}, - CorpusDirectory: "", - CoverageEnabled: true, - CoverageFormats: []string{"html", "lcov"}, + Workers: 10, + WorkerResetLimit: 50, + Timeout: 0, + TestLimit: 0, + ShrinkLimit: 5_000, + CallSequenceLength: 100, + TargetContracts: []string{}, + TargetContractsBalances: []*ContractBalance{}, + TargetContractsInitFunctions: []string{}, + PredeployedContracts: map[string]string{}, + ConstructorArgs: map[string]map[string]any{}, + InitializationArgs: map[string]map[string]any{}, + CorpusDirectory: "", + CoverageEnabled: true, + CoverageFormats: []string{"html", "lcov"}, SenderAddresses: []string{ "0x10000", "0x20000", diff --git a/fuzzing/fuzzer.go b/fuzzing/fuzzer.go index d37d4770..537926f1 100644 --- a/fuzzing/fuzzer.go +++ b/fuzzing/fuzzer.go @@ -483,15 +483,35 @@ func chainSetupFromCompilations(fuzzer *Fuzzer, testChain *chain.TestChain) (*ex // while still being able to use the contract address overrides contractsToDeploy := make([]string, 0) balances := make([]*config.ContractBalance, 0) + initFunctions := make([]string, 0) for contractName := range fuzzer.config.Fuzzing.PredeployedContracts { contractsToDeploy = append(contractsToDeploy, contractName) // Preserve index of target contract balances balances = append(balances, &config.ContractBalance{Int: *big.NewInt(0)}) + // Set default empty init function for predeployed contracts + initFunctions = append(initFunctions, "") } + contractsToDeploy = append(contractsToDeploy, fuzzer.config.Fuzzing.TargetContracts...) balances = append(balances, fuzzer.config.Fuzzing.TargetContractsBalances...) + // Process target contracts init functions + targetContractsCount := len(fuzzer.config.Fuzzing.TargetContracts) + initConfigCount := len(fuzzer.config.Fuzzing.TargetContractsInitFunctions) + + // Add initialization functions for target contracts + for i := 0; i < targetContractsCount; i++ { + initFunction := "" // No default initialization + + // Use custom init function if available + if i < initConfigCount && fuzzer.config.Fuzzing.TargetContractsInitFunctions[i] != "" { + initFunction = fuzzer.config.Fuzzing.TargetContractsInitFunctions[i] + } + + initFunctions = append(initFunctions, initFunction) + } + deployedContractAddr := make(map[string]common.Address) // Loop for all contracts to deploy for i, contractName := range contractsToDeploy { @@ -585,9 +605,114 @@ func chainSetupFromCompilations(fuzzer *Fuzzer, testChain *chain.TestChain) (*ex // Record our deployed contract so the next config-specified constructor args can reference this // contract by name. deployedContractAddr[contractName] = block.MessageResults[0].Receipt.ContractAddress + contractAddr := deployedContractAddr[contractName] + + // Get the initialization function name if exists + if i < len(initFunctions) && initFunctions[i] != "" { + initFunction := initFunctions[i] + fuzzer.logger.Info(fmt.Sprintf("Checking if init function %s on %s exists", initFunction, contractName)) + + // Check if the initialization function exists + contractABI := contract.CompiledContract().Abi + if method, exists := contractABI.Methods[initFunction]; !exists { + fuzzer.logger.Info(fmt.Sprintf("Init function %s not found on %s, skipping", initFunction, contractName)) + } else { + // Initialization function exists, proceed with calling it + fuzzer.logger.Info(fmt.Sprintf("Found init function %s with %d inputs", initFunction, len(method.Inputs))) + + // Check if the init function accepts parameters and process them if needed + var args []any + if len(method.Inputs) > 0 { + // Verify InitializationArgs map exists + if fuzzer.config.Fuzzing.InitializationArgs == nil { + fuzzer.logger.Error(fmt.Errorf("initialization args map is nil but function requires args")) + continue + } + + // Look for initialization arguments in the config + jsonArgs, ok := fuzzer.config.Fuzzing.InitializationArgs[contractName] + if !ok { + fuzzer.logger.Error(fmt.Errorf("initialization arguments for contract %s not provided", contractName)) + continue + } + + // Debug what args we found + fuzzer.logger.Info(fmt.Sprintf("Found args for %s: %+v", contractName, jsonArgs)) + + // Decode the arguments + decoded, err := valuegeneration.DecodeJSONArgumentsFromMap(method.Inputs, + jsonArgs, deployedContractAddr) + if err != nil { + fuzzer.logger.Error(fmt.Errorf("decoding failed for initialization arguments for contract %s: %v", + contractName, err)) + continue + } + + args = decoded + fuzzer.logger.Info(fmt.Sprintf("Decoded %d args for %s function %s", + len(args), contractName, initFunction)) + } + + // Log before packing + fuzzer.logger.Info(fmt.Sprintf("About to call initialization function %s on contract %s with %d args", + initFunction, contractName, len(args))) + + // Pack the function call data with arguments + callData, err := contractABI.Pack(initFunction, args...) + if err != nil { + fuzzer.logger.Error(fmt.Errorf("failed to encode init call to %s: %v", initFunction, err)) + continue + } - // Flag that we found a matching compiled contract definition and deployed it, then exit out of this - // inner loop to process the next contract to deploy in the outer loop. + // Create and send the transaction + destAddr := contractAddr + msg := calls.NewCallMessage(fuzzer.deployer, &destAddr, 0, big.NewInt(0), + fuzzer.config.Fuzzing.BlockGasLimit, nil, nil, nil, callData) + msg.FillFromTestChainProperties(testChain) + + // Debug log after creating the message + fuzzer.logger.Info(fmt.Sprintf("Created message for init function call to %s", initFunction)) + + // Create and commit a block with the transaction + block, err = testChain.PendingBlockCreate() + if err != nil { + fuzzer.logger.Error(fmt.Errorf("failed to create pending block for init call: %v", err)) + continue + } + + if err = testChain.PendingBlockAddTx(msg.ToCoreMessage()); err != nil { + fuzzer.logger.Error(fmt.Errorf("failed to add initialization transaction for function %s on contract %s to pending block: %v", + initFunction, contractName, err)) + continue + } + + if err = testChain.PendingBlockCommit(); err != nil { + fuzzer.logger.Error(fmt.Errorf("failed to commit block containing initialization call to function %s on contract %s: %v", + initFunction, contractName, err)) + continue + } + + // Check if the call succeeded + if block.MessageResults[0].Receipt.Status != types.ReceiptStatusSuccessful { + // Create a call sequence element for the trace + cse := calls.NewCallSequenceElement(nil, msg, 0, 0) + cse.ChainReference = &calls.CallSequenceElementChainReference{ + Block: block, + TransactionIndex: len(block.Messages) - 1, + } + + fuzzer.logger.Error(fmt.Errorf("init function %s call failed on %s: %v", + initFunction, contractName, + block.MessageResults[0].ExecutionResult.Err)) + } else { + fuzzer.logger.Info(fmt.Sprintf("Successfully called %s on %s with %d args", + initFunction, contractName, len(args))) + } + } + } + + // Flag that we found a matching compiled contract definition, deployed it and called available init functions if any, + // then exit out of this inner loop to process the next contract to deploy in the outer loop. found = true break } @@ -598,6 +723,7 @@ func chainSetupFromCompilations(fuzzer *Fuzzer, testChain *chain.TestChain) (*ex return nil, fmt.Errorf("%v was specified in the target contracts but was not found in the compilation artifacts", contractName) } } + return nil, nil } diff --git a/fuzzing/fuzzer_hooks.go b/fuzzing/fuzzer_hooks.go index f6be4324..ede734bf 100644 --- a/fuzzing/fuzzer_hooks.go +++ b/fuzzing/fuzzer_hooks.go @@ -1,9 +1,10 @@ package fuzzing import ( - "github.com/crytic/medusa/fuzzing/config" "math/rand" + "github.com/crytic/medusa/fuzzing/config" + "github.com/crytic/medusa/fuzzing/executiontracer" "github.com/crytic/medusa/chain" diff --git a/fuzzing/fuzzer_test.go b/fuzzing/fuzzer_test.go index 6b1292ee..f9b2825e 100644 --- a/fuzzing/fuzzer_test.go +++ b/fuzzing/fuzzer_test.go @@ -349,6 +349,7 @@ func TestConsoleLog(t *testing.T) { filePath: filePath, configUpdates: func(config *config.ProjectConfig) { config.Fuzzing.TargetContracts = []string{"TestContract"} + config.Fuzzing.TargetContractsInitFunctions = []string{"testConsoleLog"} config.Fuzzing.TestLimit = 10000 config.Fuzzing.Testing.PropertyTesting.Enabled = false config.Fuzzing.Testing.OptimizationTesting.Enabled = false @@ -514,6 +515,51 @@ func TestDeploymentsWithPayableConstructors(t *testing.T) { }) } +// TestInitializationFunctions runs a test to ensure initialization functions work both with and without arguments +func TestInitializationWithParam(t *testing.T) { + runFuzzerTest(t, &fuzzerSolcFileTest{ + filePath: "testdata/contracts/deployments/deploy_with_init_fns.sol", + configUpdates: func(pkgConfig *config.ProjectConfig) { + // Just a single contract + pkgConfig.Fuzzing.TargetContracts = []string{"SimpleInitParamTest"} + + // With zero balance + pkgConfig.Fuzzing.TargetContractsBalances = []*config.ContractBalance{ + {Int: *big.NewInt(0)}, + } + + // Initialization function with a parameter + pkgConfig.Fuzzing.TargetContractsInitFunctions = []string{"initWithParam"} + + // Create the initialization args map if it doesn't exist + if pkgConfig.Fuzzing.InitializationArgs == nil { + pkgConfig.Fuzzing.InitializationArgs = make(map[string]map[string]any) + } + + // Specify the parameter value - must match the exact parameter name + pkgConfig.Fuzzing.InitializationArgs["SimpleInitParamTest"] = map[string]any{ + "_value": "42", + } + + // Enable property testing + pkgConfig.Fuzzing.Testing.PropertyTesting.Enabled = false + pkgConfig.Fuzzing.TestLimit = 10 + pkgConfig.Fuzzing.Testing.AssertionTesting.Enabled = true + pkgConfig.Fuzzing.Testing.OptimizationTesting.Enabled = false + pkgConfig.Slither.UseSlither = false + + }, + method: func(f *fuzzerTestContext) { + // Start the fuzzer + err := f.fuzzer.Start() + assert.NoError(t, err) + + assertFailedTestsExpected(f, false) + + }, + }) +} + // TestDeploymentsSelfDestruct runs a test to ensure dynamically deployed contracts are detected by the Fuzzer and // their properties are tested appropriately. func TestDeploymentsSelfDestruct(t *testing.T) { @@ -697,7 +743,7 @@ func TestTestingScope(t *testing.T) { // TestDeploymentsWithArgs runs tests to ensure contracts deployed with config provided constructor arguments are // deployed as expected. It expects all properties should fail (indicating values provided were set accordingly). func TestDeploymentsWithArgs(t *testing.T) { - // This contract deploys a contract with specific constructor arguments. Property tests will fail if they are + // This contract deploys a contract with specific constructor arguments as well as init functions with arguments. Property tests will fail if they are // set correctly. runFuzzerTest(t, &fuzzerSolcFileTest{ filePath: "testdata/contracts/deployments/deployment_with_args.sol", @@ -716,6 +762,16 @@ func TestDeploymentsWithArgs(t *testing.T) { "_deployed": "DeployedContract:DeploymentWithArgs", }, } + config.Fuzzing.TargetContractsInitFunctions = []string{"dummyFunction", "dummyFunction"} // this should execute predefined functions in the respective contracts + config.Fuzzing.InitializationArgs = map[string]map[string]any{ + "DeploymentWithArgs": { + "a": "100", // argument for DeploymentWithArgs.dummyFunction + }, + "Dependent": { + "a": "200", // argument for Dependent.dummyFunction + }, + } + config.Fuzzing.Testing.StopOnFailedTest = false config.Fuzzing.TestLimit = 500 // this test should expose a failure quickly. config.Fuzzing.Testing.AssertionTesting.Enabled = false diff --git a/fuzzing/testdata/contracts/cheat_codes/vm/assume.sol b/fuzzing/testdata/contracts/cheat_codes/vm/assume.sol new file mode 100644 index 00000000..6cba5bd5 --- /dev/null +++ b/fuzzing/testdata/contracts/cheat_codes/vm/assume.sol @@ -0,0 +1,21 @@ +// This test ensures that the chainId can be set with cheat codes +interface CheatCodes { + function assume(bool) external; +} + +contract TestContract { + + function test_true() public { + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + cheats.assume(true); + // this always happens (only useful if tested with manual coverage check) + assert(true); + } + + function test_false() public { + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + cheats.assume(false); + // this is not reachable + assert(false); + } +} \ No newline at end of file diff --git a/fuzzing/testdata/contracts/cheat_codes/vm/expectRevert.sol b/fuzzing/testdata/contracts/cheat_codes/vm/expectRevert.sol new file mode 100644 index 00000000..535c8029 --- /dev/null +++ b/fuzzing/testdata/contracts/cheat_codes/vm/expectRevert.sol @@ -0,0 +1,60 @@ +// This test ensures that the expectRevert works as expected +interface CheatCodes { + function expectRevert() external; + function prank(address msgSender) external; +} +contract Target{ + function good() public {} +} +interface FakeTarget{ + function good() external; + function bad() external; +} +contract TestContract { + + function test_true() public { + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + FakeTarget target = FakeTarget(address(new Target())); + cheats.expectRevert(); + target.bad(); + + // this always happens (only useful if tested with manual coverage check) + assert(true); + } + function test_false() public { + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + FakeTarget target = FakeTarget(address(new Target())); + cheats.expectRevert(); + target.good(); + + // this never happens (only useful if tested with manual coverage check) + assert(false); + } + + function test_true_with_prank() public { + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + FakeTarget target = FakeTarget(address(new Target())); + cheats.expectRevert(); + // Calls to the VM are ignored from the expectRevert logic + // So this test that + cheats.prank(address(0x41414141)); + target.bad(); + + // this always happens (only useful if tested with manual coverage check) + assert(true); + } + + + function test_false_with_prank() public { + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + FakeTarget target = FakeTarget(address(new Target())); + cheats.expectRevert(); + // Calls to the VM are ignored from the expectRevert logic + // So this test that + cheats.prank(address(0x41414141)); + target.good(); + + // this never happens (only useful if tested with manual coverage check) + assert(false); + } +} \ No newline at end of file diff --git a/fuzzing/testdata/contracts/deployments/deploy_with_init_fns.sol b/fuzzing/testdata/contracts/deployments/deploy_with_init_fns.sol new file mode 100644 index 00000000..6556de80 --- /dev/null +++ b/fuzzing/testdata/contracts/deployments/deploy_with_init_fns.sol @@ -0,0 +1,19 @@ +// Ultra-simple test for initialization functions with parameters +contract SimpleInitParamTest { + // Track if functions were called and parameter values + bool public initCalled; + uint public initValue; + + // Empty constructor + constructor() {} + + // Initialization function with a parameter + function initWithParam(uint _value) public { + initCalled = true; + initValue = _value; + emit InitCalled("initWithParam", _value); + } + + // Event for tracking + event InitCalled(string functionName, uint value); +} \ No newline at end of file