26
26
import time
27
27
import types
28
28
import typing
29
+ import uuid
29
30
import warnings
30
31
import weakref
31
32
from contextlib import contextmanager
64
65
from torch ._dispatch .python import enable_python_dispatcher
65
66
from torch ._guards import TracingContext
66
67
from torch ._subclasses .meta_utils import is_sparse_compressed
67
- from torch ._utils_internal import log_compilation_event
68
+ from torch ._utils_internal import log_chromium_event_internal , log_compilation_event
68
69
from torch .fx ._utils import _format_graph_code , lazy_format_graph_code
69
70
from torch .nn .modules .lazy import LazyModuleMixin
70
71
from torch .utils ._triton import has_triton , has_triton_package
@@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
212
213
frame_phase_timing [key ][phase_name ] += time_spent
213
214
214
215
216
+ def get_cache_stats () -> Dict [str , Any ]:
217
+ """Get a bunch of metadata about cache hits and misses to use in chromium events"""
218
+ cache_stats = {
219
+ "fxgraph_cache_hit" : counters ["inductor" ]["fxgraph_cache_hit" ],
220
+ "fxgraph_cache_miss" : counters ["inductor" ]["fxgraph_cache_miss" ],
221
+ "fxgraph_cache_bypass" : counters ["inductor" ]["fxgraph_cache_bypass" ],
222
+ }
223
+ return cache_stats
224
+
225
+
215
226
# dynamo_timed is a context manager
216
227
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
217
228
# where the key is the functions name.
@@ -245,22 +256,34 @@ def dynamo_timed(
245
256
phase_name : Optional [str ] = None ,
246
257
fwd_only : bool = True ,
247
258
):
259
+ chromium_log : ChromiumEventLogger = get_chromium_event_logger ()
248
260
if key not in compilation_time_metrics :
249
261
compilation_time_metrics [key ] = []
250
262
251
263
fail_type : Optional [str ] = None
252
264
fail_reason : Optional [str ] = None
253
265
time_spent = float ("-inf" )
266
+ if phase_name == "entire_frame_compile" :
267
+ chromium_log .reset ()
254
268
try :
255
269
with torch .profiler .record_function (f"{ key } (dynamo_timed)" ):
256
270
t0 = time .time ()
257
- ChromiumEventLogger .log_event_start (key , time .time_ns ())
271
+ start = time .time_ns ()
272
+ chromium_log .log_event_start (key , start , None )
258
273
if phase_name :
259
- ChromiumEventLogger .log_event_start (phase_name , time . time_ns () )
274
+ chromium_log .log_event_start (phase_name , start )
260
275
yield
276
+
261
277
if phase_name :
262
- ChromiumEventLogger .log_event_end (phase_name , time .time_ns ())
263
- ChromiumEventLogger .log_event_end (key , time .time_ns ())
278
+ chromium_log .log_event_end (
279
+ phase_name ,
280
+ time .time_ns (),
281
+ {"cache_stats" : get_cache_stats ()},
282
+ start ,
283
+ )
284
+ chromium_log .log_event_end (
285
+ key , time .time_ns (), {"cache_stats" : get_cache_stats ()}, start
286
+ )
264
287
time_spent = time .time () - t0
265
288
compilation_time_metrics [key ].append (time_spent )
266
289
except Exception as e :
@@ -814,8 +837,17 @@ class ChromiumEventLogger:
814
837
a specification of the Chromium Event JSON format.
815
838
"""
816
839
817
- @staticmethod
840
+ def __init__ (self ):
841
+ self .stack = ["__start__" ]
842
+ # Generate a unique id for this logger, which we can use in scuba to filter down
843
+ # to a single python run.
844
+ self .id_ = str (uuid .uuid4 ())
845
+
846
+ # TODO: log to init/id tlparse after I add support for it
847
+ log .info ("ChromiumEventLogger initialized with id %s" , self .id_ )
848
+
818
849
def log_event_start (
850
+ self ,
819
851
event_name : str ,
820
852
time_ns : int ,
821
853
metadata : Optional [Dict [str , Any ]] = None ,
@@ -826,18 +858,24 @@ def log_event_start(
826
858
:param time_ns Timestamp in nanoseconds
827
859
:param metadata: Any extra metadata associated with this event
828
860
"""
829
- ChromiumEventLogger ._log_timed_event (
861
+ event = self ._log_timed_event (
830
862
event_name ,
831
863
time_ns ,
832
864
"B" ,
833
865
metadata ,
834
866
)
867
+ log_chromium_event_internal (event , self .stack , self .id_ )
868
+ self .stack .append (event_name )
869
+
870
+ def reset (self ) -> None :
871
+ self .stack = ["__start__" ]
835
872
836
- @staticmethod
837
873
def log_event_end (
874
+ self ,
838
875
event_name : str ,
839
876
time_ns : int ,
840
877
metadata : Optional [Dict [str , Any ]] = None ,
878
+ start_time_ns : Optional [int ] = None ,
841
879
) -> None :
842
880
"""
843
881
Logs the end of a single event. This function should only be
@@ -846,28 +884,53 @@ def log_event_end(
846
884
:param time_ns: Timestamp in nanoseconds
847
885
:param metadata: Any extra metadata associated with this event
848
886
"""
849
- ChromiumEventLogger ._log_timed_event (
887
+ # These stack health checks currently never happen,
888
+ # but they're written this way to future proof any weird event
889
+ # overlaps in the future.
890
+ if event_name not in self .stack :
891
+ # Something went wrong, we never called start on this event,
892
+ # or it was skipped due to overlapping events below
893
+ log .warning ("ChromiumEventLogger: Start event not in stack, ignoring" )
894
+ return
895
+
896
+ event = self ._log_timed_event (
850
897
event_name ,
851
898
time_ns ,
852
899
"E" ,
853
900
metadata ,
854
901
)
855
902
856
- @staticmethod
903
+ while event_name != self .stack [- 1 ]:
904
+ # If the event isn't the most recent one to end, pop
905
+ # off the stack until it is.
906
+ # Since event_name in self.stack, this pop is always safe
907
+ log .warning (
908
+ "ChromiumEventLogger: Detected overlapping events, fixing stack"
909
+ )
910
+ self .stack .pop ()
911
+
912
+ log_chromium_event_internal (event , self .stack , self .id_ , start_time_ns )
913
+ # Finally pop the actual event off the stack
914
+ self .stack .pop ()
915
+
857
916
def _log_timed_event (
917
+ self ,
858
918
event_name : str ,
859
919
time_ns : int ,
860
920
phase : str ,
861
921
metadata : Optional [Dict [str , Any ]] = None ,
862
- ) -> None :
922
+ ) -> Dict [ str , Any ] :
863
923
"""
864
924
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
865
925
"""
866
926
event = {
867
927
"name" : event_name ,
868
- "ts" : time_ns / 1000 , # Chromium events are in ms
928
+ "ts" : time_ns / 1000 , # Chromium events are in micro seconds
869
929
"args" : metadata ,
870
930
"ph" : phase ,
931
+ # These categories are needed in all chromium traces
932
+ "cat" : "dynamo_timed" ,
933
+ "tid" : 0 ,
871
934
"pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
872
935
}
873
936
torch ._logging .trace_structured (
@@ -876,9 +939,10 @@ def _log_timed_event(
876
939
suppress_context = False ,
877
940
expect_trace_id = False , # Not every chromium event will have a trace_id
878
941
)
942
+ return event
879
943
880
- @staticmethod
881
944
def log_instant_event (
945
+ self ,
882
946
event_name : str ,
883
947
time_ns : int ,
884
948
metadata : Optional [Dict [str , Any ]] = None ,
@@ -895,7 +959,10 @@ def log_instant_event(
895
959
"ts" : time_ns / 1000 ,
896
960
"args" : metadata ,
897
961
"ph" : "i" ,
898
- "pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
962
+ # These categories are needed in all chromium traces
963
+ "cat" : "dynamo_timed" ,
964
+ "tid" : 0 ,
965
+ "pid" : 0 ,
899
966
"s" : "p" , # We use "process" level instant events so they all appear on the same row in the trace.
900
967
}
901
968
torch ._logging .trace_structured (
@@ -904,6 +971,18 @@ def log_instant_event(
904
971
suppress_context = False ,
905
972
expect_trace_id = True ,
906
973
)
974
+ # Log an instant event with the same start and end time
975
+ log_chromium_event_internal (event , self .stack , self .id_ )
976
+
977
+
978
+ chromium_event_log = None
979
+
980
+
981
+ def get_chromium_event_logger () -> ChromiumEventLogger :
982
+ global chromium_event_log
983
+ if chromium_event_log is None :
984
+ chromium_event_log = ChromiumEventLogger ()
985
+ return chromium_event_log
907
986
908
987
909
988
@dataclasses .dataclass
0 commit comments