Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions smda/intel/IndirectCallAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
class IndirectCallAnalyzer:
"""Perform basic dataflow analysis to resolve indirect call targets"""

RE_MOV_REG_REG = re.compile(r"(?P<reg1>[a-z]{3}), (?P<reg2>[a-z]{3})$")
RE_MOV_REG_CONST = re.compile(r"(?P<reg>[a-z]{3}), (?P<val>0x[0-9a-f]{,8})$")
RE_REG_DWORD_PTR_ADDR = re.compile(r"(?P<reg>[a-z]{3}), dword ptr \[(?P<addr>0x[0-9a-f]{,8})\]$")
RE_REG_QWORD_PTR_RIP_ADDR = re.compile(r"(?P<reg>[a-z]{3}), qword ptr \[rip \+ (?P<addr>0x[0-9a-f]{,8})\]$")
RE_REG_ADDR = re.compile(r"(?P<reg>[a-z]{3}), \[(?P<addr>0x[0-9a-f]{,8})\]$")

def __init__(self, disassembler):
self.disassembler = disassembler
self.disassembly = self.disassembler.disassembly
Expand Down Expand Up @@ -42,11 +48,11 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
LOGGER.debug("0x%08x: %s %s", ins[0], ins[2], ins[3])
if ins[2] == "mov":
# mov <reg>, <reg>
match1 = re.match(r"(?P<reg1>[a-z]{3}), (?P<reg2>[a-z]{3})$", ins[3])
match1 = self.RE_MOV_REG_REG.match(ins[3])
if match1 and match1.group("reg1") == register_name:
register_name = match1.group("reg2")
# mov <reg>, <const>
match2 = re.match(r"(?P<reg>[a-z]{3}), (?P<val>0x[0-9a-f]{,8})$", ins[3])
match2 = self.RE_MOV_REG_CONST.match(ins[3])
if match2:
registers[match2.group("reg")] = int(match2.group("val"), 16)
LOGGER.debug(
Expand All @@ -57,10 +63,7 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
if match2.group("reg") == register_name:
abs_value_found = True
# mov <reg>, dword ptr [<addr>]
match3 = re.match(
r"(?P<reg>[a-z]{3}), dword ptr \[(?P<addr>0x[0-9a-f]{,8})\]$",
ins[3],
)
match3 = self.RE_REG_DWORD_PTR_ADDR.match(ins[3])
if match3:
# HACK: test to see if the address points to a import and
# use that instead of the actual memory value
Expand Down Expand Up @@ -89,10 +92,7 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
if match3.group("reg") == register_name:
abs_value_found = True
# mov <reg>, qword ptr [reg + <addr>]
match4 = re.match(
r"(?P<reg>[a-z]{3}), qword ptr \[rip \+ (?P<addr>0x[0-9a-f]{,8})\]$",
ins[3],
)
match4 = self.RE_REG_QWORD_PTR_RIP_ADDR.match(ins[3])
if match4:
rip = ins[0] + ins[1]
dword = self.getDword(rip + int(match4.group("addr"), 16))
Expand All @@ -110,10 +110,7 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
elif ins[2] == "lea":
LOGGER.debug("*checking %s %s", ins[2], ins[3])
# lea <reg>, dword ptr [<addr>]
match1 = re.match(
r"(?P<reg>[a-z]{3}), dword ptr \[(?P<addr>0x[0-9a-f]{,8})\]$",
ins[3],
)
match1 = self.RE_REG_DWORD_PTR_ADDR.match(ins[3])
if match1:
dword = self.getDword(int(match1.group("addr"), 16))
if dword:
Expand All @@ -126,7 +123,7 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
if match1.group("reg") == register_name:
abs_value_found = True
# lea <reg>, [<addr>]
match1 = re.match(r"(?P<reg>[a-z]{3}), \[(?P<addr>0x[0-9a-f]{,8})\]$", ins[3])
match1 = self.RE_REG_ADDR.match(ins[3])
if match1:
dword = self.getDword(int(match1.group("addr"), 16))
if dword:
Expand All @@ -140,10 +137,6 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
abs_value_found = True
# not handled: lea <reg>, dword ptr [<reg> +- <val>]
# requires state-keeping of multiple registers
# there exist potentially many more way how the register being called can be calculated
# for now we ignore them
elif ins[2] == "other instruction":
pass
# if the absolute value was found for the call <reg> instruction, detect API
if abs_value_found:
candidate = registers.get(register_name, None)
Expand Down Expand Up @@ -201,6 +194,7 @@ def processBlock(self, analysis_state, block, registers, register_name, processe
for b in [self.searchBlock(analysis_state, i) for i in refs_in]
):
return True
return False

def resolveRegisterCalls(self, analysis_state, block_depth=3):
# after block reconstruction do simple data flow analysis to resolve open cases like "call <register>" as stored in self.call_register_ins
Expand Down
74 changes: 74 additions & 0 deletions tests/testIndirectCallAnalyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest
from unittest.mock import MagicMock

from smda.intel.IndirectCallAnalyzer import IndirectCallAnalyzer


class IndirectCallAnalyzerTestSuite(unittest.TestCase):
"""Basic tests for IndirectCallAnalyzer regex and logic"""

def test_regex_matching(self):
analyzer = IndirectCallAnalyzer(MagicMock())

# Test mov <reg>, <reg>
match = analyzer.RE_MOV_REG_REG.match("eax, ebx")
self.assertIsNotNone(match)
self.assertEqual(match.group("reg1"), "eax")
self.assertEqual(match.group("reg2"), "ebx")

# Test mov <reg>, <const>
match = analyzer.RE_MOV_REG_CONST.match("ecx, 0x12345678")
self.assertIsNotNone(match)
self.assertEqual(match.group("reg"), "ecx")
self.assertEqual(match.group("val"), "0x12345678")

# Test mov <reg>, dword ptr [<addr>]
match = analyzer.RE_REG_DWORD_PTR_ADDR.match("edx, dword ptr [0x8048000]")
self.assertIsNotNone(match)
self.assertEqual(match.group("reg"), "edx")
self.assertEqual(match.group("addr"), "0x8048000")

# Test mov <reg>, qword ptr [rip + <addr>]
match = analyzer.RE_REG_QWORD_PTR_RIP_ADDR.match("rax, qword ptr [rip + 0x1234]")
self.assertIsNotNone(match)
self.assertEqual(match.group("reg"), "rax")
self.assertEqual(match.group("addr"), "0x1234")

# Test lea <reg>, [<addr>]
match = analyzer.RE_REG_ADDR.match("rsi, [0x400000]")
self.assertIsNotNone(match)
self.assertEqual(match.group("reg"), "rsi")
self.assertEqual(match.group("addr"), "0x400000")
Comment on lines +10 to +41
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's great that you've added tests for the regexes! To make them more comprehensive, I suggest adding cases for 64-bit architecture. This would include registers with names that are not 3 characters long (e.g., r8, r10) and full 64-bit hexadecimal addresses. This will help ensure the regexes are robust across different architectures.

    def test_regex_matching(self):
        analyzer = IndirectCallAnalyzer(MagicMock())

        # Test mov <reg>, <reg>
        match = analyzer.RE_MOV_REG_REG.match("eax, ebx")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg1"), "eax")
        self.assertEqual(match.group("reg2"), "ebx")
        match = analyzer.RE_MOV_REG_REG.match("r8, r9")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg1"), "r8")
        self.assertEqual(match.group("reg2"), "r9")

        # Test mov <reg>, <const>
        match = analyzer.RE_MOV_REG_CONST.match("ecx, 0x12345678")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg"), "ecx")
        self.assertEqual(match.group("val"), "0x12345678")
        match = analyzer.RE_MOV_REG_CONST.match("r10, 0x1122334455667788")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg"), "r10")
        self.assertEqual(match.group("val"), "0x1122334455667788")

        # Test mov <reg>, dword ptr [<addr>]
        match = analyzer.RE_REG_DWORD_PTR_ADDR.match("edx, dword ptr [0x8048000]")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg"), "edx")
        self.assertEqual(match.group("addr"), "0x8048000")

        # Test mov <reg>, qword ptr [rip + <addr>]
        match = analyzer.RE_REG_QWORD_PTR_RIP_ADDR.match("rax, qword ptr [rip + 0x1234]")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg"), "rax")
        self.assertEqual(match.group("addr"), "0x1234")

        # Test lea <reg>, [<addr>]
        match = analyzer.RE_REG_ADDR.match("rsi, [0x400000]")
        self.assertIsNotNone(match)
        self.assertEqual(match.group("reg"), "rsi")
        self.assertEqual(match.group("addr"), "0x400000")


def test_processBlock_logic(self):
disassembler = MagicMock()
disassembler.resolveApi.return_value = (None, None)
analyzer = IndirectCallAnalyzer(disassembler)
analyzer.getDword = MagicMock(return_value=0x12345678)

analysis_state = MagicMock()
analyzer.state = analysis_state
# block is a list of [address, size, mnemonic, op_str]
block = [
[0x401000, 5, "mov", "eax, 0x402000"],
[0x401005, 2, "mov", "ebx, eax"],
]
registers = {}
register_name = "ebx"
processed = []
depth = 1

# Mock disassembly
analyzer.disassembly = MagicMock()
analyzer.disassembly.isAddrWithinMemoryImage.return_value = True

result = analyzer.processBlock(analysis_state, block, registers, register_name, processed, depth)

# result should be True because we found an absolute value for the register we were looking for
self.assertTrue(result, f"processBlock should return True, but returned {result}")
# eax should have 0x402000
self.assertEqual(registers.get("eax"), 0x402000, f"Expected eax to be 0x402000, but got {registers.get('eax')}")


if __name__ == "__main__":
unittest.main()