Skip to content

Commit 8c81047

Browse files
committed
Using EventGuard in decorator form
1 parent 26af0ee commit 8c81047

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import TYPE_CHECKING
2121

2222
import numpy as np
23+
from sot.profiler import event_register
2324

2425
import paddle
2526
import paddle.pir.core as ir_static
@@ -779,6 +780,7 @@ def __call__(self, inputs):
779780
restored_nest_out = self._restore_out(out)
780781
return self._remove_no_value(restored_nest_out)
781782

783+
@event_register("sot call partial_program")
782784
def sot_call(self, inputs):
783785
"""
784786
In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up

python/paddle/jit/sot/symbolic/compile_cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import paddle
2121

2222
from ..infer_meta import convert_meta_to_input_spec
23-
from ..profiler import EventGuard
23+
from ..profiler import EventGuard, event_register
2424
from ..utils import (
2525
ENV_SOT_EXPORT,
2626
Cache,
@@ -210,7 +210,8 @@ def update_compile_time_info(self, SIR, partial_program_layer):
210210
] += partial_program_layer._compile_time_counter.get_total_time()
211211

212212
def __call__(self, *args, **kwargs):
213-
with EventGuard(f"FallbackWrapper: {self.SIR.name}"):
213+
@event_register(f"FallbackWrapper: {self.SIR.name}")
214+
def call_fn():
214215
if StepInfoManager().need_back_trace:
215216
trace_back_frames()
216217

@@ -233,8 +234,7 @@ def __call__(self, *args, **kwargs):
233234
self.partial_program,
234235
) = self.compiled_fn.get_concrete_program(*args, **kwargs)
235236
self.partial_program.training = self.is_training
236-
with EventGuard("FallbackWrapper: sot call partial_program"):
237-
outputs = self.partial_program.sot_call(*args, **kwargs)
237+
outputs = self.partial_program.sot_call(*args, **kwargs)
238238

239239
clear_eager_tensor_name(outputs)
240240
log_do(
@@ -252,6 +252,8 @@ def __call__(self, *args, **kwargs):
252252
self.is_first_call = False
253253
return outputs
254254

255+
return call_fn()
256+
255257

256258
class CompileSIRCache(Cache, metaclass=Singleton):
257259
"""

0 commit comments

Comments
 (0)