Skip to content

Commit 12362de

Browse files
Create parse_xla_logs.py
1 parent 036fce8 commit 12362de

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import re
2+
import sys
3+
4+
class BenchmarkResult: # Define the class here
5+
def __init__(self):
6+
self.name = ""
7+
self.flavor = ""
8+
self.description = ""
9+
self.xla_flags = ""
10+
self.build_flags = []
11+
self.source_uri = ""
12+
self.device_name = ""
13+
self.time_us = 0.0
14+
self.last_average_loss = 0.0
15+
self.memory_read_written = 0
16+
self.buffer_allocation = []
17+
self.nvptx_compilation_time_us = 0
18+
self.hlo_passes_time_us = 0
19+
self.run_backend_time_us = 0
20+
self.device_time_us = 0.0
21+
self.device_memcpy_time_us = 0.0
22+
self.xprof_session_id = ""
23+
self.error = ""
24+
self.custom_metrics = []
25+
26+
class LogParser:
27+
def __init__(self):
28+
self.compilation_time = {} # You'll need to initialize this appropriately
29+
# Units defined in google3/third_party/tensorflow/tsl/platform/numbers.cc
30+
_TIME_UNITS = {"us": 1e-6, "ms": 1e-3, "s": 1, "min": 60, "h": 3600}
31+
_TIME_REGEXP = re.compile(r"time: (\d+\.?\d*) (%s)" % "|".join(_TIME_UNITS))
32+
_SIZE_UNITS = "KMGTPE"
33+
_SIZE_REGEXP = re.compile(r"(\d+\.?\d*)([%s]i)?B" % _SIZE_UNITS)
34+
35+
def _ParseTimeFromLog(self, time_str: bytes, position: int = -1) -> int:
36+
"""Returns the time in microseconds parsed from XLA logs."""
37+
match = self._TIME_REGEXP.search(time_str.decode())
38+
assert match, "Unable to parse the time on log line %d" % position
39+
exp_ = self._TIME_UNITS[match.group(2)] * 1e6
40+
return int(float(match.group(1)) * exp_)
41+
42+
def _ParseBytesFromLog(self, size_str: bytes, position: int = -1) -> int:
43+
"""Returns the size in bytes parsed from XLA logs."""
44+
match = self._SIZE_REGEXP.search(size_str.decode())
45+
assert match, "Unable to parse the size on log line %d" % position
46+
prefix = match.group(2) or "-"
47+
exp_ = pow(2, 10 * (self._SIZE_UNITS.find(prefix[0]) + 1))
48+
return int(float(match.group(1)) * exp_)
49+
50+
def _ParseLogLine(self, result, line, position) -> None:
51+
"""Parses a single XLA log line and updates the BenchmarkResult proto.
52+
53+
It looks for certain text patterns and updates the result proto with
54+
compilation stats and Xprof session id.
55+
56+
Args:
57+
result: a benchmark result proto to be updated
58+
line: a single line of benchmark log output
59+
position: line number (for debugging)
60+
"""
61+
# Log output generated by --vmodule=nvptx_compiler=1
62+
if b"NVPTXCompiler::CompileTargetBinary - CompileToPtx" in line:
63+
thread_id = re.split(rb"\s+", line)[2]
64+
self.compilation_time[thread_id] += self._ParseTimeFromLog(line, position)
65+
return
66+
67+
# Log output generated by --vmodule=gpu_compiler=1
68+
if b"HLO memory read+written:" in line:
69+
result.memory_read_written += self._ParseBytesFromLog(line, position)
70+
return
71+
if b"GpuCompiler::RunHloPasses for" in line:
72+
result.hlo_passes_time_us += self._ParseTimeFromLog(line, position)
73+
return
74+
if b"GpuCompiler::RunBackend for" in line:
75+
result.run_backend_time_us += self._ParseTimeFromLog(line, position)
76+
return
77+
78+
# Log output generated by --vmodule=bfc_allocator=2
79+
if b"New Peak memory usage" in line and b"for GPU" in line:
80+
match = re.search(rb"(\d+) bytes", line)
81+
assert match, "Unable to parse the size on log line %d" % position
82+
alloc_size = int(match.group(1))
83+
if alloc_size > max(result.buffer_allocation, default=0):
84+
result.buffer_allocation[:] = [alloc_size]
85+
return
86+
87+
# Log output generated by --xprof_end_2_end_upload
88+
if b"XprofResponse uploaded to http://xprof/" in line:
89+
match = re.search(rb"session_id=([\w\-]+)", line)
90+
assert match, "Unable to parse the XProf link"
91+
result.xprof_session_id = match.group(1)
92+
return
93+
94+
def parse_log_file(self, log_file_path):
95+
result = BenchmarkResult()
96+
with open(log_file_path, "rb") as f:
97+
for i, line in enumerate(f):
98+
self._ParseLogLine(result, line, i + 1)
99+
100+
# Update compilation time (total across threads) in the result object
101+
for time_taken in self.compilation_time.values():
102+
result.nvptx_compilation_time_us += time_taken
103+
return result # Return the result object
104+
105+
if __name__ == "__main__":
106+
if len(sys.argv) != 2:
107+
print("Usage: python parse_log.py <log_file_path>")
108+
sys.exit(1)
109+
110+
log_file_path = sys.argv[1]
111+
112+
parser = LogParser()
113+
result = parser.parse_log_file(log_file_path)
114+
115+
# Access results
116+
print(f"Memory Read+Written: {result.memory_read_written} bytes")
117+
print(f"NVPTX Compilation Time: {result.nvptx_compilation_time_us} us")
118+
print(f"HLO Passes Time: {result.hlo_passes_time_us} us")
119+
print(f"Run Backend Time: {result.run_backend_time_us} us")
120+
print(f"Buffer Allocation: {result.buffer_allocation} bytes")

0 commit comments

Comments
 (0)