26
26
import time
27
27
import types
28
28
import typing
29
+ import uuid
29
30
import warnings
30
31
import weakref
31
32
from contextlib import contextmanager
64
65
from torch ._dispatch .python import enable_python_dispatcher
65
66
from torch ._guards import TracingContext
66
67
from torch ._subclasses .meta_utils import is_sparse_compressed
67
- from torch ._utils_internal import log_compilation_event
68
+ from torch ._utils_internal import log_chromium_event_internal , log_compilation_event
68
69
from torch .fx ._utils import _format_graph_code , lazy_format_graph_code
69
70
from torch .nn .modules .lazy import LazyModuleMixin
70
71
from torch .utils ._triton import has_triton , has_triton_package
@@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
212
213
frame_phase_timing [key ][phase_name ] += time_spent
213
214
214
215
216
+ def get_cache_stats () -> Dict [str , Any ]:
217
+ """Get a bunch of metadata about cache hits and misses to use in chromium events"""
218
+ cache_stats = {
219
+ "fxgraph_cache_hit" : counters ["inductor" ]["fxgraph_cache_hit" ],
220
+ "fxgraph_cache_miss" : counters ["inductor" ]["fxgraph_cache_miss" ],
221
+ "fxgraph_cache_bypass" : counters ["inductor" ]["fxgraph_cache_bypass" ],
222
+ }
223
+ return cache_stats
224
+
225
+
215
226
# dynamo_timed is a context manager
216
227
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
217
228
# where the key is the functions name.
@@ -245,6 +256,7 @@ def dynamo_timed(
245
256
phase_name : Optional [str ] = None ,
246
257
fwd_only : bool = True ,
247
258
):
259
+ chromium_log : ChromiumEventLogger = get_chromium_event_logger ()
248
260
if key not in compilation_time_metrics :
249
261
compilation_time_metrics [key ] = []
250
262
@@ -254,13 +266,22 @@ def dynamo_timed(
254
266
try :
255
267
with torch .profiler .record_function (f"{ key } (dynamo_timed)" ):
256
268
t0 = time .time ()
257
- ChromiumEventLogger .log_event_start (key , time .time_ns ())
269
+ start = time .time_ns ()
270
+ chromium_log .log_event_start (key , start , None )
258
271
if phase_name :
259
- ChromiumEventLogger .log_event_start (phase_name , time . time_ns () )
272
+ chromium_log .log_event_start (phase_name , start )
260
273
yield
274
+
261
275
if phase_name :
262
- ChromiumEventLogger .log_event_end (phase_name , time .time_ns ())
263
- ChromiumEventLogger .log_event_end (key , time .time_ns ())
276
+ chromium_log .log_event_end (
277
+ phase_name ,
278
+ time .time_ns (),
279
+ {"cache_stats" : get_cache_stats ()},
280
+ start ,
281
+ )
282
+ chromium_log .log_event_end (
283
+ key , time .time_ns (), {"cache_stats" : get_cache_stats ()}, start
284
+ )
264
285
time_spent = time .time () - t0
265
286
compilation_time_metrics [key ].append (time_spent )
266
287
except Exception as e :
@@ -814,8 +835,17 @@ class ChromiumEventLogger:
814
835
a specification of the Chromium Event JSON format.
815
836
"""
816
837
817
- @staticmethod
838
+ def __init__ (self ):
839
+ self .stack = ["__start__" ]
840
+ # Generate a unique id for this logger, which we can use in scuba to filter down
841
+ # to a single python run.
842
+ self .id_ = str (uuid .uuid4 ())
843
+
844
+ # TODO: log to init/id tlparse after I add support for it
845
+ log .info ("ChromiumEventLogger initialized with id %s" , self .id_ )
846
+
818
847
def log_event_start (
848
+ self ,
819
849
event_name : str ,
820
850
time_ns : int ,
821
851
metadata : Optional [Dict [str , Any ]] = None ,
@@ -826,18 +856,27 @@ def log_event_start(
826
856
:param time_ns Timestamp in nanoseconds
827
857
:param metadata: Any extra metadata associated with this event
828
858
"""
829
- ChromiumEventLogger ._log_timed_event (
859
+ event = self ._log_timed_event (
830
860
event_name ,
831
861
time_ns ,
832
862
"B" ,
833
863
metadata ,
834
864
)
865
+ log_chromium_event_internal (event , self .stack , self .id_ )
866
+ self .stack .append (event_name )
867
+
868
+ def reset (self ) -> None :
869
+ # We this on every compile in case a compile crashes or restarts and we haven't
870
+ # cleared the stack.
871
+ self .stack .clear ()
872
+ self .stack .append ("__start__" )
835
873
836
- @staticmethod
837
874
def log_event_end (
875
+ self ,
838
876
event_name : str ,
839
877
time_ns : int ,
840
878
metadata : Optional [Dict [str , Any ]] = None ,
879
+ start_time_ns : Optional [int ] = None ,
841
880
) -> None :
842
881
"""
843
882
Logs the end of a single event. This function should only be
@@ -846,28 +885,53 @@ def log_event_end(
846
885
:param time_ns: Timestamp in nanoseconds
847
886
:param metadata: Any extra metadata associated with this event
848
887
"""
849
- ChromiumEventLogger ._log_timed_event (
888
+ # These stack health checks currently never happen,
889
+ # but they're written this way to future proof any weird event
890
+ # overlaps in the future.
891
+ if event_name not in self .stack :
892
+ # Something went wrong, we never called start on this event,
893
+ # or it was skipped due to overlapping events below
894
+ log .warning ("ChromiumEventLogger: Start event not in stack, ignoring" )
895
+ return
896
+
897
+ event = self ._log_timed_event (
850
898
event_name ,
851
899
time_ns ,
852
900
"E" ,
853
901
metadata ,
854
902
)
855
903
856
- @staticmethod
904
+ while event_name != self .stack [- 1 ]:
905
+ # If the event isn't the most recent one to end, pop
906
+ # off the stack until it is.
907
+ # Since event_name in self.stack, this pop is always safe
908
+ log .warning (
909
+ "ChromiumEventLogger: Detected overlapping events, fixing stack"
910
+ )
911
+ self .stack .pop ()
912
+
913
+ log_chromium_event_internal (event , self .stack , self .id_ , start_time_ns )
914
+ # Finally pop the actual event off the stack
915
+ self .stack .pop ()
916
+
857
917
def _log_timed_event (
918
+ self ,
858
919
event_name : str ,
859
920
time_ns : int ,
860
921
phase : str ,
861
922
metadata : Optional [Dict [str , Any ]] = None ,
862
- ) -> None :
923
+ ) -> Dict [ str , Any ] :
863
924
"""
864
925
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
865
926
"""
866
927
event = {
867
928
"name" : event_name ,
868
- "ts" : time_ns / 1000 , # Chromium events are in ms
929
+ "ts" : time_ns / 1000 , # Chromium events are in micro seconds
869
930
"args" : metadata ,
870
931
"ph" : phase ,
932
+ # These categories are needed in all chromium traces
933
+ "cat" : "dynamo_timed" ,
934
+ "tid" : 0 ,
871
935
"pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
872
936
}
873
937
torch ._logging .trace_structured (
@@ -876,9 +940,10 @@ def _log_timed_event(
876
940
suppress_context = False ,
877
941
expect_trace_id = False , # Not every chromium event will have a trace_id
878
942
)
943
+ return event
879
944
880
- @staticmethod
881
945
def log_instant_event (
946
+ self ,
882
947
event_name : str ,
883
948
time_ns : int ,
884
949
metadata : Optional [Dict [str , Any ]] = None ,
@@ -895,7 +960,10 @@ def log_instant_event(
895
960
"ts" : time_ns / 1000 ,
896
961
"args" : metadata ,
897
962
"ph" : "i" ,
898
- "pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
963
+ # These categories are needed in all chromium traces
964
+ "cat" : "dynamo_timed" ,
965
+ "tid" : 0 ,
966
+ "pid" : 0 ,
899
967
"s" : "p" , # We use "process" level instant events so they all appear on the same row in the trace.
900
968
}
901
969
torch ._logging .trace_structured (
@@ -904,6 +972,18 @@ def log_instant_event(
904
972
suppress_context = False ,
905
973
expect_trace_id = True ,
906
974
)
975
+ # Log an instant event with the same start and end time
976
+ log_chromium_event_internal (event , self .stack , self .id_ )
977
+
978
+
979
+ CHROMIUM_EVENT_LOG : Optional [ChromiumEventLogger ] = None
980
+
981
+
982
+ def get_chromium_event_logger () -> ChromiumEventLogger :
983
+ global CHROMIUM_EVENT_LOG
984
+ if CHROMIUM_EVENT_LOG is None :
985
+ CHROMIUM_EVENT_LOG = ChromiumEventLogger ()
986
+ return CHROMIUM_EVENT_LOG
907
987
908
988
909
989
@dataclasses .dataclass
0 commit comments