Skip to content

Commit f8fbfe5

Browse files
jamesjwupytorchmergebot
authored andcommitted
Always emit end events even on failure, use thread local storage for stack (pytorch#134279)
Summary: We should always emit an end event in a finally block so that if a unit test or job fails, the stack is still correct. Also, we use thread local storage for the stack, so that in multithreaded scenarios the stack will still be correctly added. Test Plan: Run benchmark and see that everything still works Run ``` TORCH_LOGS=dynamo buck run test/functorch:test_aotdispatch -- -r test_backward_mutation_on_grad_out ``` With some extra logging to see that start events with the correct stack are emitted, and the end events are also emitted even though the test fails at runtime. Differential Revision: D61682556 Pull Request resolved: pytorch#134279 Approved by: https://github.com/aorenste
1 parent a23d86c commit f8fbfe5

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

torch/_dynamo/utils.py

+32-23
Original file line numberDiff line numberDiff line change
@@ -269,32 +269,32 @@ def dynamo_timed(
269269
fail_type: Optional[str] = None
270270
fail_reason: Optional[str] = None
271271
time_spent = float("-inf")
272+
start = time.time_ns()
272273
try:
273274
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
274275
t0 = time.time()
275-
start = time.time_ns()
276276
chromium_log.log_event_start(key, start, None)
277277
if phase_name:
278278
chromium_log.log_event_start(phase_name, start)
279279
yield
280-
281-
if phase_name:
282-
chromium_log.log_event_end(
283-
phase_name,
284-
time.time_ns(),
285-
{"cache_stats": get_cache_stats()},
286-
start,
287-
)
288-
chromium_log.log_event_end(
289-
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
290-
)
291280
time_spent = time.time() - t0
292281
compilation_time_metrics[key].append(time_spent)
293282
except Exception as e:
294283
fail_type = str(type(e))
295284
fail_reason = str(e)
296285
raise
297286
finally:
287+
# Always log the end event even on exception
288+
if phase_name:
289+
chromium_log.log_event_end(
290+
phase_name,
291+
time.time_ns(),
292+
{"cache_stats": get_cache_stats()},
293+
start,
294+
)
295+
chromium_log.log_event_end(
296+
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
297+
)
298298
# Only record backward compilation metrics if phase_name is not None!
299299
if phase_name:
300300
frame_key = str(curr_frame)
@@ -841,8 +841,15 @@ class ChromiumEventLogger:
841841
a specification of the Chromium Event JSON format.
842842
"""
843843

844+
def get_stack(self):
845+
if hasattr(self.tls, "stack"):
846+
return self.tls.stack
847+
else:
848+
self.tls.stack = ["__start__"]
849+
return self.tls.stack
850+
844851
def __init__(self):
845-
self.stack = ["__start__"]
852+
self.tls = threading.local()
846853
# Generate a unique id for this logger, which we can use in scuba to filter down
847854
# to a single python run.
848855
self.id_ = str(uuid.uuid4())
@@ -868,14 +875,15 @@ def log_event_start(
868875
"B",
869876
metadata,
870877
)
871-
log_chromium_event_internal(event, self.stack, self.id_)
872-
self.stack.append(event_name)
878+
log_chromium_event_internal(event, self.get_stack(), self.id_)
879+
self.get_stack().append(event_name)
873880

874881
def reset(self) -> None:
875882
# We this on every compile in case a compile crashes or restarts and we haven't
876883
# cleared the stack.
877-
self.stack.clear()
878-
self.stack.append("__start__")
884+
stack = self.get_stack()
885+
stack.clear()
886+
stack.append("__start__")
879887

880888
def log_event_end(
881889
self,
@@ -894,7 +902,8 @@ def log_event_end(
894902
# These stack health checks currently never happen,
895903
# but they're written this way to future proof any weird event
896904
# overlaps in the future.
897-
if event_name not in self.stack:
905+
stack = self.get_stack()
906+
if event_name not in stack:
898907
# Something went wrong, we never called start on this event,
899908
# or it was skipped due to overlapping events below
900909
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
@@ -907,18 +916,18 @@ def log_event_end(
907916
metadata,
908917
)
909918

910-
while event_name != self.stack[-1]:
919+
while event_name != stack[-1]:
911920
# If the event isn't the most recent one to end, pop
912921
# off the stack until it is.
913922
# Since event_name in self.stack, this pop is always safe
914923
log.warning(
915924
"ChromiumEventLogger: Detected overlapping events, fixing stack"
916925
)
917-
self.stack.pop()
926+
stack.pop()
918927

919-
log_chromium_event_internal(event, self.stack, self.id_, start_time_ns)
928+
log_chromium_event_internal(event, stack, self.id_, start_time_ns)
920929
# Finally pop the actual event off the stack
921-
self.stack.pop()
930+
stack.pop()
922931

923932
def _log_timed_event(
924933
self,
@@ -979,7 +988,7 @@ def log_instant_event(
979988
expect_trace_id=True,
980989
)
981990
# Log an instant event with the same start and end time
982-
log_chromium_event_internal(event, self.stack, self.id_)
991+
log_chromium_event_internal(event, self.get_stack(), self.id_)
983992

984993

985994
CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None

0 commit comments

Comments
 (0)