Skip to content

core/state: implement optional BAL construction in statedb #31959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/evm/blockrunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func runBlockTest(ctx *cli.Context, fname string) ([]testResult, error) {
continue
}
result := &testResult{Name: name, Pass: true}
if err := tests[name].Run(false, rawdb.HashScheme, ctx.Bool(WitnessCrossCheckFlag.Name), tracer, func(res error, chain *core.BlockChain) {
if err := tests[name].Run(false, rawdb.HashScheme, ctx.Bool(WitnessCrossCheckFlag.Name), false, tracer, func(res error, chain *core.BlockChain) {
if ctx.Bool(DumpFlag.Name) {
if s, _ := chain.State(); s != nil {
result.State = dump(s)
Expand Down
22 changes: 13 additions & 9 deletions core/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
}
defer bc.chainmu.Unlock()

_, n, err := bc.insertChain(chain, true, false) // No witness collection for mass inserts (would get super large)
_, n, err := bc.insertChain(chain, true, false, false) // No witness collection for mass inserts (would get super large)
return n, err
}

Expand All @@ -1679,7 +1679,7 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
// racey behaviour. If a sidechain import is in progress, and the historic state
// is imported, but then new canon-head is added before the actual sidechain
// completes, then the historic state could be pruned again
func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool, makeWitness bool) (*stateless.Witness, int, error) {
func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool, makeWitness bool, makeBAL bool) (*stateless.Witness, int, error) {
// If the chain is terminating, don't even bother starting up.
if bc.insertStopped() {
return nil, 0, nil
Expand Down Expand Up @@ -1837,7 +1837,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool, makeWitness
}
// The traced section of block import.
start := time.Now()
res, err := bc.processBlock(parent.Root, block, setHead, makeWitness && len(chain) == 1)
res, err := bc.processBlock(parent.Root, block, setHead, makeWitness && len(chain) == 1, makeBAL && len(chain) == 1)
if err != nil {
return nil, it.index, err
}
Expand Down Expand Up @@ -1905,7 +1905,7 @@ type blockProcessingResult struct {

// processBlock executes and validates the given block. If there was no error
// it writes the block and associated state to database.
func (bc *BlockChain) processBlock(parentRoot common.Hash, block *types.Block, setHead bool, makeWitness bool) (_ *blockProcessingResult, blockEndErr error) {
func (bc *BlockChain) processBlock(parentRoot common.Hash, block *types.Block, setHead bool, makeWitness bool, makeBAL bool) (_ *blockProcessingResult, blockEndErr error) {
var (
err error
startTime = time.Now()
Expand Down Expand Up @@ -1950,6 +1950,10 @@ func (bc *BlockChain) processBlock(parentRoot common.Hash, block *types.Block, s
}(time.Now(), throwaway, block)
}

if makeBAL && bc.vmConfig.BALConstruction {
statedb.EnableBALConstruction()
}

// If we are past Byzantium, enable prefetching to pull in trie node paths
// while processing transactions. Before Byzantium the prefetcher is mostly
// useless due to the intermediate root hashing after each transaction.
Expand Down Expand Up @@ -2165,7 +2169,7 @@ func (bc *BlockChain) insertSideChain(block *types.Block, it *insertIterator, ma
// memory here.
if len(blocks) >= 2048 || memory > 64*1024*1024 {
log.Info("Importing heavy sidechain segment", "blocks", len(blocks), "start", blocks[0].NumberU64(), "end", block.NumberU64())
if _, _, err := bc.insertChain(blocks, true, false); err != nil {
if _, _, err := bc.insertChain(blocks, true, false, false); err != nil {
return nil, 0, err
}
blocks, memory = blocks[:0], 0
Expand All @@ -2179,7 +2183,7 @@ func (bc *BlockChain) insertSideChain(block *types.Block, it *insertIterator, ma
}
if len(blocks) > 0 {
log.Info("Importing sidechain segment", "start", blocks[0].NumberU64(), "end", blocks[len(blocks)-1].NumberU64())
return bc.insertChain(blocks, true, makeWitness)
return bc.insertChain(blocks, true, makeWitness, false)
}
return nil, 0, nil
}
Expand Down Expand Up @@ -2228,7 +2232,7 @@ func (bc *BlockChain) recoverAncestors(block *types.Block, makeWitness bool) (co
} else {
b = bc.GetBlock(hashes[i], numbers[i])
}
if _, _, err := bc.insertChain(types.Blocks{b}, false, makeWitness && i == 0); err != nil {
if _, _, err := bc.insertChain(types.Blocks{b}, false, makeWitness && i == 0, false); err != nil {
return b.ParentHash(), err
}
}
Expand Down Expand Up @@ -2448,13 +2452,13 @@ func (bc *BlockChain) reorg(oldHead *types.Header, newHead *types.Header) error
// The key difference between the InsertChain is it won't do the canonical chain
// updating. It relies on the additional SetCanonical call to finalize the entire
// procedure.
func (bc *BlockChain) InsertBlockWithoutSetHead(block *types.Block, makeWitness bool) (*stateless.Witness, error) {
func (bc *BlockChain) InsertBlockWithoutSetHead(block *types.Block, makeWitness bool, makeBAL bool) (*stateless.Witness, error) {
if !bc.chainmu.TryLock() {
return nil, errChainStopped
}
defer bc.chainmu.Unlock()

witness, _, err := bc.insertChain(types.Blocks{block}, false, makeWitness)
witness, _, err := bc.insertChain(types.Blocks{block}, false, makeWitness, makeBAL)
return witness, err
}

Expand Down
252 changes: 252 additions & 0 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package state

import (
"bytes"
"errors"
"fmt"
"maps"
Expand Down Expand Up @@ -138,6 +139,9 @@ type StateDB struct {
// State witness if cross validation is needed
witness *stateless.Witness

// block access list, if bal construction is specified
b *bal

// Measurements gathered during execution for debugging purposes
AccountReads time.Duration
AccountHashes time.Duration
Expand All @@ -157,6 +161,15 @@ type StateDB struct {
StorageDeleted atomic.Int64 // Number of storage slots deleted during the state transition
}

func (s *StateDB) EnableBALConstruction() {
s.b = &bal{
make(map[common.Address]*accountAccess),
make(map[common.Address]codeDiff),
make(map[common.Address]nonceDiff),
make(map[common.Address]balanceDiff),
}
}

// New creates a new state from a given trie.
func New(root common.Hash, db Database) (*StateDB, error) {
reader, err := db.Reader(root)
Expand All @@ -166,6 +179,221 @@ func New(root common.Hash, db Database) (*StateDB, error) {
return NewWithReader(root, db, reader)
}

type slotAccess struct {
writes map[uint64]common.Hash // map of tx index to post-tx slot value
}

type accountAccess struct {
address common.Address
accesses map[common.Hash]slotAccess // map of slot key to all post-tx values where that slot was read/written
code []byte
}

func (a *accountAccess) MarkRead(key common.Hash) {
if _, ok := a.accesses[key]; !ok {
a.accesses[key] = slotAccess{
make(map[uint64]common.Hash),
}
}
}

func (a *accountAccess) MarkWrite(txIdx uint64, key, value common.Hash) {
if _, ok := a.accesses[key]; !ok {
a.accesses[key] = slotAccess{
make(map[uint64]common.Hash),
}
}

a.accesses[key].writes[txIdx] = value
}

// map of transaction idx to the new code
type codeDiff struct {
txIdx uint64
code []byte
}

type balanceDiff map[uint64]*uint256.Int

// map of tx-idx to pre-state nonce
type nonceDiff map[uint64]uint64

type bal struct {
accountAccesses map[common.Address]*accountAccess
codeChanges map[common.Address]codeDiff
prestateNonces map[common.Address]nonceDiff
balanceChanges map[common.Address]balanceDiff
}

func (b *bal) eq(other *bal) bool {

// check that the account accesses are equal (consider moving this into its own function)

if len(b.accountAccesses) != len(other.accountAccesses) {
return false
}
for address, aa := range b.accountAccesses {
otherAA, ok := other.accountAccesses[address]
if !ok {
return false
}
if len(aa.accesses) != len(otherAA.accesses) {
return false
}
for key, vals := range aa.accesses {
otherAccesses, ok := otherAA.accesses[key]
if !ok {
return false
}
if len(vals.writes) != len(otherAccesses.writes) {
return false
}

for i, writeVal := range vals.writes {
otherWriteVal, ok := otherAccesses.writes[i]
if !ok {
return false
}
if writeVal != otherWriteVal {
return false
}
}
}
}

// check that the code changes are equal

if len(b.codeChanges) != len(other.codeChanges) {
return false
}
for addr, codeCh := range b.codeChanges {
otherCodeCh, ok := other.codeChanges[addr]
if !ok {
return false
}
if codeCh.txIdx != otherCodeCh.txIdx {
return false
}
if bytes.Compare(codeCh.code, otherCodeCh.code) != 0 {
return false
}
}

if len(b.prestateNonces) != len(other.prestateNonces) {
return false
}
for addr, nonces := range b.prestateNonces {
otherNonces, ok := other.prestateNonces[addr]
if !ok {
return false
}

if len(nonces) != len(otherNonces) {
return false
}

for txIdx, nonce := range nonces {
otherNonce, ok := otherNonces[txIdx]
if !ok {
return false
}
if nonce != otherNonce {
return false
}
}
}

if len(b.balanceChanges) != len(other.balanceChanges) {
return false
}

for addr, balanceChanges := range b.balanceChanges {
otherBalanceChanges, ok := other.balanceChanges[addr]
if !ok {
return false
}

if len(balanceChanges) != len(otherBalanceChanges) {
return false
}

for txIdx, balanceCh := range balanceChanges {
otherBalanceCh, ok := otherBalanceChanges[txIdx]
if !ok {
return false
}

if balanceCh != otherBalanceCh {
return false
}
}
}
return true
}

// called during tx finalisation for each dirty account with changed nonce (whether by being the sender of a tx or calling CREATE)
func (b *bal) NonceDiff(account *stateObject, txIdx uint64) {
if _, ok := b.prestateNonces[account.address]; !ok {
b.prestateNonces[account.address] = make(nonceDiff)
}
var prestateNonce uint64
if account.origin != nil {
prestateNonce = account.origin.Nonce
}
b.prestateNonces[account.address][txIdx] = prestateNonce
}

// called during tx finalisation for each
func (b *bal) BalanceChange(txIdx uint64, account *stateObject) {
if _, ok := b.balanceChanges[account.address]; !ok {
b.balanceChanges[account.address] = make(balanceDiff)
}
b.balanceChanges[account.address][txIdx] = account.Balance().Clone()
}

// TODO for eip: specify that storage slots which are read/modified for accounts that are created/selfdestructed
// in same transaction aren't included in teh BAL (?)

// TODO for eip: specify that storage slots of newly-created accounts which are only read are not included in the BAL (?)

// called during tx execution every time a storage slot is read
func (b *bal) StorageRead(account *stateObject, key common.Hash) {
if _, ok := b.accountAccesses[account.address]; !ok {
b.accountAccesses[account.address] = &accountAccess{
account.address,
make(map[common.Hash]slotAccess),
bytes.Clone(account.code),
}
}
b.accountAccesses[account.address].MarkRead(key)
}

// called every time a mutated storage value is committed upon transaction finalization
func (b *bal) StorageWrite(account *stateObject, txIdx uint64, key, value common.Hash) {
if _, ok := b.accountAccesses[account.address]; !ok {
b.accountAccesses[account.address] = &accountAccess{
account.address,
make(map[common.Hash]slotAccess),
bytes.Clone(account.code),
}
}
b.accountAccesses[account.address].MarkWrite(txIdx, key, value)
}

// TODO: eip doesn't explicitly mention delegation changes being included in code changes, but they should be imo
// will assume that this was implicit for the implementation here.

// called during tx finalisation for each dirty account with mutated code
func (b *bal) CodeChange(txIdx uint64, account *stateObject) {
if _, ok := b.codeChanges[account.address]; !ok {
b.codeChanges[account.address] = codeDiff{}
}
b.codeChanges[account.address] = codeDiff{
txIdx,
bytes.Clone(account.Code()),
}
}

// NewWithReader creates a new state for the specified state root. Unlike New,
// this function accepts an additional Reader which is bound to the given root.
func NewWithReader(root common.Hash, db Database, reader Reader) (*StateDB, error) {
Expand All @@ -190,6 +418,12 @@ func NewWithReader(root common.Hash, db Database, reader Reader) (*StateDB, erro
if db.TrieDB().IsVerkle() {
sdb.accessEvents = NewAccessEvents(db.PointCache())
}
sdb.b = &bal{
make(map[common.Address]*accountAccess),
make(map[common.Address]codeDiff),
make(map[common.Address]nonceDiff),
make(map[common.Address]balanceDiff),
}
return sdb, nil
}

Expand Down Expand Up @@ -378,6 +612,9 @@ func (s *StateDB) GetCodeHash(addr common.Address) common.Hash {
func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
stateObject := s.getStateObject(addr)
if stateObject != nil {
if s.b != nil {
s.b.StorageRead(stateObject, hash)
}
return stateObject.GetState(hash)
}
return common.Hash{}
Expand Down Expand Up @@ -758,6 +995,21 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) {
s.stateObjectsDestruct[obj.address] = obj
}
} else {
if s.b != nil {
for key, val := range obj.dirtyStorage {
s.b.StorageWrite(obj, uint64(s.txIndex), key, val)
}
if obj.origin == nil || obj.origin.Balance.Cmp(obj.Balance()) != 0 {
s.b.BalanceChange(uint64(s.txIndex), obj)
}
if obj.origin == nil || obj.origin.Nonce != obj.Nonce() {
s.b.NonceDiff(obj, uint64(s.txIndex))
}
if obj.origin == nil || bytes.Compare(obj.origin.CodeHash, obj.CodeHash()) != 0 {
s.b.CodeChange(uint64(s.txIndex), obj)
}
}

obj.finalise()
s.markUpdate(addr)
}
Expand Down
Loading
Loading