Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/ethereum_test_fixtures/pre_alloc_groups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pre-allocation group models for test fixture generation."""

import json
from pathlib import Path
from typing import Any, Dict, List

Expand Down Expand Up @@ -57,10 +58,26 @@ def to_file(self, file: Path) -> None:
if file.exists():
with open(file, "r") as f:
previous_pre_alloc_group = PreAllocGroup.model_validate_json(f.read())
for account in previous_pre_alloc_group.pre:
if account not in self.pre:
self.pre[account] = previous_pre_alloc_group.pre[account]
self.test_ids.extend(previous_pre_alloc_group.test_ids)
for account in previous_pre_alloc_group.pre:
existing_account = previous_pre_alloc_group.pre[account]
if account not in self.pre:
self.pre[account] = existing_account
else:
new_account = self.pre[account]
if new_account != existing_account:
# This procedure fails during xdist worker's pytest_sessionfinish
# and is not reported to the master thread.
# We signal here that the groups created contain a collision.
collision_file_path = file.with_suffix(".fail")
collision_exception = Alloc.CollisionError(
address=account,
account_1=existing_account,
account_2=new_account,
)
with open(collision_file_path, "w") as f:
f.write(json.dumps(collision_exception.to_json()))
raise collision_exception
self.test_ids.extend(previous_pre_alloc_group.test_ids)

with open(file, "w") as f:
f.write(self.model_dump_json(by_alias=True, exclude_none=True, indent=2))
Expand All @@ -78,6 +95,10 @@ def __setitem__(self, key: str, value: Any):
@classmethod
def from_folder(cls, folder: Path) -> "PreAllocGroups":
"""Create PreAllocGroups from a folder of pre-allocation files."""
# First check for collision failures
for fail_file in folder.glob("*.fail"):
with open(fail_file) as f:
raise Alloc.CollisionError.from_json(json.loads(f.read()))
data = {}
for file in folder.glob("*.json"):
with open(file) as f:
Expand Down
2 changes: 1 addition & 1 deletion src/ethereum_test_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def update_pre_alloc_groups(
group.pre = Alloc.merge(
group.pre,
self.pre,
allow_key_collision=True,
key_collision_mode=Alloc.KeyCollisionMode.ALLOW_IDENTICAL_ACCOUNTS,
)
group.fork = fork
group.test_ids.append(str(test_id))
Expand Down
94 changes: 74 additions & 20 deletions src/ethereum_test_types/account_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Account-related types for Ethereum tests."""

import json
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple
from enum import Enum, auto
from typing import Any, Dict, List, Literal, Optional, Self, Tuple

from coincurve.keys import PrivateKey
from ethereum_types.bytes import Bytes20
from ethereum_types.numeric import U256, Bytes32, Uint
from pydantic import PrivateAttr
from typing_extensions import Self

from ethereum_test_base_types import (
Account,
Expand Down Expand Up @@ -144,12 +145,6 @@ class UnexpectedAccountError(Exception):
address: Address
account: Account | None

def __init__(self, address: Address, account: Account | None, *args):
"""Initialize the exception."""
super().__init__(args)
self.address = address
self.account = account

def __str__(self):
"""Print exception string."""
return f"unexpected account in allocation {self.address}: {self.account}"
Expand All @@ -160,25 +155,82 @@ class MissingAccountError(Exception):

address: Address

def __init__(self, address: Address, *args):
"""Initialize the exception."""
super().__init__(args)
self.address = address

def __str__(self):
"""Print exception string."""
return f"Account missing from allocation {self.address}"

@dataclass(kw_only=True)
class CollisionError(Exception):
"""Different accounts at the same address."""

address: Address
account_1: Account | None
account_2: Account | None

def to_json(self) -> Dict[str, Any]:
"""Dump to json object."""
return {
"address": self.address.hex(),
"account_1": self.account_1.model_dump(mode="json")
if self.account_1 is not None
else None,
"account_2": self.account_2.model_dump(mode="json")
if self.account_2 is not None
else None,
}

@classmethod
def from_json(cls, obj: Dict[str, Any]) -> Self:
"""Parse from a json dict."""
return cls(
address=Address(obj["address"]),
account_1=Account.model_validate(obj["account_1"])
if obj["account_1"] is not None
else None,
account_2=Account.model_validate(obj["account_2"])
if obj["account_2"] is not None
else None,
)

def __str__(self) -> str:
"""Print exception string."""
return (
"Overlapping key defining different accounts detected:\n"
f"{json.dumps(self.to_json(), indent=2)}"
)

class KeyCollisionMode(Enum):
"""Mode for handling key collisions when merging allocations."""

ERROR = auto()
OVERWRITE = auto()
ALLOW_IDENTICAL_ACCOUNTS = auto()

@classmethod
def merge(
cls, alloc_1: "Alloc", alloc_2: "Alloc", allow_key_collision: bool = True
cls,
alloc_1: "Alloc",
alloc_2: "Alloc",
key_collision_mode: KeyCollisionMode = KeyCollisionMode.OVERWRITE,
) -> "Alloc":
"""Return merged allocation of two sources."""
overlapping_keys = alloc_1.root.keys() & alloc_2.root.keys()
if overlapping_keys and not allow_key_collision:
raise Exception(
f"Overlapping keys detected: {[key.hex() for key in overlapping_keys]}"
)
if overlapping_keys:
if key_collision_mode == cls.KeyCollisionMode.ERROR:
raise Exception(
f"Overlapping keys detected: {[key.hex() for key in overlapping_keys]}"
)
elif key_collision_mode == cls.KeyCollisionMode.ALLOW_IDENTICAL_ACCOUNTS:
# The overlapping keys must point to the exact same account
for key in overlapping_keys:
account_1 = alloc_1[key]
account_2 = alloc_2[key]
if account_1 != account_2:
raise Alloc.CollisionError(
address=key,
account_1=account_1,
account_2=account_2,
)
merged = alloc_1.model_dump()

for address, other_account in alloc_2.root.items():
Expand Down Expand Up @@ -267,15 +319,17 @@ def verify_post_alloc(self, got_alloc: "Alloc"):
if account is None:
# Account must not exist
if address in got_alloc.root and got_alloc.root[address] is not None:
raise Alloc.UnexpectedAccountError(address, got_alloc.root[address])
raise Alloc.UnexpectedAccountError(
address=address, account=got_alloc.root[address]
)
else:
if address in got_alloc.root:
got_account = got_alloc.root[address]
assert isinstance(got_account, Account)
assert isinstance(account, Account)
account.check_alloc(address, got_account)
else:
raise Alloc.MissingAccountError(address)
raise Alloc.MissingAccountError(address=address)

def deploy_contract(
self,
Expand Down
Loading