|
35 | 35 |
|
36 | 36 | # TODO: Weight initialization |
37 | 37 |
|
| 38 | +import platform |
38 | 39 | from typing import List, Optional, Tuple, Union |
39 | 40 |
|
40 | 41 | import numpy as np |
41 | 42 | import paddle |
42 | 43 | import paddle.nn as nn |
43 | 44 | import paddle.nn.functional as F |
44 | 45 |
|
45 | | -from ......utils.env import get_gpu_compute_capability |
| 46 | +from ......utils.env import ( |
| 47 | + get_device_type, |
| 48 | + get_gpu_compute_capability, |
| 49 | + get_paddle_cuda_version, |
| 50 | +) |
46 | 51 | from ....common.vlm.activations import ACT2FN |
47 | 52 | from ....common.vlm.transformers import PretrainedModel |
48 | 53 | from ....common.vlm.transformers.model_outputs import ( |
@@ -139,7 +144,18 @@ def __init__(self, config): |
139 | 144 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
140 | 145 |
|
141 | 146 | cap = get_gpu_compute_capability() |
142 | | - self._supports_sdpa = cap >= (8, 0) if cap is not None else False |
| 147 | + cuda_ver = get_paddle_cuda_version() |
| 148 | + self._supports_sdpa = False |
| 149 | + if ( |
| 150 | + cap is not None |
| 151 | + and cap >= (8, 0) |
| 152 | + and cuda_ver is not None |
| 153 | + and cuda_ver >= (11, 4) |
| 154 | + and platform.system() == "Linux" |
| 155 | + ): |
| 156 | + self._supports_sdpa = True |
| 157 | + if get_device_type() == "iluvatar_gpu": |
| 158 | + self._supports_sdpa = True |
143 | 159 |
|
144 | 160 | def forward( |
145 | 161 | self, |
|
0 commit comments