11
11
import enum
12
12
import functools
13
13
import gc
14
+ import uuid
14
15
import importlib
15
16
import inspect
16
17
import itertools
64
65
from torch ._dispatch .python import enable_python_dispatcher
65
66
from torch ._guards import TracingContext
66
67
from torch ._subclasses .meta_utils import is_sparse_compressed
67
- from torch ._utils_internal import log_compilation_event
68
+ from torch ._utils_internal import log_compilation_event , log_chromium_event_internal
68
69
from torch .fx ._utils import _format_graph_code , lazy_format_graph_code
69
70
from torch .nn .modules .lazy import LazyModuleMixin
70
71
from torch .utils ._triton import has_triton , has_triton_package
@@ -212,6 +213,15 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
212
213
frame_phase_timing [key ][phase_name ] += time_spent
213
214
214
215
216
+ def get_cache_stats () -> Dict [str , Any ]:
217
+ """Get a bunch of metadata about cache hits and misses to use in chromium events"""
218
+ cache_stats = {
219
+ "fxgraph_cache_hit" :counters ["inductor" ]["fxgraph_cache_hit" ],
220
+ "fxgraph_cache_miss" : counters ["inductor" ]["fxgraph_cache_miss" ],
221
+ "fxgraph_cache_bypass" : counters ["inductor" ]["fxgraph_cache_bypass" ],
222
+ }
223
+ return cache_stats
224
+
215
225
# dynamo_timed is a context manager
216
226
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
217
227
# where the key is the functions name.
@@ -251,16 +261,20 @@ def dynamo_timed(
251
261
fail_type : Optional [str ] = None
252
262
fail_reason : Optional [str ] = None
253
263
time_spent = float ("-inf" )
264
+ if phase_name == "entire_frame_compile" :
265
+ reset_chromium_events ()
254
266
try :
255
267
with torch .profiler .record_function (f"{ key } (dynamo_timed)" ):
256
268
t0 = time .time ()
257
- ChromiumEventLogger .log_event_start (key , time .time_ns ())
269
+ start = time .time_ns ()
270
+ ChromiumEventLogger .log_event_start (key , start , None )
258
271
if phase_name :
259
- ChromiumEventLogger .log_event_start (phase_name , time . time_ns () )
272
+ ChromiumEventLogger .log_event_start (phase_name , start )
260
273
yield
274
+
261
275
if phase_name :
262
- ChromiumEventLogger .log_event_end (phase_name , time .time_ns ())
263
- ChromiumEventLogger .log_event_end (key , time .time_ns ())
276
+ ChromiumEventLogger .log_event_end (phase_name , time .time_ns (), { "cache_stats" : get_cache_stats ()}, start )
277
+ ChromiumEventLogger .log_event_end (key , time .time_ns (), { "cache_stats" : get_cache_stats ()}, start )
264
278
time_spent = time .time () - t0
265
279
compilation_time_metrics [key ].append (time_spent )
266
280
except Exception as e :
@@ -807,6 +821,18 @@ def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMe
807
821
return list (_compilation_metrics )
808
822
809
823
824
+ chromium_event_stack = ["__start__" ]
825
+ # Generate a unique id for this process, which we can use in scuba to filter down
826
+ # to a single python run.
827
+ # TODO: figure out what this actually should be reset at
828
+ compile_unique_id = str (uuid .uuid4 ())
829
+
830
+ def reset_chromium_events () -> None :
831
+ global chromium_event_stack
832
+ chromium_event_stack = ["__start__" ]
833
+
834
+
835
+
810
836
class ChromiumEventLogger :
811
837
"""Logs chromium events to structured logs. tlparse will concatenate these into a perfetto UI link.
812
838
@@ -826,18 +852,22 @@ def log_event_start(
826
852
:param time_ns Timestamp in nanoseconds
827
853
:param metadata: Any extra metadata associated with this event
828
854
"""
829
- ChromiumEventLogger ._log_timed_event (
855
+ global chromium_event_stack
856
+ event = ChromiumEventLogger ._log_timed_event (
830
857
event_name ,
831
858
time_ns ,
832
859
"B" ,
833
860
metadata ,
834
861
)
862
+ log_chromium_event_internal (event , chromium_event_stack , compile_unique_id )
863
+ chromium_event_stack .append (event_name )
835
864
836
865
@staticmethod
837
866
def log_event_end (
838
867
event_name : str ,
839
868
time_ns : int ,
840
869
metadata : Optional [Dict [str , Any ]] = None ,
870
+ start_time_ns : Optional [int ] = None ,
841
871
) -> None :
842
872
"""
843
873
Logs the end of a single event. This function should only be
@@ -846,28 +876,53 @@ def log_event_end(
846
876
:param time_ns: Timestamp in nanoseconds
847
877
:param metadata: Any extra metadata associated with this event
848
878
"""
849
- ChromiumEventLogger ._log_timed_event (
879
+ global chromium_event_stack
880
+ # These stack health checks currently never happen,
881
+ # but they're written this way to future proof any weird event
882
+ # overlaps in the future.
883
+ if (event_name not in chromium_event_stack ):
884
+ # Something went wrong, we never called start on this event,
885
+ # or it was skipped due to overlapping events below
886
+ log .warn ("Start event not in stack, ignoring" )
887
+ return
888
+
889
+ event = ChromiumEventLogger ._log_timed_event (
850
890
event_name ,
851
891
time_ns ,
852
892
"E" ,
853
893
metadata ,
854
894
)
855
895
896
+ while event_name != chromium_event_stack [- 1 ]:
897
+ # If the event isn't the most recent one to end, pop
898
+ # off the stack until it is.
899
+ # Since event_name in chromium_event_stack, this pop is always safe
900
+ log .warn ("Detected overlapping events, fixing stack" )
901
+ chromium_event_stack .pop ()
902
+
903
+ log_chromium_event_internal (event , chromium_event_stack , compile_unique_id , start_time_ns )
904
+ # Finally pop the actual event off the stack
905
+ chromium_event_stack .pop ()
906
+
907
+
856
908
@staticmethod
857
909
def _log_timed_event (
858
910
event_name : str ,
859
911
time_ns : int ,
860
912
phase : str ,
861
913
metadata : Optional [Dict [str , Any ]] = None ,
862
- ) -> None :
914
+ ) -> Dict [ str , Any ] :
863
915
"""
864
916
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
865
917
"""
866
918
event = {
867
919
"name" : event_name ,
868
- "ts" : time_ns / 1000 , # Chromium events are in ms
920
+ "ts" : time_ns / 1000 , # Chromium events are in micro seconds
869
921
"args" : metadata ,
870
922
"ph" : phase ,
923
+ # These categories are needed in all chromium traces
924
+ "cat" : "dynamo_timed" ,
925
+ "tid" : 0 ,
871
926
"pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
872
927
}
873
928
torch ._logging .trace_structured (
@@ -876,6 +931,7 @@ def _log_timed_event(
876
931
suppress_context = False ,
877
932
expect_trace_id = False , # Not every chromium event will have a trace_id
878
933
)
934
+ return event
879
935
880
936
@staticmethod
881
937
def log_instant_event (
@@ -895,7 +951,10 @@ def log_instant_event(
895
951
"ts" : time_ns / 1000 ,
896
952
"args" : metadata ,
897
953
"ph" : "i" ,
898
- "pid" : 0 , # pid should be specified on all logs, we don't personally care about the actual process id
954
+ # These categories are needed in all chromium traces
955
+ "cat" : "dynamo_timed" ,
956
+ "tid" : 0 ,
957
+ "pid" : 0 ,
899
958
"s" : "p" , # We use "process" level instant events so they all appear on the same row in the trace.
900
959
}
901
960
torch ._logging .trace_structured (
@@ -904,6 +963,9 @@ def log_instant_event(
904
963
suppress_context = False ,
905
964
expect_trace_id = True ,
906
965
)
966
+ # Log an instant event with the same start and end time
967
+ log_chromium_event_internal (event , chromium_event_stack , compile_unique_id )
968
+
907
969
908
970
909
971
@dataclasses .dataclass
0 commit comments