27
27
DEFAULT_PROFILER_ACTIVITIES = {
28
28
torch .profiler .ProfilerActivity .CPU ,
29
29
torch .profiler .ProfilerActivity .CUDA ,
30
+ torch .profiler .ProfilerActivity .XPU ,
30
31
}
31
32
32
33
DEFAULT_SCHEDULE : dict = {
@@ -111,7 +112,7 @@ def trace_handler(
111
112
log .info (f"Finished dumping traces in { time .monotonic () - begin :.2f} seconds" )
112
113
113
114
# Memory timeline sometimes fails to export
114
- if prof .profile_memory :
115
+ if prof .profile_memory and torch . cuda . is_available () :
115
116
if rank == 0 :
116
117
try :
117
118
prof .export_memory_timeline (
@@ -185,6 +186,7 @@ def setup_torch_profiler(
185
186
enabled : bool = False ,
186
187
cpu : bool = True ,
187
188
cuda : bool = True ,
189
+ xpu : bool = True ,
188
190
profile_memory : bool = DEFAULT_TRACE_OPTS ["profile_memory" ],
189
191
with_stack : bool = DEFAULT_TRACE_OPTS ["with_stack" ],
190
192
record_shapes : bool = DEFAULT_TRACE_OPTS ["record_shapes" ],
@@ -252,6 +254,7 @@ def setup_torch_profiler(
252
254
enabled (bool): Enable pytorch profiler. Default is False.
253
255
cpu (bool): Enable cpu profiling. Default is True.
254
256
cuda (bool): Enable cuda profiling. Default is True.
257
+ xpu (bool): Enable xpu profiling. Default is True.
255
258
profile_memory (bool): Profile memory usage. Default is False.
256
259
with_stack (bool): Profile stack. Default is False.
257
260
record_shapes (bool): Record shapes. Default is True.
@@ -276,10 +279,12 @@ def setup_torch_profiler(
276
279
activities .append (torch .profiler .ProfilerActivity .CPU )
277
280
if cuda :
278
281
activities .append (torch .profiler .ProfilerActivity .CUDA )
282
+ if xpu :
283
+ activities .append (torch .profiler .ProfilerActivity .XPU )
279
284
if len (activities ) == 0 :
280
285
_warn ("No activities specified, defaulting to CPU + CUDA" )
281
286
activities = DEFAULT_PROFILER_ACTIVITIES
282
- cpu = cuda = True
287
+ cpu = cuda = xpu = True
283
288
284
289
# Check for schedule
285
290
# 1) If no schedule is provided, set to DEFAULT_SCHEDULE
@@ -372,6 +377,7 @@ def setup_torch_profiler(
372
377
"output_dir" : output_dir ,
373
378
"cpu" : cpu ,
374
379
"cuda" : cuda ,
380
+ "xpu" : xpu ,
375
381
"profile_memory" : profile_memory ,
376
382
"with_stack" : with_stack ,
377
383
"record_shapes" : record_shapes ,
0 commit comments