Skip to content

Commit 54baddb

Browse files
committed
[Feat] Support vLLM deployment on DCUs (#4710)
* Support vLLM deployment on DCUs * Fix * Fix DCU check
1 parent a88b267 commit 54baddb

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

paddlex/inference/genai/backends/vllm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ....utils import logging
1516
from ....utils.deps import is_genai_engine_plugin_available, require_genai_engine_plugin
1617
from ..configs.utils import (
1718
backend_config_to_args,
@@ -61,6 +62,16 @@ def run_vllm_server(host, port, model_name, model_dir, config, chat_template_pat
6162
},
6263
)
6364

65+
import torch
66+
67+
if torch.version.hip is not None and torch.version.cuda is None:
68+
# For DCU
69+
if "api-server-count" in config:
70+
logging.warning(
71+
"Key 'api-server-count' will be popped as it is not supported"
72+
)
73+
config.pop("api-server-count")
74+
6475
args = backend_config_to_args(config)
6576
args = parser.parse_args(args)
6677
validate_parsed_serve_args(args)

paddlex/utils/env.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,18 @@ def is_cuda_available():
6565
import paddle.device
6666

6767
# TODO: Check runtime availability
68-
return paddle.device.is_compiled_with_cuda()
68+
return (
69+
paddle.device.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
70+
)
6971
else:
7072
# If Paddle is unavailable, check GPU availability using PyTorch API.
7173
require_deps("torch")
74+
7275
import torch.cuda
76+
import torch.version
7377

74-
return torch.cuda.is_available()
78+
# Distinguish GPUs and DCUs by checking `torch.version.cuda`
79+
return torch.cuda.is_available() and torch.version.cuda
7580

7681

7782
def get_gpu_compute_capability():
@@ -85,6 +90,7 @@ def get_gpu_compute_capability():
8590
else:
8691
# If Paddle is unavailable, retrieve GPU compute capability from PyTorch instead.
8792
require_deps("torch")
93+
8894
import torch.cuda
8995

9096
cap = torch.cuda.get_device_capability()

0 commit comments

Comments
 (0)