Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions slither/analyses/data_flow/analyses/reentrancy/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import Optional, Set, Union

from slither.analyses.data_flow.analyses.reentrancy.analysis.domain import (
DomainVariant,
ReentrancyDomain,
)
from slither.analyses.data_flow.analyses.reentrancy.core.state import State
from slither.analyses.data_flow.engine.analysis import Analysis
from slither.analyses.data_flow.engine.direction import Direction, Forward
from slither.analyses.data_flow.engine.domain import Domain
from slither.core.cfg.node import Node
from slither.core.declarations.function import Function
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations.event_call import EventCall
from slither.slithir.operations.high_level_call import HighLevelCall
from slither.slithir.operations.internal_call import InternalCall
from slither.slithir.operations.low_level_call import LowLevelCall
from slither.slithir.operations.operation import Operation
from slither.slithir.operations.send import Send
from slither.slithir.operations.transfer import Transfer


class ReentrancyAnalysis(Analysis):
def __init__(self):
self._direction = Forward()

def domain(self) -> Domain:
return ReentrancyDomain.bottom()

def direction(self) -> Direction:
return self._direction

def bottom_value(self) -> Domain:
return ReentrancyDomain.bottom()

def transfer_function(self, node: Node, domain: ReentrancyDomain, operation: Operation):
self.transfer_function_helper(node, domain, operation, private_functions_seen=set())

def transfer_function_helper(
self,
node: Node,
domain: ReentrancyDomain,
operation: Operation,
private_functions_seen: Optional[Set[Function]] = None,
):
if private_functions_seen is None:
private_functions_seen = set()

if domain.variant == DomainVariant.BOTTOM:
domain.variant = DomainVariant.STATE
domain.state = State()

self._analyze_operation_by_type(operation, domain, node, private_functions_seen)

def _analyze_operation_by_type(
self,
operation: Operation,
domain: ReentrancyDomain,
node: Node,
private_functions_seen: Set[Function],
):
if isinstance(operation, EventCall):
self._handle_event_call_operation(operation, domain)
elif isinstance(operation, InternalCall):
self._handle_internal_call_operation(operation, domain, private_functions_seen)
elif isinstance(operation, (HighLevelCall, LowLevelCall, Transfer, Send)):
self._handle_abi_call_contract_operation(operation, domain, node)

self._handle_storage(domain, node)
self._update_writes_after_calls(domain, node)

def _handle_storage(self, domain: ReentrancyDomain, node: Node):
# Track state reads
for var in node.state_variables_read:
if isinstance(var, StateVariable) and var.is_stored:
domain.state.add_read(var, node)
# Track state writes
for var in node.state_variables_written:
if isinstance(var, StateVariable) and var.is_stored:
domain.state.add_written(var, node)

def _update_writes_after_calls(self, domain: ReentrancyDomain, node: Node):
# Writes after any external call
if node in domain.state.calls:
for var_name, write_nodes in domain.state.written.items():
for wn in write_nodes:
domain.state.add_write_after_call(var_name, wn)
# Writes after ETH-sending calls
if node in domain.state.send_eth:
for var_name, write_nodes in domain.state.written.items():
for wn in write_nodes:
domain.state.add_write_after_call(var_name, wn)

def _handle_internal_call_operation(
self,
operation: InternalCall,
domain: ReentrancyDomain,
private_functions_seen: Set[Function],
):
function = operation.function
if not isinstance(function, Function) or function in private_functions_seen:
return

private_functions_seen.add(function)
for node in function.nodes:
for internal_operation in node.irs:
if isinstance(internal_operation, (HighLevelCall, LowLevelCall, Transfer, Send)):
continue
self.transfer_function_helper(
node,
domain,
internal_operation,
private_functions_seen,
)
# Mark cross-function reentrancy for written variables
for var_name in domain.state.written.keys():
domain.state.add_cross_function(var_name, function)

def _handle_abi_call_contract_operation(
self,
operation: Union[LowLevelCall, HighLevelCall, Send, Transfer],
domain: ReentrancyDomain,
node: Node,
):
# Track all external calls - avoid duplicates
if operation.node not in domain.state.calls.get(node, set()):
domain.state.add_call(node, operation.node)

# Track variables read prior to this call
for var_name in domain.state.reads.keys():
domain.state.add_reads_prior_calls(node, var_name)

# Track external calls that send ETH - avoid duplicates
if operation.can_send_eth:
if operation.node not in domain.state.send_eth.get(node, set()):
domain.state.add_send_eth(node, operation.node)

def _handle_event_call_operation(self, operation: EventCall, domain: ReentrancyDomain):
# Track events and propagate previous external calls
# Only propagate calls that haven't already been propagated to this event node
existing_calls = domain.state.calls.get(operation.node, set())

# Collect all calls to add before modifying the dictionary
calls_to_add = []
for calls_set in domain.state.calls.values():
for call_node in calls_set:
if call_node not in existing_calls:
calls_to_add.append(call_node)

# Add all collected calls
for call_node in calls_to_add:
domain.state.add_call(operation.node, call_node)

domain.state.add_event(operation, operation.node)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import Enum, auto
from typing import Optional

from slither.analyses.data_flow.analyses.reentrancy.core.state import State
from slither.analyses.data_flow.engine.domain import Domain


class DomainVariant(Enum):
BOTTOM = auto()
TOP = auto()
STATE = auto()


class ReentrancyDomain(Domain):
def __init__(self, variant: DomainVariant, state: Optional[State] = None):
self.variant = variant
self.state = state or State()

@classmethod
def bottom(cls) -> "ReentrancyDomain":

Check warning on line 20 in slither/analyses/data_flow/analyses/reentrancy/analysis/domain.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0221: Number of parameters was 1 in 'Domain.bottom' and is now 1 in overriding 'ReentrancyDomain.bottom' method (arguments-differ)
return cls(DomainVariant.BOTTOM)

@classmethod
def top(cls) -> "ReentrancyDomain":

Check warning on line 24 in slither/analyses/data_flow/analyses/reentrancy/analysis/domain.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0221: Number of parameters was 1 in 'Domain.top' and is now 1 in overriding 'ReentrancyDomain.top' method (arguments-differ)
return cls(DomainVariant.TOP)

@classmethod
def with_state(cls, info: State) -> "ReentrancyDomain":
return cls(DomainVariant.STATE, info)

def join(self, other: "ReentrancyDomain") -> bool:
if self.variant == DomainVariant.TOP or other.variant == DomainVariant.BOTTOM:
return False

if self.variant == DomainVariant.BOTTOM and other.variant == DomainVariant.STATE:
self.variant = DomainVariant.STATE
self.state = other.state.deep_copy()
self.state.written.clear()
self.state.events.clear()
self.state.writes_after_calls.clear()
self.state.cross_function.clear()
return True

if self.variant == DomainVariant.STATE and other.variant == DomainVariant.STATE:

Check warning on line 44 in slither/analyses/data_flow/analyses/reentrancy/analysis/domain.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

R1705: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it (no-else-return)
if self.state == other.state:
return False

self.state.send_eth.update(other.state.send_eth)
self.state.calls.update(other.state.calls)
self.state.reads.update(other.state.reads)
self.state.reads_prior_calls.update(other.state.reads_prior_calls)
self.state.safe_send_eth.update(other.state.safe_send_eth)
self.state.writes_after_calls.update(other.state.writes_after_calls)
self.state.cross_function.update(other.state.cross_function)
return True

else:
self.variant = DomainVariant.TOP

return True
162 changes: 162 additions & 0 deletions slither/analyses/data_flow/analyses/reentrancy/core/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import copy
from collections import defaultdict
from typing import Dict, Set

from slither.core.cfg.node import Node
from slither.core.declarations.function import Function
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations.event_call import EventCall


class State:

Check warning on line 11 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

R0902: Too many instance attributes (9/7) (too-many-instance-attributes)
def __init__(self):
self._send_eth: Dict[Node, Set[Node]] = defaultdict(set)
self._safe_send_eth: Dict[Node, Set[Node]] = defaultdict(set)
self._calls: Dict[Node, Set[Node]] = defaultdict(set)
self._reads: Dict[str, Set[Node]] = defaultdict(set)
self._reads_prior_calls: Dict[Node, Set[str]] = defaultdict(set)
self._events: Dict[EventCall, Set[Node]] = defaultdict(set)
self._written: Dict[str, Set[Node]] = defaultdict(set)
self.writes_after_calls: Dict[str, Set[Node]] = defaultdict(set)
self.cross_function: Dict[StateVariable, Set[Function]] = defaultdict(set)

# -------------------- Add methods --------------------
def add_call(self, node: Node, call_node: Node):
self._calls[node].add(call_node)

def add_send_eth(self, node: Node, call_node: Node):
self._send_eth[node].add(call_node)

def add_safe_send_eth(self, node: Node, call_node: Node):
self._safe_send_eth[node].add(call_node)

def add_written(self, var: StateVariable, node: Node):
# Ensure the canonical name exists and is not None
if var.canonical_name is not None:
# Ensure the key exists in the defaultdict
if var.canonical_name not in self._written:
self._written[var.canonical_name] = set()
self._written[var.canonical_name].add(node)

def add_read(self, var: StateVariable, node: Node):
# Ensure the canonical name exists and is not None
if var.canonical_name is not None:
# Ensure the key exists in the defaultdict
if var.canonical_name not in self._reads:
self._reads[var.canonical_name] = set()
self._reads[var.canonical_name].add(node)

def add_reads_prior_calls(self, node: Node, var_name: str):
self._reads_prior_calls[node].add(var_name)

def add_write_after_call(self, var_name: str, node: Node):
self.writes_after_calls[var_name].add(node)

def add_cross_function(self, var: StateVariable, function: Function):
self.cross_function[var].add(function)

def add_event(self, event: EventCall, node: Node):
self._events[event].add(node)

# -------------------- Properties --------------------
@property
def send_eth(self) -> Dict[Node, Set[Node]]:
return self._send_eth

@property
def safe_send_eth(self) -> Dict[Node, Set[Node]]:
return self._safe_send_eth

@property
def all_eth_calls(self) -> Dict[Node, Set[Node]]:
result = defaultdict(set)
for node, calls in self._send_eth.items():
result[node].update(calls)
for node, calls in self._safe_send_eth.items():
result[node].update(calls)
return result

@property
def calls(self) -> Dict[Node, Set[Node]]:
return self._calls

@property
def reads(self) -> Dict[str, Set[Node]]:
return self._reads

@property
def written(self) -> Dict[str, Set[Node]]:
return self._written

@property
def reads_prior_calls(self) -> Dict[Node, Set[str]]:
return self._reads_prior_calls

@property
def events(self) -> Dict[EventCall, Set[Node]]:
return self._events

# -------------------- Utilities --------------------
def __eq__(self, other):
if not isinstance(other, State):
return False
return (
self._send_eth == other._send_eth
and self._safe_send_eth == other._safe_send_eth
and self._calls == other._calls
and self._reads == other._reads
and self._reads_prior_calls == other._reads_prior_calls
and self._events == other._events
and self._written == other._written
and self.writes_after_calls == other.writes_after_calls
and self.cross_function == other.cross_function
)

def __hash__(self):
return hash(
(
frozenset(self._send_eth.items()),
frozenset(self._safe_send_eth.items()),
frozenset(self._calls.items()),
frozenset(self._reads.items()),
frozenset(self._reads_prior_calls.items()),
frozenset(self._events.items()),
frozenset(self._written.items()),
frozenset((k, frozenset(v)) for k, v in self.writes_after_calls.items()),
frozenset((k, frozenset(v)) for k, v in self.cross_function.items()),
)
)

def __str__(self):
return (
f"State(\n"
f" send_eth: {len(self._send_eth)} items,\n"
f" safe_send_eth: {len(self._safe_send_eth)} items,\n"
f" calls: {len(self._calls)} items,\n"
f" reads: {len(self._reads)} items,\n"
f" reads_prior_calls: {len(self._reads_prior_calls)} items,\n"
f" events: {len(self._events)} items,\n"
f" written: {len(self._written)} items,\n"
f" writes_after_calls: {len(self.writes_after_calls)} items,\n"
f" cross_function: {len(self.cross_function)} items,\n"
f")"
)

def deep_copy(self) -> "State":
new_state = State()
# Use shallow copy for Node objects to avoid circular reference issues

new_state._send_eth.update({k: v.copy() for k, v in self._send_eth.items()})

Check warning on line 149 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _send_eth of a client class (protected-access)
new_state._safe_send_eth.update({k: v.copy() for k, v in self._safe_send_eth.items()})

Check warning on line 150 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _safe_send_eth of a client class (protected-access)
new_state._calls.update({k: v.copy() for k, v in self._calls.items()})

Check warning on line 151 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _calls of a client class (protected-access)
new_state._reads.update({k: v.copy() for k, v in self._reads.items()})

Check warning on line 152 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _reads of a client class (protected-access)
new_state._reads_prior_calls.update(

Check warning on line 153 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _reads_prior_calls of a client class (protected-access)
{k: v.copy() for k, v in self._reads_prior_calls.items()}
)
new_state._events.update({k: v.copy() for k, v in self._events.items()})

Check warning on line 156 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _events of a client class (protected-access)
new_state._written.update({k: v.copy() for k, v in self._written.items()})
new_state.writes_after_calls.update(
{k: v.copy() for k, v in self.writes_after_calls.items()}
)
new_state.cross_function.update({k: v.copy() for k, v in self.cross_function.items()})
return new_state
Loading
Loading