@@ -44,17 +44,22 @@ class Print(Flag):
4444class TraceConfig :
4545 """Tracing configuration."""
4646
47- def __init__ (self ,
48- mem_address : Optional [LeakageModel ] = None ,
49- mem_value : Optional [LeakageModel ] = None ,
50- register : Optional [LeakageModel ] = None ,
51- instruction : bool = False ,
52- ignored_registers : Optional [Set [str ]] = None ):
47+ def __init__ (
48+ self ,
49+ mem_address : Optional [LeakageModel ] = None ,
50+ mem_value : Optional [LeakageModel ] = None ,
51+ register : Optional [LeakageModel ] = None ,
52+ instruction : bool = False ,
53+ ignored_registers : Optional [Set [str ]] = None ,
54+ only_when_registers : bool = True ,
55+ ):
5356 self .mem_address = mem_address
5457 self .mem_value = mem_value
5558 self .register = register
5659 self .instructions = instruction
5760 self .ignored_registers = ignored_registers
61+ # When true, only output register leakage when at least one register is written
62+ self .only_when_registers = only_when_registers
5863
5964
6065class Rainbow (abc .ABC ):
@@ -83,18 +88,21 @@ class Rainbow(abc.ABC):
8388 ENDIANNESS : str
8489 PC : int
8590
86- last_regs : Optional [ List [str ] ]
91+ last_regs : List [str ]
8792 last_reg_values : Optional [Dict [str , int ]]
8893 last_address : Optional [int ]
8994 last_value : Optional [int ]
9095 trace : List [Any ]
9196
92- block_hook : Optional [int ]
9397 mem_hook : Optional [int ]
9498 code_hook : Optional [int ]
9599
96- def __init__ (self , print_config : Print = Print (0 ), trace_config : TraceConfig = TraceConfig (),
97- allow_breakpoints : bool = False , allow_stubs : bool = False ):
100+ def __init__ (
101+ self ,
102+ print_config : Print = Print (0 ),
103+ trace_config : TraceConfig = TraceConfig (),
104+ allow_breakpoints : bool = False ,
105+ ):
98106 self .breakpoints = set ()
99107 self .functions = {}
100108 self .function_names = {}
@@ -104,7 +112,6 @@ def __init__(self, print_config: Print = Print(0), trace_config: TraceConfig = T
104112 self .print_config = print_config
105113 self .trace_config = trace_config
106114 self .allow_breakpoints = allow_breakpoints
107- self .allow_stubs = allow_stubs
108115
109116 # Leak storage
110117 self .last_reg_values = {}
@@ -311,12 +318,17 @@ def start_and_fault(self, fault_model, fault_index: int, begin: int, end: int, *
311318 return pc_fault
312319
313320 def setup (self ):
314- """Add base hooks to the engine."""
315- # We need the block hook only if we are
316- # printing functions or need to handle stubs.
317- # if self.print_config & Print.Functions or self.allow_stubs:
318- self .block_hook = self .emu .hook_add (uc .UC_HOOK_BLOCK ,
319- HookWeakMethod (self ._block_hook ))
321+ """Setup engine hooks."""
322+ # Hook functions calls for printing
323+ if self .print_config & Print .Functions :
324+ for addr , name in self .function_names .items ():
325+ self .emu .hook_add (
326+ uc .UC_HOOK_BLOCK ,
327+ self ._print_function_hook ,
328+ begin = addr ,
329+ end = addr ,
330+ user_data = name ,
331+ )
320332
321333 # We need the mem hook only if we are
322334 # printing memory or tracing memory values or addresses.
@@ -394,91 +406,72 @@ def disassemble_single_detailed(self, addr: int, size: int) -> cs.CsInsn:
394406 insn = self .emu .mem_read (addr , 2 * size )
395407 return self ._disassemble_cache (self .disasm .disasm , bytes (insn ), addr )
396408
409+ def _get_addrs (self , name_or_addr ):
410+ if isinstance (name_or_addr , str ):
411+ # Stub all function addresses matching this name
412+ addrs = [a for a , n in self .function_names .items () if n == name_or_addr ]
413+ if not addrs :
414+ raise IndexError (f"'{ name_or_addr } ' could not be found." )
415+ return addrs
416+ elif isinstance (name_or_addr , int ):
417+ # Name is only one address
418+ return [name_or_addr ]
419+ raise TypeError ("name_or_addr should be function name or address" )
420+
421+ def _stub_hook (self , _uci , _address , _size , userdata ):
422+ """Call user stub set up with hook_prolog/hook_bypass."""
423+ fn , bypass = userdata
424+ if fn is not None :
425+ fn (self )
426+ if bypass :
427+ # Make the function return early
428+ self .return_force ()
429+
397430 def hook_prolog (self , name , fn ):
398431 """
399432 Add a call to function 'fn' when 'name' is called during execution.
400433 After executing 'fn, execution resumes into 'name'.
401434 """
402- if not self .allow_stubs :
403- raise ValueError ("Cannot use stubs, allow_stubs is False." )
404-
405- def to_hook (x ):
406- if fn is not None :
407- fn (x )
408- return False
409-
410- if isinstance (name , str ):
411- # Stub all function addresses matching this name
412- addrs = [a for a , n in self .function_names .items () if n == name ]
413- if not addrs :
414- raise IndexError (f"'{ name } ' could not be found." )
415- for addr in addrs :
416- self .stubbed_functions [addr ] = to_hook
417- elif isinstance (name , int ):
418- # Name is an address
419- self .stubbed_functions [name ] = to_hook
420- else :
421- raise TypeError ("name should be function name or address" )
435+ for addr in self ._get_addrs (name ):
436+ self .stubbed_functions [addr ] = self .emu .hook_add (
437+ uc .UC_HOOK_BLOCK ,
438+ HookWeakMethod (self ._stub_hook ),
439+ begin = addr ,
440+ end = addr ,
441+ user_data = (fn , False ),
442+ )
422443
423444 def hook_bypass (self , name , fn = None ):
424445 """
425446 Add a call to function 'fn' when 'name' is called during execution.
426447 After executing 'fn', execution returns to the caller.
427448 """
428- if not self .allow_stubs :
429- raise ValueError ("Cannot use stubs, allow_stubs is False." )
430-
431- def to_hook (x ):
432- if fn is not None :
433- fn (x )
434- return True
435-
436- if isinstance (name , str ):
437- # Stub all function addresses matching this name
438- addrs = [a for a , n in self .function_names .items () if n == name ]
439- if not addrs :
440- raise IndexError (f"'{ name } ' could not be found." )
441- for addr in addrs :
442- self .stubbed_functions [addr ] = to_hook
443- elif isinstance (name , int ):
444- # Name is an address
445- self .stubbed_functions [name ] = to_hook
446- else :
447- raise TypeError ("name should be function name or address" )
449+ for addr in self ._get_addrs (name ):
450+ self .stubbed_functions [addr ] = self .emu .hook_add (
451+ uc .UC_HOOK_BLOCK ,
452+ HookWeakMethod (self ._stub_hook ),
453+ begin = addr ,
454+ end = addr ,
455+ user_data = (fn , True ),
456+ )
448457
449458 def remove_hook (self , name ):
450459 """Remove the hook."""
451- if not self .allow_stubs :
452- raise ValueError ("Cannot use stubs, allow_stubs is False." )
453- del self .stubbed_functions [name ]
460+ for addr in self ._get_addrs (name ):
461+ if addr in self .stubbed_functions :
462+ self .emu .hook_del (self .stubbed_functions [addr ])
463+ del self .stubbed_functions [addr ]
454464
455465 def remove_hooks (self ):
456- """Remove the hooked functions."""
457- if not self .allow_stubs :
458- raise ValueError ( "Cannot use stubs, allow_stubs is False." )
459- self .stubbed_functions = {}
466+ """Remove all hooked functions."""
467+ for addr , hook in self .stubbed_functions . items () :
468+ self . emu . hook_del ( hook )
469+ del self .stubbed_functions [ addr ]
460470
461- def _block_hook (self , _uci , address : int , _size , _ ):
462- """
463- Hook called on every jump to a basic block that checks if a known
464- address+function is redefined in the user's python script and if so,
465- calls that instead.
466- """
467- # Print function calls
468- if address in self .function_names and (self .allow_stubs or self .print_config & Print .Functions ):
469- # Handle the function call printing
470- f = self .function_names [address ]
471- if self .print_config & Print .Functions :
472- print (f"{ color ('MAGENTA' , f )} (...) @ 0x{ address :x} " )
473-
474- # If stub is enabled and set at this address, run it
475- if self .allow_stubs :
476- stub_func = self .stubbed_functions .get (address )
477- if stub_func is not None :
478- r = stub_func (self )
479- if r :
480- # If stub returns True, then make the function return early
481- self .return_force ()
471+ @staticmethod
472+ def _print_function_hook (_uci , address : int , _size , name : str ):
473+ """Print function call."""
474+ print (f"{ color ('MAGENTA' , name )} (...) @ 0x{ address :x} " )
482475
483476 def _mem_hook (self , uci , access , address , size , value , _ ):
484477 # Get the value
@@ -566,7 +559,7 @@ def _code_hook(self, uci, address, size, _):
566559 # - last_reg_values are register values as they were before the previous instruction
567560 #
568561 # So we need to go over last_regs, get their prev values from last_reg_values and get their current values.
569- if self .last_regs :
562+ if self .last_regs or not self . trace_config . only_when_registers :
570563 reg_values = {r : uci .reg_read (self .REGS [r ]) for r in self .last_regs }
571564 leak = sum (
572565 self .trace_config .register (reg_values [r ], self .last_reg_values .get (r , 0 )) for r in self .last_regs )
@@ -580,19 +573,18 @@ def _code_hook(self, uci, address, size, _):
580573 if ins is None :
581574 ins = self .disassemble_single_detailed (address , size )
582575 _ , regs_written = ins .regs_access ()
583- if regs_written :
584- regs = list (filter (lambda r : r not in self .IGNORED_REGS and (
585- not self .trace_config .ignored_registers or r not in self .trace_config .ignored_registers ),
586- map (ins .reg_name , regs_written ))) # type: ignore
587- else :
588- regs = None
576+ regs = list (filter (lambda r : r not in self .IGNORED_REGS and (
577+ not self .trace_config .ignored_registers or r not in self .trace_config .ignored_registers ),
578+ map (ins .reg_name , regs_written ))) # type: ignore
589579
590580 if self .trace_config .instructions :
591- if ins is None :
592- ins = self .disassemble_single_detailed (address , size )
581+ if ins is not None :
582+ address , mnemonic , op_str = ins .address , ins .mnemonic , ins .op_str
583+ else :
584+ address , _size , mnemonic , op_str = self .disassemble_single (address , size )
593585 if event is None :
594586 event = {"type" : "code" }
595- event ["instruction" ] = f"{ ins . address :8X} { ins . mnemonic :<6} { ins . op_str } "
587+ event ["instruction" ] = f"{ address :8X} { mnemonic :<6} { op_str } "
596588 if event is not None :
597589 self .trace .append (event )
598590 self .last_regs = regs
0 commit comments