28
28
from trie import HexaryTrie
29
29
30
30
from ethereum_test_forks import Fork
31
+ from evm_transition_tool import EVMTransactionTrace
31
32
32
33
from ..exceptions import ExceptionList , TransactionException
33
34
from .base_types import Address , Bytes , Hash , HexNumber , Number , ZeroPaddedHexNumber
@@ -63,6 +64,55 @@ def __repr__(self) -> str:
63
64
return "auto"
64
65
65
66
67
+ class TraceableException (Exception ):
68
+ """
69
+ Exception that can use a trace to provide more information.
70
+ """
71
+
72
+ traces : List [List [EVMTransactionTrace ]] | None
73
+
74
+ def set_traces (self , traces : List [List [EVMTransactionTrace ]]):
75
+ """
76
+ Set the traces for the exception.
77
+ """
78
+ self .traces = traces
79
+
80
+ def get_trace_context (
81
+ self ,
82
+ * ,
83
+ context_address : Optional [Address ] = None ,
84
+ opcode_name : Optional [str ] = None ,
85
+ stack_item_0 : Optional [int ] = None ,
86
+ previous_lines : int = 0 ,
87
+ ) -> List [str ]:
88
+ """
89
+ Returns a list of strings with the context of the exception.
90
+ """
91
+ lines = []
92
+ if self .traces :
93
+ for execution_trace in self .traces :
94
+ for tx_trace in execution_trace :
95
+ for i , trace in enumerate (tx_trace .trace_lines ()):
96
+ if (
97
+ (context_address is None or trace .context_address == context_address )
98
+ and (opcode_name is None or trace .opcode_name == opcode_name )
99
+ and (
100
+ stack_item_0 is None
101
+ or (len (trace .stack ) > 0 and trace .stack [- 1 ] == stack_item_0 )
102
+ )
103
+ ):
104
+ lines += [
105
+ * [
106
+ f" { dict (trace )} "
107
+ for trace in tx_trace .trace_lines ()[
108
+ max (0 , i - previous_lines ) : i
109
+ ]
110
+ ],
111
+ f" { dict (trace )} " ,
112
+ ]
113
+ return lines
114
+
115
+
66
116
MAX_STORAGE_KEY_VALUE = 2 ** 256 - 1
67
117
MIN_STORAGE_KEY_VALUE = - (2 ** 255 )
68
118
@@ -148,23 +198,31 @@ def __str__(self):
148
198
"""
149
199
150
200
@dataclass (kw_only = True )
151
- class MissingKey (Exception ):
201
+ class MissingKey (TraceableException ):
152
202
"""
153
203
Test expected to find a storage key set but key was missing.
154
204
"""
155
205
206
+ address : Address
156
207
key : int
157
208
158
- def __init__ (self , key : int , * args ):
209
+ def __init__ (self , address : Address , key : int , * args ):
159
210
super ().__init__ (args )
211
+ self .address = address
160
212
self .key = key
161
213
162
214
def __str__ (self ):
163
- """Print exception string"""
164
- return "key {0} not found in storage" .format (Storage .key_value_to_string (self .key ))
215
+ """Print exception string lines"""
216
+ lines = [
217
+ f"key { Storage .key_value_to_string (self .key )} not found in"
218
+ + f"storage of { self .address } "
219
+ ]
220
+ if self .traces :
221
+ pass
222
+ return "\n " .join (lines )
165
223
166
224
@dataclass (kw_only = True )
167
- class KeyValueMismatch (Exception ):
225
+ class KeyValueMismatch (TraceableException ):
168
226
"""
169
227
Test expected a certain value in a storage key but value found
170
228
was different.
@@ -183,13 +241,23 @@ def __init__(self, address: Address, key: int, want: int, got: int, *args):
183
241
self .got = got
184
242
185
243
def __str__ (self ):
186
- """Print exception string"""
187
- return (
244
+ """Print exception string lines """
245
+ lines = [
188
246
f"incorrect value in address { self .address } for "
189
- + f"key { Storage .key_value_to_string (self .key )} :"
190
- + f" want { Storage .key_value_to_string (self .want )} (dec:{ self .want } ),"
191
- + f" got { Storage .key_value_to_string (self .got )} (dec:{ self .got } )"
192
- )
247
+ + f"key { Storage .key_value_to_string (self .key )} :" ,
248
+ f"want: { Storage .key_value_to_string (self .want )} (dec:{ self .want } )" ,
249
+ f"got { Storage .key_value_to_string (self .got )} (dec:{ self .got } )" ,
250
+ ]
251
+ if self .traces :
252
+ lines += ["" , "Relevant EVM traces:" ]
253
+ lines += self .get_trace_context (
254
+ context_address = self .address ,
255
+ opcode_name = "SSTORE" ,
256
+ stack_item_0 = self .key ,
257
+ previous_lines = 2 ,
258
+ )
259
+
260
+ return "\n " .join (lines )
193
261
194
262
@staticmethod
195
263
def parse_key_value (input : str | int | bytes | SupportsBytes ) -> int :
@@ -309,7 +377,7 @@ def must_contain(self, address: Address, other: "Storage"):
309
377
if key not in self :
310
378
# storage[key]==0 is equal to missing storage
311
379
if other [key ] != 0 :
312
- raise Storage .MissingKey (key = key )
380
+ raise Storage .MissingKey (address = address , key = key )
313
381
elif self [key ] != other [key ]:
314
382
raise Storage .KeyValueMismatch (
315
383
address = address , key = key , want = self [key ], got = other [key ]
0 commit comments