Skip to content

Commit 3fe80eb

Browse files
committed
[Fix] Refine Checks for SDPA Availability (#4820)
* Do not use SDPA for SM120 GPUs * Fix
1 parent 9cc0285 commit 3fe80eb

File tree

1 file changed

+18
-2
lines changed
  • paddlex/inference/models/doc_vlm/modeling/paddleocr_vl

1 file changed

+18
-2
lines changed

paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_siglip.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@
3535

3636
# TODO: Weight initialization
3737

38+
import platform
3839
from typing import List, Optional, Tuple, Union
3940

4041
import numpy as np
4142
import paddle
4243
import paddle.nn as nn
4344
import paddle.nn.functional as F
4445

45-
from ......utils.env import get_gpu_compute_capability
46+
from ......utils.env import (
47+
get_device_type,
48+
get_gpu_compute_capability,
49+
get_paddle_cuda_version,
50+
)
4651
from ....common.vlm.activations import ACT2FN
4752
from ....common.vlm.transformers import PretrainedModel
4853
from ....common.vlm.transformers.model_outputs import (
@@ -139,7 +144,18 @@ def __init__(self, config):
139144
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
140145

141146
cap = get_gpu_compute_capability()
142-
self._supports_sdpa = cap >= (8, 0) if cap is not None else False
147+
cuda_ver = get_paddle_cuda_version()
148+
self._supports_sdpa = False
149+
if (
150+
cap is not None
151+
and cap >= (8, 0)
152+
and cuda_ver is not None
153+
and cuda_ver >= (11, 4)
154+
and platform.system() == "Linux"
155+
):
156+
self._supports_sdpa = True
157+
if get_device_type() == "iluvatar_gpu":
158+
self._supports_sdpa = True
143159

144160
def forward(
145161
self,

0 commit comments

Comments
 (0)