diff --git a/.buildkite/lm-eval-harness/conftest.py b/.buildkite/lm-eval-harness/conftest.py new file mode 100644 index 00000000000..a0bcc993ed4 --- /dev/null +++ b/.buildkite/lm-eval-harness/conftest.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--config-list-file", + action="store", + help="Path to the file listing model config YAMLs (one per line)") + parser.addoption("--tp-size", + action="store", + default="1", + help="Tensor parallel size to use for evaluation") + + +@pytest.fixture(scope="session") +def config_list_file(pytestconfig, config_dir): + rel_path = pytestconfig.getoption("--config-list-file") + return config_dir / rel_path + + +@pytest.fixture(scope="session") +def tp_size(pytestconfig): + return pytestconfig.getoption("--tp-size") + + +def pytest_generate_tests(metafunc): + if "config_filename" in metafunc.fixturenames: + rel_path = metafunc.config.getoption("--config-list-file") + config_list_file = Path(rel_path).resolve() + config_dir = config_list_file.parent + with open(config_list_file, encoding="utf-8") as f: + configs = [ + config_dir / line.strip() for line in f + if line.strip() and not line.startswith("#") + ] + metafunc.parametrize("config_filename", configs) diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh deleted file mode 100644 index 26f33b74428..00000000000 --- a/.buildkite/lm-eval-harness/run-tests.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -usage() { - echo`` - echo "Runs lm eval harness on GSM8k using vllm and compares to " - echo "precomputed baseline (measured by HF transformers.)" - echo - echo "usage: ${0} " - echo - echo " -c - path to the test data config (e.g. configs/small-models.txt)" - echo " -t - tensor parallel size" - echo -} - -SUCCESS=0 - -while getopts "c:t:" OPT; do - case ${OPT} in - c ) - CONFIG="$OPTARG" - ;; - t ) - TP_SIZE="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -# Parse list of configs. -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG" - -for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" -do - LOCAL_SUCCESS=0 - - echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE===" - - export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG} - export LM_EVAL_TP_SIZE=$TP_SIZE - pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$? - - if [[ $LOCAL_SUCCESS == 0 ]]; then - echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" - else - echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" - fi - - SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) - -done - -if [ "${SUCCESS}" -eq "0" ]; then - exit 0 -else - exit 1 -fi diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 6015a83e829..c5411daf0df 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -3,35 +3,25 @@ LM eval harness on model to compare vs HF baseline computed offline. Configs are found in configs/$MODEL.yaml -* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml -* export LM_EVAL_TP_SIZE=4 -* pytest -s test_lm_eval_correctness.py +pytest -s -v test_lm_eval_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 """ -import os -from pathlib import Path - import lm_eval -import numpy -import pytest +import numpy as np import yaml RTOL = 0.08 -TEST_DATA_FILE = os.environ.get( - "LM_EVAL_TEST_DATA_FILE", - ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") - -TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) -def launch_lm_eval(eval_config): +def launch_lm_eval(eval_config, tp_size): trust_remote_code = eval_config.get('trust_remote_code', False) - model_args = f"pretrained={eval_config['model_name']}," \ - f"tensor_parallel_size={TP_SIZE}," \ + f"tensor_parallel_size={tp_size}," \ + f"enforce_eager=true," \ f"add_bos_token=true," \ f"trust_remote_code={trust_remote_code}" - results = lm_eval.simple_evaluate( model="vllm", model_args=model_args, @@ -39,22 +29,14 @@ def launch_lm_eval(eval_config): num_fewshot=eval_config["num_fewshot"], limit=eval_config["limit"], batch_size="auto") - return results -def test_lm_eval_correctness(): - eval_config = yaml.safe_load( - Path(TEST_DATA_FILE).read_text(encoding="utf-8")) - - if eval_config[ - "model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501 - pytest.skip("FBGEMM is currently failing on main.") +def test_lm_eval_correctness_param(config_filename, tp_size): + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) - # Launch eval requests. - results = launch_lm_eval(eval_config) + results = launch_lm_eval(eval_config, tp_size) - # Confirm scores match ground truth. success = True for task in eval_config["tasks"]: for metric in task["metrics"]: @@ -62,8 +44,7 @@ def test_lm_eval_correctness(): measured_value = results["results"][task["name"]][metric["name"]] print(f'{task["name"]} | {metric["name"]}: ' f'ground_truth={ground_truth} | measured={measured_value}') - success = success and numpy.isclose( + success = success and np.isclose( ground_truth, measured_value, rtol=RTOL) - # Assert at the end, print all scores even on failure for debugging. assert success diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 21982b01b9c..07b898787eb 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -47,7 +47,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_10 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ && echo TEST_11 \ - && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ + && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \ + && echo TEST_12 \ + && pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b3005b1b4b0..01d04759f53 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -408,7 +408,7 @@ steps: - vllm/model_executor/layers/quantization commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-small.txt -t 1 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness source_file_dependencies: @@ -713,4 +713,4 @@ steps: - vllm/model_executor/layers/quantization commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index b81c2f8192d..98d3360cd6f 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -887,6 +887,94 @@ def sample(self, return sampled_requests +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + # ----------------------------------------------------------------------------- # ASR Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index c236d64261d..89fb0e1df03 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -53,8 +53,9 @@ from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, InstructCoderDataset, MTBenchDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) + NextEditPredictionDataset, RandomDataset, + SampleRequest, ShareGPTDataset, SonnetDataset, + VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -603,6 +604,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: dataset_class = ASRDataset args.hf_split = "train" diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 8a59e884d6c..74bb014cf39 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -9,7 +9,8 @@ void rotary_embedding_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -85,10 +86,13 @@ void rotary_embedding_impl( compute_loop(token_head, cache_ptr, query); } - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - compute_loop(token_head, cache_ptr, key); + if (key != nullptr) { + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); + } } } } @@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl( } } + if (key == nullptr) { + return; + } + #pragma omp parallel for collapse(2) for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { @@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = positions.numel(); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t key_stride = key.stride(-2); + int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads; + int64_t key_stride = key.has_value() ? key->stride(-2) : 0; int64_t query_stride = query.stride(-2); VLLM_DISPATCH_FLOATING_TYPES( @@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, if (is_neox) { rotary_embedding_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } else { rotary_embedding_gptj_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7ae7e3386b4..84b2a8555cc 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index fb6882f3e7c..d073dd6d2de 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h deleted file mode 100644 index a217401b3d7..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ /dev/null @@ -1,1616 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/scalar_type.hpp" - -namespace marlin_moe { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales -using FragZP = Vec; - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { - half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - // Apply zero-point to frag_b0 - if constexpr (has_zp) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - -#else - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - -#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu deleted file mode 100644 index 77bc0dd90ed..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku4.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h deleted file mode 100644 index 833fadf3772..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu deleted file mode 100644 index f7e57b03759..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku4b8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = false; - - if (false) { - } - GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h deleted file mode 100644 index 494da8f10e2..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu deleted file mode 100644 index a901f0b11cd..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8b128.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = false; - - if (false) { - } - GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h deleted file mode 100644 index f3018aa0c1a..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu deleted file mode 100644 index 5f12483e951..00000000000 --- a/csrc/moe/marlin_moe_ops.cu +++ /dev/null @@ -1,588 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/exception.hpp" -#include "core/scalar_type.hpp" -#include "core/registration.h" -#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" -#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" -#include "marlin_kernels/marlin_moe_kernel_ku4.h" - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_moe { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / blockDim.x; - int rest = size_k % blockDim.x; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += blockDim.x; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - int expert_id = threadIdx.x; - int num_experts = blockDim.x; - - int occurrences = 0; - for (int i = 0; i < topk_length; ++i) { - occurrences += (topk_ids[i] == expert_id); - } - expert_offsets[expert_id + 1] = occurrences; - __syncthreads(); - - if (threadIdx.x == 0) { - int tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; - expert_offsets[i + 1] = tot_offset; - } - } - __syncthreads(); -} - -#else - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N - {64, 64, 128}, // Reduce both 2X -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X - {64, 64, 128}, // Reduce N 4X, same K -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - - int tb_n = th_config.thread_n; - int tb_k = th_config.thread_k; - - // Get max scale groups per thread-block - int tb_groups; - if (group_size == -1) { - tb_groups = 1; - } else if (group_size == 0) { - tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size - } else { - tb_groups = ceildiv(tb_k, group_size); - } - - if (cache_scales_chunk) { - int load_groups = - tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 4; - - } else { - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * STAGES; - } -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = ceildiv(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * STAGES; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - -#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ - else if (KERNEL_FUNCTION( \ - q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ - group_blocks, num_threads, blocks, max_shared_mem, stream, \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks)) { \ - } - -void marlin_mm_moe(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, void* zp, - const void* g_idx, const void* perm, void* a_tmp, - void* expert_offsets, int prob_m, int prob_n, int prob_k, - void* workspace, vllm::ScalarType const& q_type, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - int num_bits = q_type.size_bits(); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - int tot_m = prob_m; - - const int* topk_ids_ptr = (const int*)topk_ids; - int* expert_offsets_ptr = (int*)expert_offsets; - compute_expert_offsets<<<1, num_experts, 0, stream>>>( - topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); - - bool do_permute_a = has_act_order; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - int pack_factor = 32 / q_type.size_bits(); - - for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - const int4* A_ptr = (const int4*)A; - int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = - (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; - int4* C_ptr = (int4*)C; - const float* topk_weights_ptr = (const float*)topk_weights; - const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; - const int4* zp_ptr = - (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; - const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; - const int* perm_ptr = (const int*)perm + prob_k * expert_idx; - int* locks = (int*)workspace; - - if (do_permute_a) { - // Permute A columns - int topk_rows = replicate_input ? tot_m : tot_m * topk; - int block_rows = ceildiv(topk_rows, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - int tot_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < tot_m_blocks; - m_block += 4 * exec_cfg.max_m_blocks) { - if (false) { - } - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - } - } -} - -} // namespace marlin_moe - -torch::Tensor marlin_gemm_moe( - const torch::Tensor& a, const torch::Tensor& b_q_weights, - const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, - const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - torch::Tensor& b_zeros, const torch::Tensor& g_idx, - const torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, - int64_t moe_block_size, bool replicate_input, bool apply_weights) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - bool has_zp = b_zeros.size(1) != 0; - if (has_zp) { - TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); - } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); - } - - int pack_factor = 32 / b_q_type.size_bits(); - - int max_par = 4; - - int dev = a.get_device(); - - auto options_dtype = - torch::TensorOptions().dtype(a.dtype()).device(a.device()); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); - torch::Tensor a_tmp = - replicate_input ? torch::zeros({size_m, size_k}, options_dtype) - : torch::zeros({size_m, topk, size_k}, options_dtype); - torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(1) != 0; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); - TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), - " is not size_n = ", size_n); - num_groups = b_scales.size(1); - - TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), - "if is_k_full is false, has_act_order must be true"); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - // Verify b_zeros - if (has_zp) { - int rank = b_zeros.sizes().size(); - TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); - TORCH_CHECK(b_zeros.size(1) == num_groups, - "b_zeros dim 1 = ", b_zeros.size(1), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, - "b_zeros dim 2 = ", b_zeros.size(2), - " is not size_n / pack_factor = ", size_n / pack_factor); - } - - marlin_moe::marlin_mm_moe( - a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), - topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, - num_experts, topk, moe_block_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, - replicate_input, apply_weights); - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm_moe", &marlin_gemm_moe); -} diff --git a/csrc/ops.h b/csrc/ops.h index 59ae0937604..1dfd2e067e8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional residual); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox, - int64_t rot_dim, + std::optional key, + int64_t head_size, torch::Tensor& cos_sin_cache, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -178,6 +178,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, torch::Tensor num_tokens_post_padded, int64_t type, int64_t row, int64_t top_k, int64_t tokens); +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens); + int64_t ggml_moe_get_block_size(int64_t type); #ifndef USE_ROCM diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index c085d31a3e9..ef6dd1c0978 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, @@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( - key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } } } @@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -127,10 +132,12 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { @@ -138,40 +145,40 @@ void rotary_embedding( int64_t num_tokens = positions.numel(); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key // hidden_size = num_heads * head_size int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have consistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int rot_dim = cos_sin_cache.size(1); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -181,15 +188,16 @@ void rotary_embedding( if (is_neox) { vllm::rotary_embedding_kernel<<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, + num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); } }); } @@ -204,10 +212,12 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional + key, // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, @@ -221,38 +231,38 @@ void batched_rotary_embedding( "cos_sin_cache_offsets"); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have concistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -263,14 +273,16 @@ void batched_rotary_embedding( vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); } diff --git a/csrc/quantization/fp8/fp8_marlin.cu b/csrc/quantization/fp8/fp8_marlin.cu deleted file mode 100644 index 376bbd498ca..00000000000 --- a/csrc/quantization/fp8/fp8_marlin.cu +++ /dev/null @@ -1,1311 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "../gptq_marlin/marlin.cuh" -#include "../gptq_marlin/marlin_dtypes.cuh" - -#include "core/registration.h" - -using namespace marlin; - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace fp8_marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to BF16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to bfloat162 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - int slice_k_start = tb_k * slice_row; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We scale a `half2` tile in row-major layout for column-wise quantization. - int s_sh_rd = - 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - - thread_block_reduce(); - - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ - locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, - int group_size) { - int tb_n = th_config.thread_n; - - // Get max scale groups per thread-block - // Fixed for channelwise - int tb_groups = 1; - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, - prob_k, num_bits, group_size); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, - group_size, max_shared_mem); - } - - TORCH_CHECK( - exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, - ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = -1; - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - // Define kernel configurations - if (false) { - } - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - // Channelwise only for FP8 - TORCH_CHECK(b_scales.size(0) == 1) - num_groups = b_scales.size(0); - - // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), size_m, - size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, - dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else { - TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("fp8_marlin_gemm", &fp8_marlin_gemm); -} \ No newline at end of file diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 56b78f1834d..6c146c3fb6f 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -13,6 +13,7 @@ #include "mmvq.cuh" #include "mmq.cuh" #include "moe.cuh" +#include "moe_vec.cuh" // Q8 gemv template @@ -377,6 +378,142 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input return Y; } +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input + torch::Tensor W, // expert weights + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens) { + int col = X.sizes()[1]; + const int padded = (col + 512 - 1) / 512 * 512; + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); + at::Tensor Y = torch::zeros({tokens * top_k, row}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); + at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), + (void*)quant_X.data_ptr(), col, tokens, + stream); + switch (type) { + case 2: + moe_vec_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 3: + moe_vec_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 6: + moe_vec_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 7: + moe_vec_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 8: + moe_vec_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 10: + moe_vec_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 11: + moe_vec_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 12: + moe_vec_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 13: + moe_vec_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 14: + moe_vec_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 16: + moe_vec_iq2_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 17: + moe_vec_iq2_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 18: + moe_vec_iq3_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 19: + moe_vec_iq1_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 20: + moe_vec_iq4_nl_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 21: + moe_vec_iq3_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 22: + moe_vec_iq2_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 23: + moe_vec_iq4_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 29: + moe_vec_iq1_m_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + } + }); + return Y; +} + int64_t ggml_moe_get_block_size(int64_t type) { switch (type) { case 2: diff --git a/csrc/quantization/gguf/moe_vec.cuh b/csrc/quantization/gguf/moe_vec.cuh new file mode 100644 index 00000000000..60f65a1bfdc --- /dev/null +++ b/csrc/quantization/gguf/moe_vec.cuh @@ -0,0 +1,338 @@ +// copied and adapted from +// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu +template +static __global__ void moe_vec_q(const void* __restrict__ vx, + const void* __restrict__ vy, + scalar_t* __restrict__ dst, + const int* topk_ids, const int topk, + const int ncols, const int nrows, + const int token_stride) { + const auto row = blockIdx.x * blockDim.y + threadIdx.y; + + const auto token = blockIdx.z / topk; + const auto expert = (topk_ids)[blockIdx.z]; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row; + const block_q8_1* y = + (const block_q8_1*)(((const int*)vy) + token * token_stride); + + for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (threadIdx.x % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); + } + + if (threadIdx.x == 0) { + dst[blockIdx.z * nrows + row] = tmp; + } +} + +template +static void moe_vec_q4_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q4_1_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_1_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q8_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q2_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q3_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q4_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q6_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_xxs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_xs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq3_xxs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq1_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq1_m_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq4_nl_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_iq4_xs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq3_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f59b42d88c6..7ca40a5e782 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); @@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // (supports multiple loras). ops.def( "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox," " int rot_dim," " Tensor cos_sin_cache_offsets) -> ()"); @@ -337,6 +337,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); + ops.def( + "ggml_moe_a8_vec(Tensor X, Tensor W, " + "Tensor topk_ids, int top_k, " + "int type, SymInt row, SymInt tokens) -> Tensor"); + ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec); + ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 6989106c429..53b8ccd8049 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -309,5 +309,7 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1 RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/nightly_torch_test.txt -#################### UNITTEST IMAGE ############################# +# Logging to confirm the torch versions +RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer' +#################### UNITTEST IMAGE ############################# diff --git a/docs/source/conf.py b/docs/source/conf.py index 060649e43b9..5620d6de2c5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -60,9 +60,6 @@ autodoc2_output_dir = "api" autodoc2_render_plugin = "myst" autodoc2_hidden_objects = ["dunder", "private", "inherited"] -autodoc2_docstring_parser_regexes = [ - (".*", "docs.source.autodoc2_docstring_parser"), -] autodoc2_sort_names = True autodoc2_index_template = None diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index 683fa8217a8..d1c058eafa4 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -11,6 +11,7 @@ helm lws modal open-webui +retrieval_augmented_generation skypilot streamlit triton diff --git a/docs/source/deployment/frameworks/retrieval_augmented_generation.md b/docs/source/deployment/frameworks/retrieval_augmented_generation.md new file mode 100644 index 00000000000..f84451fafe9 --- /dev/null +++ b/docs/source/deployment/frameworks/retrieval_augmented_generation.md @@ -0,0 +1,84 @@ +(deployment-retrieval-augmented-generation)= + +# Retrieval-Augmented Generation + +[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources. + +Here are the integrations: +- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus) +- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus) + +## vLLM + langchain + +### Prerequisites + +- Setup vLLM and langchain environment + +```console +pip install -U vllm \ + langchain_milvus langchain_openai \ + langchain_community beautifulsoup4 \ + langchain-text-splitters +``` + +### Deploy + +- Start the vLLM server with the supported embedding model, e.g. + +```console +# Start embedding service (port 8000) +vllm serve ssmits/Qwen2-7B-Instruct-embed-base +``` + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +# Start chat service (port 8001) +vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 +``` + +- Use the script: + +- Run the script + +```python +python retrieval_augmented_generation_with_langchain.py +``` + +## vLLM + llamaindex + +### Prerequisites + +- Setup vLLM and llamaindex environment + +```console +pip install vllm \ + llama-index llama-index-readers-web \ + llama-index-llms-openai-like \ + llama-index-embeddings-openai-like \ + llama-index-vector-stores-milvus \ +``` + +### Deploy + +- Start the vLLM server with the supported embedding model, e.g. + +```console +# Start embedding service (port 8000) +vllm serve ssmits/Qwen2-7B-Instruct-embed-base +``` + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +# Start chat service (port 8001) +vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 +``` + +- Use the script: + +- Run the script + +```python +python retrieval_augmented_generation_with_llamaindex.py +``` diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index f98ec6108ce..f3b808b3d2b 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -141,9 +141,9 @@ Known issues: much shorter than what vLLM generates. Since an exception is thrown when this condition is not met, the following additional chat templates are provided: -* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +* - this is the "official" Mistral chat template, but tweaked so that it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) -* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +* - this is a "better" version that adds a tool-use system prompt when tools are provided, that results in much better reliability when working with parallel tool calling. Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` @@ -170,15 +170,15 @@ Known issues: VLLM provides two JSON based chat templates for Llama 3.1 and 3.2: -* `examples/tool_chat_template_llama3.1_json.jinja` - this is the "official" chat template for the Llama 3.1 +* - this is the "official" chat template for the Llama 3.1 models, but tweaked so that it works better with vLLM. -* `examples/tool_chat_template_llama3.2_json.jinja` - this extends upon the Llama 3.1 chat template by adding support for +* - this extends upon the Llama 3.1 chat template by adding support for images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` VLLM also provides a JSON based chat template for Llama 4: -* `examples/tool_chat_template_llama4_json.jinja` - this is based on the "official" chat template for the Llama 4 +* - this is based on the "official" chat template for the Llama 4 models, but tweaked so that it works better with vLLM. For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. @@ -191,7 +191,7 @@ Supported models: Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` -`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. +: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. * `ibm-granite/granite-3.1-8b-instruct` @@ -203,7 +203,7 @@ The chat template from Huggingface can be used directly. Parallel function calls Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` -`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. +: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. ### InternLM Models (`internlm`) @@ -253,12 +253,12 @@ Limitations: Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) -* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) -* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`) -* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`) +* `meta-llama/Llama-3.2-1B-Instruct`\* (use with ) +* `meta-llama/Llama-3.2-3B-Instruct`\* (use with ) +* `Team-ACE/ToolACE-8B` (use with ) +* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with ) +* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with ) +* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with ) Flags: `--tool-call-parser pythonic --chat-template {see_above}` @@ -270,7 +270,7 @@ Llama's smaller models frequently fail to emit tool calls in the correct format. ## How to write a tool parser plugin -A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in . Here is a summary of a plugin file: diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index a5b63cf7bed..287947feb3d 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -239,7 +239,9 @@ print(output) See [this page](#generative-models) for more information on how to use generative models. -#### Text Generation (`--task generate`) +#### Text Generation + +Specified using `--task generate`. :::{list-table} :widths: 25 25 50 5 5 @@ -605,7 +607,9 @@ Since some model architectures support both generative and pooling tasks, you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. ::: -#### Text Embedding (`--task embed`) +#### Text Embedding + +Specified using `--task embed`. :::{list-table} :widths: 25 25 50 5 5 @@ -670,7 +674,9 @@ If your model is not in the above list, we will try to automatically convert the {func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings of the whole prompt are extracted from the normalized hidden state corresponding to the last token. -#### Reward Modeling (`--task reward`) +#### Reward Modeling + +Specified using `--task reward`. :::{list-table} :widths: 25 25 50 5 5 @@ -711,7 +717,9 @@ For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. ::: -#### Classification (`--task classify`) +#### Classification + +Specified using `--task classify`. :::{list-table} :widths: 25 25 50 5 5 @@ -737,7 +745,9 @@ e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "r If your model is not in the above list, we will try to automatically convert the model using {func}`~vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. -#### Sentence Pair Scoring (`--task score`) +#### Sentence Pair Scoring + +Specified using `--task score`. :::{list-table} :widths: 25 25 50 5 5 @@ -824,7 +834,9 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal See [this page](#generative-models) for more information on how to use generative models. -#### Text Generation (`--task generate`) +#### Text Generation + +Specified using `--task generate`. :::{list-table} :widths: 25 25 15 20 5 5 5 @@ -1200,7 +1212,9 @@ Since some model architectures support both generative and pooling tasks, you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. ::: -#### Text Embedding (`--task embed`) +#### Text Embedding + +Specified using `--task embed`. Any text generation model can be converted into an embedding model by passing `--task embed`. @@ -1240,7 +1254,9 @@ The following table lists those that are tested in vLLM. * ✅︎ ::: -#### Transcription (`--task transcription`) +#### Transcription + +Specified using `--task transcription`. Speech2Text models trained specifically for Automatic Speech Recognition. diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index d9a093e8d14..bcaa4f9b96c 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -216,7 +216,7 @@ A chat template is **required** to use Chat Completions API. Although most models come with a chat template, for others you have to define one yourself. The chat template can be inferred based on the documentation on the model's HuggingFace repo. -For example, LLaVA-1.5 (`llava-hf/llava-1.5-7b-hf`) requires a chat template that can be found here: +For example, DeepSeek-VL2 requires a chat template that can be found here: ::: ### Image Inputs diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index ab235ddd754..b6608ec6e95 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -75,43 +75,38 @@ def initialize_engine(model: str, quantization: str, lora_repo: Optional[str]) -> LLMEngine: """Initialize the LLMEngine.""" - if quantization == "bitsandbytes": - # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique. - # It quantizes the model when loading, with some config info from the - # LoRA adapter repo. So need to set the parameter of load_format and - # qlora_adapter_name_or_path as below. - engine_args = EngineArgs(model=model, - quantization=quantization, - qlora_adapter_name_or_path=lora_repo, - enable_lora=True, - max_lora_rank=64) - else: - engine_args = EngineArgs(model=model, - quantization=quantization, - enable_lora=True, - max_loras=4) + engine_args = EngineArgs(model=model, + quantization=quantization, + enable_lora=True, + max_lora_rank=64, + max_loras=4) return LLMEngine.from_engine_args(engine_args) def main(): """Main function that sets up and runs the prompt processing.""" - test_configs = [{ - "name": "qlora_inference_example", - 'model': "huggyllama/llama-7b", - 'quantization': "bitsandbytes", - 'lora_repo': 'timdettmers/qlora-flan-7b' - }, { - "name": "AWQ_inference_with_lora_example", - 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', - 'quantization': "awq", - 'lora_repo': 'jashing/tinyllama-colorist-lora' - }, { - "name": "GPTQ_inference_with_lora_example", - 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', - 'quantization': "gptq", - 'lora_repo': 'jashing/tinyllama-colorist-lora' - }] + test_configs = [ + # QLoRA (https://arxiv.org/abs/2305.14314) + { + "name": "qlora_inference_example", + 'model': "huggyllama/llama-7b", + 'quantization': "bitsandbytes", + 'lora_repo': 'timdettmers/qlora-flan-7b' + }, + { + "name": "AWQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', + 'quantization': "awq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + }, + { + "name": "GPTQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', + 'quantization': "gptq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + } + ] for test_config in test_configs: print( diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py new file mode 100644 index 00000000000..4f63f1a2fb3 --- /dev/null +++ b/examples/offline_inference/neuron_eagle.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with an EAGLE speculative +decoding model on neuron. To use EAGLE speculative decoding, you must use +a draft model that is specifically fine-tuned for EAGLE speculation. +Additionally, to use EAGLE with NxD Inference, the draft model must include +the LM head weights from the target model. These weights are shared between +the draft and target model. +""" + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "What is annapurna labs?", +] + +# Create a sampling params object. +sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) + +# Create an LLM. +llm = LLM( + model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", + speculative_config={ + "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + "num_speculative_tokens": 5, + "max_model_len": 2048 + }, + max_num_seqs=4, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in neuronx-distributed-inference. + max_model_len=2048, + block_size=2048, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + tensor_parallel_size=32, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculation": True + }, +) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py new file mode 100644 index 00000000000..bef434bae5b --- /dev/null +++ b/examples/offline_inference/neuron_speculation.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with a speculative +decoding model on neuron. +""" + +import os + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, I am a language model and I can help", + "The president of the United States is", + "The capital of France is", +] + + +def config_buckets(): + """Configure context length and token gen buckets.""" + # creates XLA hlo graphs for all the context length buckets. + os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" + # creates XLA hlo graphs for all the token gen buckets. + os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + + +def initialize_model(): + """Create an LLM with speculative decoding.""" + return LLM( + model="openlm-research/open_llama_7b", + speculative_config={ + "model": "openlm-research/open_llama_3b", + "num_speculative_tokens": 4, + "max_model_len": 2048 + }, + max_num_seqs=4, + max_model_len=2048, + block_size=2048, + use_v2_block_manager=True, + device="neuron", + tensor_parallel_size=32, + ) + + +def process_requests(model: LLM, sampling_params: SamplingParams): + """Generate texts from prompts and print them.""" + outputs = model.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +def main(): + """Main function that sets up the model and processes prompts.""" + config_buckets() + model = initialize_model() + # Create a sampling params object. + sampling_params = SamplingParams(max_tokens=100, top_k=1) + process_requests(model, sampling_params) + + +if __name__ == '__main__': + main() diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index dea717c3608..71cd88f2788 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -22,7 +22,8 @@ def main(): # In real workloads, `enforace_eager` should be `False`. llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_num_batched_tokens=64, - max_num_seqs=4) + max_num_seqs=4, + max_model_len=128) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers): diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index aca11f5c50b..5c173ab1abb 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData: max_model_len=4096, max_num_seqs=2, dtype="bfloat16", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [(f"<|im_start|>user\n<|img|>{question}" @@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: max_model_len=2048, max_num_seqs=2, mm_processor_kwargs={"crop_to_patches": True}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" @@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( model="Salesforce/blip2-opt-6.7b", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: model="facebook/chameleon-7b", max_model_len=4096, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -130,7 +130,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: max_model_len=4096, max_num_seqs=2, hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ @@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, trust_remote_code=True, dtype="bfloat16", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = ["" for _ in questions] @@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData: model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: max_model_len=2048, max_num_seqs=2, mm_processor_kwargs={"do_pan_and_scan": True}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [("user\n" @@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, enforce_eager=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ @@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=8192, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: "longest_edge": 3 * 364 }, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [( f"<|begin_of_text|>User:{question}\nAssistant:" @@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: "longest_edge": 384 }, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ (f"<|im_start|>User:{question}\nAssistant:") @@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -378,7 +378,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: model="moonshotai/Kimi-VL-A3B-Instruct", trust_remote_code=True, max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -398,7 +398,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="llava-hf/llava-1.5-7b-hf", max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -415,7 +415,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -437,7 +437,7 @@ def run_llava_next_video(questions: list[str], model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -465,7 +465,7 @@ def run_llava_onevision(questions: list[str], engine_args = EngineArgs( model="llava-hf/llava-onevision-qwen2-7b-ov-hf", max_model_len=16384, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -488,7 +488,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData: model="TIGER-Lab/Mantis-8B-siglip-llama3", max_model_len=4096, hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) stop_token_ids = [128009] @@ -529,7 +529,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): max_model_len=4096, max_num_seqs=2, trust_remote_code=True, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) # NOTE The stop_token_ids are different for various versions of MiniCPM-V # 2.0 @@ -584,7 +584,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: max_model_len=8192, max_num_seqs=2, tensor_parallel_size=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [f"[INST]{question}\n[IMG][/INST]" for question in questions] @@ -610,7 +610,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=8192, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -645,7 +645,7 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=4, tensor_parallel_size=8, gpu_memory_utilization=0.4, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -680,7 +680,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, dtype="bfloat16", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ @@ -706,7 +706,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, max_model_len=4096, tensor_parallel_size=4, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -738,7 +738,7 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, dtype="half", hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) placeholder = "\n" @@ -761,7 +761,7 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: prompts = ["caption en" for _ in questions] engine_args = EngineArgs( model="google/paligemma-3b-mix-224", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -778,7 +778,7 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData: prompts = ["caption en" for _ in questions] engine_args = EngineArgs( model="google/paligemma2-3b-ft-docci-448", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -815,7 +815,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"num_crops": 16}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -849,7 +849,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: max_lora_rank=320, # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"dynamic_hd": 16}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -870,7 +870,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=6144, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [f"[INST]{question}\n[IMG][/INST]" for question in questions] @@ -891,7 +891,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: max_model_len=1024, max_num_seqs=2, hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [f"{question}Picture 1: \n" for question in questions] @@ -916,7 +916,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: "min_pixels": 28 * 28, "max_pixels": 1280 * 28 * 28, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) if modality == "image": @@ -951,7 +951,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: "max_pixels": 1280 * 28 * 28, "fps": 1, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) if modality == "image": @@ -985,7 +985,7 @@ def run_qwen2_5_omni(questions: list[str], modality: str): "max_pixels": 1280 * 28 * 28, "fps": [1], }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) if modality == "image": @@ -1018,7 +1018,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index 70db4d95e64..cffd093c983 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -5,7 +5,7 @@ Launch the vLLM server with the following command: (single image inference with Llava) -vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja +vllm serve llava-hf/llava-1.5-7b-hf (multi-image inference with Phi-3.5-vision-instruct) vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ diff --git a/examples/online_serving/retrieval_augmented_generation_with_langchain.py b/examples/online_serving/retrieval_augmented_generation_with_langchain.py new file mode 100644 index 00000000000..73063065cb3 --- /dev/null +++ b/examples/online_serving/retrieval_augmented_generation_with_langchain.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Retrieval Augmented Generation (RAG) Implementation with Langchain +================================================================== + +This script demonstrates a RAG implementation using LangChain, Milvus +and vLLM. RAG enhances LLM responses by retrieving relevant context +from a document collection. + +Features: +- Web content loading and chunking +- Vector storage with Milvus +- Embedding generation with vLLM +- Question answering with context + +Prerequisites: +1. Install dependencies: + pip install -U vllm \ + langchain_milvus langchain_openai \ + langchain_community beautifulsoup4 \ + langchain-text-splitters + +2. Start services: + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + +Usage: + python retrieval_augmented_generation_with_langchain.py + +Notes: + - Ensure both vLLM services are running before executing + - Default ports: 8000 (embedding), 8001 (chat) + - First run may take time to download models +""" + +import argparse +from argparse import Namespace +from typing import Any + +from langchain_community.document_loaders import WebBaseLoader +from langchain_core.documents import Document +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_milvus import Milvus +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_text_splitters import RecursiveCharacterTextSplitter + + +def load_and_split_documents(config: dict[str, Any]): + """ + Load and split documents from web URL + """ + try: + loader = WebBaseLoader(web_paths=(config["url"], )) + docs = loader.load() + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) + return text_splitter.split_documents(docs) + except Exception as e: + print(f"Error loading document from {config['url']}: {str(e)}") + raise + + +def init_vectorstore(config: dict[str, Any], documents: list[Document]): + """ + Initialize vector store with documents + """ + return Milvus.from_documents( + documents=documents, + embedding=OpenAIEmbeddings( + model=config["embedding_model"], + openai_api_key=config["vllm_api_key"], + openai_api_base=config["vllm_embedding_endpoint"], + ), + connection_args={"uri": config["uri"]}, + drop_old=True, + ) + + +def init_llm(config: dict[str, Any]): + """ + Initialize llm + """ + return ChatOpenAI( + model=config["chat_model"], + openai_api_key=config["vllm_api_key"], + openai_api_base=config["vllm_chat_endpoint"], + ) + + +def get_qa_prompt(): + """ + Get question answering prompt template + """ + template = """You are an assistant for question-answering tasks. +Use the following pieces of retrieved context to answer the question. +If you don't know the answer, just say that you don't know. +Use three sentences maximum and keep the answer concise. +Question: {question} +Context: {context} +Answer: +""" + return PromptTemplate.from_template(template) + + +def format_docs(docs: list[Document]): + """ + Format documents for prompt + """ + return "\n\n".join(doc.page_content for doc in docs) + + +def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate): + """ + Set up question answering chain + """ + return ({ + "context": retriever | format_docs, + "question": RunnablePassthrough(), + } + | prompt + | llm + | StrOutputParser()) + + +def get_parser() -> argparse.ArgumentParser: + """ + Parse command line arguments + """ + parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') + + # Add command line arguments + parser.add_argument('--vllm-api-key', + default="EMPTY", + help='API key for vLLM compatible services') + parser.add_argument('--vllm-embedding-endpoint', + default="http://localhost:8000/v1", + help='Base URL for embedding service') + parser.add_argument('--vllm-chat-endpoint', + default="http://localhost:8001/v1", + help='Base URL for chat service') + parser.add_argument('--uri', + default="./milvus.db", + help='URI for Milvus database') + parser.add_argument( + '--url', + default=("https://docs.vllm.ai/en/latest/getting_started/" + "quickstart.html"), + help='URL of the document to process') + parser.add_argument('--embedding-model', + default="ssmits/Qwen2-7B-Instruct-embed-base", + help='Model name for embeddings') + parser.add_argument('--chat-model', + default="qwen/Qwen1.5-0.5B-Chat", + help='Model name for chat') + parser.add_argument('-i', + '--interactive', + action='store_true', + help='Enable interactive Q&A mode') + parser.add_argument('-k', + '--top-k', + type=int, + default=3, + help='Number of top results to retrieve') + parser.add_argument('-c', + '--chunk-size', + type=int, + default=1000, + help='Chunk size for document splitting') + parser.add_argument('-o', + '--chunk-overlap', + type=int, + default=200, + help='Chunk overlap for document splitting') + + return parser + + +def init_config(args: Namespace): + """ + Initialize configuration settings from command line arguments + """ + + return { + "vllm_api_key": args.vllm_api_key, + "vllm_embedding_endpoint": args.vllm_embedding_endpoint, + "vllm_chat_endpoint": args.vllm_chat_endpoint, + "uri": args.uri, + "embedding_model": args.embedding_model, + "chat_model": args.chat_model, + "url": args.url, + "chunk_size": args.chunk_size, + "chunk_overlap": args.chunk_overlap, + "top_k": args.top_k + } + + +def main(): + # Parse command line arguments + args = get_parser().parse_args() + + # Initialize configuration + config = init_config(args) + + # Load and split documents + documents = load_and_split_documents(config) + + # Initialize vector store and retriever + vectorstore = init_vectorstore(config, documents) + retriever = vectorstore.as_retriever(search_kwargs={"k": config["top_k"]}) + + # Initialize llm and prompt + llm = init_llm(config) + prompt = get_qa_prompt() + + # Set up QA chain + qa_chain = create_qa_chain(retriever, llm, prompt) + + # Interactive mode + if args.interactive: + print("\nWelcome to Interactive Q&A System!") + print("Enter 'q' or 'quit' to exit.") + + while True: + question = input("\nPlease enter your question: ") + if question.lower() in ['q', 'quit']: + print("\nThank you for using! Goodbye!") + break + + output = qa_chain.invoke(question) + print(output) + else: + # Default single question mode + question = ("How to install vLLM?") + output = qa_chain.invoke(question) + print("-" * 50) + print(output) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py new file mode 100644 index 00000000000..a8f76dfe4c6 --- /dev/null +++ b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +RAG (Retrieval Augmented Generation) Implementation with LlamaIndex +================================================================ + +This script demonstrates a RAG system using: +- LlamaIndex: For document indexing and retrieval +- Milvus: As vector store backend +- vLLM: For embedding and text generation + +Features: +1. Document Loading & Processing +2. Embedding & Storage +3. Query Processing + +Requirements: +1. Install dependencies: +pip install llama-index llama-index-readers-web \ + llama-index-llms-openai-like \ + llama-index-embeddings-openai-like \ + llama-index-vector-stores-milvus \ + +2. Start services: + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + +Usage: + python retrieval_augmented_generation_with_llamaindex.py + +Notes: + - Ensure both vLLM services are running before executing + - Default ports: 8000 (embedding), 8001 (chat) + - First run may take time to download models +""" +import argparse +from argparse import Namespace +from typing import Any + +from llama_index.core import Settings, StorageContext, VectorStoreIndex +from llama_index.core.node_parser import SentenceSplitter +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from llama_index.llms.openai_like import OpenAILike +from llama_index.readers.web import SimpleWebPageReader +from llama_index.vector_stores.milvus import MilvusVectorStore + + +def init_config(args: Namespace): + """Initialize configuration with command line arguments""" + return { + "url": args.url, + "embedding_model": args.embedding_model, + "chat_model": args.chat_model, + "vllm_api_key": args.vllm_api_key, + "embedding_endpoint": args.embedding_endpoint, + "chat_endpoint": args.chat_endpoint, + "db_path": args.db_path, + "chunk_size": args.chunk_size, + "chunk_overlap": args.chunk_overlap, + "top_k": args.top_k + } + + +def load_documents(url: str) -> list: + """Load and process web documents""" + return SimpleWebPageReader(html_to_text=True).load_data([url]) + + +def setup_models(config: dict[str, Any]): + """Configure embedding and chat models""" + Settings.embed_model = OpenAILikeEmbedding( + api_base=config["embedding_endpoint"], + api_key=config["vllm_api_key"], + model_name=config["embedding_model"], + ) + + Settings.llm = OpenAILike( + model=config["chat_model"], + api_key=config["vllm_api_key"], + api_base=config["chat_endpoint"], + context_window=128000, + is_chat_model=True, + is_function_calling_model=False, + ) + + Settings.transformations = [ + SentenceSplitter( + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) + ] + + +def setup_vector_store(db_path: str) -> MilvusVectorStore: + """Initialize vector store""" + sample_emb = Settings.embed_model.get_text_embedding("test") + print(f"Embedding dimension: {len(sample_emb)}") + return MilvusVectorStore(uri=db_path, dim=len(sample_emb), overwrite=True) + + +def create_index(documents: list, vector_store: MilvusVectorStore): + """Create document index""" + storage_context = StorageContext.from_defaults(vector_store=vector_store) + return VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + ) + + +def query_document(index: VectorStoreIndex, question: str, top_k: int): + """Query document with given question""" + query_engine = index.as_query_engine(similarity_top_k=top_k) + return query_engine.query(question) + + +def get_parser() -> argparse.ArgumentParser: + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description='RAG with vLLM and LlamaIndex') + + # Add command line arguments + parser.add_argument( + '--url', + default=("https://docs.vllm.ai/en/latest/getting_started/" + "quickstart.html"), + help='URL of the document to process') + parser.add_argument('--embedding-model', + default="ssmits/Qwen2-7B-Instruct-embed-base", + help='Model name for embeddings') + parser.add_argument('--chat-model', + default="qwen/Qwen1.5-0.5B-Chat", + help='Model name for chat') + parser.add_argument('--vllm-api-key', + default="EMPTY", + help='API key for vLLM compatible services') + parser.add_argument('--embedding-endpoint', + default="http://localhost:8000/v1", + help='Base URL for embedding service') + parser.add_argument('--chat-endpoint', + default="http://localhost:8001/v1", + help='Base URL for chat service') + parser.add_argument('--db-path', + default="./milvus_demo.db", + help='Path to Milvus database') + parser.add_argument('-i', + '--interactive', + action='store_true', + help='Enable interactive Q&A mode') + parser.add_argument('-c', + '--chunk-size', + type=int, + default=1000, + help='Chunk size for document splitting') + parser.add_argument('-o', + '--chunk-overlap', + type=int, + default=200, + help='Chunk overlap for document splitting') + parser.add_argument('-k', + '--top-k', + type=int, + default=3, + help='Number of top results to retrieve') + + return parser + + +def main(): + # Parse command line arguments + args = get_parser().parse_args() + + # Initialize configuration + config = init_config(args) + + # Load documents + documents = load_documents(config["url"]) + + # Setup models + setup_models(config) + + # Setup vector store + vector_store = setup_vector_store(config["db_path"]) + + # Create index + index = create_index(documents, vector_store) + + if args.interactive: + print("\nEntering interactive mode. Type 'quit' to exit.") + while True: + # Get user question + question = input("\nEnter your question: ") + + # Check for exit command + if question.lower() in ['quit', 'exit', 'q']: + print("Exiting interactive mode...") + break + + # Get and print response + print("\n" + "-" * 50) + print("Response:\n") + response = query_document(index, question, config["top_k"]) + print(response) + print("-" * 50) + else: + # Single query mode + question = "How to install vLLM?" + response = query_document(index, question, config["top_k"]) + print("-" * 50) + print("Response:\n") + print(response) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/template_chameleon.jinja b/examples/template_chameleon.jinja new file mode 100644 index 00000000000..3fa2cccc240 --- /dev/null +++ b/examples/template_chameleon.jinja @@ -0,0 +1,3 @@ +{%- for message in messages -%} + {{- message['content'] -}} +{%- endfor -%} diff --git a/examples/template_florence2.jinja b/examples/template_florence2.jinja index d257aed6a85..3fa2cccc240 100644 --- a/examples/template_florence2.jinja +++ b/examples/template_florence2.jinja @@ -1,7 +1,3 @@ {%- for message in messages -%} - {%- if message['role'] == 'user' -%} - {{- message['content'] -}} - {%- elif message['role'] == 'assistant' -%} - {{- message['content'] -}} - {%- endif -%} + {{- message['content'] -}} {%- endfor -%} diff --git a/examples/template_fuyu.jinja b/examples/template_fuyu.jinja new file mode 100644 index 00000000000..ec337d0c644 --- /dev/null +++ b/examples/template_fuyu.jinja @@ -0,0 +1,3 @@ +{%- for message in messages -%} + {{- message['content'] + '\n' -}} +{%- endfor -%} diff --git a/examples/template_llava.jinja b/examples/template_llava.jinja deleted file mode 100644 index 6a902ee1677..00000000000 --- a/examples/template_llava.jinja +++ /dev/null @@ -1,23 +0,0 @@ -{%- if messages[0]['role'] == 'system' -%} - {%- set system_message = messages[0]['content'] -%} - {%- set messages = messages[1:] -%} -{%- else -%} - {% set system_message = '' -%} -{%- endif -%} - -{{ bos_token + system_message }} -{%- for message in messages -%} - {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {%- endif -%} - - {%- if message['role'] == 'user' -%} - {{ 'USER: ' + message['content'] + '\n' }} - {%- elif message['role'] == 'assistant' -%} - {{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }} - {%- endif -%} -{%- endfor -%} - -{%- if add_generation_prompt -%} - {{ 'ASSISTANT:' }} -{% endif %} diff --git a/examples/template_paligemma.jinja b/examples/template_paligemma.jinja new file mode 100644 index 00000000000..3fa2cccc240 --- /dev/null +++ b/examples/template_paligemma.jinja @@ -0,0 +1,3 @@ +{%- for message in messages -%} + {{- message['content'] -}} +{%- endfor -%} diff --git a/examples/template_qwen_vl.jinja b/examples/template_qwen_vl.jinja new file mode 100644 index 00000000000..3fa2cccc240 --- /dev/null +++ b/examples/template_qwen_vl.jinja @@ -0,0 +1,3 @@ +{%- for message in messages -%} + {{- message['content'] -}} +{%- endfor -%} diff --git a/examples/template_qwen_vl_chat.jinja b/examples/template_qwen_vl_chat.jinja new file mode 100644 index 00000000000..e76ab0c2d25 --- /dev/null +++ b/examples/template_qwen_vl_chat.jinja @@ -0,0 +1,10 @@ +{%- for message in messages -%} + {{- '<|im_start|>' + message['role'] + '\n' + message['content'] -}} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '<|im_end|>' + '\n' -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '<|im_start|>assistant\n' -}} +{%- endif -%} diff --git a/requirements/docs.txt b/requirements/docs.txt index 385de841691..ccc5ef0aa97 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,6 +1,5 @@ sphinx==7.4.7 sphinx-argparse==0.5.2 -sphinx-autodoc2==0.5.0 sphinx-book-theme==1.1.4 sphinx-copybutton==0.5.2 sphinx-design==0.6.1 @@ -9,6 +8,10 @@ myst-parser==3.0.1 # `myst-parser==4.0.1` breaks inline code in titles msgspec commonmark # Required by sphinx-argparse when using :markdownhelp: +# Custom autodoc2 is necessary for faster docstring processing +# see: https://github.com/sphinx-extensions2/sphinx-autodoc2/issues/33#issuecomment-2856386035 +git+https://github.com/hmellor/sphinx-autodoc2.git # sphinx-autodoc2==0.5.0 + # packages to install to build the documentation cachetools -f https://download.pytorch.org/whl/cpu diff --git a/requirements/neuron.txt b/requirements/neuron.txt index f8e3030834e..7df478eddde 100644 --- a/requirements/neuron.txt +++ b/requirements/neuron.txt @@ -5,4 +5,5 @@ packaging>=24.2 setuptools>=77.0.3,<80.0.0 torch-neuronx >= 2.5.0 -neuronx-cc +neuronx-cc>=2.0.0a0 +torchvision # Required for Llama3.2 multimodal image preprocessing diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index e2711354ac1..3aebcaa623c 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -8,7 +8,6 @@ pytest-rerunfailures pytest-shard pytest-timeout - librosa # required by audio tests in entrypoints/openai sentence-transformers numba == 0.61.2; python_version > '3.9' @@ -31,3 +30,12 @@ bitsandbytes>=0.45.3 # required for minicpmo_26 test vector_quantize_pytorch vocos + +# required for Basic Models Test +blobfile # required for kimi-vl test +matplotlib # required for qwen-vl test + +# required for Multi-Modal Models Test (Standard) +num2words # required for smolvlm test +pqdm +timm # required for internvl test diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 17d57058bfa..11501bc5d92 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250408 -torchvision==0.22.0.dev20250408 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250430 +torchvision==0.22.0.dev20250430 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 03de8d9b92b..9c90fe381bb 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -100,9 +100,8 @@ def detailed( eager_mode=True, chunked_prefill=False), ], - # only ray is supported for V1 - distributed_backends=["mp", "ray", "ray"], - vllm_major_versions=["0", "0", "1"], + distributed_backends=["mp", "mp", "ray", "ray"], + vllm_major_versions=["0", "1", "0", "1"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -350,6 +349,11 @@ def _compare_tp( # Temporary. Currently when zeromq + SPMD is used, it does not properly # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") + elif distributed_backend == "mp": + # Both V0/V1 of multiprocessing executor support PP + pp_env = { + "VLLM_USE_V1": vllm_major_version, + } else: pp_env = None diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 92c1e0fec6b..1de30f0ac05 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -900,6 +900,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): [("template_alpaca.jinja", "string"), ("template_baichuan.jinja", "string"), ("template_blip2.jinja", "string"), + ("template_chameleon.jinja", "string"), ("template_chatglm.jinja", "string"), ("template_chatglm2.jinja", "string"), ("template_chatml.jinja", "string"), @@ -908,9 +909,12 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("template_falcon_180b.jinja", "string"), ("template_falcon.jinja", "string"), ("template_florence2.jinja", "string"), + ("template_fuyu.jinja", "string"), ("template_inkbot.jinja", "string"), - ("template_llava.jinja", "string"), + ("template_paligemma.jinja", "string"), ("template_teleflm.jinja", "string"), + ("template_qwen_vl.jinja", "string"), + ("template_qwen_vl_chat.jinja", "string"), ("template_vlm2vec.jinja", "openai"), ("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"), diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 9907b5c863e..78211a2bd67 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -24,6 +24,7 @@ if current_platform.is_hpu(): import habana_frameworks.torch as htorch CUDA_DEVICES = ['hpu'] +USE_KEY = [True, False] def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, @@ -49,6 +50,7 @@ def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, @@ -61,6 +63,7 @@ def test_rotary_embedding( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -77,7 +80,7 @@ def test_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. @@ -90,10 +93,14 @@ def test_rotary_embedding( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -106,6 +113,7 @@ def test_rotary_embedding( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_batched_rotary_embedding( is_neox_style: bool, @@ -118,6 +126,7 @@ def test_batched_rotary_embedding( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -134,7 +143,7 @@ def test_batched_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. @@ -152,10 +161,14 @@ def test_batched_rotary_embedding( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -167,6 +180,7 @@ def test_batched_rotary_embedding( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_batched_rotary_embedding_multi_lora( is_neox_style: bool, @@ -178,6 +192,7 @@ def test_batched_rotary_embedding_multi_lora( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -197,7 +212,7 @@ def test_batched_rotary_embedding_multi_lora( seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None offset_map = torch.tensor( list( @@ -223,10 +238,14 @@ def test_batched_rotary_embedding_multi_lora( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @torch.inference_mode() diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index c497dd90edd..4e54861005f 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -15,7 +15,7 @@ def rotary_embedding_opcheck(rot, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) @@ -37,9 +37,10 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("rotary_dim", [32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.parametrize("use_key", [True, False]) def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, - seq_len): + seq_len, use_key): batch_size = 1 base = 10000 num_heads = 7 @@ -54,7 +55,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, num_heads * head_size, dtype=torch.float32, device=device) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None rotary_embedding_opcheck(rot, positions, query, key) offsets = torch.zeros(batch_size * seq_len, diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index ee908105f55..f5e751bea41 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -6,7 +6,7 @@ from einops import rearrange, repeat from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _seq_idx_to_chunk_indices_offsets) + _query_start_loc_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform @@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, last_taken, exhausted, n_heads, d_head, itype): - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) Y, new_states = mamba_chunk_scan_combined( X, diff --git a/tests/kernels/quantization/test_ggml.py b/tests/kernels/quantization/test_ggml.py index cc157da518c..73697a6d124 100644 --- a/tests/kernels/quantization/test_ggml.py +++ b/tests/kernels/quantization/test_ggml.py @@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type): opcheck(torch.ops._C.ggml_moe_a8, (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, quant_type, qweight.shape[0], 1, x.shape[0])) + + topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32) + + opcheck( + torch.ops._C.ggml_moe_a8_vec, + (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0])) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 4c0fae9d9fd..6cf88604ec6 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("hidden_size", [512]) @pytest.mark.parametrize("top_k", [4, 8]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize( - "quant_type", - [ - # k-quants - GGMLQuantizationType.Q2_K, - GGMLQuantizationType.Q3_K, - GGMLQuantizationType.Q4_K, - GGMLQuantizationType.Q5_K, - GGMLQuantizationType.Q6_K, - # standard quants - GGMLQuantizationType.Q4_0, - GGMLQuantizationType.Q5_0, - GGMLQuantizationType.Q8_0, - ]) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType, top_k: int): @@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) - topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda") + topk_ids = torch.randint(0, + E, (num_tokens, top_k), + device="cuda", + dtype=torch.int32) tensors = get_gguf_MoE_tensors(hidden_size, quant_type) diff --git a/tests/kernels/test_triton_unified_attention.py b/tests/kernels/test_triton_unified_attention.py new file mode 100644 index 00000000000..50da8e5fd5c --- /dev/null +++ b/tests/kernels/test_triton_unified_attention.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import pytest +import torch + +from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.platforms import current_platform + +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] + +DTYPES = [torch.float16, torch.bfloat16] +QDTYPES = [None, torch.float8_e4m3fn] +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + if soft_cap is not None and soft_cap > 0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("seq_lens", + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("q_dtype", QDTYPES) +@torch.inference_mode() +def test_triton_unified_attn( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + num_blocks: int, + q_dtype: Optional[torch.dtype], +) -> None: + torch.set_default_device("cuda") + + current_platform.seed_everything(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window - 1, 0) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = torch.empty_like(query) + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = None # Not yet supported + k_descale = torch.rand(scale_shape, dtype=torch.float32) + v_descale = torch.rand(scale_shape, dtype=torch.float32) + + unified_attention( + q=maybe_quantized_query, + k=maybe_quantized_key_cache, + v=maybe_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + ) + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/neuron/1_core/test_neuron_model_runner.py b/tests/neuron/1_core/test_neuron_model_runner.py new file mode 100644 index 00000000000..92417fb64f7 --- /dev/null +++ b/tests/neuron/1_core/test_neuron_model_runner.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import MagicMock + +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform +from vllm.platforms.neuron import NeuronFramework +from vllm.sampling_params import SamplingParams +from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.worker.neuron_model_runner import NeuronModelRunner + +os.environ[ + 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value + + +def _create_neuron_model_runner(model: str, *args, + **kwargs) -> NeuronModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + vllm_config = VllmConfig( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + ) + neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config) + return neuron_model_runner + + +def test_update_neuron_sampling_params_not_full_batch(): + os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" + model_runner = _create_neuron_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) + assert not model_runner._on_device_sampling_disabled + # Test sampling param updating only when TNx is framework + # NxDI handles sampling parameter updating inside model + if current_platform.use_transformers_neuronx(): + model_mock = MagicMock() + model_runner.model = model_mock + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0.5, + top_k=1, + top_p=0.5), + block_tables={0: [1]}, + ) + ] + + model_runner.prepare_model_input(seq_group_metadata_list) + + # Index neuron sampling parameters based on block_tables indices. + # The first block_id of the sequence 0 is 1, so its parameters are + # placed at index 1. So the sampling parameters will be: + # Index 0: default sampling parameters + # Index 1: sequecne 0's sampling parameters. + neuron_sampling_params = ( + model_runner.model_config.neuron_sampling_params) + assert neuron_sampling_params.temperature == [1.0, 0.5] + assert neuron_sampling_params.top_k == [ + model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 + ] + assert neuron_sampling_params.top_p == [1.0, 0.5] + model_mock.model.update_generation_config.assert_called_once_with( + neuron_sampling_params) + + +def test_update_neuron_sampling_params_full_batch(): + os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" + model_runner = _create_neuron_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) + assert not model_runner._on_device_sampling_disabled + + # Test sampling param updating only when TNx is framework + # NxDI handles sampling parameter updating inside model + if current_platform.use_transformers_neuronx(): + model_mock = MagicMock() + model_runner.model = model_mock + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0.5, + top_k=1, + top_p=0.5), + block_tables={0: [1]}, + ), + SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={1: SequenceData.from_seqs([4, 5, 6])}, + sampling_params=SamplingParams(temperature=0.2, + top_k=2, + top_p=0.2), + block_tables={1: [0]}, + ) + ] + + model_runner.prepare_model_input(seq_group_metadata_list) + + # Index neuron sampling parameters based on block_tables indices. + # The first block_id of the sequence 0 is 1, so its parameters are + # placed at index 1. So the sampling parameters will be: + # Index 0: sequence 1's sampling parameters + # Index 1: sequecne 0's sampling parameters. + neuron_sampling_params = ( + model_runner.model_config.neuron_sampling_params) + assert neuron_sampling_params.temperature == [0.2, 0.5] + assert neuron_sampling_params.top_k == [2, 1] + assert neuron_sampling_params.top_p == [0.2, 0.5] + model_mock.model.update_generation_config.assert_called_once_with( + neuron_sampling_params) diff --git a/tests/neuron/1_core/test_rotary_embedding.py b/tests/neuron/1_core/test_rotary_embedding.py index c015b80bd47..da57631fcfc 100644 --- a/tests/neuron/1_core/test_rotary_embedding.py +++ b/tests/neuron/1_core/test_rotary_embedding.py @@ -11,14 +11,16 @@ @pytest.mark.parametrize( - "max_position,is_neox_style,rotary_dim,head_size,seq_len", [ - (16, False, 32, 32, 1024), - (16, False, 32, 128, 1024), - (16, True, 32, 32, 1024), - (16, True, 32, 128, 1024), + "max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [ + (16, False, 32, 32, 1024, True), + (16, False, 32, 128, 1024, True), + (16, True, 32, 32, 1024, True), + (16, True, 32, 128, 1024, True), + (16, False, 32, 128, 1024, False), + (16, True, 32, 128, 1024, False), ]) def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, - head_size, seq_len): + head_size, seq_len, use_key): import torch_xla.core.xla_model as xm device = xm.xla_device() @@ -40,19 +42,26 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, num_heads * head_size, dtype=torch.float32, device="cpu") - key = torch.randn_like(query) - + key = torch.randn_like(query) if use_key else None assert positions.is_cpu, \ "reference input tensor is expected to be CPU tensor." ref_query, ref_key = rot.to(device="cpu").forward_native( positions, query, key) out_query, out_key = rot.to(device=device).forward_neuron( positions.to(device=device), query.to(device=device), - key.to(device=device)) - assert out_query.is_xla and out_key.is_xla, \ - "output tensor is expected to be XLA tensor" + key.to(device=device) if key is not None else None) + if use_key: + assert out_query.is_xla and out_key.is_xla, \ + "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out_key.cpu(), + ref_key, + atol=1e-2, + rtol=1e-2) + else: + assert out_key is None, "expected returned key to be None" + assert out_query.is_xla, \ + "output tensor is expected to be XLA tensor" torch.testing.assert_close(out_query.cpu(), ref_query, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2) diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index aa91fa8e1c1..8b96184f579 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,8 +2,7 @@ from vllm import SamplingParams from vllm.config import LoadConfig, LoadFormat -from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader, - get_model_loader) +from vllm.model_executor.model_loader import get_model_loader test_model = "openai-community/gpt2" @@ -24,7 +23,7 @@ def get_runai_model_loader(): def test_get_model_loader_with_runai_flag(): model_loader = get_runai_model_loader() - assert isinstance(model_loader, RunaiModelStreamerLoader) + assert model_loader.__class__.__name__ == "RunaiModelStreamerLoader" def test_runai_model_loader_download_files(vllm_runner): diff --git a/tests/spec_decode/test_memory_usage.py b/tests/spec_decode/test_memory_usage.py index 7a205f2ab18..16dffe6d7d6 100644 --- a/tests/spec_decode/test_memory_usage.py +++ b/tests/spec_decode/test_memory_usage.py @@ -42,12 +42,12 @@ def add_seq_group_to_engine(engine: vllm.LLMEngine, seq_group: SequenceGroup): def test_memory_usage_no_spec(): previous_memory_allocated = None - llm = vllm.LLM( - model=MAIN_MODEL, - speculative_model=SPEC_MODEL, - num_speculative_tokens=3, - speculative_disable_by_batch_size=SPEC_DISABLE_BATCH_SIZE, - ) + llm = vllm.LLM(model=MAIN_MODEL, + speculative_config={ + "model": SPEC_MODEL, + "num_speculative_tokens": 3, + "disable_by_batch_size": SPEC_DISABLE_BATCH_SIZE, + }) batch_sequences = set() engine = llm.llm_engine diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 94b0156e104..77fec096800 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -10,7 +10,7 @@ from huggingface_hub import snapshot_download from vllm import LLM, SamplingParams -from vllm.model_executor.model_loader.loader import ShardedStateLoader +from vllm.model_executor.model_loader import ShardedStateLoader prompts = [ "Hello, my name is", diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py new file mode 100644 index 00000000000..13fc8bc8fa2 --- /dev/null +++ b/tests/tpu/test_moe_pallas.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the Pallas MOE implementation. + +Run `pytest tests/kernels/moe/test_moe_pallas.py`. +""" +import pytest +import torch + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.fused_moe.moe_pallas import ( + fused_moe as pallas_moe) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as torch_moe) +# yapf: enable +from vllm.platforms import current_platform + +if not current_platform.is_tpu(): + pytest.skip("This test needs a TPU.", allow_module_level=True) + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1] +TOP_KS = [2, 6] + + +# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 +@pytest.mark.parametrize("m", [8, 16, 64, 2048]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_pallas_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, +): + import torch_xla.core.xla_model as xm + with torch.device(xm.xla_device()): + a = torch.randn((m, k), dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 + w2 = torch.randn((e, k, n), dtype=dtype) / 10 + + score = torch.randn((m, e), dtype=dtype) + + # TODO: Support ep + if ep_size > 1: + pytest.skip("No support for ep_size > 1 yet") + else: + e_map = None + + # Run both implementations + torch_output = torch_moe( + hidden_states=a, + w1=w1, + w2=w2, + gating_output=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False, + ) + + pallas_output = pallas_moe( + hidden_states=a, + w1=w1, + w2=w2, + gating_output=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False, + ) + xm.mark_step() + + # Compare outputs + torch.testing.assert_close( + pallas_output.cpu(), + torch_output.cpu(), + atol=2e-2, + rtol=0, + ) diff --git a/tests/utils.py b/tests/utils.py index 0983687e2ce..bf38d784385 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,7 +29,7 @@ init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import (FlexibleArgumentParser, GB_bytes, diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e8069b8c6d7..df487ec2cca 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -542,7 +542,7 @@ def test_allocate_with_lookahead(): num_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) - assert len(blocks) == 2 # ceil(5/4)=2 blocks + assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, @@ -553,7 +553,7 @@ def test_allocate_with_lookahead(): num_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks) == 2 + assert len(blocks.blocks) == 2 # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 @@ -564,4 +564,4 @@ def test_allocate_with_lookahead(): num_tokens=3, num_lookahead_tokens=4, ) - assert len(blocks) == 2 + assert len(blocks.blocks) == 2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 4c05e0b87fc..01295e848ee 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -79,10 +79,10 @@ def test_prefill(hash_algo): req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -105,12 +105,12 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [1, 2, 3] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5] - for block in computed_blocks: + assert blocks.get_block_ids() == [5] + for block in computed_blocks.blocks: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -137,11 +137,11 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [1, 2, 3] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6] + assert blocks.get_block_ids() == [6] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -159,11 +159,11 @@ def test_prefill(hash_algo): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] + assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -195,11 +195,11 @@ def test_prefill_plp(): req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] - req0_block_hashes = [b.block_hash for b in blocks] + assert blocks.get_block_ids() == [1, 2, 3, 4] + req0_block_hashes = [b.block_hash for b in blocks.blocks] # Check full block metadata parent_block_hash = None @@ -223,12 +223,12 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [1, 2, 3] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5] - for block in computed_blocks: + assert blocks.get_block_ids() == [5] + for block in computed_blocks.blocks: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -257,12 +257,12 @@ def test_prefill_plp(): prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, computed_blocks) - block_ids = [b.block_id for b in blocks] + block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 - assert [b.block_hash for b in blocks] == req0_block_hashes + assert [b.block_hash for b in blocks.blocks] == req0_block_hashes assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. @@ -288,17 +288,17 @@ def test_decode(): unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks) == 0 assert manager.req_to_blocks[req0.request_id][-1].block_hash is None # Append slots with allocating a new block. @@ -308,7 +308,7 @@ def test_decode(): for _ in range(9 + 10): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19) - assert new_blocks is not None and len(new_blocks) == 1 + assert new_blocks is not None and len(new_blocks.blocks) == 1 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None assert manager.req_to_blocks[req0.request_id][-1].block_hash is None @@ -323,19 +323,19 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 6 # 5 full + 1 partial + assert len(blocks.blocks) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) - assert len(blocks) == 3 # 3 full blocks + assert len(blocks.blocks) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -352,10 +352,10 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert [b.block_id for b in computed_blocks] == [1, 2] + assert computed_blocks.get_block_ids() == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [10] + assert blocks.get_block_ids() == [10] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -375,10 +375,10 @@ def test_hash_block_correct_reuse(): num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, computed_blocks) - assert len(blocks) == 1 + assert len(blocks.blocks) == 1 # Deallocate the block. manager.free(req) @@ -387,12 +387,13 @@ def test_hash_block_correct_reuse(): # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) - assert len(blocks) == 1 + assert len(blocks.blocks) == 1 - assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None + assert manager.block_pool.blocks[ + blocks.blocks[0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -411,20 +412,20 @@ def test_computed_blocks_not_evicted(): num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 1 + assert len(blocks.blocks) == 1 + assert blocks.blocks[0].block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 2 + assert len(blocks.blocks) == 1 + assert blocks.blocks[0].block_id == 2 # Free the blocks. manager.free(req0) @@ -434,14 +435,14 @@ def test_computed_blocks_not_evicted(): # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks) == 1 - assert computed_blocks[0].block_id == 1 + assert len(computed_blocks.blocks) == 1 + assert computed_blocks.blocks[0].block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 2 + assert len(blocks.blocks) == 1 + assert blocks.blocks[0].block_id == 2 def test_basic_prefix_caching_disabled(): @@ -458,10 +459,10 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, computed_blocks) - assert len(blocks) == 3 + assert len(blocks.blocks) == 3 # Free the blocks. manager.free(req1) @@ -469,15 +470,15 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, computed_blocks) - assert len(blocks) == 4 + assert len(blocks.blocks) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, computed_blocks) assert not blocks @@ -569,7 +570,7 @@ def test_mm_prefix_caching(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 @@ -578,14 +579,14 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks) == 0 # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 @@ -603,7 +604,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(computed_blocks) == 3 + assert len(computed_blocks.blocks) == 3 assert num_computed_tokens == 3 * 16 @@ -626,7 +627,7 @@ def test_cache_key_salting(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 @@ -635,14 +636,14 @@ def test_cache_key_salting(): assert block_hashes[2].extra_keys is None blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks) == 0 # Now one more block that should not have extra keys. assert len(block_hashes) == 4 @@ -653,14 +654,14 @@ def test_cache_key_salting(): req1 = make_request("1", token_ids, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. - assert len(computed_blocks) == 3 + assert len(computed_blocks.blocks) == 3 assert num_computed_tokens == 3 * block_size # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks) == 0 + assert len(computed_blocks.blocks) == 0 assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req2.request_id] assert len(block_hashes) == 3 @@ -685,7 +686,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) block_part0 = manager.req_to_blocks[req0.request_id] @@ -693,7 +694,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert computed_blocks == block_part0 + assert computed_blocks.blocks == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) block_part1 = manager.req_to_blocks[req1.request_id] @@ -707,7 +708,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, computed_blocks) @@ -717,7 +718,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert computed_blocks == block_part1 + assert computed_blocks.blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, computed_blocks) is None @@ -739,16 +740,16 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert len(computed_blocks) == 3 + assert len(computed_blocks.blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) - assert [b.block_id for b in blocks] == [5] + assert blocks.get_block_ids() == [5] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -776,7 +777,7 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req, 16, computed_blocks) manager.reset_prefix_cache() @@ -866,7 +867,7 @@ def test_eagle_enabled_removes_last_block(): # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks # 2. drop last matched block → 1 remaining block - assert len(computed_blocks) == 1 + assert len(computed_blocks.blocks) == 1 assert num_tokens == 1 * block_size # 16 tokens @@ -892,7 +893,7 @@ def test_eagle_with_partial_blocks(): req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks) == 1 + assert len(computed_blocks.blocks) == 1 assert num_tokens == 1 * block_size @@ -934,7 +935,7 @@ def test_eagle_with_sliding_window(): req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks) == 1 + assert len(computed_blocks.blocks) == 1 assert num_tokens == 1 * block_size # Evict the first block in the request @@ -948,5 +949,5 @@ def test_eagle_with_sliding_window(): # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. - assert len(computed_blocks) == 0 + assert len(computed_blocks.blocks) == 0 assert num_tokens == 0 diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py index dbd2e220451..8c87fc836b5 100644 --- a/tests/v1/tpu/test_multimodal.py +++ b/tests/v1/tpu/test_multimodal.py @@ -64,8 +64,6 @@ def whats_in_this_image_msg(b64): "576", # NOTE: max-num-batched-tokens>=mm_item_size "--disable_chunked_mm_input", - "--chat-template", - "examples/template_llava.jinja" ] # Server will pre-compile on first startup (takes a long time). diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9018bbfa960..d73860827c3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,40 +153,44 @@ def merge_attn_states(output: torch.Tensor, def rotary_embedding( positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor], head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: # TODO: Remove this contiguous call when the kernel is updated to support tensor slices query_contiguous = query.contiguous() - key_contiguous = key.contiguous() + key_contiguous = key.contiguous() if key is not None else None torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous, head_size, cos_sin_cache, is_neox) query.copy_(query_contiguous) - key.copy_(key_contiguous) + if key is not None: + key.copy_(key_contiguous) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, + key: Optional[torch.Tensor], head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: # TODO: Remove this contiguous call when the kernel is updated to support tensor slices query_contiguous = query.contiguous() - key_contiguous = key.contiguous() + key_contiguous = key.contiguous() if key is not None else None torch.ops._C.batched_rotary_embedding(positions, query_contiguous, key_contiguous, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) query.copy_(query_contiguous) - key.copy_(key_contiguous) + if key is not None: + key.copy_(key_contiguous) # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - torch.ops._C.rms_norm(out, input, weight, epsilon) + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, @@ -496,6 +500,24 @@ def _ggml_moe_a8_fake( device=W.device) +if hasattr(torch.ops._C, "ggml_moe_a8_vec"): + + @register_fake("_C::ggml_moe_a8_vec") + def _ggml_moe_a8_vec_fake( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=X.dtype, + device=W.device) + + # cutlass def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) @@ -1145,6 +1167,19 @@ def ggml_moe_a8( top_k, tokens) +def ggml_moe_a8_vec( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, + tokens) + + def ggml_moe_get_block_size(quant_type: int) -> int: return torch.ops._C.ggml_moe_get_block_size(quant_type) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 91d20a4e7bf..19642a939b4 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -123,7 +123,8 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.logits_soft_cap = logits_soft_cap if head_size % 128 != 0: - raise NotImplementedError("Head size must be a multiple of 128.") + raise NotImplementedError( + f"Head size must be a multiple of 128, found {head_size}.") if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") if sliding_window is not None: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7cea191eccb..139d407b214 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,8 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -226,6 +228,8 @@ def forward( if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -373,7 +377,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): attn_metadata = forward_context.attn_metadata if attn_metadata is None: return - + assert isinstance(attn_metadata, dict) connector.wait_for_layer_load(layer_name) @@ -390,8 +394,9 @@ def maybe_save_kv_layer_to_connector( attn_metadata = forward_context.attn_metadata if attn_metadata is None: return - - connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + assert isinstance(attn_metadata, dict) + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata[layer_name]) def unified_attention( @@ -404,6 +409,8 @@ def unified_attention( forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, query, key, value, kv_cache, @@ -441,6 +448,8 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py new file mode 100644 index 00000000000..8c0cf9267f3 --- /dev/null +++ b/vllm/attention/ops/triton_unified_attention.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import triton +import triton.language as tl + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, +): + + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + mid_val = tl.load(query_start_len_ptr + mid) // BLOCK_Q + mid + if mid_val <= q_block_global_idx: + left = mid + 1 + else: + right = mid + + seq_idx = left - 1 + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_Q * num_queries_per_kv], + float("-inf"), + dtype=tl.float32) + L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED], + dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE), + dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_Q * num_queries_per_kv,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_Q * num_queries_per_kv,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_Q * num_queries_per_kv, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = 16 + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + kernel_unified_attention_2d[( + total_num_q_blocks, + num_kv_heads, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + ) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 299c888c2e7..fab44fb6062 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -829,3 +829,91 @@ def sample(self, )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples diff --git a/vllm/config.py b/vllm/config.py index ef3270bf15b..968a94a859f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,7 +54,7 @@ from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) - from vllm.model_executor.model_loader.loader import BaseModelLoader + from vllm.model_executor.model_loader import BaseModelLoader ConfigType = type[DataclassInstance] else: @@ -2300,6 +2300,9 @@ class SpeculativeConfig: """Scaling factor for entropy-based threshold, applied when using `TypicalAcceptanceSampler`.""" + speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, init=True) # type: ignore @@ -2474,10 +2477,11 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") + from vllm.platforms import current_platform from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) if isinstance(self.draft_model_config.hf_config, - EAGLEConfig): + EAGLEConfig) or current_platform.is_neuron(): pass else: eagle_config = EAGLEConfig( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 99321966bcf..68f0e7a5387 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,6 +7,7 @@ import os import re import threading +import warnings from dataclasses import MISSING, dataclass, fields from itertools import permutations from typing import (Any, Callable, Dict, List, Literal, Optional, Type, @@ -415,7 +416,13 @@ def __post_init__(self): if isinstance(self.compilation_config, (int, dict)): self.compilation_config = CompilationConfig.from_cli( str(self.compilation_config)) - + if self.qlora_adapter_name_or_path is not None: + warnings.warn( + "The `qlora_adapter_name_or_path` is deprecated " + "and will be removed in v0.10.0. ", + DeprecationWarning, + stacklevel=2, + ) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() @@ -531,10 +538,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **load_kwargs["ignore_patterns"]) load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"]) - load_group.add_argument('--qlora-adapter-name-or-path', - type=str, - default=None, - help='Name or path of the QLoRA adapter.') + load_group.add_argument( + "--qlora-adapter-name-or-path", + type=str, + default=None, + help="The `--qlora-adapter-name-or-path` has no effect, do not set" + " it, and it will be removed in v0.10.0.", + deprecated=True, + ) load_group.add_argument('--pt-load-map-location', **load_kwargs["pt_load_map_location"]) @@ -558,9 +569,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: guided_decoding_group.add_argument( "--enable-reasoning", action=argparse.BooleanOptionalAction, + deprecated=True, help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " "of v0.8.6. Use `--reasoning-parser` to specify the reasoning " - "parser backend insteadThis flag (`--enable-reasoning`) will be " + "parser backend instead. This flag (`--enable-reasoning`) will be " "removed in v0.10.0. When `--reasoning-parser` is specified, " "reasoning mode is automatically enabled.") guided_decoding_group.add_argument( @@ -943,12 +955,6 @@ def create_model_config(self) -> ModelConfig: def create_load_config(self) -> LoadConfig: - if(self.qlora_adapter_name_or_path is not None) and \ - self.quantization != "bitsandbytes": - raise ValueError( - "QLoRA adapter only support " - f"'bitsandbytes' quantization, but got {self.quantization}") - if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" @@ -1149,11 +1155,6 @@ def create_engine_config( max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None - if self.qlora_adapter_name_or_path is not None and \ - self.qlora_adapter_name_or_path != "": - self.model_loader_extra_config[ - "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path - # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": self.quantization = self.load_format = "bitsandbytes" @@ -1391,11 +1392,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: and _warn_or_fallback("Engine in background thread")): return False - # PP is supported on V1 with Ray distributed executor, - # but off for MP distributed executor for now. if (self.pipeline_parallel_size > 1 - and self.distributed_executor_backend != "ray"): - name = "Pipeline Parallelism without Ray distributed executor" + and self.distributed_executor_backend not in ["ray", "mp"]): + name = "Pipeline Parallelism without Ray distributed executor " \ + "or multiprocessing executor" _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False @@ -1407,9 +1407,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if is_eagle_enabled and _warn_or_fallback("Eagle"): return False - # Non-CUDA is supported on V1, but off by default for now. - not_cuda = not current_platform.is_cuda() - if not_cuda and _warn_or_fallback( # noqa: SIM103 + # Non-[CUDA, TPU] may be supported on V1, but off by default for now. + v0_hardware = not any( + (current_platform.is_cuda(), current_platform.is_tpu())) + if v0_hardware and _warn_or_fallback( # noqa: SIM103 current_platform.device_name): return False ############################################################# diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d413cc7051d..d966e75f1fc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -393,10 +393,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.scheduler, self.seq_counter, get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - get_tokenizer_for_seq, - ), + stop_checker=StopChecker(self.scheduler_config.max_model_len, + get_tokenizer_for_seq), )) self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} @@ -2015,7 +2013,7 @@ def _validate_model_input( if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data - if prompt_inputs["type"] == "embeds": + elif prompt_inputs["type"] == "embeds": pass else: raise ValueError(f"The {prompt_type} prompt cannot be empty") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5..9ddc3d1f2c5 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist @@ -38,8 +38,13 @@ class DPMetadata: class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, map from layer_name of each + attention layer to its attention metadata + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8e34a35ebef..745e8098f36 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -751,13 +751,15 @@ def get_default_config( if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] + # num_stages=3 can cause triton.runtime.errors.OutOfResources + # on ROCm, set it to 2 instead. config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3, + "num_stages": 3 if not current_platform.is_rocm() else 2, } elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 0365afa10a4..8f28b64ed48 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -11,7 +11,9 @@ def fused_moe( w2: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, ) -> torch.Tensor: """ Args: @@ -20,6 +22,7 @@ def fused_moe( w2: [num_experts, hidden_size, intermediate_size] gating_output: [*, num_experts] """ + assert expert_map is None, "expert_map is not supported for pallas MoE." orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index b1c46190403..e5b88de2fcc 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -13,7 +13,6 @@ @dataclass class Mamba2Metadata: - has_prefill: bool has_initial_states: torch.Tensor prep_initial_states: bool @@ -24,21 +23,23 @@ class Mamba2Metadata: chunk_offsets: torch.Tensor -def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): +def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, + chunk_size: int, + total_seqlens: int): - # convert seq_idx to chunk indices and offsets - # - derive the cu_seqlens - _, cu_seqlens = torch.where(seq_idx.diff()) - cu_seqlens += 1 + cu_seqlens = query_start_loc[1:] # remove prepended 0 # outputs will have length expansion of chunks that do not divide # chunk_size - N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) - chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) - cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] p = 0 # num of insertions for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): @@ -60,48 +61,49 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): def prepare_mamba2_metadata( chunk_size: int, - input_ids: torch.Tensor, attn_metadata: AttentionMetadata, ) -> Mamba2Metadata: + # compute number of prefill and decode requests + # NOTE: in V0 we assume prefills are before decodes + num_prefills = attn_metadata.num_prefills + num_prefill_tokens = attn_metadata.num_prefill_tokens + + seq_idx = None + chunk_indices, chunk_offsets = None, None # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend has_initial_states = None prep_initial_states = False - if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata, - PlaceholderAttentionMetadata)) - and attn_metadata.context_lens_tensor is not None): - has_initial_states = attn_metadata.context_lens_tensor > 0 - # precompute flag to avoid device syncs later in mamba2 forwards - prep_initial_states = torch.any(has_initial_states).item() - - has_prefill = attn_metadata.num_prefills > 0 - seq_idx = None - chunk_indices, chunk_offsets = None, None - if has_prefill: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + if num_prefills > 0: + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = \ + attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] + # precompute flag to avoid device syncs in mamba2 layer forwards + # prep is only needed for mamba2 ssd prefill processing + prep_initial_states = torch.any(has_initial_states).item() + + query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] + seq_idx = torch.repeat_interleave(torch.arange( + num_prefills, dtype=torch.int32, device=query_start_loc.device), + query_start_loc.diff(), + output_size=num_prefill_tokens) seq_idx.unsqueeze_(0) - # compute metadata for chunked prefill. - # actually this is only needed if there are initial states, - # but this is determinable only from attention metadata yet - # unavailable from the top-level model forward. Rather than - # complicating things to extract said metadata, we simply just - # compute them once at the top level model forward and reuse - # them in mamba layers. If not needed, they will be ignored - # inside mamba kernels. - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) - - return Mamba2Metadata(has_prefill=has_prefill, - has_initial_states=has_initial_states, + # We compute metadata for chunked prefill once at the top level model + # forward and reuse them in mamba layers. If not needed, they will be + # ignored inside mamba kernels. + if prep_initial_states: + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + query_start_loc, chunk_size, num_prefill_tokens) + + return Mamba2Metadata(has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, chunk_size=chunk_size, seq_idx=seq_idx, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index d459c93a26b..05b9d87ac0a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -388,10 +388,15 @@ def forward_cuda( # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they - # are the same and reused for all mamba layers in the same iteration + # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - seq_len, _ = hidden_states.shape + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + groups_time_state_size = self.n_groups * self.ssm_state_size # 1. Gated MLP's linear projection @@ -406,44 +411,32 @@ def forward_cuda( dim=-1, ) - # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if mamba2_metadata.has_prefill: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" - hidden_states_B_C = causal_conv1d_fn( - hidden_states_B_C.transpose(0, 1), - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=mamba2_metadata.has_initial_states, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc).transpose( - 0, 1)[:seq_len] - - # TODO: Why is this needed? - hidden_states_B_C = hidden_states_B_C.contiguous() - else: - hidden_states_B_C = causal_conv1d_update( - hidden_states_B_C, - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) + # Separate prefill and decode by splitting varlen input + # Split along token dimension + hidden_states_B_C_p, hidden_states_B_C_d = torch.split( + hidden_states_B_C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + dt_p, dt_d = torch.split( + dt, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + mamba_cache_params.state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] + if has_prefill else None) # - get hidden_states, B and C after depthwise convolution. - hidden_states, B, C = torch.split( + split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( hidden_states_B_C, [ self.intermediate_size // self.tp_size, @@ -453,24 +446,48 @@ def forward_cuda( dim=-1, ) - # 3. State Space Model sequence transformation - if mamba2_metadata.has_prefill: + ssd_output_list = [] + + # Process prefill requests + if has_prefill: + # 2. Convolution sequence transformation + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C_p = causal_conv1d_fn( + hidden_states_B_C_p.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=mamba2_metadata.has_initial_states, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + # TODO: Why is this needed? + hidden_states_B_C_p = hidden_states_B_C_p.contiguous() + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( + hidden_states_B_C_p) + + # 3. State Space Model sequence transformation initial_states = None if (mamba2_metadata.has_initial_states is not None and mamba2_metadata.prep_initial_states): # making a copy of the states initial_states = torch.where( mamba2_metadata.has_initial_states[:, None, None, None], - mamba_cache_params.ssm_state[ - mamba_cache_params.state_indices_tensor], 0) + mamba_cache_params.ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( - hidden_states.view(1, seq_len, self.num_heads // self.tp_size, - self.head_dim), - dt.unsqueeze(0), + hidden_states_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, + self.head_dim), + dt_p.unsqueeze(0), self.A, - B.view(1, seq_len, self.n_groups // self.tp_size, -1), - C.view(1, seq_len, self.n_groups // self.tp_size, -1), + B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + -1), + C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + -1), chunk_size=mamba2_metadata.chunk_size, D=self.D, z=None, @@ -478,7 +495,7 @@ def forward_cuda( seq_idx=mamba2_metadata.seq_idx, chunk_indices=mamba2_metadata.chunk_indices, chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc, + cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], initial_states=initial_states, return_varlen_states=True, return_final_states=False, @@ -487,52 +504,65 @@ def forward_cuda( ) # update ssm states - # - varlen state is a (batch, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[ - mamba_cache_params.state_indices_tensor] = varlen_state + # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor + mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state # - reshape - hidden_states = scan_output.view(seq_len, -1) - else: + ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) + # Process decode requests + if has_decode: + # 2. Convolution sequence transformation + hidden_states_B_C_d = causal_conv1d_update( + hidden_states_B_C_d, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=state_indices_tensor_d) + + hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( + hidden_states_B_C_d) + + # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size - A = self.A[:, None, ...][:, :, None].expand( + A_d = self.A[:, None, ...][:, :, None].expand( -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) - dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) - D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B.view(-1, n_groups, B.shape[1] // n_groups) - C = C.view(-1, n_groups, C.shape[1] // n_groups) - hidden_states_reshaped = hidden_states.view( + D_d = self.D[:, None, ...].expand(-1, self.head_dim) + B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups) + C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups) + hidden_states_d = hidden_states_d.view( -1, self.num_heads // self.tp_size, self.head_dim) - # - the hidden is reshaped into number of current batches - # - in this case there is no more prefill, so the batches gen - # 1 token at a time - # - thus hidden will be (bs, num_heads, head_dim) + # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected - # using "mamba_cache_params.state_indices_tensor", just as - # above in the prefill case + # using state_indices_tensor_d - hidden_states = selective_state_update( + hidden_states_d = selective_state_update( mamba_cache_params.ssm_state, - hidden_states_reshaped, - dt, - A, - B, - C, - D, + hidden_states_d, + dt_d, + A_d, + B_d, + C_d, + D_d, z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, + state_batch_indices=state_indices_tensor_d, ) - hidden_states = hidden_states.view( - -1, (self.num_heads // self.tp_size) * self.head_dim) + ssd_output_list.append( + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(ssd_output_list) - # # 4. gated MLP + # 4. gated MLP hidden_states = self.norm(hidden_states, gate) - # # 5. Final linear projection + # 5. Final linear projection out, _ = self.out_proj(hidden_states) return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index e9efe642825..79a1663b85b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x, _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) - assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) assert A.shape == (nheads, ) assert C.shape == B.shape diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 05058dfaa73..c8815245494 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -145,7 +145,9 @@ def _fused_moe_gguf( moe_align_block_size) out_hidden_states = torch.empty_like(x) - if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES: + # unless we decent expert reuse we are better off running moe_vec kernel + if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES + and x.shape[0] > 64): num_tokens, _ = x.shape E, N, _ = w1.shape top_k = topk_ids.shape[1] @@ -163,6 +165,20 @@ def _fused_moe_gguf( out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( topk_weights.view(num_tokens, top_k, 1)) ops.moe_sum(out, out_hidden_states) + elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES: + num_tokens, _ = x.shape + E, N, _ = w1.shape + top_k = topk_ids.shape[1] + + out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, + num_tokens) + out = act(out) + + out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2, + w2.shape[1], num_tokens * top_k) + out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( + topk_weights.view(num_tokens, top_k, 1)) + ops.moe_sum(out, out_hidden_states) else: logger.warning_once("There is no support for fast MoE kernel " "for current quantization method. " diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 99c35c4273c..fdf4c039fd8 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -165,9 +165,9 @@ def forward_native( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets @@ -184,22 +184,24 @@ def forward_native( self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + # key may be None in some cases, e.g. cross-layer KV sharing + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm import _custom_ops as ops # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) @@ -225,32 +227,39 @@ def forward_xpu( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) + if key is None: + # XPU kernel doesn't support key=None so fall back to native impl + # TODO(sarckk): add support for optional key in + # ipex.llm.functional.rotary_embedding_batched + return self.forward_native(positions, query, key, offsets) else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + self.rotary_dim, offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def forward_hpu( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) @@ -277,12 +286,12 @@ def forward_hpu( sin = self.sin cos = self.cos query_shape = query.shape - key_shape = key.shape query = query.view(num_tokens, -1, self.head_size) - key = key.view(num_tokens, -1, self.head_size) - if self.head_size == self.rotary_dim: + if self.head_size == self.rotary_dim and key is not None: # Avoid unnecessary slicing and concatenation + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode) key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode) return query.reshape(query_shape), key.reshape(key_shape) @@ -292,20 +301,23 @@ def forward_hpu( query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, + rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_neuron( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def _apply_rotary_emb_neuron( x: torch.Tensor, @@ -345,14 +357,16 @@ def _apply_rotary_emb_neuron( query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) if self.rotary_dim == self.head_size: query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) query = query.reshape(query_shape) - key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) - key = key.reshape(key_shape) + if key is not None: + key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) + key = key.reshape(key_shape) else: head_size = query.shape[-1] query_reshaped = query.view(-1, head_size) @@ -365,14 +379,15 @@ def _apply_rotary_emb_neuron( query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - key_reshaped = key.view(-1, head_size) - key_pass = key_reshaped[:, self.rotary_dim:].view( - *key.shape[:-1], head_size - self.rotary_dim) - key_rot = key_reshaped[:, :self.rotary_dim].view( - *key.shape[:-1], self.rotary_dim) - key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + if key is not None: + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def extra_repr(self) -> str: @@ -718,9 +733,10 @@ def forward_native( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) @@ -872,10 +888,11 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" + assert key is not None query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] if self.rotary_dim < self.head_size: @@ -1002,8 +1019,9 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: def forward( self, query: torch.Tensor, - key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + key: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) query_ = torch.view_as_complex(query.float().reshape( *query.shape[:-1], -1, 2)) @@ -1047,8 +1065,8 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + key: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). Args: @@ -1059,6 +1077,7 @@ def forward( key: [num_tokens, num_kv_heads * head_size] """ assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 9048c70c7a7..92a0b0923b6 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -2,19 +2,67 @@ from torch import nn -from vllm.config import VllmConfig -from vllm.model_executor.model_loader.loader import (BaseModelLoader, - get_model_loader) +from vllm.config import LoadConfig, LoadFormat, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.bitsandbytes_loader import ( + BitsAndBytesModelLoader) +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader +from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader +from vllm.model_executor.model_loader.runai_streamer_loader import ( + RunaiModelStreamerLoader) +from vllm.model_executor.model_loader.sharded_state_loader import ( + ShardedStateLoader) +from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: + return ShardedStateLoader(load_config, runai_model_streamer=True) + + return DefaultModelLoader(load_config) + + def get_model(*, vllm_config: VllmConfig) -> nn.Module: loader = get_model_loader(vllm_config.load_config) return loader.load_model(vllm_config=vllm_config) __all__ = [ - "get_model", "get_model_loader", "BaseModelLoader", - "get_architecture_class_name", "get_model_architecture" + "get_model", + "get_model_loader", + "get_architecture_class_name", + "get_model_architecture", + "BaseModelLoader", + "BitsAndBytesModelLoader", + "GGUFModelLoader", + "DefaultModelLoader", + "DummyModelLoader", + "RunaiModelStreamerLoader", + "ShardedStateLoader", + "TensorizerLoader", ] diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py new file mode 100644 index 00000000000..f17cab05c25 --- /dev/null +++ b/vllm/model_executor/model_loader/base_loader.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod + +import torch.nn as nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + """Load a model with the given configurations.""" + raise NotImplementedError diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py new file mode 100644 index 00000000000..57189bfafc0 --- /dev/null +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: SIM117 +import copy +import fnmatch +import glob +import itertools +import math +import os +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import numpy as np +import torch +from huggingface_hub import HfApi +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +# yapf: enable +from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import (ParamMapping, + initialize_model, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] + # Save the module names that are sharded by column. + self.column_sharded_weights_modules: List[str] = [] + # Store all module names (from transformers) that support + # BNB quantization. + self.target_modules: List[str] = [] + # mapping weight names from transformers to vllm. + self.weight_mapper: Callable = lambda name: name + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None, + ) -> Tuple[str, List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return model_name_or_path, weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return hf_folder, glob.glob( + os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + use_safetensors = matched_pattern == "*.safetensors" + is_local = os.path.isdir(model_name_or_path) + index_file = SAFE_WEIGHTS_INDEX_NAME + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, use_safetensors + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + ) + for org_name, param in iterator: + # mapping weight names from transformers to vllm while preserving + # original names. + mapped_name = self.weight_mapper(org_name) + yield org_name, mapped_name, param + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + + if bitsandbytes.__version__ < "0.45.3": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.45.3.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.45.3 via " + "`pip install bitsandbytes>=0.45.3` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict: Dict[str, Any] = {} + + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) + for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if not mapped_weight_name.lower().endswith(".scb"): + continue + + weight_key = mapped_weight_name.lower().replace(".scb", ".weight") + quant_state_dict[weight_key] = weight_tensor + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_8bit_weight_name(mapped_weight_name): + continue + + if mapped_weight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in weight_iterator: + if not self._is_4bit_weight_name(mapped_weight_name): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in mapped_weight_name: + temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[mapped_weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, + device=current_platform.device_type) + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_4bit_weight_name(mapped_weight_name): + continue + + if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" + in temp_state_dict) or ( + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" + in temp_state_dict): + quant_state = _parse_quant_state(mapped_weight_name, + temp_state_dict) + quant_state_dict[mapped_weight_name] = quant_state + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if any(target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): + # Without sharding + if any( + mapped_weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any( + mapped_weight_name.startswith(module) + for module in self.column_sharded_weights_modules): + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + # Weights have fused on disk. In this case, we assume that the + # weight and module use same name. + elif any( + mapped_weight_name.startswith(module) + for module in self.maybe_fused_weights_modules): + # special case for fused weights + # get the size of each shard weight tensor + total_shard_sizes = next( + (sizes for module, sizes in + self.maybe_fused_weights_modules.items() + if mapped_weight_name.startswith(module))) + total_size = weight_tensor.size(0) + assert total_size == sum(total_shard_sizes) + # get the start/end index of each shard weight tensor + total_start_index = list( + itertools.accumulate([0] + total_shard_sizes))[:-1] + shard_weights_index = [( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) for idx, size in zip(total_start_index, + total_shard_sizes)] + # slice and reorder the weight tensor + weight_tensor = [ + weight_tensor[start_index:end_index, ...] + for start_index, end_index in shard_weights_index + ] + weight_sub_tensor = torch.cat(weight_tensor, dim=0) + # Shard by row + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4", + ) + + quant_state_dict[mapped_weight_name] = quant_state + else: + processed_weight = weight_tensor + yield org_weight_name, processed_weight + + def _get_bnb_target_modules(self, model: nn.Module) -> None: + + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + if modules_info := self.modules_mapping.get_sub_modules(name): + # Map vllm's names to transformers's names. + rep_name, sub_modules = modules_info + for sub_name in sub_modules: + self.target_modules.append( + name.replace(rep_name, sub_name)) + # Add original module name even if the module has stacked map, + # in case model has a mixture of disk-merged and disk-splitted + # weights with same last name. + self.target_modules.append(name) + + assert (self.target_modules + ), "vllm currently does not support BNB quantization for" + f" {type(model).__name__}" + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") + + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") + + self.modules_mapping = ParamMapping( + copy.deepcopy(model.packed_modules_mapping)) + + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: Dict[str, List[int]] = {} + self._get_bnb_target_modules(model) + for name, module in model.named_modules(): + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + if isinstance(module, (ReplicatedLinear, )): + self.unsharded_weights_modules.append(name) + # `QKVParallelLinear` and `MergedColumnParallelLinear` might have + # fused weights on disk. We need to use the output sizes of these + # modules to shard the weights correctly. + elif isinstance(module, + (QKVParallelLinear, MergedColumnParallelLinear)): + self.maybe_fused_weights_modules[name] = module.output_sizes + # In TP, these weights are partitioned along the column + # dimension (dim=-1) + elif isinstance(module, (RowParallelLinear, )): + self.column_sharded_weights_modules.append(name) + + self.model_type = type(model).__name__ + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get("quant_method") + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with tensor parallelism is not " + "supported. Please try with pipeline parallelism.") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get("load_in_8bit", False) + + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator(model_config.model, + model_config.revision, + pre_quant, load_8bit)) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + # TODO: Change this lazy import to normal import + # after the checks are updated to run on a new version + from vllm.model_executor.models.utils import is_pp_missing_parameter + + for quant_param_name in quant_state_dict: + if is_pp_missing_parameter(quant_param_name, model): + continue + + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, + index, + ) in self.modules_mapping.inverse_packed_mapping.items(): + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight + shard_pos = quant_param_name.find(shard_name) + can_correct_rename = (shard_pos + > 0) and (quant_param_name[shard_pos - 1] + == ".") + # If the quant_param_name is packed, it won't occur in the + # param_dict before renaming. + new_quant_param_name = quant_param_name.replace( + shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) \ + and (new_quant_param_name in param_dict) + if can_correct_rename and need_rename: + shard_index = index + quant_param_name = new_quant_param_name + break + + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. + if quant_param_name not in param_dict: + continue + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = (math.prod(quant_state.shape) // + pack_ratio) + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + # Make torch infer_schema happy + offsets = torch.tensor(offsets).cpu() + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + + model = initialize_model(vllm_config=vllm_config) + + self._load_weights(model_config, model) + + return model.eval() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py new file mode 100644 index 00000000000..c8bc4aecaec --- /dev/null +++ b/vllm/model_executor/model_loader/default_loader.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +import dataclasses +import glob +import os +import time +from typing import Generator, Iterable, List, Optional, Tuple, cast + +import huggingface_hub +import torch +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: Optional[list[str]] = None + """If defined, weights will load exclusively using these patterns.""" + + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, self.load_config.download_dir): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants. + HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt, + source.allow_patterns_overrides) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + elif use_safetensors: + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: + weights_iterator = fastsafetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + ) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + + elif current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore + + def _hpu_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + htcore.mark_step() + + weights_iterator = _hpu_weights_iterator(weights_iterator) + + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", + None), + ) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model)) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + process_weights_after_loading(model, model_config, target_device) + + return model.eval() diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py new file mode 100644 index 00000000000..5047a161f3f --- /dev/null +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import torch.nn as nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + initialize_dummy_weights) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + process_weights_after_loading(model, model_config, target_device) + return model.eval() diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py new file mode 100644 index 00000000000..ace1cd37128 --- /dev/null +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Dict, Generator, Tuple + +import gguf +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + gguf_to_hf_name_map = {} + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + if model_type in ("deepseek_v3", "deepseek_v2"): + model_type = "deepseek2" + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config( + config, trust_remote_code=model_config.trust_remote_code) + state_dict = dummy_model.state_dict() + + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, + gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map): + model_config.hf_config.update({"tie_word_embeddings": True}) + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map)) + + process_weights_after_loading(model, model_config, target_device) + return model diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py deleted file mode 100644 index ac984e8f2da..00000000000 --- a/vllm/model_executor/model_loader/loader.py +++ /dev/null @@ -1,1550 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# ruff: noqa: SIM117 -import collections -import copy -import dataclasses -import fnmatch -import glob -import inspect -import itertools -import math -import os -import time -import warnings -from abc import ABC, abstractmethod -from contextlib import contextmanager -from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, - Tuple, cast) - -import gguf -import huggingface_hub -import numpy as np -import torch -from huggingface_hub import HfApi -from torch import nn -from transformers import AutoModelForCausalLM -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME - -from vllm.attention import Attention -from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, - VllmConfig, set_current_vllm_config) -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.envs import VLLM_USE_MODELSCOPE -from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.layers.linear import (LinearBase, - MergedColumnParallelLinear, - QKVCrossParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) -from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, - serialize_vllm_model, tensorizer_weights_iterator) -from vllm.model_executor.model_loader.utils import (ParamMapping, - configure_quant_config, - get_model_architecture, - set_default_torch_dtype) -from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, - get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, - np_cache_weights_iterator, pt_weights_iterator, - runai_safetensors_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.transformers_utils.s3_utils import glob as s3_glob -from vllm.transformers_utils.utils import is_s3 -from vllm.utils import is_pin_memory_available - - -@contextmanager -def device_loading_context(module: torch.nn.Module, - target_device: torch.device): - if target_device.type == "cpu": - # If target is CPU, no need to move anything - yield module - return - - original_device_states: Dict[str, torch.device] = {} - - # Store original device states and move parameters to GPU if they're on CPU - for name, p in module.named_parameters(): - if p.device.type == "cpu" and target_device.type != 'hpu': - original_device_states[name] = p.device - p.data = p.data.to(target_device) - # Parameters already on target device are not touched - - try: - yield module - - finally: - # Restore parameters to their original devices, ignoring new parameters - pin_memory = is_pin_memory_available() - for name, p in module.named_parameters(): - if name in original_device_states: - original_device: torch.device = original_device_states[name] - if original_device.type == "cpu": - # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided( - size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_data.copy_(p.data) - p.data = cpu_data - else: - p.data = p.data.to(original_device) - # New parameters or parameters already on target device are untouched - - -logger = init_logger(__name__) - - -def _initialize_model( - vllm_config: VllmConfig, - *, - prefix: str = "", - model_class: Optional[type[nn.Module]] = None, -) -> nn.Module: - """Initialize a model with the given configurations.""" - model_config = vllm_config.model_config - if model_class is None: - model_class, _ = get_model_architecture(model_config) - - if vllm_config.quant_config is not None: - configure_quant_config(vllm_config.quant_config, model_class) - - signatures = inspect.signature(model_class.__init__) - all_params = [param.name for param in signatures.parameters.values()] - if "vllm_config" in all_params and "prefix" in all_params: - # new-style model class - with set_current_vllm_config(vllm_config, check_compile=True): - return model_class(vllm_config=vllm_config, prefix=prefix) - - msg = ("vLLM model class should accept `vllm_config` and `prefix` as " - "input arguments. Possibly you have an old-style model class" - " registered from out of tree and it is used for new vLLM version. " - "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " - "for the design and update the model class accordingly.") - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - logger.warning( - "Trying to guess the arguments for old-style model class %s", - model_class, - ) - # try to be compatible with old-style model class - kwargs = {} - if "prefix" in all_params: - kwargs["prefix"] = prefix - if "config" in all_params: - kwargs["config"] = model_config.hf_config - if "cache_config" in all_params: - kwargs["cache_config"] = vllm_config.cache_config - if "quant_config" in all_params: - kwargs["quant_config"] = vllm_config.quant_config - if "lora_config" in all_params: - kwargs["lora_config"] = vllm_config.lora_config - if "scheduler_config" in all_params: - kwargs["scheduler_config"] = vllm_config.scheduler_config - with set_current_vllm_config(vllm_config, check_compile=True): - return model_class(**kwargs) - - -def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, - target_device: torch.device) -> None: - for _, module in model.named_modules(): - if isinstance(module, QKVCrossParallelLinear): - # NOTE(Isotr0py): special case for cross QKV layer because - # q and kv proj aren't registered as submodules intentionally - module.process_weights_after_loading() - continue - quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - - # Currently only used by MLA. - # NOTE: This intentionally happens after other modules so we can easily - # decompress the weights for MLA. - for _, module in model.named_modules(): - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # TODO(lucas): see if there is a way to unify the signatures - # of process_weights_after_loading - module.process_weights_after_loading(model_config.dtype) - - -class BaseModelLoader(ABC): - """Base class for model loaders.""" - - def __init__(self, load_config: LoadConfig): - self.load_config = load_config - - @abstractmethod - def download_model(self, model_config: ModelConfig) -> None: - """Download a model so that it can be immediately loaded.""" - raise NotImplementedError - - @abstractmethod - def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: - """Load a model with the given configurations.""" - raise NotImplementedError - - -class DefaultModelLoader(BaseModelLoader): - """Model loader that can load different file types from disk.""" - - @dataclasses.dataclass - class Source: - """A source for weights.""" - - model_or_path: str - """The model ID or path.""" - - revision: Optional[str] - """The optional model revision.""" - - prefix: str = "" - """A prefix to prepend to all weights.""" - - fall_back_to_pt: bool = True - """Whether .pt weights can be used.""" - - allow_patterns_overrides: Optional[list[str]] = None - """If defined, weights will load exclusively using these patterns.""" - - counter_before_loading_weights: float = 0.0 - counter_after_loading_weights: float = 0.0 - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. - - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - if not os.path.exists(model): - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None - - def _prepare_weights( - self, - model_name_or_path: str, - revision: Optional[str], - fall_back_to_pt: bool, - allow_patterns_overrides: Optional[list[str]], - ) -> Tuple[str, List[str], bool]: - """Prepare weights for the model. - - If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( - model_name_or_path, revision) or model_name_or_path) - - is_local = os.path.isdir(model_name_or_path) - load_format = self.load_config.load_format - use_safetensors = False - index_file = SAFE_WEIGHTS_INDEX_NAME - # Some quantized models use .pt files for storing the weights. - if load_format == LoadFormat.AUTO: - allow_patterns = ["*.safetensors", "*.bin"] - elif (load_format == LoadFormat.SAFETENSORS - or load_format == LoadFormat.FASTSAFETENSORS): - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == LoadFormat.MISTRAL: - use_safetensors = True - allow_patterns = ["consolidated*.safetensors"] - index_file = "consolidated.safetensors.index.json" - elif load_format == LoadFormat.PT: - allow_patterns = ["*.pt"] - elif load_format == LoadFormat.NPCACHE: - allow_patterns = ["*.bin"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if allow_patterns_overrides is not None: - allow_patterns = allow_patterns_overrides - - if not is_local: - hf_folder = download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - allow_patterns, - revision, - ignore_patterns=self.load_config.ignore_patterns, - ) - else: - hf_folder = model_name_or_path - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if use_safetensors: - # For models like Mistral-7B-Instruct-v0.3 - # there are both sharded safetensors files and a consolidated - # safetensors file. Using both breaks. - # Here, we download the `model.safetensors.index.json` and filter - # any files not found in the index. - if not is_local: - download_safetensors_index_file_from_hf( - model_name_or_path, - index_file, - self.load_config.download_dir, - revision, - ) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) - else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors - - def _get_weights_iterator( - self, source: "Source" - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Get an iterator for the model weights based on the load format.""" - hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - source.model_or_path, source.revision, source.fall_back_to_pt, - source.allow_patterns_overrides) - if self.load_config.load_format == LoadFormat.NPCACHE: - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - weights_iterator = np_cache_weights_iterator( - source.model_or_path, - self.load_config.download_dir, - hf_folder, - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - elif use_safetensors: - if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: - weights_iterator = fastsafetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - else: - weights_iterator = safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - else: - weights_iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) - - if current_platform.is_tpu(): - # In PyTorch XLA, we should call `xm.mark_step` frequently so that - # not too many ops are accumulated in the XLA program. - import torch_xla.core.xla_model as xm - - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - xm.mark_step() - - weights_iterator = _xla_weights_iterator(weights_iterator) - - elif current_platform.is_hpu(): - import habana_frameworks.torch.core as htcore - - def _hpu_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - htcore.mark_step() - - weights_iterator = _hpu_weights_iterator(weights_iterator) - - if self.counter_before_loading_weights == 0.0: - self.counter_before_loading_weights = time.perf_counter() - # Apply the prefix. - return ((source.prefix + name, tensor) - for (name, tensor) in weights_iterator) - - def get_all_weights( - self, - model_config: ModelConfig, - model: nn.Module, - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - primary_weights = DefaultModelLoader.Source( - model_config.model, - model_config.revision, - prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", - True), - allow_patterns_overrides=getattr(model, "allow_patterns_overrides", - None), - ) - yield from self._get_weights_iterator(primary_weights) - - secondary_weights = cast( - Iterable[DefaultModelLoader.Source], - getattr(model, "secondary_weights", ()), - ) - for source in secondary_weights: - yield from self._get_weights_iterator(source) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, - model_config.revision, - fall_back_to_pt=True, - allow_patterns_overrides=None) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - load_config = vllm_config.load_config - model_config = vllm_config.model_config - - load_device = device_config.device if load_config.device is None else \ - load_config.device - target_device = torch.device(load_device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - - logger.info("Loading weights on %s...", load_device) - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) - self.counter_after_loading_weights = time.perf_counter() - logger.info( - "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - warning_msg = f"Following weights were not initialized \ - from checkpoint: {weights_not_loaded}" - - logger.warning(warning_msg) - - _process_weights_after_loading(model, model_config, target_device) - - return model.eval() - - -class DummyModelLoader(BaseModelLoader): - """Model loader that will set model weights to random values.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - - def download_model(self, model_config: ModelConfig) -> None: - pass # Nothing to download - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - - _process_weights_after_loading(model, model_config, target_device) - return model.eval() - - -class TensorizerLoader(BaseModelLoader): - """Model loader using CoreWeave's tensorizer library.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if isinstance(load_config.model_loader_extra_config, TensorizerConfig): - self.tensorizer_config = load_config.model_loader_extra_config - else: - self.tensorizer_config = TensorizerConfig( - **load_config.model_loader_extra_config) - - def _verify_config(self, model_config: ModelConfig, - parallel_config: ParallelConfig): - self.tensorizer_config.verify_with_model_config(model_config) - self.tensorizer_config.verify_with_parallel_config(parallel_config) - - def _get_weights_iterator( - self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - tensorizer_args = self.tensorizer_config._construct_tensorizer_args() - return tensorizer_weights_iterator(tensorizer_args) - - def _load_model_serialized_cpu( - self, - vllm_config: VllmConfig, - ) -> nn.Module: - """Load a serialized model with tensorizer to the CPU. - - This is only necessary when the model isn't vLLM-tensorized (see - examples/other/tensorize_vllm_model.py) This should still - be faster than default HuggingFace loading, but will be slower than - loading a vLLM-tensorized model. - """ - device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(vllm_config=vllm_config) - - model.load_weights(self._get_weights_iterator()) - return model.eval() - - def _load_model_serialized( - self, - vllm_config: VllmConfig, - ) -> nn.Module: - """Load a serialized model with tensorizer. - - Expects a vLLM-tensorized model. See the - examples/other/tensorize_vllm_model.py example script - for serializing vLLM models.""" - - device_config = vllm_config.device_config - model_config = vllm_config.model_config - - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model_class = get_model_architecture(model_config)[0] - - tensorizer_config = copy.copy(self.tensorizer_config) - tensorizer_config.model_class = model_class - tensorizer_config.hf_config = model_config.hf_config - tensorizer_config.dtype = model_config.dtype - - model = load_with_tensorizer(tensorizer_config, - vllm_config=vllm_config) - return model.eval() - - def download_model(self, model_config: ModelConfig) -> None: - self.tensorizer_config.verify_with_model_config(model_config) - - with self.tensorizer_config.open_stream(): - pass - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - self._verify_config(model_config, parallel_config) - - if parallel_config.tensor_parallel_size > 1: - from vllm.distributed import get_tensor_model_parallel_rank - - self.tensorizer_config.tensorizer_uri = ( - self.tensorizer_config.tensorizer_uri % - get_tensor_model_parallel_rank()) - - if is_vllm_tensorized(self.tensorizer_config): - return self._load_model_serialized(vllm_config=vllm_config) - return self._load_model_serialized_cpu(vllm_config=vllm_config) - - @staticmethod - def save_model( - model: torch.nn.Module, - tensorizer_config: TensorizerConfig, - ) -> None: - serialize_vllm_model( - model=model, - tensorizer_config=tensorizer_config, - ) - - -class ShardedStateLoader(BaseModelLoader): - """ - Model loader that directly loads each worker's model state dict, which - enables a fast load path for large tensor-parallel models where each worker - only needs to read its own shard rather than the entire checkpoint. See - `examples/offline_inference/save_sharded_state.py` for creating a sharded - checkpoint. - """ - - DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - - def __init__(self, - load_config: LoadConfig, - runai_model_streamer: bool = False): - super().__init__(load_config) - - self.runai_model_streamer = runai_model_streamer - extra_config = ({} if load_config.model_loader_extra_config is None - else load_config.model_loader_extra_config.copy()) - self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) - if extra_config: - raise ValueError(f"Unexpected extra config keys for load format " - f"{load_config.load_format}: " - f"{load_config.model_loader_extra_config.keys()}") - - @staticmethod - def _filter_subtensors( - tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - """ - Filter out all tensors that share the same memory or a subset of the - memory of another tensor. - """ - same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( - collections.defaultdict(list)) - for key, tensor in tensors.items(): - if tensor.numel(): - ptr = tensor.untyped_storage().data_ptr() - same_storage_groups[tensor.device, ptr].append((key, tensor)) - - def get_end_ptr(tensor: torch.Tensor) -> int: - return tensor.view(-1)[-1].data_ptr() + tensor.element_size() - - result: Dict[str, torch.Tensor] = {} - for group in same_storage_groups.values(): - for k, t in group: - a, b = t.data_ptr(), get_end_ptr(t) - for k2, t2 in group: - if not t2.is_contiguous(): - continue - a2, b2 = t2.data_ptr(), get_end_ptr(t2) - if a < a2 or b2 < b: - continue - if a2 < a or b < b2 or not t.is_contiguous(): - break # t2 covers strictly more memory than t. - if k2 < k: - # Same tensors, keep the one with the smaller key. - break - else: - result[k] = t - return result - - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]): - if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): - return model_name_or_path - else: - allow_patterns = ["*.safetensors"] - return download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - allow_patterns, - revision, - ignore_patterns=self.load_config.ignore_patterns, - ) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) - - from vllm.distributed import get_tensor_model_parallel_rank - - model_weights = model_config.model - if hasattr(model_config, "model_weights"): - model_weights = model_config.model_weights - local_model_path = model_weights - - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - _process_weights_after_loading(model, model_config, - target_device) - rank = get_tensor_model_parallel_rank() - pattern = os.path.join( - local_model_path, - self.pattern.format(rank=rank, part="*"), - ) - - filepaths = [] - if is_s3(local_model_path): - file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" - filepaths = s3_glob(path=local_model_path, - allow_pattern=[file_pattern]) - else: - filepaths = glob.glob(pattern) - if not filepaths: - # TODO: support un-sharded checkpoints too - raise ValueError( - f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!") - state_dict = self._filter_subtensors(model.state_dict()) - for key, tensor in self.iterate_over_files(filepaths): - # If loading with LoRA enabled, additional padding may - # be added to certain parameters. We only load into a - # narrowed view of the parameter data. - param_data = state_dict[key].data - param_shape = state_dict[key].shape - for dim, size in enumerate(tensor.shape): - if size < param_shape[dim]: - param_data = param_data.narrow(dim, 0, size) - if tensor.shape != param_shape: - logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, - key, - param_shape, - ) - param_data.copy_(tensor) - state_dict.pop(key) - if state_dict: - raise ValueError( - f"Missing keys {tuple(state_dict)} in loaded state!") - return model.eval() - - def iterate_over_files( - self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: - if self.runai_model_streamer: - yield from runai_safetensors_weights_iterator(paths, True) - else: - from safetensors.torch import safe_open - for path in paths: - with safe_open(path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - tensor = f.get_tensor(key) - yield key, tensor - - @staticmethod - def save_model( - model: torch.nn.Module, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from safetensors.torch import save_file - - from vllm.distributed import get_tensor_model_parallel_rank - - if pattern is None: - pattern = ShardedStateLoader.DEFAULT_PATTERN - rank = get_tensor_model_parallel_rank() - part_idx = 0 - total_size = 0 - state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) - state_dict_part: Dict[str, torch.Tensor] = {} - for key, tensor in state_dict.items(): - param_size = tensor.nelement() * tensor.element_size() - if max_size is not None and total_size + param_size > max_size: - filename = pattern.format(rank=rank, part=part_idx) - save_file( - state_dict_part, - os.path.join(path, filename), - ) - part_idx += 1 - total_size = 0 - state_dict_part = {} - state_dict_part[key] = tensor - total_size += param_size - if len(state_dict_part) > 0: - filename = pattern.format(rank=rank, part=part_idx) - save_file( - state_dict_part, - os.path.join(path, filename), - ) - - -class BitsAndBytesModelLoader(BaseModelLoader): - """Model loader to load model weights with BitAndBytes quantization.""" - - possible_config_file_names = ["adapter_config.json"] - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - - # Save the module names without sharding. - self.unsharded_weights_modules: List[str] = [] - # Save the module names that are sharded by column. - self.column_sharded_weights_modules: List[str] = [] - # Store all module names (from transformers) that support - # BNB quantization. - self.target_modules: List[str] = [] - # mapping weight names from transformers to vllm. - self.weight_mapper: Callable = lambda name: name - - def _get_weight_files( - self, - model_name_or_path: str, - allowed_patterns: List[str], - revision: Optional[str] = None, - ) -> Tuple[str, List[str], str]: - """Retrieve weight files. Download the files if necessary. - - Return the weight files and the file pattern.""" - is_local = os.path.isdir(model_name_or_path) - - if is_local: - for pattern in allowed_patterns: - weight_files = glob.glob( - os.path.join(model_name_or_path, pattern)) - if weight_files: - return model_name_or_path, weight_files, pattern - else: - hf_api = HfApi() - repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) - for pattern in allowed_patterns: - matching_files = fnmatch.filter(repo_files, pattern) - if matching_files: - hf_folder = download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - [pattern], - revision, - ignore_patterns=self.load_config.ignore_patterns, - ) - return hf_folder, glob.glob( - os.path.join(hf_folder, pattern)), pattern - - raise RuntimeError( - f"No model weights found in: `{model_name_or_path}`") - - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> Tuple[List[str], bool]: - """Prepare weight files for the model.""" - - allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] - - hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( - model_name_or_path, allowed_patterns, revision) - - use_safetensors = matched_pattern == "*.safetensors" - is_local = os.path.isdir(model_name_or_path) - index_file = SAFE_WEIGHTS_INDEX_NAME - if use_safetensors: - # For models like Mistral-7B-Instruct-v0.3 - # there are both sharded safetensors files and a consolidated - # safetensors file. Using both breaks. - # Here, we download the `model.safetensors.index.json` and filter - # any files not found in the index. - if not is_local: - download_safetensors_index_file_from_hf( - model_name_or_path, - index_file, - self.load_config.download_dir, - revision, - ) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) - else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_weights_files, use_safetensors - - def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - if use_safetensors: - iterator = safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - else: - iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) - for org_name, param in iterator: - # mapping weight names from transformers to vllm while preserving - # original names. - mapped_name = self.weight_mapper(org_name) - yield org_name, mapped_name, param - - def _get_quantized_weights_iterator( - self, - model_name_or_path: str, - revision: Optional[str], - pre_quant: bool, - load_8bit: bool, - ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, - Any]]: - """Get an iterator to the model weights with bitsandbytes quantization, - as well as the quantization state dictionary.""" - - # only load the bitsandbytes module when needed - try: - import bitsandbytes - - if bitsandbytes.__version__ < "0.45.3": - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.45.3.") - except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.45.3 via " - "`pip install bitsandbytes>=0.45.3` to use " - "bitsandbytes quantizer.") from err - - hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision) - - quant_state_dict: Dict[str, Any] = {} - - if pre_quant: - if load_8bit: - return self._quantized_8bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict - else: - return self._quantized_4bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict - - return self._unquantized_generator(hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict - - def _is_8bit_weight_name(self, weight_name: str): - quantized_suffix = {".scb", ".weight_format"} - return any(weight_name.lower().endswith(suffix) - for suffix in quantized_suffix) - - def _is_4bit_weight_name(self, weight_name: str): - quantized_suffix = { - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "bitsandbytes", - } - suffix = weight_name.split(".")[-1] - return any(q_suffix in suffix for q_suffix in quantized_suffix) - - def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if not mapped_weight_name.lower().endswith(".scb"): - continue - - weight_key = mapped_weight_name.lower().replace(".scb", ".weight") - quant_state_dict[weight_key] = weight_tensor - - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if self._is_8bit_weight_name(mapped_weight_name): - continue - - if mapped_weight_name in quant_state_dict: - set_weight_attrs(weight_tensor, {"load_in_8bit": True}) - yield org_weight_name, weight_tensor - else: - yield org_weight_name, weight_tensor - - def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: - from bitsandbytes.functional import QuantState - - # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) - temp_state_dict = {} - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in weight_iterator: - if not self._is_4bit_weight_name(mapped_weight_name): - continue - # bitsandbytes library requires - # weight.quant_state.bitsandbytes__* in CPU - if "quant_state.bitsandbytes" in mapped_weight_name: - temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data - else: - temp_state_dict[mapped_weight_name] = weight_tensor - - # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: Dict) -> QuantState: - quant_state = {} - for k in temp_state_dict: - if param_name + "." in k: - quant_state[k] = temp_state_dict[k] - - return QuantState.from_dict(quant_state, - device=current_platform.device_type) - - # Second iterate over all prequant and normal weights - # pre quantized weights would have a quant_state - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if self._is_4bit_weight_name(mapped_weight_name): - continue - - if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" - in temp_state_dict) or ( - f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" - in temp_state_dict): - quant_state = _parse_quant_state(mapped_weight_name, - temp_state_dict) - quant_state_dict[mapped_weight_name] = quant_state - yield org_weight_name, weight_tensor - else: - yield org_weight_name, weight_tensor - - def _unquantized_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: - from bitsandbytes.functional import quantize_4bit - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if any(target_module in mapped_weight_name - for target_module in self.target_modules - ) and mapped_weight_name.endswith(".weight"): - # Without sharding - if any( - mapped_weight_name.startswith(module) - for module in self.unsharded_weights_modules): - weight_sub_tensor = weight_tensor - # Shard by column - elif any( - mapped_weight_name.startswith(module) - for module in self.column_sharded_weights_modules): - total_size = weight_tensor.size(-1) - start_index = total_size // tp_size * tp_rank - end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[..., - start_index:end_index] - # Weights have fused on disk. In this case, we assume that the - # weight and module use same name. - elif any( - mapped_weight_name.startswith(module) - for module in self.maybe_fused_weights_modules): - # special case for fused weights - # get the size of each shard weight tensor - total_shard_sizes = next( - (sizes for module, sizes in - self.maybe_fused_weights_modules.items() - if mapped_weight_name.startswith(module))) - total_size = weight_tensor.size(0) - assert total_size == sum(total_shard_sizes) - # get the start/end index of each shard weight tensor - total_start_index = list( - itertools.accumulate([0] + total_shard_sizes))[:-1] - shard_weights_index = [( - idx + size // tp_size * tp_rank, - idx + size // tp_size * (tp_rank + 1), - ) for idx, size in zip(total_start_index, - total_shard_sizes)] - # slice and reorder the weight tensor - weight_tensor = [ - weight_tensor[start_index:end_index, ...] - for start_index, end_index in shard_weights_index - ] - weight_sub_tensor = torch.cat(weight_tensor, dim=0) - # Shard by row - else: - total_size = weight_tensor.size(0) - start_index = total_size // tp_size * tp_rank - end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[start_index:end_index, - ...] - - # bitsandbytes requires data in GPU - if weight_sub_tensor.is_cuda: - loaded_weight = weight_sub_tensor - else: - loaded_weight = weight_sub_tensor.cuda() - - # remove the following after the issue is fixed: - # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 - if loaded_weight.is_contiguous() is False: - loaded_weight = loaded_weight.contiguous() - - with set_default_torch_dtype(torch.float32): - processed_weight, quant_state = quantize_4bit( - loaded_weight, - compress_statistics=True, - quant_type="nf4", - ) - - quant_state_dict[mapped_weight_name] = quant_state - else: - processed_weight = weight_tensor - yield org_weight_name, processed_weight - - def _get_bnb_target_modules(self, model: nn.Module) -> None: - - for name, module in model.named_modules(): - if isinstance(module, (LinearBase, )): - if modules_info := self.modules_mapping.get_sub_modules(name): - # Map vllm's names to transformers's names. - rep_name, sub_modules = modules_info - for sub_name in sub_modules: - self.target_modules.append( - name.replace(rep_name, sub_name)) - # Add original module name even if the module has stacked map, - # in case model has a mixture of disk-merged and disk-splitted - # weights with same last name. - self.target_modules.append(name) - - assert (self.target_modules - ), "vllm currently does not support BNB quantization for" - f" {type(model).__name__}" - - def _load_weights(self, model_config: ModelConfig, - model: nn.Module) -> None: - if not hasattr(model, "load_weights"): - raise AttributeError( - "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") - - if not hasattr(model, "packed_modules_mapping"): - raise AttributeError( - f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") - - self.modules_mapping = ParamMapping( - copy.deepcopy(model.packed_modules_mapping)) - - # For some models like Molmo, we need to use hf_to_vllm_mapper - # to ensure correct loading of weights. - if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): - self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) - - # Modules whose weights might have fused on disk - # we need their output_sizes to make shard in flight correctly with TP - self.maybe_fused_weights_modules: Dict[str, List[int]] = {} - self._get_bnb_target_modules(model) - for name, module in model.named_modules(): - # Some modules like `ReplicatedLinear` should not have their weights - # sharded. The reason for implementing it this way is to avoid new - # static variable in the model implementation. - if isinstance(module, (ReplicatedLinear, )): - self.unsharded_weights_modules.append(name) - # `QKVParallelLinear` and `MergedColumnParallelLinear` might have - # fused weights on disk. We need to use the output sizes of these - # modules to shard the weights correctly. - elif isinstance(module, - (QKVParallelLinear, MergedColumnParallelLinear)): - self.maybe_fused_weights_modules[name] = module.output_sizes - # In TP, these weights are partitioned along the column - # dimension (dim=-1) - elif isinstance(module, (RowParallelLinear, )): - self.column_sharded_weights_modules.append(name) - - self.model_type = type(model).__name__ - - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") - - quant_config = getattr(model_config.hf_config, "quantization_config", - None) - - pre_quant = False - if quant_config is not None: - quant_method = quant_config.get("quant_method") - if quant_method == "bitsandbytes": - pre_quant = True - else: - raise ValueError( - f"BitsAndBytes loader does not support {quant_method} " - "quantization") - - # The quant_states in pre_quantized models cannot work with a split - # weight tensor. So TP does not work with pre_quantized bnb models. - if pre_quant and get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "Prequant BitsAndBytes models with tensor parallelism is not " - "supported. Please try with pipeline parallelism.") - - load_8bit = False - if pre_quant: - load_8bit = quant_config.get("load_in_8bit", False) - - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator(model_config.model, - model_config.revision, - pre_quant, load_8bit)) - - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights(qweight_iterator) - # Some models may have weights loading tracker unimplemented. - if loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") - - torch.cuda.empty_cache() - - param_dict = dict(model.named_parameters()) - stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} - # TODO: Change this lazy import to normal import - # after the checks are updated to run on a new version - from vllm.model_executor.models.utils import is_pp_missing_parameter - - for quant_param_name in quant_state_dict: - if is_pp_missing_parameter(quant_param_name, model): - continue - - non_stacked_param_name = quant_param_name - - shard_index = 0 - for shard_name, ( - weight_name, - index, - ) in self.modules_mapping.inverse_packed_mapping.items(): - # Some models, such as MiniCPM V2.5/2.6, contain both - # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' - # from being incorrectly identified as being present in - # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight - shard_pos = quant_param_name.find(shard_name) - can_correct_rename = (shard_pos - > 0) and (quant_param_name[shard_pos - 1] - == ".") - # If the quant_param_name is packed, it won't occur in the - # param_dict before renaming. - new_quant_param_name = quant_param_name.replace( - shard_name, weight_name) - need_rename = (quant_param_name not in param_dict) \ - and (new_quant_param_name in param_dict) - if can_correct_rename and need_rename: - shard_index = index - quant_param_name = new_quant_param_name - break - - # Models like Clip/Siglip may skip some layers in initialization, - # causing unused quant_param_name in state_dict. - if quant_param_name not in param_dict: - continue - - if quant_param_name not in stacked_quant_state_dict: - stacked_quant_state_dict[quant_param_name] = {} - - stacked_quant_state_dict[quant_param_name][shard_index] = ( - quant_state_dict[non_stacked_param_name]) - - # save quant_states and offsets as the attributes of the parameters - for param_name, param in param_dict.items(): - if param_name in stacked_quant_state_dict: - quant_states = stacked_quant_state_dict[param_name] - set_weight_attrs(param, {"bnb_quant_state": quant_states}) - - pack_ratio = getattr(param, "pack_factor", -1) - if pack_ratio == -1: - raise ValueError( - f"pack_factor not set for parameter {param_name}.") - - num_elements = [0] * len(quant_states) - for seq, quant_state in quant_states.items(): - num_elements[seq] = (math.prod(quant_state.shape) // - pack_ratio) - - offsets = np.concatenate(([0], np.cumsum(num_elements))) - # Make torch infer_schema happy - offsets = torch.tensor(offsets).cpu() - set_weight_attrs(param, {"bnb_shard_offsets": offsets}) - - if load_8bit: - set_weight_attrs( - param, {"matmul_state": [None] * len(quant_states)}) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(vllm_config=vllm_config) - - self._load_weights(model_config, model) - - return model.eval() - - -class GGUFModelLoader(BaseModelLoader): - """ - Model loader that can load GGUF files. This is useful for loading models - that are quantized with GGUF and saved in the GGUF format. This loader - supports loading both full models and sharded models. - """ - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - - def _prepare_weights(self, model_name_or_path: str): - if os.path.isfile(model_name_or_path): - return model_name_or_path - else: - raise ValueError(f"{model_name_or_path} is not a file.") - - def _get_gguf_weights_map(self, model_config: ModelConfig): - """ - GGUF uses this naming convention for their tensors from HF checkpoint: - `blk.N.BB.weight` and `blk.N.BB.bias` - where N signifies the block number of a layer, and BB signifies the - attention/mlp layer components. - See "Standardized tensor names" in - https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. - """ - config = model_config.hf_config - model_type = config.model_type - gguf_to_hf_name_map = {} - # hack: ggufs have a different name than transformers - if model_type == "cohere": - model_type = "command-r" - if model_type in ("deepseek_v3", "deepseek_v2"): - model_type = "deepseek2" - # GGUF layer map assumes that we will have a merged expert weights - # so we need to map them manually - for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ - f"model.layers.{idx}.mlp.gate.e_score_correction_bias" - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" - - arch = None - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - arch = key - break - if arch is None: - raise RuntimeError(f"Unknown gguf model_type: {model_type}") - num_layers = config.num_hidden_layers - name_map = gguf.get_tensor_name_map(arch, num_layers) - with torch.device("meta"): - dummy_model = AutoModelForCausalLM.from_config( - config, trust_remote_code=model_config.trust_remote_code) - state_dict = dummy_model.state_dict() - - for hf_name in state_dict: - name, suffix = hf_name.rsplit(".", 1) - gguf_name = name_map.get_name(name) - gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name - return gguf_to_hf_name_map - - def _get_weights_iterator( - self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(model_name_or_path, - gguf_to_hf_name_map) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - local_model_path = self._prepare_weights(model_config.model) - gguf_weights_map = self._get_gguf_weights_map(model_config) - # we can only know if tie word embeddings after mapping weights - if "lm_head.weight" in get_gguf_extra_tensor_names( - local_model_path, gguf_weights_map): - model_config.hf_config.update({"tie_word_embeddings": True}) - - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map)) - - _process_weights_after_loading(model, model_config, target_device) - return model - - -class RunaiModelStreamerLoader(BaseModelLoader): - """ - Model loader that can load safetensors - files from local FS or S3 bucket. - """ - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - extra_config = load_config.model_loader_extra_config - - if ("concurrency" in extra_config - and isinstance(extra_config.get("concurrency"), int)): - os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( - extra_config.get("concurrency")) - - if ("memory_limit" in extra_config - and isinstance(extra_config.get("memory_limit"), int)): - os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( - extra_config.get("memory_limit")) - - runai_streamer_s3_endpoint = os.getenv( - 'RUNAI_STREAMER_S3_ENDPOINT') - aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') - if (runai_streamer_s3_endpoint is None - and aws_endpoint_url is not None): - os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url - - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> List[str]: - """Prepare weights for the model. - - If the model is not local, it will be downloaded.""" - - is_s3_path = is_s3(model_name_or_path) - is_local = os.path.isdir(model_name_or_path) - safetensors_pattern = "*.safetensors" - index_file = SAFE_WEIGHTS_INDEX_NAME - - hf_folder = (model_name_or_path if - (is_local or is_s3_path) else download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - [safetensors_pattern], - revision, - ignore_patterns=self.load_config.ignore_patterns, - )) - if is_s3_path: - hf_weights_files = s3_glob(path=hf_folder, - allow_pattern=[safetensors_pattern]) - else: - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) - - if not is_local and not is_s3_path: - download_safetensors_index_file_from_hf( - model_name_or_path, index_file, self.load_config.download_dir, - revision) - - if not hf_weights_files: - raise RuntimeError( - f"Cannot find any safetensors model weights with " - f"`{model_name_or_path}`") - - return hf_weights_files - - def _get_weights_iterator( - self, model_or_path: str, - revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Get an iterator for the model weights based on the load format.""" - hf_weights_files = self._prepare_weights(model_or_path, revision) - return runai_safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - - def download_model(self, model_config: ModelConfig) -> None: - """Download model if necessary""" - self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - """Perform streaming of the model to destination""" - device_config = vllm_config.device_config - model_config = vllm_config.model_config - - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - - model_weights = model_config.model - if hasattr(model_config, "model_weights"): - model_weights = model_config.model_weights - model.load_weights( - self._get_weights_iterator(model_weights, - model_config.revision)) - - _process_weights_after_loading(model, model_config, target_device) - return model.eval() - - -def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: - """Get a model loader based on the load format.""" - if isinstance(load_config.load_format, type): - return load_config.load_format(load_config) - - if load_config.load_format == LoadFormat.DUMMY: - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.TENSORIZER: - return TensorizerLoader(load_config) - - if load_config.load_format == LoadFormat.SHARDED_STATE: - return ShardedStateLoader(load_config) - - if load_config.load_format == LoadFormat.BITSANDBYTES: - return BitsAndBytesModelLoader(load_config) - - if load_config.load_format == LoadFormat.GGUF: - return GGUFModelLoader(load_config) - - if load_config.load_format == LoadFormat.RUNAI_STREAMER: - return RunaiModelStreamerLoader(load_config) - - if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: - return ShardedStateLoader(load_config, runai_model_streamer=True) - - return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index a7b313f4e50..e4a48483764 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -"""Utilities for selecting and loading neuron models.""" +"""Utilities for selecting and loading Neuron models in transformers-neuronx +framework.""" +import ast import copy import importlib import os @@ -9,7 +11,8 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -113,6 +116,67 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() +class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" + + SPECULATION_TERMINATION_ID = -1 + + def __init__(self, speculation_model) -> None: + super().__init__() + self.model = speculation_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + tokens, counts = self.model.speculative_iteration( + input_ids, positions, input_block_ids) + + # Mark the end of accepted speculative tokens for each sequence with the + # speculation termination id. + batch_size, steps = tokens.shape + mask = torch.arange(steps).expand(batch_size, -1) >= counts + tokens[mask] = self.SPECULATION_TERMINATION_ID + + return tokens + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[List[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == self.SPECULATION_TERMINATION_ID + for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -138,6 +202,7 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]: def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig): + """Generate a neuron config based on vllm config args.""" from transformers_neuronx.config import ContinuousBatchingConfig from transformers_neuronx.constants import LAYOUT_BSH @@ -162,6 +227,27 @@ def _get_default_neuron_config(model_config: ModelConfig, return default_neuron_args +def _get_default_neuron_config_for_speculation( + model_config: ModelConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + """Generate a neuron config for speculative decoding based on + vllm config args.""" + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + + default_neuron_args = dict(collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + on_device_embedding=True, + continuous_batching=continuous_batching_config, + on_device_generation=copy.deepcopy( + model_config.neuron_sampling_params)) + return default_neuron_args + + def _get_neuron_on_device_generation_config(model_config: ModelConfig): if not _is_neuron_on_device_sampling_disabled(model_config): return copy.deepcopy(model_config.neuron_sampling_params) @@ -213,7 +299,7 @@ def _get_neuron_config_after_override(default_neuron_config, def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: - + """Initializes a neuron-optimized model for inference.""" # Create a model instance. model = NeuronCausalLM( model_config.hf_config, @@ -230,7 +316,6 @@ def get_neuron_model(model_config: ModelConfig, n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", [scheduler_config.max_model_len]) - # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, tp_degree=parallel_config.tensor_parallel_size, amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], @@ -240,3 +325,151 @@ def get_neuron_model(model_config: ModelConfig, batch_size=scheduler_config.max_num_seqs) return model.eval() + + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This method is only applicable for speculation with a standalone draft model + """ + from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder + + # For Eagle SD, we need to pass in additional parameters in neuron config. + is_eagle = getattr(speculation_config.draft_model_config.hf_config, + "is_eagle", False) + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + if is_eagle: + default_neuron_config_args['is_eagle_target'] = True + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) + if is_eagle: + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + num_speculative_tokens = speculation_config.num_speculative_tokens + # Create speculation model instance. + speculation_model = FusedSpeculativeDecoder(draft_model.model, + target_model.model, + num_speculative_tokens) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) + + +def get_neuron_eagle_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized EAGLE speculation model for inference.""" + from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + default_neuron_config_args['is_eagle_target'] = True + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + token_tree: Dict[int, List[int]] = ast.literal_eval( + speculation_config.speculative_token_tree) + + speculation_model = EagleSpeculativeDecoder(draft_model.model, + target_model.model, + token_tree=token_tree) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py new file mode 100644 index 00000000000..f879c99ac2e --- /dev/null +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading Neuron models in +neuronx-distributed-inference framework.""" +# Disabling yapf because yapf and isort have conflicts for the below imports +# yapf: disable +import copy +import hashlib +import importlib +import multiprocessing +import os +import shutil +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from neuronx_distributed_inference.models.config import ( + FusedSpecNeuronConfig, OnDeviceSamplingConfig) +from neuronx_distributed_inference.models.mllama.utils import ( + create_vision_mask) +from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config) +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig + +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +# yapf: enable +logger = init_logger(__name__) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "float32", + "half": "float16", + "float16": "float16", + "bfloat16": "bfloat16", + "float": "float32", + "float32": "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", +} + +# Models supported by Neuronx distributed for inference. +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { + "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "DbrxForCausalLM": + ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", + "NeuronDbrxForCausalLM"), + "MixtralForCausalLM": + ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", + "NeuronMixtralForCausalLM"), + "MllamaForConditionalGeneration": + ("neuronx_distributed_inference.models.mllama.modeling_mllama", + "NeuronMllamaForCausalLM"), +} + + +class NeuronCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=input_block_ids, + sampling_params=sampling_params) + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + return output.hidden_states + else: + return output.logits[:, -1, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + batch_size = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.flatten() + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + step_output_token_ids = [] + for i, seq_id in enumerate(seq_ids): + token_id = accepted_token_ids_by_step[i] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + return SamplerOutput(outputs=step_output_token_ids) + else: + return self.sampler(logits, sampling_metadata) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + self.config.neuron_config = neuron_config + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + + +class NeuronMllamaForCausalLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor( + config.get_text_config().vocab_size, logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + seq_ids: torch.Tensor, pixel_values: torch.Tensor, + aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, + has_image: torch.Tensor, sampling_params) -> torch.Tensor: + self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) + output = self.model( + input_ids.to(torch.int32), + attention_mask=None, + position_ids=positions.to(torch.int32), + seq_ids=seq_ids.flatten().to(torch.int32), + pixel_values=pixel_values.to( + self.config.vision_config.torch_dtype), + aspect_ratios=aspect_ratios.to(torch.int32), + vision_mask=self.vision_mask.to(torch.int32), + sampling_params=sampling_params, + num_chunks=num_chunks.to(torch.int32), + has_image=has_image.to(torch.int32), + ) + if self.config.neuron_config.on_device_sampling_config: + return output.hidden_states + return output.logits[:, -1, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample(self, hidden_states, sampling_metadata): + if not self.on_device_sampling_disabled: + with torch.profiler.record_function("sample"): + hidden_states = hidden_states.flatten() + res = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + samples = [] + for seq_id in seq_ids: + token_id = hidden_states[sample_idx].item() + samples.append( + SequenceOutput( + parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + res.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + next_tokens = SamplerOutput(outputs=res) + else: + next_tokens = self.sampler(None, hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + self.config.neuron_config = neuron_config + logger.info("neuron_config buckets: %s", + self.config.neuron_config.buckets) + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + try: + self.model = neuronx_model_cls(compiled_model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.vision_token_id = tokenizer( + "<|image|>", add_special_tokens=False).input_ids + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError): + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + + logger.info("\nCompiling and saving model to %s", model_name_or_path) + + p = multiprocessing.Process(target=compile_model, + args=(self, compiled_model_path)) + p.start() + p.join() + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(compiled_model_path) + logger.info("Successfully compiled and saved the model in %s", + compiled_model_path) + + # Read "<|image|>" token_id from the tokenizer + self.vision_token_id = tokenizer("<|image|>", + add_special_tokens=False).input_ids + logger.info("\nLoading model from compiled checkpoint...") + self.model.load(compiled_model_path) + + +def compile_model(neuron_model, traced_model_path): + neuron_model.model.compile(traced_model_path) + + +class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=input_block_ids, + sampling_params=sampling_params) + # CTX encoding + if (positions[:, 0]).sum().item() == 0: + return output.fused_outputs[0][:, 0:1] + + # Fused Spec (Generation) + accepted_tokens_with_padding = output.fused_outputs[0] + next_pos_ids = output.fused_outputs[-1] + generated_token_counts = next_pos_ids - positions + + assert torch.any(generated_token_counts == 0).item() is False, \ + "NxDI model generated no output for one or more sequences." + + batch_size, steps = accepted_tokens_with_padding.shape + mask = torch.arange(steps).expand(batch_size, + -1) >= generated_token_counts + accepted_tokens_with_padding[mask] = -1 + + return accepted_tokens_with_padding + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[List[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def load_weights(self, model_name_or_path: str, + draft_model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + + draft_neuron_config = copy.deepcopy(config.neuron_config) + if not config.neuron_config.enable_eagle_speculation: + draft_neuron_config.speculation_length = 0 + draft_neuron_config.trace_tokengen_model = True + draft_neuron_config.enable_fused_speculation = False + if config.neuron_config.enable_eagle_speculation: + draft_neuron_config.is_eagle_draft = True + draft_neuron_config.sequence_parallel_enabled = False + draft_config = neuronx_model_cls.get_config_cls()( + draft_neuron_config, + load_config=load_pretrained_config(draft_model_name_or_path)) + fused_spec_config = (FusedSpecNeuronConfig( + neuronx_model_cls._model_cls, + draft_config=draft_config, + draft_model_path=draft_model_name_or_path)) + config.fused_spec_config = fused_spec_config + self.config.neuron_config = neuron_config + + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + if not os.path.exists(draft_model_name_or_path): + if draft_model_name_or_path != model_name_or_path: + hf_model = AutoModelForCausalLM.from_pretrained( + draft_model_name_or_path) + saved_path = os.path.join("local-models", + draft_model_name_or_path) + hf_model.save_pretrained(saved_path) + draft_model_name_or_path = saved_path + else: + draft_model_name_or_path = model_name_or_path + config.fused_spec_config.draft_model_path = draft_model_name_or_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + """Generate a neuron config based on vllm config args.""" + on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, + deterministic=False) + batch_size = scheduler_config.max_num_seqs + + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, + batch_size=batch_size, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + enable_bucketing=True, + is_continuous_batching=(batch_size > 1), + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + padding_side="right", + on_device_sampling_config=on_device_sampling_config, + sequence_parallel_enabled=True, + ) + return neuron_config + + +def _get_default_speculation_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Generate a neuron config for speculative decoding based on vllm config + args.""" + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + batch_size=scheduler_config.max_num_seqs, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + speculation_length=speculation_config.num_speculative_tokens, + trace_tokengen_model=False, + enable_fused_speculation=True, + enable_bucketing=True, + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + on_device_sampling_config=dict( + top_k=1, + do_sample=False, + )) + return neuron_config + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + """Update default neuron config values with override args""" + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return default_neuron_config + + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + """Initializes a neuron-optimized model for inference.""" + model_arch = _get_model_architecture(model_config.hf_config) + if model_arch == "MllamaForConditionalGeneration": + model = NeuronMllamaForCausalLM(model_config.hf_config) + else: + model = NeuronCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config + model.load_weights(model_config.model, + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) + return model.eval() + + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This model handles speculation using both a draft model and an EAGLE draft. + """ + model = NeuronSpeculationCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_speculation_config( + model_config, parallel_config, scheduler_config, speculation_config) + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config + model.load_weights(model_config.model, + speculation_config.draft_model_config.model, + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) + return model.eval() diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py new file mode 100644 index 00000000000..1fbb5ca5664 --- /dev/null +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: SIM117 +import glob +import os +from typing import Generator, List, Optional, Tuple + +import torch +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + runai_safetensors_weights_iterator) +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 + + +class RunaiModelStreamerLoader(BaseModelLoader): + """ + Model loader that can load safetensors + files from local FS or S3 bucket. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv( + 'RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if (runai_streamer_s3_endpoint is None + and aws_endpoint_url is not None): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> List[str]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + + is_s3_path = is_s3(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = (model_name_or_path if + (is_local or is_s3_path) else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + )) + if is_s3_path: + hf_weights_files = s3_glob(path=hf_folder, + allow_pattern=[safetensors_pattern]) + else: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local and not is_s3_path: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, self.load_config.download_dir, + revision) + + if not hf_weights_files: + raise RuntimeError( + f"Cannot find any safetensors model weights with " + f"`{model_name_or_path}`") + + return hf_weights_files + + def _get_weights_iterator( + self, model_or_path: str, + revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_weights_files = self._prepare_weights(model_or_path, revision) + return runai_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + + def download_model(self, model_config: ModelConfig) -> None: + """Download model if necessary""" + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + """Perform streaming of the model to destination""" + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + model.load_weights( + self._get_weights_iterator(model_weights, + model_config.revision)) + + process_weights_after_loading(model, model_config, target_device) + return model.eval() diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py new file mode 100644 index 00000000000..152a3d69972 --- /dev/null +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 + +import collections +import glob +import os +from typing import Any, Dict, Generator, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, runai_safetensors_weights_iterator) +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 + +logger = init_logger(__name__) + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/offline_inference/save_sharded_state.py` for creating a sharded + checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, + load_config: LoadConfig, + runai_model_streamer: bool = False): + super().__init__(load_config) + + self.runai_model_streamer = runai_model_streamer + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list)) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + + from vllm.distributed import get_tensor_model_parallel_rank + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + local_model_path = model_weights + + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + process_weights_after_loading(model, model_config, + target_device) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + def iterate_over_files( + self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: + if self.runai_model_streamer: + yield from runai_safetensors_weights_iterator(paths, True) + else: + from safetensors.torch import safe_open + for path in paths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + yield key, tensor + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py new file mode 100644 index 00000000000..7cf3940ab64 --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: SIM117 +import copy +from typing import Generator, Tuple + +import torch +from torch import nn + +from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, + serialize_vllm_model, tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + initialize_model, + set_default_torch_dtype) + +logger = init_logger(__name__) + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_serialized_cpu( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer to the CPU. + + This is only necessary when the model isn't vLLM-tensorized (see + examples/other/tensorize_vllm_model.py) This should still + be faster than default HuggingFace loading, but will be slower than + loading a vLLM-tensorized model. + """ + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = initialize_model(vllm_config=vllm_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/other/tensorize_vllm_model.py example script + for serializing vLLM models.""" + + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, + vllm_config=vllm_config) + return model.eval() + + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self._verify_config(model_config, parallel_config) + + if parallel_config.tensor_parallel_size > 1: + from vllm.distributed import get_tensor_model_parallel_rank + + self.tensorizer_config.tensorizer_uri = ( + self.tensorizer_config.tensorizer_uri % + get_tensor_model_parallel_rank()) + + if is_vllm_tensorized(self.tensorizer_config): + return self._load_model_serialized(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config) + + @staticmethod + def save_model( + model: torch.nn.Module, + tensorizer_config: TensorizerConfig, + ) -> None: + serialize_vllm_model( + model=model, + tensorizer_config=tensorizer_config, + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index d4a89242347..2b636954a42 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """Utilities for selecting and loading models.""" import contextlib +import inspect +import warnings +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Type @@ -9,14 +12,18 @@ from torch import nn from transformers.dynamic_module_utils import get_class_from_dynamic_module -from vllm.config import ModelConfig, ModelImpl +from vllm.attention import Attention +from vllm.config import (ModelConfig, ModelImpl, VllmConfig, + set_current_vllm_config) from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -30,6 +37,128 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) +def initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", + model_class: Optional[type[nn.Module]] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_config = vllm_config.model_config + if model_class is None: + model_class, _ = get_model_architecture(model_config) + + if vllm_config.quant_config is not None: + configure_quant_config(vllm_config.quant_config, model_class) + + signatures = inspect.signature(model_class.__init__) + all_params = [param.name for param in signatures.parameters.values()] + if "vllm_config" in all_params and "prefix" in all_params: + # new-style model class + with set_current_vllm_config(vllm_config, check_compile=True): + return model_class(vllm_config=vllm_config, prefix=prefix) + + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " + "for the design and update the model class accordingly.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + logger.warning( + "Trying to guess the arguments for old-style model class %s", + model_class, + ) + # try to be compatible with old-style model class + kwargs = {} + if "prefix" in all_params: + kwargs["prefix"] = prefix + if "config" in all_params: + kwargs["config"] = model_config.hf_config + if "cache_config" in all_params: + kwargs["cache_config"] = vllm_config.cache_config + if "quant_config" in all_params: + kwargs["quant_config"] = vllm_config.quant_config + if "lora_config" in all_params: + kwargs["lora_config"] = vllm_config.lora_config + if "scheduler_config" in all_params: + kwargs["scheduler_config"] = vllm_config.scheduler_config + with set_current_vllm_config(vllm_config, check_compile=True): + return model_class(**kwargs) + + +def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + for _, module in model.named_modules(): + if isinstance(module, QKVCrossParallelLinear): + # NOTE(Isotr0py): special case for cross QKV layer because + # q and kv proj aren't registered as submodules intentionally + module.process_weights_after_loading() + continue + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): for i, arch in enumerate(architectures): diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f775c1edce1..c7e7cd9c07d 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -162,23 +162,15 @@ def get_quant_config(model_config: ModelConfig, None) if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) - # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + # Inflight BNB quantization if model_config.quantization == "bitsandbytes": - if (not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" - not in load_config.model_loader_extra_config): - return quant_cls.from_config({"adapter_name_or_path": ""}) - model_name_or_path = load_config.model_loader_extra_config[ - "qlora_adapter_name_or_path"] - - else: - model_name_or_path = model_config.model - is_local = os.path.isdir(model_name_or_path) + return quant_cls.from_config({}) + is_local = os.path.isdir(model_config.model) if not is_local: # Download the config files. - with get_lock(model_name_or_path, load_config.download_dir): + with get_lock(model_config.model, load_config.download_dir): hf_folder = snapshot_download( - model_name_or_path, + model_config.model, revision=model_config.revision, allow_patterns="*.json", cache_dir=load_config.download_dir, @@ -186,7 +178,7 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm, ) else: - hf_folder = model_name_or_path + hf_folder = model_config.model possible_config_filenames = quant_cls.get_config_filenames() @@ -213,7 +205,7 @@ def get_quant_config(model_config: ModelConfig, config = json.load(f) if model_config.quantization == "bitsandbytes": - config["adapter_name_or_path"] = model_name_or_path + config["adapter_name_or_path"] = model_config.model elif model_config.quantization == "modelopt": if config["producer"]["name"] == "modelopt": return quant_cls.from_config(config) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 16dac6123d6..87e1e102efd 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -313,7 +313,6 @@ def forward( mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.mamba_chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index dea9a0da312..706e648f1b4 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -338,7 +338,6 @@ def forward( attn_metadata = get_forward_context().attn_metadata mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.mamba_chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 0499f339b24..fdcef8b9be8 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -190,8 +190,8 @@ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q = self.q_norm(q) + k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) @@ -264,10 +264,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.qk_normalization: B_, N_, H_, D_ = q.shape - q = self.q_norm.forward_native(q.flatten(-2, - -1)).view(B_, N_, H_, D_) - k = self.k_norm.forward_native(k.flatten(-2, - -1)).view(B_, N_, H_, D_) + q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) + k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 78303733f6b..72daf34c441 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -142,7 +142,6 @@ def forward( mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 56a7f02c415..741b9837398 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -37,7 +37,7 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -670,7 +670,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")) - self.language_model = _initialize_model( + self.language_model = initialize_model( vllm_config=vllm_config.with_hf_config(config.text_config, ["LlamaForCausalLM"]), prefix=maybe_prefix(prefix, "language_model"), diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 75eebdacfdc..42bbb77a22c 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -438,8 +438,8 @@ def _apply_qk_norm(self, q: torch.Tensor, if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q = self.q_norm(q) + k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 44beae5726d..422b53d86f1 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -139,8 +139,8 @@ def _apply_qk_norm(self, q: torch.Tensor, if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q = self.q_norm(q) + k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6eda89ac193..e545bdac121 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -298,13 +298,8 @@ def forward( q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN - q = apply_rotary_pos_emb_vision(q, - rotary_pos_emb, - use_flash_attn=use_flash_attn) - k = apply_rotary_pos_emb_vision(k, - rotary_pos_emb, - use_flash_attn=use_flash_attn) + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 95f0c29d485..a00b756ecec 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -64,7 +64,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import ( @@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor, - use_flash_attn=False) -> torch.Tensor: + freqs: torch.Tensor) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() apply_rotary_emb = apply_rotary_emb_torch - if use_flash_attn: - from flash_attn.layers.rotary import apply_rotary_emb + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb output = apply_rotary_emb(t_, cos, sin).type_as(t) return output diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 73d2838f461..40e0ccc1bab 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -133,11 +133,11 @@ def forward( # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_by_head = self.q_norm.forward_native(q_by_head) + q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_by_head = self.k_norm.forward_native(k_by_head) + k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 97acbaa2ac3..fe6b303ba0b 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -225,12 +225,12 @@ def forward( # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_by_head = self.q_norm.forward_native(q_by_head) + q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_by_head = self.k_norm.forward_native(k_by_head) + k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index bfa48099b74..0bc5d218f8d 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -17,7 +17,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 9ae80055a52..975c6505c31 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -273,7 +273,7 @@ def init_vllm_registered_model( Helper function to initialize an inner model registered to vLLM, based on the arguments passed to the outer vLLM model. """ - from vllm.model_executor.model_loader.loader import _initialize_model + from vllm.model_executor.model_loader.utils import initialize_model if hf_config is None and architectures is not None: # So that the architectures field is overridden @@ -283,7 +283,7 @@ def init_vllm_registered_model( vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures) - return _initialize_model(vllm_config=vllm_config, prefix=prefix) + return initialize_model(vllm_config=vllm_config, prefix=prefix) @overload diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index d34033e3ac9..eddccbba5a2 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -751,7 +751,6 @@ def forward( mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index ba77546992e..1b98b89b9d3 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -185,17 +185,26 @@ def cpu_platform_plugin() -> Optional[str]: def neuron_platform_plugin() -> Optional[str]: - is_neuron = False + tnx_installed = False + nxd_installed = False logger.debug("Checking if Neuron platform is available.") try: import transformers_neuronx # noqa: F401 - is_neuron = True + tnx_installed = True logger.debug("Confirmed Neuron platform is available because" " transformers_neuronx is found.") - except ImportError as e: - logger.debug("Neuron platform is not available because: %s", str(e)) + except ImportError: pass + try: + import neuronx_distributed_inference # noqa: F401 + nxd_installed = True + logger.debug("Confirmed Neuron platform is available because" + " neuronx_distributed_inference is found.") + except ImportError: + pass + + is_neuron = tnx_installed or nxd_installed return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index e37a3a578cf..71f7c718cdf 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 - +import enum +import os +from functools import lru_cache from typing import TYPE_CHECKING, Optional from vllm import envs @@ -15,6 +17,11 @@ logger = init_logger(__name__) +class NeuronFramework(enum.Enum): + TRANSFORMERS_NEURONX = "transformers-neuronx" + NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" + + class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON device_name: str = "neuron" @@ -43,8 +50,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert (vllm_config.lora_config is None), "LoRA is not supported for Neuron backend." - assert (not vllm_config.speculative_config - ), "Speculative decoding not yet supported for Neuron backend." cache_config = vllm_config.cache_config if cache_config: @@ -67,3 +72,71 @@ def get_device_communicator_cls(cls) -> str: @classmethod def use_all_gather(cls) -> bool: return True + + @classmethod + @lru_cache + def is_neuronx_distributed_inference(cls) -> bool: + try: + import neuronx_distributed_inference + except ImportError: + neuronx_distributed_inference = None + return neuronx_distributed_inference is not None + + @classmethod + @lru_cache + def is_transformers_neuronx(cls) -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + def get_neuron_framework_to_use(self): + """Return the specified framework if corresponding installations are + available. + + If no framework is specified, use neuronx-distributed-inference by + default. + If that's unavailable, check and switch to transformers-neuronx. + """ + if not self.is_neuron(): + raise AssertionError( + f"Neuron Framework unavailable for platform: {self}") + + tnx_installed = self.is_transformers_neuronx() + nxd_installed = self.is_neuronx_distributed_inference() + + specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") + tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value + if specified_framework == tnx_framework and tnx_installed: + return self.TRANSFORMERS_NEURONX + + if ((specified_framework == nxd_framework and nxd_installed) + or (specified_framework is None and nxd_installed)): + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + + if specified_framework is None and tnx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + + return None + + def use_neuronx_distributed(self): + """ + Return True if the framework determined in get_neuron_framework_to_use() + is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This + is used to select the Neuron model framework and framework-specific + configuration to apply during model compilation. + """ + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + return self.get_neuron_framework_to_use() == nxd_framework + + def use_transformers_neuronx(self): + """ + Return True if the framework determined in get_neuron_framework_to_use() + is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used + to select the Neuron model framework and framework-specific + configuration to apply during model compilation. + """ + return self.get_neuron_framework_to_use( + ) == NeuronFramework.TRANSFORMERS_NEURONX diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 52deaf12248..8c968e7df3e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -76,9 +76,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel cache_config = vllm_config.cache_config + # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_ONCE compilation level @@ -101,16 +101,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_USE_V1: from vllm.v1.attention.backends.pallas import ( PallasAttentionBackend) + cache_config.block_size = PallasAttentionBackend.get_page_size( + vllm_config) min_page_size = PallasAttentionBackend.get_min_page_size( vllm_config) - if min_page_size > vllm_config.cache_config.block_size: + if min_page_size > cache_config.block_size: logger.warning( "Increase the page size from %s to %s to make sure there's" "no SMEM OOM", - vllm_config.cache_config.block_size, + cache_config.block_size, min_page_size, ) - vllm_config.cache_config.block_size = min_page_size + cache_config.block_size = min_page_size parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config diff --git a/vllm/utils.py b/vllm/utils.py index 652d257fcb6..8c7873b2399 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -34,7 +34,7 @@ import warnings import weakref from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, - ArgumentTypeError) + ArgumentTypeError, _ArgumentGroup) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, @@ -42,6 +42,7 @@ from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps +from gettext import gettext as _gettext from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, @@ -71,6 +72,8 @@ from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: + from argparse import Namespace + from vllm.config import ModelConfig, VllmConfig logger = init_logger(__name__) @@ -723,6 +726,13 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) +def next_power_of_2(n) -> int: + """The next power of 2 (inclusive)""" + if n < 1: + return 1 + return 1 << (n - 1).bit_length() + + def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y @@ -1424,16 +1434,78 @@ def add_arguments(self, actions): super().add_arguments(actions) +class _FlexibleArgumentGroup(_ArgumentGroup): + + def __init__(self, parser: FlexibleArgumentParser, *args, **kwargs): + self._parser = parser + super().__init__(*args, **kwargs) + + def add_argument(self, *args: Any, **kwargs: Any): + if sys.version_info < (3, 13): + deprecated = kwargs.pop('deprecated', False) + action = super().add_argument(*args, **kwargs) + object.__setattr__(action, 'deprecated', deprecated) + if deprecated and action.dest not in \ + self._parser.__class__._deprecated: + self._parser._deprecated.add(action) + return action + + # python>3.13 + return super().add_argument(*args, **kwargs) + + class FlexibleArgumentParser(ArgumentParser): """ArgumentParser that allows both underscore and dash in names.""" + _deprecated: set[Action] = set() + _seen: set[str] = set() + def __init__(self, *args, **kwargs): # Set the default 'formatter_class' to SortedHelpFormatter if 'formatter_class' not in kwargs: kwargs['formatter_class'] = SortedHelpFormatter super().__init__(*args, **kwargs) - def parse_args(self, args=None, namespace=None): + if sys.version_info < (3, 13): + + def parse_known_args( # type: ignore[override] + self, + args: Sequence[str] | None = None, + namespace: Namespace | None = None, + ) -> tuple[Namespace | None, list[str]]: + namespace, args = super().parse_known_args(args, namespace) + for action in FlexibleArgumentParser._deprecated: + if action.dest not in FlexibleArgumentParser._seen and getattr( + namespace, action.dest, + None) != action.default: # noqa: E501 + self._warning( + _gettext("argument '%(argument_name)s' is deprecated") + % {'argument_name': action.dest}) + FlexibleArgumentParser._seen.add(action.dest) + return namespace, args + + def add_argument(self, *args: Any, **kwargs: Any): + # add a deprecated=True compatibility + # for python < 3.13 + deprecated = kwargs.pop('deprecated', False) + action = super().add_argument(*args, **kwargs) + object.__setattr__(action, 'deprecated', deprecated) + if deprecated and \ + action not in FlexibleArgumentParser._deprecated: + self._deprecated.add(action) + + return action + + def _warning(self, message: str): + self._print_message( + _gettext('warning: %(message)s\n') % {'message': message}, + sys.stderr) + + def parse_args( # type: ignore[override] + self, + args: list[str] | None = None, + namespace: Namespace | None = None, + ): if args is None: args = sys.argv[1:] @@ -1604,6 +1676,15 @@ def _load_config_file(self, file_path: str) -> list[str]: return processed_args + def add_argument_group( + self, + *args: Any, + **kwargs: Any, + ) -> _FlexibleArgumentGroup: + group = _FlexibleArgumentGroup(self, self, *args, **kwargs) + self._action_groups.append(group) + return group + async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f986d797f2b..db792690215 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -309,13 +310,11 @@ def reorder_batch(self, input_batch: "InputBatch", return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 6e964b471fa..0852e15f9c1 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -18,6 +18,7 @@ get_layers_from_vllm_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention +from vllm.v1.attention.backends.utils import CommonAttentionMetadata if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -394,16 +395,15 @@ def _plan(self, attn_metadata: FlashInferMetadata): ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) page_size = self.runner.block_size device = self.runner.device - qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) + qo_indptr = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8b1875e7356..0d18a5639c2 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -207,6 +207,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import CommonAttentionMetadata try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -451,7 +452,8 @@ def _build_decode(self, input_positions: torch.Tensor, ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -460,15 +462,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, device = self.runner.device block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens prefill_metadata = None if self._num_prefills > 0: diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 05b97172bc6..79ec67b89e9 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -12,7 +12,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv +from vllm.utils import cdiv, next_power_of_2 logger = init_logger(__name__) @@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int: min_page_size = 1 << (min_page_size - 1).bit_length() return min_page_size + # TPU has limited SREGs (scalar registers), if page_size is too small, we + # can spill SREGs easily which leads to bad performance. The strategy we + # apply here is trying to split max-model-len to 16 pages which make the + # spill less likely. Meanwhile we make sure the page size is in [16, 256]. + @staticmethod + def get_page_size(vllm_config: VllmConfig) -> int: + page_size = next_power_of_2( + vllm_config.model_config.max_model_len) // 16 + if page_size <= 16: + return 16 + if page_size >= 256: + return 256 + return page_size + @dataclass class PallasMetadata: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 5f961047056..bb700c8e2e7 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -4,11 +4,10 @@ import torch +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) -from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) @@ -87,6 +86,11 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.use_irope = use_irope assert self.num_heads % self.num_kv_heads == 0 @@ -143,11 +147,9 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - PagedAttention.write_to_paged_cache( + key_cache, value_cache = kv_cache.unbind(0) + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, key_cache, @@ -158,6 +160,18 @@ def forward( layer._v_scale, ) + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + num_tokens, num_heads, head_size = query.shape + assert layer._q_scale == 1.0, \ + "A non 1.0 q_scale is not currently supported." + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + use_local_attn = \ (self.use_irope and attn_metadata.local_attn_metadata is not None) @@ -165,34 +179,37 @@ def forward( assert attn_metadata.local_attn_metadata is not None local_metadata = attn_metadata.local_attn_metadata cu_seqlens_q = local_metadata.local_query_start_loc - sequesd_k = local_metadata.local_seqused_k + seqused_k = local_metadata.local_seqused_k max_seqlen_q = local_metadata.local_max_query_len max_seqlen_k = local_metadata.local_max_seq_len block_table = local_metadata.local_block_table else: cu_seqlens_q = attn_metadata.query_start_loc - sequesd_k = attn_metadata.seq_lens + seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=sequesd_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py new file mode 100644 index 00000000000..10a771e830b --- /dev/null +++ b/vllm/v1/attention/backends/utils.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + + +@dataclass +class CommonAttentionMetadata: + """ + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. + """ + + query_start_loc: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a2fa5825bb1..9e172b6bdb0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,6 +2,7 @@ from collections import defaultdict from collections.abc import Iterable +from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent @@ -18,6 +19,24 @@ logger = init_logger(__name__) +@dataclass +class KVCacheBlocks: + blocks: list[KVCacheBlock] + + def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": + """Adds two KVCacheBlocks instances.""" + return KVCacheBlocks(self.blocks + other.blocks) + + @classmethod + def create_empty(cls) -> "KVCacheBlocks": + """Creates a new KVCacheBlocks instance with no blocks.""" + return cls([]) + + def get_block_ids(self) -> list[int]: + """Converts the KVCacheBlocks instance to a list of block IDs.""" + return [block.block_id for block in self.blocks] + + class KVCacheManager: def __init__( @@ -94,8 +113,8 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -109,7 +128,7 @@ def get_computed_blocks( """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 + return KVCacheBlocks.create_empty(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -124,7 +143,7 @@ def get_computed_blocks( self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: - return [], 0 + return KVCacheBlocks.create_empty(), 0 if len(block_hashes) * self.block_size == request.num_tokens: # When prompt length is divisible by the block size and all @@ -157,15 +176,15 @@ def get_computed_blocks( # sharing, `num_computed_tokens` is always a multiple of # `block_size`. num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None, + new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, - ) -> Optional[list[KVCacheBlock]]: + ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -173,7 +192,7 @@ def allocate_slots( num_tokens: The number of tokens to allocate, including external tokens. Note that this does not include tokens that have already been computed locally (i.e. new_computed_blocks). - new_computed_blocks: A list of new computed blocks just hitting the + new_computed_blocks: The new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such @@ -199,7 +218,10 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [] + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks + else: + new_computed_block_list = [] req_blocks = self.req_to_blocks[request.request_id] @@ -216,17 +238,18 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) + len(new_computed_block_list) * self.block_size) num_required_blocks = cdiv( num_computed_tokens + num_tokens + num_lookahead_tokens, self.block_size) num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) + len(new_computed_block_list)) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + num_evictable_computed_blocks = sum(1 + for blk in new_computed_block_list if blk.ref_cnt == 0) if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): @@ -235,15 +258,15 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_blocks) + self.block_pool.touch(new_computed_block_list) else: - assert not new_computed_blocks, ( + assert not new_computed_block_list, ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) + req_blocks.extend(new_computed_block_list) # Start to handle new blocks @@ -267,12 +290,12 @@ def allocate_slots( req_blocks.extend(new_blocks) if not self.enable_caching: - return new_blocks + return KVCacheBlocks(new_blocks) - # Use `new_computed_blocks` for a new request, and `num_cached_block` - # for a running request. - num_cached_blocks = self.num_cached_block.get(request.request_id, - len(new_computed_blocks)) + # Use `new_computed_block_list` for a new request, and + # `num_cached_block` for a running request. + num_cached_blocks = self.num_cached_block.get( + request.request_id, len(new_computed_block_list)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. @@ -291,7 +314,7 @@ def allocate_slots( self.num_cached_block[ request.request_id] = num_full_blocks_after_append - return new_blocks + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 05472ea573d..258e0d570e3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -261,9 +261,8 @@ def schedule(self) -> SchedulerOutput: # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + new_blocks.get_block_ids()) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -407,9 +406,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + computed_blocks + new_blocks).get_block_ids() num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index cb125bf4bf1..ff449901030 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -8,7 +8,7 @@ import time import traceback import weakref -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -53,10 +53,11 @@ def _init_executor(self) -> None: self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - assert self.world_size == tensor_parallel_size, ( + pp_parallel_size = self.parallel_config.pipeline_parallel_size + assert self.world_size == tensor_parallel_size * pp_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " - f"tensor_parallel_size ({tensor_parallel_size}). " - f"Pipeline parallelism is not yet implemented in v1") + f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" + f"_parallel_size ({pp_parallel_size}). ") # Set multiprocessing envs that are common to V0 and V1 set_multiprocessing_worker_envs(self.parallel_config) @@ -104,6 +105,17 @@ def _init_executor(self) -> None: self._ensure_worker_termination( [w.proc for w in unready_workers]) + # For pipeline parallel, we use a thread pool for asynchronous + # execute_model. + self.io_thread_pool: Optional[ThreadPoolExecutor] = None + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() + def start_worker_monitor(self): workers = self.workers self_ref = weakref.ref(self) @@ -145,7 +157,9 @@ def execute_model( ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: (output, ) = self.collective_rpc("execute_model", args=(scheduler_output, ), - rank0_reply_only=True, + unique_reply_rank=self.output_rank, + non_block=self.max_concurrent_batches + > 1, timeout=EXECUTE_MODEL_TIMEOUT_S) return output @@ -154,7 +168,8 @@ def collective_rpc(self, timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict] = None, - rank0_reply_only: bool = False) -> list[Any]: + non_block: bool = False, + unique_reply_rank: Optional[int] = None) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -171,22 +186,35 @@ def collective_rpc(self, send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) self.rpc_broadcast_mq.enqueue( - (send_method, args, kwargs, rank0_reply_only)) + (send_method, args, kwargs, unique_reply_rank)) - workers = (self.workers[0], ) if rank0_reply_only else self.workers - responses = [None] * len(workers) - for w in workers: - dequeue_timeout = None if deadline is None else ( - deadline - time.monotonic()) + workers = (self.workers[unique_reply_rank], + ) if unique_reply_rank is not None else self.workers + responses = [] + + def get_response(w: WorkerProcHandle, + dequeue_timeout: Optional[float] = None, + cancel_event: Optional[threading.Event] = None): status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=self.shutdown_event) + timeout=dequeue_timeout, cancel=cancel_event) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" " stack trace above for the root cause") + return result - responses[w.rank] = result + for w in workers: + dequeue_timeout = None if deadline is None else ( + deadline - time.monotonic()) + + if non_block: + result = self.io_thread_pool.submit( # type: ignore + get_response, w, dequeue_timeout, self.shutdown_event) + else: + result = get_response(w, dequeue_timeout) + + responses.append(result) return responses except TimeoutError as e: @@ -225,6 +253,11 @@ def shutdown(self): if not getattr(self, 'shutting_down', False): self.shutting_down = True self.shutdown_event.set() + + if self.io_thread_pool is not None: + self.io_thread_pool.shutdown(wait=False, cancel_futures=True) + self.io_thread_pool = None + for w in self.workers: w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in self.workers]) @@ -235,6 +268,22 @@ def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return + @property + def max_concurrent_batches(self) -> int: + return self.parallel_config.pipeline_parallel_size + + def _get_output_rank(self) -> int: + # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1 + # (the first TP worker of the last PP stage). + # Example: + # Assuming TP=8, PP=4, then the world_size=32 + # 0-7, PP rank 0 + # 8-15, PP rank 1 + # 16-23, PP rank 2 + # 24-31, PP rank 3 + # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) + return self.world_size - self.parallel_config.tensor_parallel_size + @dataclass class UnreadyWorkerProcHandle: @@ -280,12 +329,14 @@ def __init__( all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] + is_driver_worker = ( + rank % vllm_config.parallel_config.tensor_parallel_size == 0) all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, - "is_driver_worker": rank == 0, + "is_driver_worker": is_driver_worker, } wrapper.init_worker(all_kwargs) self.worker = wrapper @@ -455,7 +506,7 @@ class ResponseStatus(Enum): def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): @@ -470,11 +521,11 @@ def worker_busy_loop(self): logger.exception("WorkerProc hit an exception.") # exception might not be serializable, so we convert it to # string, only for logging purpose. - if not rank0_only or self.rank == 0: + if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.FAILURE, str(e))) continue - if not rank0_only or self.rank == 0: + if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6d71743c5e3..13cfcc4bbb6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,10 +2,12 @@ import torch import torch.nn as nn -from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config +from vllm.attention.layer import Attention +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM @@ -26,23 +28,25 @@ def __init__( device: torch.device, ): self.vllm_config = vllm_config - self.method = self.vllm_config.speculative_config.method - self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens) - self.max_model_len = vllm_config.model_config.max_model_len - self.block_size = vllm_config.cache_config.block_size + self.speculative_config = vllm_config.speculative_config + self.draft_model_config = self.speculative_config.draft_model_config + self.method = self.speculative_config.method self.dtype = vllm_config.model_config.dtype - - self.max_num_tokens = vllm_config.scheduler_config \ - .max_num_batched_tokens - - self.hidden_size = vllm_config.model_config.get_hidden_size() + self.max_model_len = vllm_config.model_config.max_model_len + self.block_size = vllm_config.cache_config.block_size + self.num_speculative_tokens = ( + self.speculative_config.num_speculative_tokens) + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) @@ -54,7 +58,6 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=device) - self.hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -129,7 +132,6 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states with set_forward_context(attn_metadata, @@ -207,7 +209,6 @@ def propose( # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions - self.hidden_states[:batch_size] = hidden_states # Run the model. @@ -276,6 +277,8 @@ def load_model(self, target_model: nn.Module) -> None: loader = get_model_loader(self.vllm_config.load_config) target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.parallel_config) + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -292,6 +295,11 @@ def load_model(self, target_model: nn.Module) -> None: vllm_config=self.vllm_config, start_layer_id=target_layer_num).to(target_device) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = next(iter(draft_attn_layer_names)) loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) if self.vllm_config.speculative_config.method == "eagle3": diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b465..e0c3d05c797 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -30,6 +30,7 @@ GiB_bytes, LayerBlockType, LazyLoader, cdiv, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -157,9 +158,12 @@ def __init__( # Sampler self.sampler = Sampler() - # Lazy initialization + # Lazy initializations # self.model: nn.Module # Set after load_model + # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + # self.kv_cache_config: KVCacheConfig + # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -488,7 +492,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor, + ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -585,20 +589,39 @@ def _prepare_inputs( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) + + attn_metadata: dict[str, FlashAttentionMetadata] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - ) + attn_metadata_i = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -608,7 +631,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -1016,7 +1039,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, torch.Tensor]: + ) -> Union[ModelRunnerOutput, IntermediateTensors]: # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): get_kv_transfer_group().bind_connector_metadata( @@ -1230,6 +1253,7 @@ def execute_model( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1241,8 +1265,8 @@ def execute_model( dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1256,7 +1280,7 @@ def execute_model( device=self.device, ) cu_num_tokens, token_indices = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, + eagle_attn_metadata.query_start_loc, num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] @@ -1266,7 +1290,8 @@ def execute_model( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -1275,7 +1300,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_table, + block_table=eagle_attn_metadata.block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() @@ -1708,6 +1733,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") + self.kv_cache_config = kv_cache_config kv_caches: dict[str, torch.Tensor] = {} diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ac6861f93a8..5352b1c5a37 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,11 +15,12 @@ init_distributed_environment, set_custom_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -266,7 +267,22 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: - output = self.model_runner.execute_model(scheduler_output) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + + if not get_pp_group().is_last_rank: + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return None + + assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None def profile(self, is_start: bool = True): @@ -302,7 +318,7 @@ def save_sharded_state( pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader + from vllm.model_executor.model_loader import ShardedStateLoader ShardedStateLoader.save_model( self.model_runner.model, path, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8e162d5170d..f5626abb2a1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -588,7 +588,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) - return attn_metadata, logits_indices, padded_num_reqs + + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + return per_layer_attn_metadata, logits_indices, padded_num_reqs def _scatter_placeholders( self, @@ -956,7 +963,14 @@ def _dummy_run(self, num_tokens: int) -> None: torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - with set_forward_context(attn_metadata, self.vllm_config, 0): + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + + with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e22bbcc656f..d96021cc688 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1220,7 +1220,7 @@ def save_sharded_state( pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader + from vllm.model_executor.model_loader import ShardedStateLoader ShardedStateLoader.save_model( self.model, path, @@ -1232,7 +1232,7 @@ def save_tensorized_model( self, tensorizer_config: TensorizerConfig, ) -> None: - from vllm.model_executor.model_loader.loader import TensorizerLoader + from vllm.model_executor.model_loader import TensorizerLoader TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py new file mode 100644 index 00000000000..9618a4b49ff --- /dev/null +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +from importlib.util import find_spec +from typing import List, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.worker.neuron_model_runner import (ModelInputForNeuron, + NeuronModelRunner) + + +class MultiStepNeuronModelRunner(NeuronModelRunner): + """A model runner for multi step decoding using the transformers_neuronx + framework""" + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + self.speculation_config = self.speculative_config + from transformers_neuronx.config import GenerationConfig + self.speculation_config.draft_model_config.neuron_sampling_params = ( + GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K + )) + + def load_model(self) -> None: + if find_spec("transformers_neuronx") is not None: + from vllm.model_executor.model_loader.neuron import ( + get_neuron_eagle_speculation_model, + get_neuron_speculation_model) + if self.speculation_config.speculative_token_tree is not None: + self.model = get_neuron_eagle_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) + else: + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return output diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py new file mode 100644 index 00000000000..b6a3492a493 --- /dev/null +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.worker.neuronx_distributed_model_runner import ( + NeuronxDistributedModelRunner) + + +class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): + """A model runner for multi-step decoding using the + neuronx-distributed-inference framework""" + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + + def load_model(self) -> None: + from vllm.model_executor.model_loader.neuronx_distributed import ( + get_neuron_speculation_model) + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculative_config) + + @torch.inference_mode() + def execute_model( + self, + model_input, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + sampling_params = torch.tensor([[ + seq_group.sampling_params.top_k, + seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature, + ] for seq_group in model_input.sampling_metadata.seq_groups]) + + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return output diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index e046ebc449d..c80b69e78dc 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -2,20 +2,20 @@ import os from dataclasses import dataclass -from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn -from transformers_neuronx.config import GenerationConfig -from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context +from vllm.config import DeviceConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs) +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -34,12 +34,18 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None + sampling_metadata: SamplingMetadata = None + multi_modal_kwargs: BatchedTensorInputs = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + return { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "input_block_ids": self.input_block_ids, + "sampling_metadata": self.sampling_metadata, + "multi_modal_kwargs": self.multi_modal_kwargs, + } @classmethod def from_broadcasted_tensor_dict( @@ -47,11 +53,17 @@ def from_broadcasted_tensor_dict( tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForNeuron": - assert attn_backend is None - return cls.from_broadcasted_tensor_dict(tensor_dict) + return ModelInputForNeuron( + input_tokens=tensor_dict["input_tokens"], + input_positions=tensor_dict["input_positions"], + input_block_ids=tensor_dict["input_block_ids"], + sampling_metadata=tensor_dict["sampling_metadata"], + multi_modal_kwargs=tensor_dict["multi_modal_kwargs"], + ) class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): + """A model runner for AWS Neuron hardware""" # NEURON has an upper limit on the top_k _MAX_NEURON_SAMPLING_TOP_K = 256 @@ -61,13 +73,20 @@ def __init__( vllm_config: VllmConfig, ): ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - if model_config is not None and model_config.get_sliding_window(): + + if (self.model_config is not None + and self.model_config.get_sliding_window()): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") + self.device_config = (self.device_config if self.device_config + is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -82,32 +101,33 @@ def __init__( self._previous_batch_request_ids: List[str] = [] if not self._on_device_sampling_disabled: - logger.warning( - "On-device sampling is turned on in Neuron by default, only " - "top_k, top_p, and temperature are current supported sampling " - "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." - ) - self.model_config.neuron_sampling_params = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + self._init_neuron_sampling() - def load_model(self) -> None: - if find_spec("transformers_neuronx") is not None: - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + def _init_neuron_sampling(self) -> None: + if current_platform.use_transformers_neuronx(): + from transformers_neuronx.config import GenerationConfig else: - raise NotImplementedError( - "Supports only Transformer-NeuronX based models.") + from transformers import GenerationConfig + logger.warning( + "On-device sampling is turned on in Neuron by default, only " + "top_k, top_p, and temperature are current supported sampling " + "parameters. To turn off the on-device sampling, please set " + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.") + self.model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + + def load_model(self) -> None: + self.model = get_neuron_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def get_model(self) -> nn.Module: return self.model @@ -240,6 +260,16 @@ def prepare_model_input( (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) seq_lens = None + + if not self._on_device_sampling_disabled: + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + top_k, top_p, temperature = ( + self._convert_to_neuron_sampling_params(sampling_params)) + sampling_params.top_k = top_k + sampling_params.top_p = top_p + sampling_params.temperature = temperature + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -251,7 +281,8 @@ def prepare_model_input( self.pin_memory, generators=self.get_generators(finished_requests_ids)) - if not self._on_device_sampling_disabled: + if current_platform.use_transformers_neuronx( + ) and not self._on_device_sampling_disabled: # Once the request IDs are changed in current iteration, we will # update the on-device sampling parameters. current_batch_request_ids = [ @@ -259,7 +290,7 @@ def prepare_model_input( for seq_group_meta_data in seq_group_metadata_list ] if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(sampling_metadata) + self._update_neuron_sampling_params(seq_group_metadata_list) self._previous_batch_request_ids = current_batch_request_ids return ModelInputForNeuron(input_tokens=input_tokens, @@ -268,31 +299,59 @@ def prepare_model_input( sampling_metadata=sampling_metadata, multi_modal_kwargs=multi_modal_kwargs) - def _update_neuron_sampling_params(self, - sampling_metadata: SamplingMetadata): + def _update_neuron_sampling_params( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): # Update Neuron sampling parameters (GenerationConfig in Neuron) current_sampling_params = self.model_config.neuron_sampling_params assert current_sampling_params is not None, ( f"Failed to update sampling_params, " f"current sampling params is {current_sampling_params}") + is_update_needed = False + top_k = current_sampling_params.top_k top_p = current_sampling_params.top_p temperature = current_sampling_params.temperature - for index, sequence_group_to_sample in enumerate( - sampling_metadata.seq_groups): - top_k[index] = self._convert_to_neuron_top_k( - sequence_group_to_sample.sampling_params.top_k) - top_p[index] = sequence_group_to_sample.sampling_params.top_p - temperature[index] = \ - sequence_group_to_sample.sampling_params.temperature - self.model.model.update_generation_config(current_sampling_params) + # The index of a sequence's sampling parameters in neuron is equal to + # its index in `input_block_ids`. + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + + seq_group_top_k = sampling_params.top_k + seq_group_top_p = sampling_params.top_p + seq_group_temperature = sampling_params.temperature - def _convert_to_neuron_top_k(self, top_k: int) -> int: + for seq_id in seq_ids: + index = seq_group_metadata.block_tables[seq_id][0] + if (top_k[index] != seq_group_top_k + or top_p[index] != seq_group_top_p + or temperature[index] != seq_group_temperature): + is_update_needed = True + + top_k[index] = seq_group_top_k + top_p[index] = seq_group_top_p + temperature[index] = seq_group_temperature + + # update_generation_config is only available in transformers-neuronx + if is_update_needed and current_platform.use_transformers_neuronx(): + self.model.model.update_generation_config(current_sampling_params) + + def _convert_to_neuron_sampling_params( + self, sampling_params: SamplingParams) -> Tuple[int, float, float]: + # Returns the top_k, top_p and temperature parameters for neuron. + top_k = sampling_params.top_k + top_p = sampling_params.top_p + temperature = sampling_params.temperature + + if temperature == 0.0: + # Enable greedy sampling on zero temperature + return (1, 1.0, 1.0) if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: - return self._MAX_NEURON_SAMPLING_TOP_K - return top_k + top_k = self._MAX_NEURON_SAMPLING_TOP_K + + return (top_k, top_p, temperature) @torch.inference_mode() def execute_model( @@ -306,7 +365,26 @@ def execute_model( raise ValueError( "NeuronModelRunner does not support multi-step execution.") - with set_forward_context(None, self.vllm_config, 0): + # extract top_k, top_p and temperature from model_input for neuron + # forward call + sampling_params = (torch.tensor([[ + seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature + ] for seq_group in model_input.sampling_metadata.seq_groups])) + + if current_platform.use_neuronx_distributed(): + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), + ) + elif current_platform.use_transformers_neuronx(): + # [TODO] validate on-device sampling + # The model signature may need change for on-device sampling hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index df651e05a7b..aa8e39613ee 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,61 +1,81 @@ # SPDX-License-Identifier: Apache-2.0 """A Neuron worker class.""" +import os from typing import List, Optional, Tuple -import torch import torch.distributed from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.platforms.neuron import NeuronFramework from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoRANotSupportedWorkerBase, WorkerBase, WorkerInput) +logger = init_logger(__name__) + class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = True, - ) -> None: + model_runner: NeuronModelRunner + + def __init__(self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner: NeuronModelRunner = NeuronModelRunner( - vllm_config=vllm_config) - self.is_driver_worker = is_driver_worker - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - assert execute_model_req is not None - assert (not execute_model_req.blocks_to_swap_in - and not execute_model_req.blocks_to_swap_out - and not execute_model_req.blocks_to_copy), ( - "Cache operations are not supported for Neuron backend.") - assert execute_model_req.num_lookahead_slots == 0, ( - "lookahead not supported for Neuron backend.") - output = LocalOrDistributedWorkerBase.execute_model( - self, execute_model_req) - return output + neuron_framework = current_platform.get_neuron_framework_to_use() + if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: + self.model_runner = self.get_tnx_model_runner(vllm_config) + elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: + self.model_runner = self.get_neuronx_distributed_model_runner( + vllm_config) + else: + raise NotImplementedError( + "Specified framework" + + f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + + " is either not installed or not supported." + + " Supported frameworks: " + + "[transformers-neuronx, neuronx-distributed-inference]") + + def get_tnx_model_runner(self, vllm_config): + from vllm.worker.multi_step_neuron_model_runner import ( + MultiStepNeuronModelRunner) + if self.speculative_config is not None: + return MultiStepNeuronModelRunner(vllm_config=vllm_config) + else: + return NeuronModelRunner(vllm_config=vllm_config) + + def get_neuronx_distributed_model_runner(self, vllm_config): + from vllm.worker.multi_step_neuronx_distributed_model_runner import ( + MultiStepNeuronxDistributedModelRunner) + from vllm.worker.neuronx_distributed_model_runner import ( + NeuronxDistributedModelRunner) + if self.speculative_config is not None: + return MultiStepNeuronxDistributedModelRunner( + vllm_config=vllm_config) + else: + return NeuronxDistributedModelRunner(vllm_config=vllm_config) def init_device(self) -> None: self.init_distributed_environment() @@ -121,17 +141,17 @@ def get_cache_block_size_bytes(self) -> int: def init_distributed_environment(self): """Neuron uses transformers-neuronx for tensor parallelism. - It has only one process to control multiple devices. - vLLM still needs the environment initialized when TP/PP > 1, - so we initialize a distributed environment with one process. + + vLLM still needs the environment initialized when TP/PP > 1 """ init_distributed_environment( world_size=1, - rank=0, - local_rank=0, + rank=self.rank, + local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, backend="gloo", ) + ensure_model_parallel_initialized( 1, 1, diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py new file mode 100644 index 00000000000..4e784e5e030 --- /dev/null +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import torch +from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.neuronx_distributed import ( + _get_model_architecture, get_neuron_model) +from vllm.sequence import IntermediateTensors +from vllm.worker.neuron_model_runner import (ModelInputForNeuron, + NeuronModelRunner) + +logger = init_logger(__name__) + + +class NeuronxDistributedModelRunner(NeuronModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + + def load_model(self) -> None: + self.model = get_neuron_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def get_nxd_sampling_params(self, sampling_metadata): + if self.model.config.neuron_config.on_device_sampling_config: + max_topk = (self.model.config.neuron_config. + on_device_sampling_config.global_topk) + else: + max_topk = self.model.config.vocab_size + + top_k = [1] * self.scheduler_config.max_num_seqs + top_p = [1.0] * self.scheduler_config.max_num_seqs + temperature = [1.0] * self.scheduler_config.max_num_seqs + + for index, sequenceGroupToSample in enumerate( + sampling_metadata.seq_groups): + top_k[index] = (sequenceGroupToSample.sampling_params.top_k + if sequenceGroupToSample.sampling_params.top_k > 0 + else max_topk) + top_p[index] = sequenceGroupToSample.sampling_params.top_p + temperature[index] = ( + sequenceGroupToSample.sampling_params.temperature) + + sampling_params = prepare_sampling_params( + batch_size=self.scheduler_config.max_num_seqs, + top_k=top_k, + top_p=top_p, + temperature=temperature) + return sampling_params + + def get_multi_modal_data_neuron(self, input_images): + raise NotImplementedError("need to restore multi-modal support") + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "NeuronModelRunner does not support multi-step execution.") + + if _get_model_architecture( + self.model.config) != "MllamaForConditionalGeneration": + return super().execute_model(model_input, kv_caches, + intermediate_tensors, num_steps) + + sampling_params = self.get_nxd_sampling_params( + model_input.sampling_metadata) + + if model_input.multi_modal_kwargs.get('image') is not None: + pixel_values = [] + aspect_ratios = [] + num_chunks = [] + has_image = [] + for multi_modal_input in model_input.multi_modal_kwargs.get( + 'image'): + image_tensors = self.get_multi_modal_data_neuron( + multi_modal_input.squeeze(0)) + pixel_values.append(image_tensors[0]) + aspect_ratios.append(image_tensors[1]) + num_chunks.append(image_tensors[2]) + has_image.append(image_tensors[3]) + + pixel_values = torch.cat(pixel_values, dim=0) + aspect_ratios = torch.cat(aspect_ratios, dim=0) + num_chunks = torch.cat(num_chunks, dim=0) + has_image = torch.cat(has_image, dim=0) + + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=pixel_values, + aspect_ratios=aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) + else: + empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], + dtype=torch.bfloat16) + empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) + num_chunks = torch.tensor([[1] + ]) # dummy num_chunks, will not be used + has_image = torch.tensor([0]) + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=empty_pixel_values, + aspect_ratios=empty_aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) + + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output]