Skip to content

Commit 04c5382

Browse files
jamesjwufacebook-github-bot
authored andcommitted
Log PT2 chromium events to scuba (#2424)
Summary: X-link: pytorch/pytorch#133859 This diff implements a bunch of views for internal scuba viewing. TODOS that I might punt to another diff: - Saving cache stats via counter is definitely sus here, but there's not really a good way to track "fx graph cache hit for this compile phase" right now. Will think about this more. - We should definitely log frame id, compile id, etc - We should definitely be logging configs. That way, we can A/B test based on whether a config is turned on. - idk what I'm doing with compile_uuid yet, but it's useful when you want to look at samples for a single run. I think if we had mast job info this field is not needed, but it's nice to be able to drill down to a single run and get its chrome trace view or icicle view, so idk Reviewed By: ezyang Differential Revision: D61392607
1 parent dd54789 commit 04c5382

File tree

1 file changed

+93
-14
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+93
-14
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

+93-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import time
2727
import types
2828
import typing
29+
import uuid
2930
import warnings
3031
import weakref
3132
from contextlib import contextmanager
@@ -64,7 +65,7 @@
6465
from torch._dispatch.python import enable_python_dispatcher
6566
from torch._guards import TracingContext
6667
from torch._subclasses.meta_utils import is_sparse_compressed
67-
from torch._utils_internal import log_compilation_event
68+
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
6869
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
6970
from torch.nn.modules.lazy import LazyModuleMixin
7071
from torch.utils._triton import has_triton, has_triton_package
@@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
212213
frame_phase_timing[key][phase_name] += time_spent
213214

214215

216+
def get_cache_stats() -> Dict[str, Any]:
217+
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
218+
cache_stats = {
219+
"fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"],
220+
"fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"],
221+
"fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"],
222+
}
223+
return cache_stats
224+
225+
215226
# dynamo_timed is a context manager
216227
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
217228
# where the key is the functions name.
@@ -245,22 +256,34 @@ def dynamo_timed(
245256
phase_name: Optional[str] = None,
246257
fwd_only: bool = True,
247258
):
259+
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
248260
if key not in compilation_time_metrics:
249261
compilation_time_metrics[key] = []
250262

251263
fail_type: Optional[str] = None
252264
fail_reason: Optional[str] = None
253265
time_spent = float("-inf")
266+
if phase_name == "entire_frame_compile":
267+
chromium_log.reset()
254268
try:
255269
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
256270
t0 = time.time()
257-
ChromiumEventLogger.log_event_start(key, time.time_ns())
271+
start = time.time_ns()
272+
chromium_log.log_event_start(key, start, None)
258273
if phase_name:
259-
ChromiumEventLogger.log_event_start(phase_name, time.time_ns())
274+
chromium_log.log_event_start(phase_name, start)
260275
yield
276+
261277
if phase_name:
262-
ChromiumEventLogger.log_event_end(phase_name, time.time_ns())
263-
ChromiumEventLogger.log_event_end(key, time.time_ns())
278+
chromium_log.log_event_end(
279+
phase_name,
280+
time.time_ns(),
281+
{"cache_stats": get_cache_stats()},
282+
start,
283+
)
284+
chromium_log.log_event_end(
285+
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
286+
)
264287
time_spent = time.time() - t0
265288
compilation_time_metrics[key].append(time_spent)
266289
except Exception as e:
@@ -814,8 +837,17 @@ class ChromiumEventLogger:
814837
a specification of the Chromium Event JSON format.
815838
"""
816839

817-
@staticmethod
840+
def __init__(self):
841+
self.stack = ["__start__"]
842+
# Generate a unique id for this logger, which we can use in scuba to filter down
843+
# to a single python run.
844+
self.id_ = str(uuid.uuid4())
845+
846+
# TODO: log to init/id tlparse after I add support for it
847+
log.info("ChromiumEventLogger initialized with id %s", self.id_)
848+
818849
def log_event_start(
850+
self,
819851
event_name: str,
820852
time_ns: int,
821853
metadata: Optional[Dict[str, Any]] = None,
@@ -826,18 +858,24 @@ def log_event_start(
826858
:param time_ns Timestamp in nanoseconds
827859
:param metadata: Any extra metadata associated with this event
828860
"""
829-
ChromiumEventLogger._log_timed_event(
861+
event = self._log_timed_event(
830862
event_name,
831863
time_ns,
832864
"B",
833865
metadata,
834866
)
867+
log_chromium_event_internal(event, self.stack, self.id_)
868+
self.stack.append(event_name)
869+
870+
def reset(self) -> None:
871+
self.stack = ["__start__"]
835872

836-
@staticmethod
837873
def log_event_end(
874+
self,
838875
event_name: str,
839876
time_ns: int,
840877
metadata: Optional[Dict[str, Any]] = None,
878+
start_time_ns: Optional[int] = None,
841879
) -> None:
842880
"""
843881
Logs the end of a single event. This function should only be
@@ -846,28 +884,53 @@ def log_event_end(
846884
:param time_ns: Timestamp in nanoseconds
847885
:param metadata: Any extra metadata associated with this event
848886
"""
849-
ChromiumEventLogger._log_timed_event(
887+
# These stack health checks currently never happen,
888+
# but they're written this way to future proof any weird event
889+
# overlaps in the future.
890+
if event_name not in self.stack:
891+
# Something went wrong, we never called start on this event,
892+
# or it was skipped due to overlapping events below
893+
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
894+
return
895+
896+
event = self._log_timed_event(
850897
event_name,
851898
time_ns,
852899
"E",
853900
metadata,
854901
)
855902

856-
@staticmethod
903+
while event_name != self.stack[-1]:
904+
# If the event isn't the most recent one to end, pop
905+
# off the stack until it is.
906+
# Since event_name in self.stack, this pop is always safe
907+
log.warning(
908+
"ChromiumEventLogger: Detected overlapping events, fixing stack"
909+
)
910+
self.stack.pop()
911+
912+
log_chromium_event_internal(event, self.stack, self.id_, start_time_ns)
913+
# Finally pop the actual event off the stack
914+
self.stack.pop()
915+
857916
def _log_timed_event(
917+
self,
858918
event_name: str,
859919
time_ns: int,
860920
phase: str,
861921
metadata: Optional[Dict[str, Any]] = None,
862-
) -> None:
922+
) -> Dict[str, Any]:
863923
"""
864924
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
865925
"""
866926
event = {
867927
"name": event_name,
868-
"ts": time_ns / 1000, # Chromium events are in ms
928+
"ts": time_ns / 1000, # Chromium events are in micro seconds
869929
"args": metadata,
870930
"ph": phase,
931+
# These categories are needed in all chromium traces
932+
"cat": "dynamo_timed",
933+
"tid": 0,
871934
"pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
872935
}
873936
torch._logging.trace_structured(
@@ -876,9 +939,10 @@ def _log_timed_event(
876939
suppress_context=False,
877940
expect_trace_id=False, # Not every chromium event will have a trace_id
878941
)
942+
return event
879943

880-
@staticmethod
881944
def log_instant_event(
945+
self,
882946
event_name: str,
883947
time_ns: int,
884948
metadata: Optional[Dict[str, Any]] = None,
@@ -895,7 +959,10 @@ def log_instant_event(
895959
"ts": time_ns / 1000,
896960
"args": metadata,
897961
"ph": "i",
898-
"pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
962+
# These categories are needed in all chromium traces
963+
"cat": "dynamo_timed",
964+
"tid": 0,
965+
"pid": 0,
899966
"s": "p", # We use "process" level instant events so they all appear on the same row in the trace.
900967
}
901968
torch._logging.trace_structured(
@@ -904,6 +971,18 @@ def log_instant_event(
904971
suppress_context=False,
905972
expect_trace_id=True,
906973
)
974+
# Log an instant event with the same start and end time
975+
log_chromium_event_internal(event, self.stack, self.id_)
976+
977+
978+
chromium_event_log = None
979+
980+
981+
def get_chromium_event_logger() -> ChromiumEventLogger:
982+
global chromium_event_log
983+
if chromium_event_log is None:
984+
chromium_event_log = ChromiumEventLogger()
985+
return chromium_event_log
907986

908987

909988
@dataclasses.dataclass

0 commit comments

Comments
 (0)