73
73
_push_on_torch_function_stack ,
74
74
)
75
75
from torch ._dispatch .python import enable_python_dispatcher
76
- from torch ._dynamo .metrics_context import MetricsContext
77
- from torch ._guards import Source , TracingContext
76
+ from torch ._dynamo .metrics_context import MetricsContext , RuntimeMetricsContext
77
+ from torch ._guards import CompileId , Source , TracingContext
78
78
from torch ._subclasses .meta_utils import is_sparse_compressed
79
79
from torch ._utils_internal import (
80
80
log_chromium_event_internal ,
@@ -288,12 +288,17 @@ def print_time_report() -> None:
288
288
# ...
289
289
#
290
290
_METRICS_CONTEXT : MetricsContext
291
+ _RUNTIME_METRICS_CONTEXT : RuntimeMetricsContext
291
292
292
293
293
294
def get_metrics_context () -> MetricsContext :
294
295
return _METRICS_CONTEXT
295
296
296
297
298
+ def get_runtime_metrics_context () -> RuntimeMetricsContext :
299
+ return _RUNTIME_METRICS_CONTEXT
300
+
301
+
297
302
@contextmanager
298
303
def dynamo_timed (
299
304
key : str ,
@@ -302,16 +307,20 @@ def dynamo_timed(
302
307
log_pt2_compile_event : bool = False ,
303
308
metadata : Optional [Dict [str , object ]] = None ,
304
309
dynamo_compile_column_us : Optional [str ] = None ,
310
+ dynamo_compile_runtime_column_us : Optional [str ] = None ,
311
+ compile_id : Optional [CompileId ] = None ,
312
+ is_forward : Optional [bool ] = None ,
305
313
log_waitcounter : bool = False ,
306
314
) -> Generator [Any , None , None ]:
307
315
"""
308
316
dynamo_timed is a context manager
309
317
By wrapping a function in dynamo_timed, we can get a few things:
310
318
311
- 1) Log timings to pt2_compile_events.
312
- 2) Log timings to CompilationMetrics (dynamo_compile).
313
- 3) Chromium events.
314
- 4) Storing a record in compilation_time_metrics
319
+ 1) Optionally log timings to pt2_compile_events.
320
+ 2) Optionally log timings to CompilationMetrics (dynamo_compile).
321
+ 3) Optionally log chromium events.
322
+ 4) Optionally increment a WaitCounter.
323
+ 5) Store a record in compilation_time_metrics
315
324
For example:
316
325
317
326
def _foo(...):
@@ -336,12 +345,23 @@ def _foo(...):
336
345
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
337
346
field to be logged to dyname_compile column. We expect all columns to be _us;
338
347
therefore, the field name must end with "_us".
348
+ - dynamo_compile_runtime_column_us: Like 'dynamo_compile_column_us', but should
349
+ be used for those columns captured outside of a compile context, e.g.,
350
+ runtime autotuning.
351
+ - compile_id: In the typical case, this parameter should not be needed. Use to
352
+ supply the compile_id for those cases where we want to log a compile_id where
353
+ it's not naturally available, e.g., for runtime autotuning.
354
+ - is_forward: Optionally set an is_forward field for those logging destinations
355
+ that support it.
339
356
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
340
357
"""
341
358
# We're standardizing on microseconds for dynamo_compile timings.
342
359
if dynamo_compile_column_us is not None :
343
360
assert dynamo_compile_column_us .endswith ("_us" )
344
361
362
+ # Only one of these should be set.
363
+ assert dynamo_compile_column_us is None or dynamo_compile_runtime_column_us is None
364
+
345
365
if phase_name :
346
366
event_name = phase_name
347
367
fn_name = key
@@ -357,11 +377,13 @@ def _foo(...):
357
377
event_metadata .update (metadata )
358
378
if fn_name :
359
379
event_metadata .update ({"fn_name" : fn_name })
380
+ if is_forward is not None :
381
+ event_metadata .update ({"is_backward" : not is_forward })
360
382
361
383
chromium_log : ChromiumEventLogger = get_chromium_event_logger ()
362
384
start_ns = time .time_ns ()
363
385
chromium_log .log_event_start (
364
- event_name , start_ns , event_metadata , log_pt2_compile_event
386
+ event_name , start_ns , event_metadata , log_pt2_compile_event , compile_id
365
387
)
366
388
367
389
try :
@@ -376,7 +398,7 @@ def _foo(...):
376
398
time_spent_ns = end_ns - start_ns
377
399
compilation_time_metrics [key ].append (time_spent_ns / 1e9 )
378
400
chromium_log .log_event_end (
379
- event_name , end_ns , {}, start_ns , log_pt2_compile_event
401
+ event_name , end_ns , {}, start_ns , log_pt2_compile_event , compile_id
380
402
)
381
403
if dynamo_compile_column_us :
382
404
metrics_context = get_metrics_context ()
@@ -391,6 +413,18 @@ def _foo(...):
391
413
# this way?
392
414
cumulative_time_spent_ns [event_name ] += time_spent_ns
393
415
416
+ if dynamo_compile_runtime_column_us :
417
+ get_runtime_metrics_context ().increment (
418
+ dynamo_compile_runtime_column_us ,
419
+ time_spent_ns // 1000 ,
420
+ extra = {
421
+ "compile_id" : compile_id ,
422
+ "is_runtime" : True ,
423
+ "is_forward" : is_forward ,
424
+ },
425
+ )
426
+ cumulative_time_spent_ns [event_name ] += time_spent_ns
427
+
394
428
395
429
@overload
396
430
def compile_times (repr : Literal ["str" ], aggregate : bool = False ) -> str :
@@ -858,7 +892,7 @@ class CompilationMetrics:
858
892
inductor_code_gen_cumulative_compile_time_us : Optional [int ] = None
859
893
triton_compile_time_us : Optional [int ] = None
860
894
runtime_cudagraphify_time_us : Optional [int ] = None # TODO: instrument
861
- runtime_triton_autotune_time_us : Optional [int ] = None # TODO: instrument
895
+ runtime_triton_autotune_time_us : Optional [int ] = None
862
896
dynamo_compile_time_before_restart_us : Optional [int ] = None
863
897
cuda_synchronize_time_us : Optional [int ] = None # TODO: instrument
864
898
distributed_ephemeral_timeout_us : Optional [int ] = None
@@ -882,6 +916,7 @@ class CompilationMetrics:
882
916
triton_version : Optional [str ] = None
883
917
feature_usage : Optional [dict [str , bool ]] = None
884
918
compile_time_autotune_time_us : Optional [int ] = None
919
+ is_runtime : Optional [bool ] = False
885
920
886
921
887
922
DEFAULT_COMPILATION_METRICS_LIMIT = 64
@@ -1022,8 +1057,14 @@ def safe_str(item: Any) -> str:
1022
1057
inductor_fx_remote_cache_backend_type = None
1023
1058
remote_cache_version = None
1024
1059
1060
+ # Populate the compile_id from the metrics context if it's set. Otherwise
1061
+ # look for it in the compile context.
1062
+ compile_id = metrics .get ("compile_id" )
1063
+ if not compile_id :
1064
+ compile_id = torch ._guards .CompileContext .current_compile_id ()
1065
+
1025
1066
common_metrics = {
1026
- "compile_id" : str (torch . _guards . CompileContext . current_compile_id ()) ,
1067
+ "compile_id" : str (compile_id ) if compile_id else None ,
1027
1068
"start_time_us" : start_time_ns // 1000 ,
1028
1069
"end_time_us" : end_time_ns // 1000 ,
1029
1070
"duration_us" : (end_time_ns - start_time_ns ) // 1000 ,
@@ -1066,10 +1107,12 @@ def safe_str(item: Any) -> str:
1066
1107
)
1067
1108
_compilation_metrics .append (compilation_metrics )
1068
1109
1069
- if compilation_metrics .is_forward :
1070
- name = "compilation_metrics"
1071
- else :
1072
- name = "bwd_compilation_metrics"
1110
+ name = "compilation_metrics"
1111
+ if compilation_metrics .is_forward is False :
1112
+ name = "bwd_" + name
1113
+ if compilation_metrics .is_runtime is True :
1114
+ name = name + "_runtime"
1115
+
1073
1116
torch ._logging .trace_structured (
1074
1117
name ,
1075
1118
lambda : {
@@ -1081,6 +1124,10 @@ def safe_str(item: Any) -> str:
1081
1124
# without making it inconsistent with compilation metrics itself, so
1082
1125
# we ignore the (hopefully small) time spent logging compilation metrics
1083
1126
record_logging_overhead = False ,
1127
+ # These may be runtime logs, e.g., runtime autotunning, so we provide
1128
+ # the CompileId from the compilation metrics in case it's not available
1129
+ # in the current trace.
1130
+ compile_id = compile_id ,
1084
1131
)
1085
1132
1086
1133
# If there's a chromium event in flight, add the CompilationMetrics to it.
@@ -1093,6 +1140,7 @@ def safe_str(item: Any) -> str:
1093
1140
1094
1141
# record_compilation_metrics is called by the singleton MetricsContext exit handler.
1095
1142
_METRICS_CONTEXT = MetricsContext (on_exit = record_compilation_metrics )
1143
+ _RUNTIME_METRICS_CONTEXT = RuntimeMetricsContext (on_exit = record_compilation_metrics )
1096
1144
1097
1145
1098
1146
def set_compilation_metrics_limit (new_size : int ) -> None :
@@ -1196,15 +1244,18 @@ def log_event_start(
1196
1244
time_ns : int ,
1197
1245
metadata : Dict [str , Any ],
1198
1246
log_pt2_compile_event : bool = False ,
1247
+ compile_id : Optional [CompileId ] = None ,
1199
1248
) -> None :
1200
1249
"""
1201
1250
Logs the start of a single event.
1202
1251
:param str event_name Name of event to appear in trace
1203
1252
:param time_ns Timestamp in nanoseconds
1204
1253
:param metadata: Any extra metadata associated with this event
1254
+ :param log_pt_compile_event: If True, log to pt2_compile_events
1255
+ :param compile_id: Explicit compile_id (rather than using the current context)
1205
1256
"""
1206
- compile_id = str ( torch ._guards .CompileContext .current_compile_id () )
1207
- metadata ["compile_id" ] = compile_id
1257
+ compile_id = compile_id or torch ._guards .CompileContext .current_compile_id ()
1258
+ metadata ["compile_id" ] = str ( compile_id )
1208
1259
self ._log_timed_event (
1209
1260
event_name ,
1210
1261
time_ns ,
@@ -1234,16 +1285,20 @@ def log_event_end(
1234
1285
metadata : Dict [str , Any ],
1235
1286
start_time_ns : int ,
1236
1287
log_pt2_compile_event : bool ,
1288
+ compile_id : Optional [CompileId ] = None ,
1237
1289
) -> None :
1238
1290
"""
1239
1291
Logs the end of a single event. This function should only be
1240
1292
called after log_event_start with the same event_name.
1241
1293
:param event_name: Name of event to appear in trace
1242
1294
:param time_ns: Timestamp in nanoseconds
1243
1295
:param metadata: Any extra metadata associated with this event
1296
+ :param start_time_ns: The start time timestamp in nanoseconds
1297
+ :param log_pt_compile_event: If True, log to pt2_compile_events
1298
+ :param compile_id: Explicit compile_id (rather than using the current context)
1244
1299
"""
1245
- compile_id = str ( torch ._guards .CompileContext .current_compile_id () )
1246
- metadata ["compile_id" ] = compile_id
1300
+ compile_id = compile_id or torch ._guards .CompileContext .current_compile_id ()
1301
+ metadata ["compile_id" ] = str ( compile_id )
1247
1302
1248
1303
# Grab metadata collected during event span
1249
1304
all_event_data = self .get_event_data ()
0 commit comments