Skip to content

Commit 478cc35

Browse files
masnesralfacebook-github-bot
authored andcommittedDec 13, 2024
Log runtime autotuning timing to scuba (#141919)
Summary: See test plan in internal diff [D66679369](https://our.internmc.facebook.com/intern/diff/D66679369) X-link: pytorch/pytorch#141919 Approved by: https://github.com/jamesjwu, https://github.com/ezyang Differential Revision: D67218561 Pulled By: masnesral
1 parent 285fb28 commit 478cc35

File tree

1 file changed

+73
-18
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+73
-18
lines changed
 

‎userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@
7373
_push_on_torch_function_stack,
7474
)
7575
from torch._dispatch.python import enable_python_dispatcher
76-
from torch._dynamo.metrics_context import MetricsContext
77-
from torch._guards import Source, TracingContext
76+
from torch._dynamo.metrics_context import MetricsContext, RuntimeMetricsContext
77+
from torch._guards import CompileId, Source, TracingContext
7878
from torch._subclasses.meta_utils import is_sparse_compressed
7979
from torch._utils_internal import (
8080
log_chromium_event_internal,
@@ -288,12 +288,17 @@ def print_time_report() -> None:
288288
# ...
289289
#
290290
_METRICS_CONTEXT: MetricsContext
291+
_RUNTIME_METRICS_CONTEXT: RuntimeMetricsContext
291292

292293

293294
def get_metrics_context() -> MetricsContext:
294295
return _METRICS_CONTEXT
295296

296297

298+
def get_runtime_metrics_context() -> RuntimeMetricsContext:
299+
return _RUNTIME_METRICS_CONTEXT
300+
301+
297302
@contextmanager
298303
def dynamo_timed(
299304
key: str,
@@ -302,16 +307,20 @@ def dynamo_timed(
302307
log_pt2_compile_event: bool = False,
303308
metadata: Optional[Dict[str, object]] = None,
304309
dynamo_compile_column_us: Optional[str] = None,
310+
dynamo_compile_runtime_column_us: Optional[str] = None,
311+
compile_id: Optional[CompileId] = None,
312+
is_forward: Optional[bool] = None,
305313
log_waitcounter: bool = False,
306314
) -> Generator[Any, None, None]:
307315
"""
308316
dynamo_timed is a context manager
309317
By wrapping a function in dynamo_timed, we can get a few things:
310318
311-
1) Log timings to pt2_compile_events.
312-
2) Log timings to CompilationMetrics (dynamo_compile).
313-
3) Chromium events.
314-
4) Storing a record in compilation_time_metrics
319+
1) Optionally log timings to pt2_compile_events.
320+
2) Optionally log timings to CompilationMetrics (dynamo_compile).
321+
3) Optionally log chromium events.
322+
4) Optionally increment a WaitCounter.
323+
5) Store a record in compilation_time_metrics
315324
For example:
316325
317326
def _foo(...):
@@ -336,12 +345,23 @@ def _foo(...):
336345
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
337346
field to be logged to dyname_compile column. We expect all columns to be _us;
338347
therefore, the field name must end with "_us".
348+
- dynamo_compile_runtime_column_us: Like 'dynamo_compile_column_us', but should
349+
be used for those columns captured outside of a compile context, e.g.,
350+
runtime autotuning.
351+
- compile_id: In the typical case, this parameter should not be needed. Use to
352+
supply the compile_id for those cases where we want to log a compile_id where
353+
it's not naturally available, e.g., for runtime autotuning.
354+
- is_forward: Optionally set an is_forward field for those logging destinations
355+
that support it.
339356
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
340357
"""
341358
# We're standardizing on microseconds for dynamo_compile timings.
342359
if dynamo_compile_column_us is not None:
343360
assert dynamo_compile_column_us.endswith("_us")
344361

362+
# Only one of these should be set.
363+
assert dynamo_compile_column_us is None or dynamo_compile_runtime_column_us is None
364+
345365
if phase_name:
346366
event_name = phase_name
347367
fn_name = key
@@ -357,11 +377,13 @@ def _foo(...):
357377
event_metadata.update(metadata)
358378
if fn_name:
359379
event_metadata.update({"fn_name": fn_name})
380+
if is_forward is not None:
381+
event_metadata.update({"is_backward": not is_forward})
360382

361383
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
362384
start_ns = time.time_ns()
363385
chromium_log.log_event_start(
364-
event_name, start_ns, event_metadata, log_pt2_compile_event
386+
event_name, start_ns, event_metadata, log_pt2_compile_event, compile_id
365387
)
366388

367389
try:
@@ -376,7 +398,7 @@ def _foo(...):
376398
time_spent_ns = end_ns - start_ns
377399
compilation_time_metrics[key].append(time_spent_ns / 1e9)
378400
chromium_log.log_event_end(
379-
event_name, end_ns, {}, start_ns, log_pt2_compile_event
401+
event_name, end_ns, {}, start_ns, log_pt2_compile_event, compile_id
380402
)
381403
if dynamo_compile_column_us:
382404
metrics_context = get_metrics_context()
@@ -391,6 +413,18 @@ def _foo(...):
391413
# this way?
392414
cumulative_time_spent_ns[event_name] += time_spent_ns
393415

416+
if dynamo_compile_runtime_column_us:
417+
get_runtime_metrics_context().increment(
418+
dynamo_compile_runtime_column_us,
419+
time_spent_ns // 1000,
420+
extra={
421+
"compile_id": compile_id,
422+
"is_runtime": True,
423+
"is_forward": is_forward,
424+
},
425+
)
426+
cumulative_time_spent_ns[event_name] += time_spent_ns
427+
394428

395429
@overload
396430
def compile_times(repr: Literal["str"], aggregate: bool = False) -> str:
@@ -858,7 +892,7 @@ class CompilationMetrics:
858892
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
859893
triton_compile_time_us: Optional[int] = None
860894
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
861-
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
895+
runtime_triton_autotune_time_us: Optional[int] = None
862896
dynamo_compile_time_before_restart_us: Optional[int] = None
863897
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
864898
distributed_ephemeral_timeout_us: Optional[int] = None
@@ -882,6 +916,7 @@ class CompilationMetrics:
882916
triton_version: Optional[str] = None
883917
feature_usage: Optional[dict[str, bool]] = None
884918
compile_time_autotune_time_us: Optional[int] = None
919+
is_runtime: Optional[bool] = False
885920

886921

887922
DEFAULT_COMPILATION_METRICS_LIMIT = 64
@@ -1022,8 +1057,14 @@ def safe_str(item: Any) -> str:
10221057
inductor_fx_remote_cache_backend_type = None
10231058
remote_cache_version = None
10241059

1060+
# Populate the compile_id from the metrics context if it's set. Otherwise
1061+
# look for it in the compile context.
1062+
compile_id = metrics.get("compile_id")
1063+
if not compile_id:
1064+
compile_id = torch._guards.CompileContext.current_compile_id()
1065+
10251066
common_metrics = {
1026-
"compile_id": str(torch._guards.CompileContext.current_compile_id()),
1067+
"compile_id": str(compile_id) if compile_id else None,
10271068
"start_time_us": start_time_ns // 1000,
10281069
"end_time_us": end_time_ns // 1000,
10291070
"duration_us": (end_time_ns - start_time_ns) // 1000,
@@ -1066,10 +1107,12 @@ def safe_str(item: Any) -> str:
10661107
)
10671108
_compilation_metrics.append(compilation_metrics)
10681109

1069-
if compilation_metrics.is_forward:
1070-
name = "compilation_metrics"
1071-
else:
1072-
name = "bwd_compilation_metrics"
1110+
name = "compilation_metrics"
1111+
if compilation_metrics.is_forward is False:
1112+
name = "bwd_" + name
1113+
if compilation_metrics.is_runtime is True:
1114+
name = name + "_runtime"
1115+
10731116
torch._logging.trace_structured(
10741117
name,
10751118
lambda: {
@@ -1081,6 +1124,10 @@ def safe_str(item: Any) -> str:
10811124
# without making it inconsistent with compilation metrics itself, so
10821125
# we ignore the (hopefully small) time spent logging compilation metrics
10831126
record_logging_overhead=False,
1127+
# These may be runtime logs, e.g., runtime autotunning, so we provide
1128+
# the CompileId from the compilation metrics in case it's not available
1129+
# in the current trace.
1130+
compile_id=compile_id,
10841131
)
10851132

10861133
# If there's a chromium event in flight, add the CompilationMetrics to it.
@@ -1093,6 +1140,7 @@ def safe_str(item: Any) -> str:
10931140

10941141
# record_compilation_metrics is called by the singleton MetricsContext exit handler.
10951142
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
1143+
_RUNTIME_METRICS_CONTEXT = RuntimeMetricsContext(on_exit=record_compilation_metrics)
10961144

10971145

10981146
def set_compilation_metrics_limit(new_size: int) -> None:
@@ -1196,15 +1244,18 @@ def log_event_start(
11961244
time_ns: int,
11971245
metadata: Dict[str, Any],
11981246
log_pt2_compile_event: bool = False,
1247+
compile_id: Optional[CompileId] = None,
11991248
) -> None:
12001249
"""
12011250
Logs the start of a single event.
12021251
:param str event_name Name of event to appear in trace
12031252
:param time_ns Timestamp in nanoseconds
12041253
:param metadata: Any extra metadata associated with this event
1254+
:param log_pt_compile_event: If True, log to pt2_compile_events
1255+
:param compile_id: Explicit compile_id (rather than using the current context)
12051256
"""
1206-
compile_id = str(torch._guards.CompileContext.current_compile_id())
1207-
metadata["compile_id"] = compile_id
1257+
compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
1258+
metadata["compile_id"] = str(compile_id)
12081259
self._log_timed_event(
12091260
event_name,
12101261
time_ns,
@@ -1234,16 +1285,20 @@ def log_event_end(
12341285
metadata: Dict[str, Any],
12351286
start_time_ns: int,
12361287
log_pt2_compile_event: bool,
1288+
compile_id: Optional[CompileId] = None,
12371289
) -> None:
12381290
"""
12391291
Logs the end of a single event. This function should only be
12401292
called after log_event_start with the same event_name.
12411293
:param event_name: Name of event to appear in trace
12421294
:param time_ns: Timestamp in nanoseconds
12431295
:param metadata: Any extra metadata associated with this event
1296+
:param start_time_ns: The start time timestamp in nanoseconds
1297+
:param log_pt_compile_event: If True, log to pt2_compile_events
1298+
:param compile_id: Explicit compile_id (rather than using the current context)
12441299
"""
1245-
compile_id = str(torch._guards.CompileContext.current_compile_id())
1246-
metadata["compile_id"] = compile_id
1300+
compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
1301+
metadata["compile_id"] = str(compile_id)
12471302

12481303
# Grab metadata collected during event span
12491304
all_event_data = self.get_event_data()

0 commit comments

Comments
 (0)