Skip to content

Commit 34a8799

Browse files
authored
Merge branch 'vllm-project:main' into main
2 parents 256475a + bf3b79e commit 34a8799

File tree

117 files changed

+5471
-2585
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+5471
-2585
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
2+
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.6353
8+
- name: "exact_match,flexible-extract"
9+
value: 0.637
10+
limit: null
11+
num_fewshot: null

.buildkite/test-pipeline.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ steps:
128128
- tests/spec_decode/e2e/test_integration_dist_tp4
129129
- tests/compile
130130
- examples/offline_inference/rlhf.py
131+
- examples/offline_inference/ray_placement.py
131132
commands:
132133
- pytest -v -s distributed/test_utils.py
133134
- pytest -v -s compile/test_basic_correctness.py
@@ -136,6 +137,7 @@ steps:
136137
# TODO: create a dedicated test section for multi-GPU example tests
137138
# when we have multiple distributed example tests
138139
- python3 ../examples/offline_inference/rlhf.py
140+
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py
139141

140142
- label: Metrics, Tracing Test # 10min
141143
num_gpus: 2

.github/workflows/reminder_comment.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ name: PR Reminder Comment Bot
22
on:
33
pull_request_target:
44
types: [opened]
5-
65
jobs:
76
pr_reminder:
87
runs-on: ubuntu-latest
@@ -15,7 +14,12 @@ jobs:
1514
owner: context.repo.owner,
1615
repo: context.repo.repo,
1716
issue_number: context.issue.number,
18-
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
17+
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
18+
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
19+
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
20+
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
21+
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
22+
'🚀'
1923
})
2024
env:
2125
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

csrc/cache.h

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
1515
std::vector<torch::Tensor> const& value_caches,
1616
const torch::Tensor& block_mapping);
1717

18+
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
19+
const torch::Tensor& block_mapping);
20+
1821
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
1922
torch::Tensor& key_cache, torch::Tensor& value_cache,
2023
torch::Tensor& slot_mapping,

csrc/cache_kernels.cu

+70-12
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
4646
char* src_ptr = static_cast<char*>(src.data_ptr());
4747
char* dst_ptr = static_cast<char*>(dst.data_ptr());
4848

49-
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
49+
// We use the stride instead of numel in case the cache is padded for memory
50+
// alignment reasons, we assume the blocks data (inclusive of any padding)
51+
// is contiguous in memory
52+
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
5053
const at::cuda::OptionalCUDAGuard device_guard(
5154
src_device.is_cuda() ? src_device : dst_device);
5255
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
9396
}
9497
}
9598

99+
// Kernel for MLA, which works on a single joint kv_cache
100+
// Grid: (num_layers, num_pairs)
101+
template <typename scalar_t>
102+
__global__ void copy_blocks_mla_kernel(
103+
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
104+
const int mem_footprint_per_block) {
105+
const int layer_idx = blockIdx.x;
106+
const int pair_idx = blockIdx.y;
107+
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
108+
int64_t src_block = block_mapping[2 * pair_idx];
109+
int64_t dst_block = block_mapping[2 * pair_idx + 1];
110+
int64_t src_offset = src_block * mem_footprint_per_block;
111+
int64_t dst_offset = dst_block * mem_footprint_per_block;
112+
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
113+
cache[dst_offset + i] = cache[src_offset + i];
114+
}
115+
}
116+
96117
} // namespace vllm
97118

98119
// Note: the key_caches and value_caches vectors are constant but
@@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
147168
}));
148169
}
149170

171+
// copy blocks kernel for MLA (assumes a joint KV-cache)
172+
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
173+
const torch::Tensor& block_mapping) {
174+
int num_layers = kv_caches.size();
175+
if (num_layers == 0) {
176+
return;
177+
}
178+
torch::Device cache_device = kv_caches[0].device();
179+
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
180+
181+
std::vector<int64_t> cache_ptrs(num_layers);
182+
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
183+
cache_ptrs[layer_idx] =
184+
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
185+
}
186+
torch::Tensor cache_ptrs_tensor =
187+
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
188+
.to(cache_device);
189+
190+
int num_pairs = block_mapping.size(0);
191+
// We use the stride instead of numel in case the cache is padded for memory
192+
// alignment reasons, we assume the blocks data (inclusive of any padding)
193+
// is contiguous in memory
194+
int mem_footprint_per_block = kv_caches[0].stride(0);
195+
dim3 grid(num_layers, num_pairs);
196+
dim3 block(std::min(1024, mem_footprint_per_block));
197+
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
198+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
199+
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
200+
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
201+
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
202+
cache_ptrs_tensor.data_ptr<int64_t>(),
203+
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
204+
}));
205+
}
206+
150207
namespace vllm {
151208

152209
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
@@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
254311
// + pe_dim)]
255312
const int64_t* __restrict__ slot_mapping, // [num_tokens]
256313
const int block_stride, //
314+
const int entry_stride, //
257315
const int kv_c_stride, //
258316
const int k_pe_stride, //
259317
const int kv_lora_rank, //
@@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
274332
int src_stride, int dst_stride, int size, int offset) {
275333
for (int i = threadIdx.x; i < size; i += blockDim.x) {
276334
const int64_t src_idx = token_idx * src_stride + i;
277-
const int64_t dst_idx = block_idx * block_stride +
278-
block_offset * (kv_lora_rank + pe_dim) + i +
279-
offset;
335+
const int64_t dst_idx =
336+
block_idx * block_stride + block_offset * entry_stride + i + offset;
280337
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
281338
dst[dst_idx] = src[src_idx];
282339
} else {
@@ -391,14 +448,14 @@ void reshape_and_cache_flash(
391448
// KV_T is the stored data type of kv-cache.
392449
// CACHE_T is the data type of key and value tensors.
393450
// KV_DTYPE is the real data type of kv-cache.
394-
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
395-
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
396-
<<<grid, block, 0, stream>>>( \
397-
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
398-
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
399-
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
400-
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
401-
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
451+
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
452+
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
453+
<<<grid, block, 0, stream>>>( \
454+
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
455+
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
456+
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
457+
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
458+
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
402459
reinterpret_cast<const float*>(scale.data_ptr()));
403460

404461
void concat_and_cache_mla(
@@ -428,6 +485,7 @@ void concat_and_cache_mla(
428485
int kv_c_stride = kv_c.stride(0);
429486
int k_pe_stride = k_pe.stride(0);
430487
int block_stride = kv_cache.stride(0);
488+
int entry_stride = kv_cache.stride(1);
431489

432490
dim3 grid(num_tokens);
433491
dim3 block(std::min(kv_lora_rank, 512));

csrc/torch_bindings.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
450450
"Tensor block_mapping) -> ()");
451451
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
452452

453+
cache_ops.def(
454+
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
455+
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);
456+
453457
// Reshape the key and value tensors and cache them.
454458
cache_ops.def(
455459
"reshape_and_cache(Tensor key, Tensor value,"

docs/source/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
# ones.
3838
extensions = [
3939
"sphinx.ext.napoleon",
40-
"sphinx.ext.viewcode",
4140
"sphinx.ext.linkcode",
4241
"sphinx.ext.intersphinx",
4342
"sphinx_copybutton",

docs/source/contributing/model/multimodal.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,11 @@ def get_max_image_tokens(self) -> int:
250250
And thus, we can override the method as:
251251

252252
```python
253-
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
253+
def get_mm_max_tokens_per_item(
254+
self,
255+
seq_len: int,
256+
mm_counts: Mapping[str, int],
257+
) -> Mapping[str, int]:
254258
return {"image": self.get_max_image_tokens()}
255259
```
256260

docs/source/features/quantization/auto_awq.md

-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@
22

33
# AutoAWQ
44

5-
:::{warning}
6-
Please note that AWQ support in vLLM is under-optimized at the moment. We would recommend using the unquantized version of the model for better
7-
accuracy and higher throughput. Currently, you can use AWQ as a way to reduce memory footprint. As of now, it is more suitable for low latency
8-
inference with small number of concurrent requests. vLLM's AWQ implementation have lower throughput than unquantized version.
9-
:::
10-
115
To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ).
126
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
137
The main benefits are lower latency and memory usage.

docs/source/features/spec_decode.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
131131
llm = LLM(
132132
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
133133
tensor_parallel_size=4,
134-
speculative_model="ibm-fms/llama3-70b-accelerator",
134+
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
135135
speculative_draft_tensor_parallel_size=1,
136136
)
137137
outputs = llm.generate(prompts, sampling_params)
@@ -149,11 +149,11 @@ limitation will be fixed in a future release.
149149

150150
A variety of speculative models of this type are available on HF hub:
151151

152-
- [llama-13b-accelerator](https://huggingface.co/ibm-fms/llama-13b-accelerator)
153-
- [llama3-8b-accelerator](https://huggingface.co/ibm-fms/llama3-8b-accelerator)
154-
- [codellama-34b-accelerator](https://huggingface.co/ibm-fms/codellama-34b-accelerator)
155-
- [llama2-70b-accelerator](https://huggingface.co/ibm-fms/llama2-70b-accelerator)
156-
- [llama3-70b-accelerator](https://huggingface.co/ibm-fms/llama3-70b-accelerator)
152+
- [llama-13b-accelerator](https://huggingface.co/ibm-ai-platform/llama-13b-accelerator)
153+
- [llama3-8b-accelerator](https://huggingface.co/ibm-ai-platform/llama3-8b-accelerator)
154+
- [codellama-34b-accelerator](https://huggingface.co/ibm-ai-platform/codellama-34b-accelerator)
155+
- [llama2-70b-accelerator](https://huggingface.co/ibm-ai-platform/llama2-70b-accelerator)
156+
- [llama3-70b-accelerator](https://huggingface.co/ibm-ai-platform/llama3-70b-accelerator)
157157
- [granite-3b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-3b-code-instruct-accelerator)
158158
- [granite-8b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-8b-code-instruct-accelerator)
159159
- [granite-7b-instruct-accelerator](https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator)

docs/source/models/supported_models.md

+19-4
Original file line numberDiff line numberDiff line change
@@ -726,14 +726,14 @@ See [this page](#generative-models) for more information on how to use generativ
726726
* `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc.
727727
*
728728
* ✅︎
729-
*
729+
* \*
730730
- * `Idefics3ForConditionalGeneration`
731731
* Idefics3
732732
* T + I
733733
* `HuggingFaceM4/Idefics3-8B-Llama3` etc.
734734
* ✅︎
735735
*
736-
*
736+
* ✅︎
737737
- * `InternVLChatModel`
738738
* InternVL 2.5, Mono-InternVL, InternVL 2.0
739739
* T + I<sup>E+</sup>
@@ -799,7 +799,7 @@ See [this page](#generative-models) for more information on how to use generativ
799799
* ✅︎
800800
- * `NVLM_D_Model`
801801
* NVLM-D 1.0
802-
* T + I<sup>E+</sup>
802+
* T + I<sup>+</sup>
803803
* `nvidia/NVLM-D-72B`, etc.
804804
*
805805
* ✅︎
@@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
846846
* ✅︎
847847
* ✅︎
848848
* ✅︎
849+
- * `Qwen2_5_VLForConditionalGeneration`
850+
* Qwen2.5-VL
851+
* T + I<sup>E+</sup> + V<sup>E+</sup>
852+
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
853+
*
854+
* ✅︎
855+
* ✅︎
849856
- * `UltravoxModel`
850857
* Ultravox
851858
* T + A<sup>E+</sup>
@@ -859,7 +866,11 @@ See [this page](#generative-models) for more information on how to use generativ
859866
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
860867

861868
:::{note}
862-
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
869+
To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
870+
:::
871+
872+
:::{note}
873+
H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
863874
:::
864875

865876
:::{note}
@@ -876,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
876887
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
877888
:::
878889

890+
:::{note}
891+
To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
892+
:::
893+
879894
### Pooling Models
880895

881896
See [this page](pooling-models) for more information on how to use pooling models.

examples/offline_inference/mlpspeculator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def time_generation(llm: LLM, prompts: List[str],
5151
# Create an LLM with spec decoding
5252
llm = LLM(
5353
model="meta-llama/Llama-2-13b-chat-hf",
54-
speculative_model="ibm-fms/llama-13b-accelerator",
54+
speculative_model="ibm-ai-platform/llama-13b-accelerator",
5555
)
5656

5757
print("With speculation")

0 commit comments

Comments
 (0)