26
26
import time
27
27
import types
28
28
import typing
29
+ import uuid
29
30
import warnings
30
31
import weakref
31
32
from contextlib import contextmanager
64
65
from torch ._dispatch .python import enable_python_dispatcher
65
66
from torch ._guards import TracingContext
66
67
from torch ._subclasses .meta_utils import is_sparse_compressed
67
- from torch ._utils_internal import log_compilation_event
68
+ from torch ._utils_internal import log_chromium_event_internal , log_compilation_event
68
69
from torch .fx ._utils import _format_graph_code , lazy_format_graph_code
69
70
from torch .nn .modules .lazy import LazyModuleMixin
70
71
from torch .utils ._triton import has_triton , has_triton_package
@@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
212
213
frame_phase_timing [key ][phase_name ] += time_spent
213
214
214
215
216
+ def get_cache_stats () -> Dict [str , Any ]:
217
+ """Get a bunch of metadata about cache hits and misses to use in chromium events"""
218
+ cache_stats = {
219
+ "fxgraph_cache_hit" : counters ["inductor" ]["fxgraph_cache_hit" ],
220
+ "fxgraph_cache_miss" : counters ["inductor" ]["fxgraph_cache_miss" ],
221
+ "fxgraph_cache_bypass" : counters ["inductor" ]["fxgraph_cache_bypass" ],
222
+ }
223
+ return cache_stats
224
+
225
+
215
226
# dynamo_timed is a context manager
216
227
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
217
228
# where the key is the functions name.
@@ -245,6 +256,7 @@ def dynamo_timed(
245
256
phase_name : Optional [str ] = None ,
246
257
fwd_only : bool = True ,
247
258
):
259
+ chromium_log : ChromiumEventLogger = get_chromium_event_logger ()
248
260
if key not in compilation_time_metrics :
249
261
compilation_time_metrics [key ] = []
250
262
@@ -254,13 +266,22 @@ def dynamo_timed(
254
266
try :
255
267
with torch .profiler .record_function (f"{ key } (dynamo_timed)" ):
256
268
t0 = time .time ()
257
- ChromiumEventLogger .log_event_start (key , time .time_ns ())
269
+ start = time .time_ns ()
270
+ chromium_log .log_event_start (key , start , None )
258
271
if phase_name :
259
- ChromiumEventLogger .log_event_start (phase_name , time . time_ns () )
272
+ chromium_log .log_event_start (phase_name , start )
260
273
yield
274
+
261
275
if phase_name :
262
- ChromiumEventLogger .log_event_end (phase_name , time .time_ns ())
263
- ChromiumEventLogger .log_event_end (key , time .time_ns ())
276
+ chromium_log .log_event_end (
277
+ phase_name ,
278
+ time .time_ns (),
279
+ {"cache_stats" : get_cache_stats ()},
280
+ start ,
281
+ )
282
+ chromium_log .log_event_end (
283
+ key , time .time_ns (), {"cache_stats" : get_cache_stats ()}, start
284
+ )
264
285
time_spent = time .time () - t0
265
286
compilation_time_metrics [key ].append (time_spent )
266
287
except Exception as e :
@@ -814,8 +835,17 @@ class ChromiumEventLogger:
814
835
a specification of the Chromium Event JSON format.
815
836
"""
816
837
817
- @staticmethod
838
+ def __init__ (self ):
839
+ self .stack = ["__start__" ]
840
+ # Generate a unique id for this logger, which we can use in scuba to filter down
841
+ # to a single python run.
842
+ self .id_ = str (uuid .uuid4 ())
843
+
844
+ # TODO: log to init/id tlparse after I add support for it
845
+ log .info ("ChromiumEventLogger initialized with id %s" , self .id_ )
846
+
818
847
def log_event_start (
848
+ self ,
819
849
event_name : str ,
820
850
time_ns : int ,
821
851
metadata : Optional [Dict [str , Any ]] = None ,
@@ -826,18 +856,24 @@ def log_event_start(
826
856
:param time_ns Timestamp in nanoseconds
827
857
:param metadata: Any extra metadata associated with this event
828
858
"""
829
- ChromiumEventLogger ._log_timed_event (
859
+ event = self ._log_timed_event (
830
860
event_name ,
831
861
time_ns ,
832
862
"B" ,
833
863
metadata ,
834
864
)
865
+ log_chromium_event_internal (event , self .stack , self .id_ )
866
+ self .stack .append (event_name )
867
+
868
+ def reset (self ) -> None :
869
+ self .stack = ["__start__" ]
835
870
836
- @staticmethod
837
871
def log_event_end (
872
+ self ,
838
873
event_name : str ,
839
874
time_ns : int ,
840
875
metadata : Optional [Dict [str , Any ]] = None ,
876
+ start_time_ns : Optional [int ] = None ,
841
877
) -> None :
842
878
"""
843
879
Logs the end of a single event. This function should only be
@@ -846,28 +882,53 @@ def log_event_end(
846
882
:param time_ns: Timestamp in nanoseconds
847
883
:param metadata: Any extra metadata associated with this event
848
884
"""
849
- ChromiumEventLogger ._log_timed_event (
885
+ # These stack health checks currently never happen,
886
+ # but they're written this way to future proof any weird event
887
+ # overlaps in the future.
888
+ if event_name not in self .stack :
889
+ # Something went wrong, we never called start on this event,
890
+ # or it was skipped due to overlapping events below
891
+ log .warning ("ChromiumEventLogger: Start event not in stack, ignoring" )
892
+ return
893
+
894
+ event = self ._log_timed_event (
850
895
event_name ,
851
896
time_ns ,
852
897
"E" ,
853
898
metadata ,
854
899
)
855
900
856
- @staticmethod
901
+ while event_name != self .stack [- 1 ]:
902
+ # If the event isn't the most recent one to end, pop
903
+ # off the stack until it is.
904
+ # Since event_name in self.stack, this pop is always safe
905
+ log .warning (
906
+ "ChromiumEventLogger: Detected overlapping events, fixing stack"
907
+ )
908
+ self .stack .pop ()
909
+
910
+ log_chromium_event_internal (event , self .stack , self .id_ , start_time_ns )
911
+ # Finally pop the actual event off the stack
912
+ self .stack .pop ()
913
+
857
914
def _log_timed_event (
915
+ self ,
858
916
event_name : str ,
859
917
time_ns : int ,
860
918
phase : str ,
861
919
metadata : Optional [Dict [str , Any ]] = None ,
862
- ) -> None :
920
+ ) -> Dict [ str , Any ] :
863
921
"""
864
922
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
865
923
"""
866
924
event = {
867
925
"name" : event_name ,
868
- "ts" : time_ns / 1000 , # Chromium events are in ms
926
+ "ts" : time_ns / 1000 , # Chromium events are in micro seconds
869
927
"args" : metadata ,
870
928
"ph" : phase ,
929
+ # These categories are needed in all chromium traces
930
+ "cat" : "dynamo_timed" ,
931
+ "tid" : 0 ,
871
932
"pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
872
933
}
873
934
torch ._logging .trace_structured (
@@ -876,9 +937,10 @@ def _log_timed_event(
876
937
suppress_context = False ,
877
938
expect_trace_id = False , # Not every chromium event will have a trace_id
878
939
)
940
+ return event
879
941
880
- @staticmethod
881
942
def log_instant_event (
943
+ self ,
882
944
event_name : str ,
883
945
time_ns : int ,
884
946
metadata : Optional [Dict [str , Any ]] = None ,
@@ -895,7 +957,10 @@ def log_instant_event(
895
957
"ts" : time_ns / 1000 ,
896
958
"args" : metadata ,
897
959
"ph" : "i" ,
898
- "pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
960
+ # These categories are needed in all chromium traces
961
+ "cat" : "dynamo_timed" ,
962
+ "tid" : 0 ,
963
+ "pid" : 0 ,
899
964
"s" : "p" , # We use "process" level instant events so they all appear on the same row in the trace.
900
965
}
901
966
torch ._logging .trace_structured (
@@ -904,6 +969,16 @@ def log_instant_event(
904
969
suppress_context = False ,
905
970
expect_trace_id = True ,
906
971
)
972
+ # Log an instant event with the same start and end time
973
+ log_chromium_event_internal (event , self .stack , self .id_ )
974
+
975
+
976
+ CHROMIUM_EVENT_LOG : Optional [ChromiumEventLogger ] = None
977
+ def get_chromium_event_logger () -> ChromiumEventLogger :
978
+ global CHROMIUM_EVENT_LOG
979
+ if CHROMIUM_EVENT_LOG is None :
980
+ CHROMIUM_EVENT_LOG = ChromiumEventLogger ()
981
+ return CHROMIUM_EVENT_LOG
907
982
908
983
909
984
@dataclasses .dataclass
0 commit comments