diff --git a/hta/common/trace_symbol_table.py b/hta/common/trace_symbol_table.py index 6e97de8..947e722 100644 --- a/hta/common/trace_symbol_table.py +++ b/hta/common/trace_symbol_table.py @@ -311,10 +311,13 @@ def get_operator_or_cuda_runtime_mask(self, df: pd.DataFrame) -> pd.Series: cpu_op_id = self.sym_index.get("cpu_op") cuda_runtime_id = self.sym_index.get("cuda_driver", self.NULL) cuda_driver_id = self.sym_index.get("cuda_runtime", self.NULL) + xpu_runtime_id = self.sym_index.get("xpu_runtime", self.NULL) + return ( (df["cat"] == cpu_op_id) | (df["cat"] == cuda_runtime_id) | (df["cat"] == cuda_driver_id) + | (df["cat"] == xpu_runtime_id) ) def get_runtime_launch_events_mask(self, df: pd.DataFrame) -> pd.Series: