Skip to content

Commit f0953dc

Browse files
authored
Merge pull request #3370 from trailofbits/cli-refactor
CLI refactor and libmagic optimizations
2 parents 86e2d1c + 4ddd279 commit f0953dc

File tree

10 files changed

+851
-289
lines changed

10 files changed

+851
-289
lines changed

polyfile/__main__.py

Lines changed: 238 additions & 116 deletions
Large diffs are not rendered by default.

polyfile/debugger.py

Lines changed: 127 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
import atexit
23
from enum import Enum
34
from pathlib import Path
45
from pdb import Pdb
@@ -9,6 +10,7 @@
910
from .magic import (
1011
AbsoluteOffset, FailedTest, InvalidOffsetError, MagicMatcher, MagicTest, Offset, TestResult, TEST_TYPES
1112
)
13+
from .profiling import Profiler, Unprofiled, unprofiled
1214
from .repl import ANSIColor, ANSIWriter, arg_completer, command, ExitREPL, log, REPL, SetCompleter
1315
from .wildcards import Wildcard
1416

@@ -383,6 +385,20 @@ def value(self, new_value):
383385
self.debugger._instrument()
384386

385387

388+
class Profile(BooleanVariable):
389+
def __init__(self, value: bool, debugger: "Debugger"):
390+
self.debugger: Debugger = debugger
391+
self._registered_callback = False
392+
super().__init__(value)
393+
394+
@Variable.value.setter
395+
def value(self, new_value):
396+
Variable.value.fset(self, new_value)
397+
if new_value and not self._registered_callback:
398+
self._registered_callback = True
399+
atexit.register(Debugger.profile_command, self.debugger, "")
400+
401+
386402
class Debugger(REPL):
387403
def __init__(self, break_on_parsing: bool = True):
388404
super().__init__(name="the debugger")
@@ -399,13 +415,17 @@ def __init__(self, break_on_parsing: bool = True):
399415
self.repl_test: Optional[MagicTest] = None
400416
self.instrumented_parsers: List[InstrumentedParser] = []
401417
self.break_on_submatching: BreakOnSubmatching = BreakOnSubmatching(break_on_parsing, self)
418+
self.profile: Profile = Profile(False, self)
402419
self.variables_by_name: Dict[str, Variable] = {
403-
"break_on_parsing": self.break_on_submatching
420+
"break_on_parsing": self.break_on_submatching,
421+
"profile": self.profile
404422
}
405423
self.variable_descriptions: Dict[str, str] = {
406424
"break_on_parsing": "Break when a PolyFile parser is about to be invoked and debug using PDB (default=True;"
407-
" disable from the command line with `--no-debug-python`)"
425+
" disable from the command line with `--no-debug-python`)",
426+
"profile": "Profile the performance of each magic test that is run (default=False)"
408427
}
428+
self.profile_results: Dict[Union[MagicTest, Type[Parser]], float] = {}
409429
self._pdb: Optional[Pdb] = None
410430

411431
def save_context(self):
@@ -442,7 +462,7 @@ def _instrument(self):
442462
if "test" in test.__dict__:
443463
# this class actually implements the test() function
444464
self.instrumented_tests.append(InstrumentedTest(test, self))
445-
if self.break_on_submatching:
465+
if self.break_on_submatching.value:
446466
for parsers in PARSERS.values():
447467
for parser in parsers:
448468
self.instrumented_parsers.append(InstrumentedParser(parser, self))
@@ -506,6 +526,8 @@ def write_test(self, test: MagicTest, is_current_test: bool = False):
506526
indent = ""
507527
self.write(f"{'>' * test.level}{test.offset!s}\t")
508528
self.write(test.message, color=ANSIColor.BLUE, bold=True)
529+
if self.profile.value and test in self.profile_results:
530+
self.write(f"\t{int(self.profile_results[test] + 0.5)}ms")
509531
if test.mime is not None:
510532
self.write(f"\n {indent}!:mime ", dim=True)
511533
self.write(test.mime, color=ANSIColor.BLUE)
@@ -541,10 +563,14 @@ def debug(
541563
absolute_offset: int,
542564
parent_match: Optional[TestResult]
543565
) -> Optional[TestResult]:
544-
if instrumented_test.original_test is None:
545-
result = instrumented_test.test.test(test, data, absolute_offset, parent_match)
546-
else:
547-
result = instrumented_test.original_test(test, data, absolute_offset, parent_match)
566+
profiler = Profiler()
567+
with profiler:
568+
if instrumented_test.original_test is None:
569+
result = instrumented_test.test.test(test, data, absolute_offset, parent_match)
570+
else:
571+
result = instrumented_test.original_test(test, data, absolute_offset, parent_match)
572+
if self.profile.value:
573+
self.profile_results[test] = profiler.elapsed_ms
548574
if self.repl_test is test:
549575
# this is a one-off test run from the REPL, so do not save its results
550576
return result
@@ -557,6 +583,7 @@ def debug(
557583
self.repl()
558584
return self.last_result
559585

586+
@unprofiled
560587
def print_where(
561588
self,
562589
test: Optional[MagicTest] = None,
@@ -659,46 +686,59 @@ def print_location():
659686
# We are already debugging!
660687
print_location()
661688
self.write(f"Parsing for submatches using {instrumented_parser.parser!s}.\n")
662-
yield from parse(file_stream, match)
689+
profiler = Profiler()
690+
with profiler:
691+
yield from parse(file_stream, match)
692+
if self.profile.value:
693+
self.profile_results[instrumented_parser.parser] = profiler.elapsed_ms
663694
return
664-
self.print_match(match)
665-
print_location()
666-
self.write(f"About to parse for submatches using {instrumented_parser.parser!s}.\n")
667-
buffer = ANSIWriter(use_ansi=sys.stderr.isatty(), escape_for_readline=True)
668-
buffer.write("Debug using PDB? ")
669-
buffer.write("(disable this prompt with `", dim=True)
670-
buffer.write("set ", color=ANSIColor.BLUE)
671-
buffer.write("break_on_parsing ", color=ANSIColor.GREEN)
672-
buffer.write("False", color=ANSIColor.CYAN)
673-
buffer.write("`)", dim=True)
695+
with Unprofiled():
696+
self.print_match(match)
697+
print_location()
698+
self.write(f"About to parse for submatches using {instrumented_parser.parser!s}.\n")
699+
buffer = ANSIWriter(use_ansi=sys.stderr.isatty(), escape_for_readline=True)
700+
buffer.write("Debug using PDB? ")
701+
buffer.write("(disable this prompt with `", dim=True)
702+
buffer.write("set ", color=ANSIColor.BLUE)
703+
buffer.write("break_on_parsing ", color=ANSIColor.GREEN)
704+
buffer.write("False", color=ANSIColor.CYAN)
705+
buffer.write("`)", dim=True)
674706
if not self.prompt(str(buffer), default=False):
675-
yield from parse(file_stream, match)
707+
with Profiler() as p:
708+
yield from parse(file_stream, match)
709+
if self.profile.value:
710+
self.profile_results[instrumented_parser.parser] = p.elapsed_ms
676711
return
677712
try:
678-
self._pdb = Pdb(skip=["polyfile.magic_debugger", "polyfile.magic"])
679-
if sys.stderr.isatty():
680-
self._pdb.prompt = "\001\u001b[1m\002(polyfile-Pdb)\001\u001b[0m\002 "
681-
else:
682-
self._pdb.prompt = "(polyfile-Pdb) "
683-
generator = parse(file_stream, match)
684-
while True:
685-
try:
686-
result = self._pdb.runcall(next, generator)
687-
self.write(f"Got a submatch:\n", dim=True)
688-
self.print_match(result)
689-
yield result
690-
except StopIteration:
691-
self.write(f"Yielded all submatches from {match.__class__.__name__} at offset {match.offset}.\n")
692-
break
693-
print_location()
694-
if not self.prompt("Continue debugging the next submatch?", default=True):
695-
if self.prompt("Print the remaining submatches?", default=False):
696-
for result in generator:
697-
self.print_match(result)
698-
yield result
699-
else:
700-
yield from generator
701-
break
713+
if self.profile.value:
714+
self.write("Warning:", bold=True, color=ANSIColor.RED)
715+
self.write(" Profiling will be disabled for this parser while debugging!\n")
716+
with Unprofiled():
717+
self._pdb = Pdb(skip=["polyfile.magic_debugger", "polyfile.magic"])
718+
if sys.stderr.isatty():
719+
self._pdb.prompt = "\001\u001b[1m\002(polyfile-Pdb)\001\u001b[0m\002 "
720+
else:
721+
self._pdb.prompt = "(polyfile-Pdb) "
722+
generator = parse(file_stream, match)
723+
while True:
724+
try:
725+
result = self._pdb.runcall(next, generator)
726+
self.write(f"Got a submatch:\n", dim=True)
727+
self.print_match(result)
728+
yield result
729+
except StopIteration:
730+
self.write(f"Yielded all submatches from {match.__class__.__name__} at offset {match.offset}."
731+
f"\n")
732+
break
733+
print_location()
734+
if not self.prompt("Continue debugging the next submatch?", default=True):
735+
if self.prompt("Print the remaining submatches?", default=False):
736+
for result in generator:
737+
self.print_match(result)
738+
yield result
739+
else:
740+
yield from generator
741+
break
702742
finally:
703743
self._pdb = None
704744

@@ -731,6 +771,50 @@ def next(self, arguments: str):
731771
def where(self, arguments: str):
732772
self.print_where()
733773

774+
@command(allows_abbreviation=False, name="profile", help="print current profiling results (to enable profiling, "
775+
"use `set profile True`)", )
776+
def profile_command(self, arguments: str):
777+
if not self.profile_results:
778+
if not self.profile.value:
779+
self.write("Profiling is disabled.\n", color=ANSIColor.RED)
780+
self.write("Enable it by running `set profile True`.\n")
781+
else:
782+
self.write("No profiling data yet.\n")
783+
return
784+
self.write("Profile Results:\n", bold=True)
785+
tests = sorted([(runtime, test) for test, runtime in self.profile_results.items()], reverse=True,
786+
key=lambda x: x[0])
787+
max_text_width = 0
788+
for runtime, test in tests:
789+
if isinstance(test, MagicTest):
790+
if test.source_info is not None and test.source_info.original_line is not None:
791+
max_text_width = max(max_text_width,
792+
len(test.source_info.path.name) + 1 + len(str(test.source_info.line)))
793+
else:
794+
max_text_width = max(max_text_width, test.level + len(str(test.offset)))
795+
else:
796+
max_text_width = max(max_text_width, len(str(test)))
797+
for runtime, test in tests:
798+
if isinstance(test, MagicTest):
799+
self.write("🪄 ")
800+
if test.source_info is not None and test.source_info.original_line is not None:
801+
self.write(test.source_info.path.name, dim=True, color=ANSIColor.CYAN)
802+
self.write(":", dim=True)
803+
self.write(test.source_info.line, dim=True, color=ANSIColor.CYAN)
804+
padding = max_text_width - (len(test.source_info.path.name) + 1 + len(str(test.source_info.line)))
805+
else:
806+
self.write(f"{'>' * test.level}{test.offset!s}", color=ANSIColor.BLUE)
807+
padding = max_text_width - test.level - len(str(test.offset))
808+
else:
809+
self.write("🖥 ")
810+
self.write(str(test), color=ANSIColor.BLUE)
811+
padding = max_text_width - len(str(test))
812+
self.write(" " * padding)
813+
if runtime >= 1.0:
814+
self.write(f" ⏱ {int(runtime + 0.5)}ms\n")
815+
else:
816+
self.write(f" ⏱ {runtime:.2f}ms\n")
817+
734818
@command(allows_abbreviation=True, help="test the following libmagic DSL test at the current position")
735819
def test(self, args: str):
736820
if args:

polyfile/fileutils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tempfile as tf
66
import shutil
77
import sys
8-
from typing import AnyStr, IO, Iterator, Iterable, List, Optional, Union
8+
from typing import AnyStr, ContextManager, IO, Iterator, Iterable, List, Optional, TextIO, Union
99

1010

1111
def make_stream(path_or_stream, mode='rb', close_on_exit=None):
@@ -41,7 +41,7 @@ def __init__(self, contents: bytes, name: str):
4141
super().__init__(contents)
4242
self._name: str = name
4343

44-
def __enter__(self):
44+
def __enter__(self) -> str:
4545
tmpdir = Path(tf.mkdtemp())
4646
file_path = tmpdir / self._name
4747
with open(file_path, "wb") as f:
@@ -56,16 +56,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
5656

5757

5858
class PathOrStdin:
59-
def __init__(self, path):
60-
self._path = path
61-
if self._path == '-':
62-
self._tempfile = Tempfile(sys.stdin.buffer.read())
59+
def __init__(self, path: str):
60+
self.path: str = path
61+
if self.path == '-':
62+
self._tempfile: Optional[ExactNamedTempfile] = ExactNamedTempfile(sys.stdin.buffer.read(), "STDIN")
6363
else:
6464
self._tempfile = None
6565

66-
def __enter__(self):
66+
def __enter__(self) -> str:
6767
if self._tempfile is None:
68-
return self._path
68+
return self.path
6969
else:
7070
return self._tempfile.__enter__()
7171

@@ -74,6 +74,27 @@ def __exit__(self, *args, **kwargs):
7474
return self._tempfile.__exit__(*args, **kwargs)
7575

7676

77+
class PathOrStdout(ContextManager[TextIO]):
78+
def __init__(self, path: str):
79+
self.path: str = path
80+
if self.path == '-':
81+
self._tempfile: Optional[TextIO] = sys.stdout
82+
else:
83+
self._tempfile = None
84+
85+
def __enter__(self) -> TextIO:
86+
if self._tempfile is None:
87+
self._tempfile = open(self.path, "w")
88+
return self._tempfile
89+
90+
def __exit__(self, *args, **kwargs) -> None: # type: ignore
91+
if self._tempfile is not None:
92+
if self._tempfile is not sys.stdout:
93+
self._tempfile.close()
94+
self._tempfile = None
95+
return None
96+
97+
7798
class FileStream(IO):
7899
def __init__(
79100
self,

polyfile/kaitai/parsers/microsoft_pe.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def __init__(self, _io, _parent=None, _root=None):
9393
self._debug = collections.defaultdict(dict)
9494

9595
def _read(self):
96-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32:
96+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32:
9797
self._debug['image_base_32']['start'] = self._io.pos()
9898
self.image_base_32 = self._io.read_u4le()
9999
self._debug['image_base_32']['end'] = self._io.pos()
100100

101-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32_plus:
101+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32_plus:
102102
self._debug['image_base_64']['start'] = self._io.pos()
103103
self.image_base_64 = self._io.read_u8le()
104104
self._debug['image_base_64']['end'] = self._io.pos()
@@ -145,42 +145,42 @@ def _read(self):
145145
self._debug['dll_characteristics']['start'] = self._io.pos()
146146
self.dll_characteristics = self._io.read_u2le()
147147
self._debug['dll_characteristics']['end'] = self._io.pos()
148-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32:
148+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32:
149149
self._debug['size_of_stack_reserve_32']['start'] = self._io.pos()
150150
self.size_of_stack_reserve_32 = self._io.read_u4le()
151151
self._debug['size_of_stack_reserve_32']['end'] = self._io.pos()
152152

153-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32_plus:
153+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32_plus:
154154
self._debug['size_of_stack_reserve_64']['start'] = self._io.pos()
155155
self.size_of_stack_reserve_64 = self._io.read_u8le()
156156
self._debug['size_of_stack_reserve_64']['end'] = self._io.pos()
157157

158-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32:
158+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32:
159159
self._debug['size_of_stack_commit_32']['start'] = self._io.pos()
160160
self.size_of_stack_commit_32 = self._io.read_u4le()
161161
self._debug['size_of_stack_commit_32']['end'] = self._io.pos()
162162

163-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32_plus:
163+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32_plus:
164164
self._debug['size_of_stack_commit_64']['start'] = self._io.pos()
165165
self.size_of_stack_commit_64 = self._io.read_u8le()
166166
self._debug['size_of_stack_commit_64']['end'] = self._io.pos()
167167

168-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32:
168+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32:
169169
self._debug['size_of_heap_reserve_32']['start'] = self._io.pos()
170170
self.size_of_heap_reserve_32 = self._io.read_u4le()
171171
self._debug['size_of_heap_reserve_32']['end'] = self._io.pos()
172172

173-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32_plus:
173+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32_plus:
174174
self._debug['size_of_heap_reserve_64']['start'] = self._io.pos()
175175
self.size_of_heap_reserve_64 = self._io.read_u8le()
176176
self._debug['size_of_heap_reserve_64']['end'] = self._io.pos()
177177

178-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32:
178+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32:
179179
self._debug['size_of_heap_commit_32']['start'] = self._io.pos()
180180
self.size_of_heap_commit_32 = self._io.read_u4le()
181181
self._debug['size_of_heap_commit_32']['end'] = self._io.pos()
182182

183-
if self._parent.std.format == MicrosoftPe.PeFormat.pe32_plus:
183+
if self._parent.std.output_format == MicrosoftPe.PeFormat.pe32_plus:
184184
self._debug['size_of_heap_commit_64']['start'] = self._io.pos()
185185
self.size_of_heap_commit_64 = self._io.read_u8le()
186186
self._debug['size_of_heap_commit_64']['end'] = self._io.pos()

0 commit comments

Comments
 (0)