@@ -269,32 +269,32 @@ def dynamo_timed(
269
269
fail_type : Optional [str ] = None
270
270
fail_reason : Optional [str ] = None
271
271
time_spent = float ("-inf" )
272
+ start = time .time_ns ()
272
273
try :
273
274
with torch .profiler .record_function (f"{ key } (dynamo_timed)" ):
274
275
t0 = time .time ()
275
- start = time .time_ns ()
276
276
chromium_log .log_event_start (key , start , None )
277
277
if phase_name :
278
278
chromium_log .log_event_start (phase_name , start )
279
279
yield
280
-
281
- if phase_name :
282
- chromium_log .log_event_end (
283
- phase_name ,
284
- time .time_ns (),
285
- {"cache_stats" : get_cache_stats ()},
286
- start ,
287
- )
288
- chromium_log .log_event_end (
289
- key , time .time_ns (), {"cache_stats" : get_cache_stats ()}, start
290
- )
291
280
time_spent = time .time () - t0
292
281
compilation_time_metrics [key ].append (time_spent )
293
282
except Exception as e :
294
283
fail_type = str (type (e ))
295
284
fail_reason = str (e )
296
285
raise
297
286
finally :
287
+ # Always log the end event even on exception
288
+ if phase_name :
289
+ chromium_log .log_event_end (
290
+ phase_name ,
291
+ time .time_ns (),
292
+ {"cache_stats" : get_cache_stats ()},
293
+ start ,
294
+ )
295
+ chromium_log .log_event_end (
296
+ key , time .time_ns (), {"cache_stats" : get_cache_stats ()}, start
297
+ )
298
298
# Only record backward compilation metrics if phase_name is not None!
299
299
if phase_name :
300
300
frame_key = str (curr_frame )
@@ -841,8 +841,15 @@ class ChromiumEventLogger:
841
841
a specification of the Chromium Event JSON format.
842
842
"""
843
843
844
+ def get_stack (self ):
845
+ if hasattr (self .tls , "stack" ):
846
+ return self .tls .stack
847
+ else :
848
+ self .tls .stack = ["__start__" ]
849
+ return self .tls .stack
850
+
844
851
def __init__ (self ):
845
- self .stack = [ "__start__" ]
852
+ self .tls = threading . local ()
846
853
# Generate a unique id for this logger, which we can use in scuba to filter down
847
854
# to a single python run.
848
855
self .id_ = str (uuid .uuid4 ())
@@ -868,14 +875,15 @@ def log_event_start(
868
875
"B" ,
869
876
metadata ,
870
877
)
871
- log_chromium_event_internal (event , self .stack , self .id_ )
872
- self .stack .append (event_name )
878
+ log_chromium_event_internal (event , self .get_stack () , self .id_ )
879
+ self .get_stack () .append (event_name )
873
880
874
881
def reset (self ) -> None :
875
882
# We this on every compile in case a compile crashes or restarts and we haven't
876
883
# cleared the stack.
877
- self .stack .clear ()
878
- self .stack .append ("__start__" )
884
+ stack = self .get_stack ()
885
+ stack .clear ()
886
+ stack .append ("__start__" )
879
887
880
888
def log_event_end (
881
889
self ,
@@ -894,7 +902,8 @@ def log_event_end(
894
902
# These stack health checks currently never happen,
895
903
# but they're written this way to future proof any weird event
896
904
# overlaps in the future.
897
- if event_name not in self .stack :
905
+ stack = self .get_stack ()
906
+ if event_name not in stack :
898
907
# Something went wrong, we never called start on this event,
899
908
# or it was skipped due to overlapping events below
900
909
log .warning ("ChromiumEventLogger: Start event not in stack, ignoring" )
@@ -907,18 +916,18 @@ def log_event_end(
907
916
metadata ,
908
917
)
909
918
910
- while event_name != self . stack [- 1 ]:
919
+ while event_name != stack [- 1 ]:
911
920
# If the event isn't the most recent one to end, pop
912
921
# off the stack until it is.
913
922
# Since event_name in self.stack, this pop is always safe
914
923
log .warning (
915
924
"ChromiumEventLogger: Detected overlapping events, fixing stack"
916
925
)
917
- self . stack .pop ()
926
+ stack .pop ()
918
927
919
- log_chromium_event_internal (event , self . stack , self .id_ , start_time_ns )
928
+ log_chromium_event_internal (event , stack , self .id_ , start_time_ns )
920
929
# Finally pop the actual event off the stack
921
- self . stack .pop ()
930
+ stack .pop ()
922
931
923
932
def _log_timed_event (
924
933
self ,
@@ -979,7 +988,7 @@ def log_instant_event(
979
988
expect_trace_id = True ,
980
989
)
981
990
# Log an instant event with the same start and end time
982
- log_chromium_event_internal (event , self .stack , self .id_ )
991
+ log_chromium_event_internal (event , self .get_stack () , self .id_ )
983
992
984
993
985
994
CHROMIUM_EVENT_LOG : Optional [ChromiumEventLogger ] = None
0 commit comments