Skip to content

Commit bfdcfb3

Browse files
authored
[GENERAL] add device index in vllmGetQueue (vllm-project#61)
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
1 parent e81f006 commit bfdcfb3

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

csrc/flash_attn/flash_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
2121
bool is_causal, int window_size_left, int window_size_right,
2222
const float softcap, const bool return_softmax,
2323
std::optional<at::Generator> gen_) {
24-
auto& queue = vllm::xpu::vllmGetQueue();
24+
auto& queue = vllm::xpu::vllmGetQueue(q.device().index());
2525

2626
at::Tensor out;
2727
if (out_.has_value()) {

csrc/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
namespace vllm {
88
namespace xpu {
99

10-
static inline sycl::queue& vllmGetQueue() {
11-
auto current_stream = c10::xpu::getCurrentXPUStream();
10+
static inline sycl::queue& vllmGetQueue(at::DeviceIndex device_index = -1) {
11+
auto current_stream = c10::xpu::getCurrentXPUStream(device_index);
1212
auto& queue = current_stream.queue();
1313
return queue;
1414
}

0 commit comments

Comments
 (0)