Skip to content

Commit 953b521

Browse files
committed
add _check_triton_available
1 parent f6b277c commit 953b521

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

paddleformers/transformers/paddleocr_vl/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
9292
return q_embed, k_embed
9393

9494

95-
@dispatch_to(apply_rotary_pos_emb_vision_triton, cond=lambda *arg: True) # TODO: update this condtion function
95+
@dispatch_to(apply_rotary_pos_emb_vision_triton, cond=apply_rotary_pos_emb_vision_triton.is_available)
9696
def apply_rotary_pos_emb_vision(q, k, cos, sin):
9797
"""Applies Rotary Position Embedding to the query and key tensors."""
9898
orig_q_dtype = q.dtype

paddleformers/triton_kernels/rope_triton.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,34 @@ def apply_rotary_pos_emb_vision(q, k, cos, sin):
290290
q_out, k_out: Same shape and dtype as input
291291
"""
292292
return ApplyRotaryPosEmbVision.apply(q, k, cos, sin)
293+
294+
295+
def _run_once(fn):
296+
"""Decorator that caches the result of a function after first call."""
297+
result = None
298+
has_run = False
299+
300+
def wrapper(*args, **kwargs):
301+
nonlocal result, has_run
302+
if not has_run:
303+
result = fn(*args, **kwargs)
304+
has_run = True
305+
return result
306+
307+
return wrapper
308+
309+
310+
@_run_once
311+
def _check_triton_available(*args, **kwargs):
312+
"""Check if triton is available and version >= 3.0.0"""
313+
try:
314+
import triton
315+
316+
version = getattr(triton, "__version__")
317+
major = int(version.split(".")[0])
318+
return major >= 3
319+
except ImportError:
320+
return False
321+
322+
323+
apply_rotary_pos_emb_vision.is_available = _check_triton_available

0 commit comments

Comments
 (0)