Skip to content

Reduce memory usage of tests #1606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eth/precompiles/ecadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ecadd(computation: BaseComputation) -> BaseComputation:
computation.consume_gas(constants.GAS_ECADD, reason='ECADD Precompile')

try:
result = _ecadd(computation.msg.data)
result = _ecadd(computation.msg.data_as_bytes)
except ValidationError:
raise VMError("Invalid ECADD parameters")

Expand Down
2 changes: 1 addition & 1 deletion eth/precompiles/ecmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ecmul(computation: BaseComputation) -> BaseComputation:
computation.consume_gas(constants.GAS_ECMUL, reason='ECMUL Precompile')

try:
result = _ecmull(computation.msg.data)
result = _ecmull(computation.msg.data_as_bytes)
except ValidationError:
raise VMError("Invalid ECMUL parameters")

Expand Down
6 changes: 5 additions & 1 deletion eth/precompiles/ecpairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
VMError,
)

from eth.typing import (
BytesOrView,
)

from eth.utils.bn128 import (
validate_point,
FQP_point_to_FQ2_point,
Expand Down Expand Up @@ -60,7 +64,7 @@ def ecpairing(computation: BaseComputation) -> BaseComputation:
return computation


def _ecpairing(data: bytes) -> bool:
def _ecpairing(data: BytesOrView) -> bool:
exponent = bn128.FQ12.one()

processing_pipeline = (
Expand Down
9 changes: 5 additions & 4 deletions eth/precompiles/ecrecover.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@

def ecrecover(computation: BaseComputation) -> BaseComputation:
computation.consume_gas(constants.GAS_ECRECOVER, reason="ECRecover Precompile")
raw_message_hash = computation.msg.data[:32]
data = computation.msg.data_as_bytes
raw_message_hash = data[:32]
message_hash = pad32r(raw_message_hash)

v_bytes = pad32r(computation.msg.data[32:64])
v_bytes = pad32r(data[32:64])
v = big_endian_to_int(v_bytes)

r_bytes = pad32r(computation.msg.data[64:96])
r_bytes = pad32r(data[64:96])
r = big_endian_to_int(r_bytes)

s_bytes = pad32r(computation.msg.data[96:128])
s_bytes = pad32r(data[96:128])
s = big_endian_to_int(s_bytes)

try:
Expand Down
2 changes: 1 addition & 1 deletion eth/precompiles/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ def identity(computation: BaseComputation) -> BaseComputation:

computation.consume_gas(gas_fee, reason="Identity Precompile")

computation.output = computation.msg.data
computation.output = computation.msg.data_as_bytes
return computation
8 changes: 5 additions & 3 deletions eth/precompiles/modexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ def modexp(computation: BaseComputation) -> BaseComputation:
"""
https://github.com/ethereum/EIPs/pull/198
"""
gas_fee = _compute_modexp_gas_fee(computation.msg.data)
data = computation.msg.data_as_bytes

gas_fee = _compute_modexp_gas_fee(data)
computation.consume_gas(gas_fee, reason='MODEXP Precompile')

result = _modexp(computation.msg.data)
result = _modexp(data)

_, _, modulus_length = _extract_lengths(computation.msg.data)
_, _, modulus_length = _extract_lengths(data)

# Modulo 0 is undefined, return zero
# https://math.stackexchange.com/questions/516251/why-is-n-mod-0-undefined
Expand Down
2 changes: 2 additions & 0 deletions eth/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@

GenesisDict = Dict[str, Union[int, BlockNumber, bytes, Hash32]]

BytesOrView = Union[bytes, memoryview]

Normalizer = Callable[[Dict[Any, Any]], Dict[str, Any]]

RawAccountDetails = TypedDict('RawAccountDetails',
Expand Down
12 changes: 12 additions & 0 deletions eth/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
UINT_256_MAX,
)

from eth.typing import (
BytesOrView,
)

if TYPE_CHECKING:
from eth.vm.base import BaseVM # noqa: F401

Expand All @@ -51,6 +55,14 @@ def validate_is_bytes(value: bytes, title: str="Value") -> None:
)


def validate_is_bytes_or_view(value: BytesOrView, title: str="Value") -> None:
if isinstance(value, (bytes, memoryview)):
return
raise ValidationError(
"{title} must be bytes or memoryview. Got {0}".format(type(value), title=title)
)


def validate_is_integer(value: Union[int, bool], title: str="Value") -> None:
if not isinstance(value, int) or isinstance(value, bool):
raise ValidationError(
Expand Down
15 changes: 12 additions & 3 deletions eth/vm/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
Halt,
VMError,
)
from eth.typing import (
BytesOrView,
)
from eth.tools.logging import (
ExtendedDebugLogger,
)
Expand Down Expand Up @@ -243,12 +246,18 @@ def memory_write(self, start_position: int, size: int, value: bytes) -> None:
"""
return self._memory.write(start_position, size, value)

def memory_read(self, start_position: int, size: int) -> bytes:
def memory_read(self, start_position: int, size: int) -> memoryview:
"""
Read and return ``size`` bytes from memory starting at ``start_position``.
Read and return a view of ``size`` bytes from memory starting at ``start_position``.
"""
return self._memory.read(start_position, size)

def memory_read_bytes(self, start_position: int, size: int) -> bytes:
"""
Read and return ``size`` bytes from memory starting at ``start_position``.
"""
return self._memory.read_bytes(start_position, size)

#
# Gas Consumption
#
Expand Down Expand Up @@ -360,7 +369,7 @@ def prepare_child_message(self,
gas: int,
to: Address,
value: int,
data: bytes,
data: BytesOrView,
code: bytes,
**kwargs: Any) -> Message:
"""
Expand Down
6 changes: 4 additions & 2 deletions eth/vm/logic/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def calldataload(computation: BaseComputation) -> None:
"""
start_position = computation.stack_pop(type_hint=constants.UINT256)

value = computation.msg.data[start_position:start_position + 32]
value = computation.msg.data_as_bytes[start_position:start_position + 32]
padded_value = value.ljust(32, b'\x00')
normalized_value = padded_value.lstrip(b'\x00')

Expand All @@ -68,7 +68,9 @@ def calldatacopy(computation: BaseComputation) -> None:

computation.consume_gas(copy_gas_cost, reason="CALLDATACOPY fee")

value = computation.msg.data[calldata_start_position: calldata_start_position + size]
value = computation.msg.data_as_bytes[
calldata_start_position: calldata_start_position + size
]
padded_value = value.ljust(size, b'\x00')

computation.memory_write(mem_start_position, size, padded_value)
Expand Down
2 changes: 1 addition & 1 deletion eth/vm/logic/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def log_XX(computation: BaseComputation, topic_count: int) -> None:
)

computation.extend_memory(mem_start_position, size)
log_data = computation.memory_read(mem_start_position, size)
log_data = computation.memory_read_bytes(mem_start_position, size)

computation.add_log_entry(
account=computation.msg.storage_address,
Expand Down
2 changes: 1 addition & 1 deletion eth/vm/logic/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def mload(computation: BaseComputation) -> None:

computation.extend_memory(start_position, 32)

value = computation.memory_read(start_position, 32)
value = computation.memory_read_bytes(start_position, 32)
computation.stack_push(value)


Expand Down
2 changes: 1 addition & 1 deletion eth/vm/logic/sha3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def sha3(computation: BaseComputation) -> None:

computation.extend_memory(start_position, size)

sha3_bytes = computation.memory_read(start_position, size)
sha3_bytes = computation.memory_read_bytes(start_position, size)
word_count = ceil32(len(sha3_bytes)) // 32

gas_cost = constants.GAS_SHA3WORD * word_count
Expand Down
10 changes: 5 additions & 5 deletions eth/vm/logic/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def return_op(computation: BaseComputation) -> None:

computation.extend_memory(start_position, size)

output = computation.memory_read(start_position, size)
computation.output = bytes(output)
computation.output = computation.memory_read_bytes(start_position, size)
raise Halt('RETURN')


Expand All @@ -42,8 +41,7 @@ def revert(computation: BaseComputation) -> None:

computation.extend_memory(start_position, size)

output = computation.memory_read(start_position, size)
computation.output = bytes(output)
computation.output = computation.memory_read_bytes(start_position, size)
raise Revert(computation.output)


Expand Down Expand Up @@ -163,7 +161,9 @@ def __call__(self, computation: BaseComputation) -> None:
computation.stack_push(0)
return

call_data = computation.memory_read(stack_data.memory_start, stack_data.memory_length)
call_data = computation.memory_read_bytes(
stack_data.memory_start, stack_data.memory_length
)

create_msg_gas = self.max_child_gas_modifier(
computation.get_gas_remaining()
Expand Down
21 changes: 18 additions & 3 deletions eth/vm/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ def extend(self, start_position: int, size: int) -> None:
return

size_to_extend = new_size - len(self)
self._bytes.extend(itertools.repeat(0, size_to_extend))
try:
self._bytes.extend(itertools.repeat(0, size_to_extend))
except BufferError:
# we can't extend the buffer (which might involve relocating it) if a
# memoryview (which stores a pointer into the buffer) has been created by
# read() and not released. Callers of read() will never try to write to the
# buffer so we're not missing anything by making a new buffer and forgetting
# about the old one. We're keeping too much memory around but this is still a
# net savings over having read() return a new bytes() object every time.
self._bytes = self._bytes + bytearray(size_to_extend)

def __len__(self) -> int:
return len(self._bytes)
Expand All @@ -51,8 +60,14 @@ def write(self, start_position: int, size: int, value: bytes) -> None:
for idx, v in enumerate(value):
self._bytes[start_position + idx] = v

def read(self, start_position: int, size: int) -> bytes:
def read(self, start_position: int, size: int) -> memoryview:
"""
Read a value from memory.
Return a view into the memory
"""
return memoryview(self._bytes)[start_position:start_position + size]

def read_bytes(self, start_position: int, size: int) -> bytes:
"""
Read a value from memory and return a fresh bytes instance
"""
return bytes(self._bytes[start_position:start_position + size])
12 changes: 10 additions & 2 deletions eth/vm/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from eth.constants import (
CREATE_CONTRACT_ADDRESS,
)
from eth.typing import (
BytesOrView,
)
from eth.validation import (
validate_canonical_address,
validate_is_bytes,
validate_is_bytes_or_view,
validate_is_integer,
validate_gte,
validate_uint256,
Expand All @@ -31,7 +35,7 @@ def __init__(self,
to: Address,
sender: Address,
value: int,
data: bytes,
data: BytesOrView,
code: bytes,
depth: int=0,
create_address: Address=None,
Expand All @@ -51,7 +55,7 @@ def __init__(self,
validate_uint256(value, title="Message.value")
self.value = value

validate_is_bytes(data, title="Message.data")
validate_is_bytes_or_view(data, title="Message.data")
self.data = data

validate_is_integer(depth, title="Message.depth")
Expand Down Expand Up @@ -100,3 +104,7 @@ def storage_address(self, value: Address) -> None:
@property
def is_create(self) -> bool:
return self.to == CREATE_CONTRACT_ADDRESS

@property
def data_as_bytes(self) -> bytes:
return bytes(self.data)