diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b964272bf..9087a2847 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @sfc-gh-aqiao @sfc-gh-jrasley @sfc-gh-mhidayetoglu @sfc-gh-yewang @sfc-gh-goliaro +* @sfc-gh-aqiao @sfc-gh-jrasley @sfc-gh-mhidayetoglu @sfc-gh-yewang @sfc-gh-goliaro @sfc-gh-reyazda diff --git a/README.md b/README.md index 5cf6734de..81499aa5c 100644 --- a/README.md +++ b/README.md @@ -36,10 +36,10 @@ Arctic Inference achieves high throughput and low latency through a wholistic se - Arctic Ulysses (blog, - paper) + Arctic Ulysses (blog)
- Shift Parallelism (blog) + Shift Parallelism (blog, + paper) Arctic Speculator (blog) @@ -105,7 +105,7 @@ By using the examples below, you can get benefits from Shift Parallelism, Specul #### Serving ```console -vllm serve Snowflake/Llama-3.1-SwiftKV-8B-Instruct \ +ARCTIC_INFERENCE_ENABLED=1 vllm serve Snowflake/Llama-3.1-SwiftKV-8B-Instruct \ --quantization "fp8" \ --tensor-parallel-size 1 \ --ulysses-sequence-parallel-size 2 \ @@ -121,6 +121,8 @@ vllm serve Snowflake/Llama-3.1-SwiftKV-8B-Instruct \ #### Offline +Save the following script to `arctic_example.py`: + ```python import vllm from vllm import LLM, SamplingParams @@ -156,6 +158,12 @@ outputs = llm.chat(conversation, sampling_params=sampling_params) print(outputs[0].outputs[0].text) ``` +Run the script with Arctic Inference enabled: + +```console +ARCTIC_INFERENCE_ENABLED=1 python arctic_example.py +``` + ## Citation ``` @misc{arcticinference2025, diff --git a/arctic_inference/envs.py b/arctic_inference/envs.py index 02056a158..6b3da367a 100644 --- a/arctic_inference/envs.py +++ b/arctic_inference/envs.py @@ -20,10 +20,18 @@ ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK: bool = False environment_variables: dict[str, Callable[[], Any]] = { + "ARCTIC_INFERENCE_ENABLED": + lambda: os.getenv("ARCTIC_INFERENCE_ENABLED", "0") == "1", + "ARCTIC_INFERENCE_SKIP_PLATFORM_CHECK": + lambda: os.getenv("ARCTIC_INFERENCE_SKIP_PLATFORM_CHECK", "0") == "1", "ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK": lambda: os.getenv("ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK", "0") == "1", + "ARCTIC_INFERENCE_SKIP_VERSION_CHECK": + lambda: os.getenv("ARCTIC_INFERENCE_SKIP_VERSION_CHECK", "0") == "1", } +# temporary workaround for gpt-oss model +ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK = 1 def __getattr__(name: str) -> Any: if name in environment_variables: diff --git a/arctic_inference/op_builder/__init__.py b/arctic_inference/op_builder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/arctic_inference/op_builder/builder.py b/arctic_inference/op_builder/builder.py new file mode 100644 index 000000000..3206c3f1c --- /dev/null +++ b/arctic_inference/op_builder/builder.py @@ -0,0 +1,545 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import re +import sys +import time +import importlib +from pathlib import Path +import subprocess +import shlex +import shutil +import tempfile +import distutils.ccompiler +import distutils.log +import distutils.sysconfig +from distutils.errors import CompileError, LinkError +from abc import ABC, abstractmethod +from typing import List + +YELLOW = '\033[93m' +END = '\033[0m' +WARNING = f"{YELLOW} [WARNING] {END}" + +DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" +DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0" + +try: + import torch +except ImportError: + print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.") +else: + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + +class MissingCUDAException(Exception): + pass + + +class CUDAMismatchException(Exception): + pass + + +def installed_cuda_version(name=""): + import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME + if cuda_home is None: + raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)") + # Ensure there is not a cuda version mismatch between torch and nvcc compiler + output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) + output_split = output.split() + release_idx = output_split.index("release") + release = output_split[release_idx + 1].replace(',', '').split(".") + # Ignore patch versions, only look at major + minor + cuda_major, cuda_minor = release[:2] + return int(cuda_major), int(cuda_minor) + + +def get_default_compute_capabilities(): + compute_caps = DEFAULT_COMPUTE_CAPABILITIES + # Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported + import torch.utils.cpp_extension + if torch.utils.cpp_extension.CUDA_HOME is not None: + if installed_cuda_version()[0] == 11: + if installed_cuda_version()[1] >= 0: + compute_caps += ";8.0" + if installed_cuda_version()[1] >= 1: + compute_caps += ";8.6" + if installed_cuda_version()[1] >= 8: + compute_caps += ";9.0" + elif installed_cuda_version()[0] == 12: + compute_caps += ";8.0;8.6;9.0" + if installed_cuda_version()[1] >= 8: + compute_caps += ";10.0;12.0" + return compute_caps + + +# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used +# to build deepspeed and system-wide installed cuda 11.2 +cuda_minor_mismatch_ok = { + 10: ["10.0", "10.1", "10.2"], + 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], + 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6", "12.8", "12.9"], # There is no CUDATk 12.7 +} + + +def assert_no_cuda_mismatch(name=""): + cuda_major, cuda_minor = installed_cuda_version(name) + sys_cuda_version = f'{cuda_major}.{cuda_minor}' + torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + # This is a show-stopping error, should probably not proceed past this + if sys_cuda_version != torch_cuda_version: + if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] + and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]): + print(f"Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda} " + "but since the APIs are compatible, accepting this combination") + return True + elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1": + print( + f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}." + "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior." + ) + return True + raise CUDAMismatchException( + f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") + return True + + +class OpBuilder(ABC): + _loaded_ops = {} + + def __init__(self, name): + self.name = name + self.jit_mode = False + self.enable_bf16 = False + self.error_log = None + + @abstractmethod + def absolute_name(self): + ''' + Returns absolute build path for cases where the op is pre-installed will be installed. + ''' + pass + + @abstractmethod + def sources(self): + ''' + Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + pass + + @staticmethod + def validate_torch_version(torch_info): + install_torch_version = torch_info['version'] + current_torch_version = ".".join(torch.__version__.split('.')[:2]) + if install_torch_version != current_torch_version: + raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install torch version={install_torch_version}, " + f"Runtime torch version={current_torch_version}") + + @staticmethod + def validate_torch_op_version(torch_info): + + current_hip_version = ".".join(torch.version.hip.split('.')[:2]) + install_hip_version = torch_info['hip_version'] + if install_hip_version != current_hip_version: + raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install HIP version={install_hip_version}, " + f"Runtime HIP version={current_hip_version}") + + def include_paths(self): + ''' + Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + return [] + + def nvcc_args(self): + ''' + Returns optional list of compiler flags to forward to nvcc when building CUDA sources + ''' + return [] + + def cxx_args(self): + ''' + Returns optional list of compiler flags to forward to the build + ''' + return [] + + def is_compatible(self, verbose=False): + ''' + Check if all non-python dependencies are satisfied to build this op + ''' + return True + + def extra_ldflags(self): + return [] + + def has_function(self, funcname, libraries, library_dirs=None, verbose=False): + ''' + Test for existence of a function within a tuple of libraries. + + This is used as a smoke test to check whether a certain library is available. + As a test, this creates a simple C program that calls the specified function, + and then distutils is used to compile that program and link it with the specified libraries. + Returns True if both the compile and link are successful, False otherwise. + ''' + tempdir = None # we create a temporary directory to hold various files + filestderr = None # handle to open file to which we redirect stderr + oldstderr = None # file descriptor for stderr + try: + # Echo compile and link commands that are used. + if verbose: + distutils.log.set_verbosity(1) + + # Create a compiler object. + compiler = distutils.ccompiler.new_compiler(verbose=verbose) + + # Configure compiler and linker to build according to Python install. + distutils.sysconfig.customize_compiler(compiler) + + # Create a temporary directory to hold test files. + tempdir = tempfile.mkdtemp() + + # Define a simple C program that calls the function in question + prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname) + + # Write the test program to a file. + filename = os.path.join(tempdir, 'test.c') + with open(filename, 'w') as f: + f.write(prog) + + # Redirect stderr file descriptor to a file to silence compile/link warnings. + if not verbose: + filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w') + oldstderr = os.dup(sys.stderr.fileno()) + os.dup2(filestderr.fileno(), sys.stderr.fileno()) + + # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames() + # Otherwise, a local directory will be used instead of tempdir + drive, driveless_filename = os.path.splitdrive(filename) + root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else '' + output_dir = os.path.join(drive, root_dir) + + # Attempt to compile the C program into an object file. + cflags = shlex.split(os.environ.get('CFLAGS', "")) + objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags)) + + # Attempt to link the object file into an executable. + # Be sure to tack on any libraries that have been specified. + ldflags = shlex.split(os.environ.get('LDFLAGS', "")) + compiler.link_executable(objs, + os.path.join(tempdir, 'a.out'), + extra_preargs=self.strip_empty_entries(ldflags), + libraries=libraries, + library_dirs=library_dirs) + + # Compile and link succeeded + return True + + except CompileError: + return False + + except LinkError: + return False + + except: + return False + + finally: + # Restore stderr file descriptor and close the stderr redirect file. + if oldstderr is not None: + os.dup2(oldstderr, sys.stderr.fileno()) + if filestderr is not None: + filestderr.close() + + # Delete the temporary directory holding the test program and stderr files. + if tempdir is not None: + shutil.rmtree(tempdir) + + def strip_empty_entries(self, args): + ''' + Drop any empty strings from the list of compile and link flags + ''' + return [x for x in args if len(x) > 0] + + def get_cuda_compile_flag(self): + try: + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" + except MissingCUDAException: + print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " + "only cpu ops can be compiled!") + return '-D__DISABLE_CUDA__' + + def command_exists(self, cmd): + if '|' in cmd: + cmds = cmd.split("|") + else: + cmds = [cmd] + valid = False + for cmd in cmds: + safe_cmd = ["bash", "-c", f"type {cmd}"] + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + valid = valid or result.wait() == 0 + + if not valid and len(cmds) > 1: + print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!") + elif not valid and len(cmds) == 1: + print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!") + return valid + + def warning(self, msg): + self.error_log = f"{msg}" + print(f"{WARNING} {msg}") + + def _src_path(self, code_path): + if os.path.isabs(code_path): + return code_path + else: + return os.path.join(Path(__file__).parent.parent.absolute(), code_path) + + def builder(self): + from torch.utils.cpp_extension import CppExtension + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + return CppExtension(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())}, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + def load(self, verbose=False): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name + from deepspeed.accelerator import get_accelerator + if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name: + # Ensure the op we're about to load was compiled with the same + # torch/cuda versions we are currently using at runtime. + self.validate_torch_version(torch_info) + if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder): + self.validate_torch_op_version(torch_info) + + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module + else: + return self.jit_load(verbose) + + def jit_load(self, verbose=True): + if not self.is_compatible(verbose): + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" + ) + try: + import ninja # noqa: F401 # type: ignore + except ImportError: + raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") + + self.jit_mode = True + from torch.utils.cpp_extension import load + + start_build = time.time() + sources = [os.path.abspath(self._src_path(path)) for path in self.sources()] + extra_include_paths = [os.path.abspath(self._src_path(path)) for path in self.include_paths()] + + # Torch will try and apply whatever CCs are in the arch list at compile time, + # we have already set the intended targets ourselves we know that will be + # needed at runtime. This prevents CC collisions such as multiple __half + # implementations. Stash arch list to reset after build. + torch_arch_list = None + if "TORCH_CUDA_ARCH_LIST" in os.environ: + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + nvcc_args = self.strip_empty_entries(self.nvcc_args()) + cxx_args = self.strip_empty_entries(self.cxx_args()) + + cxx_args.append("-UC10_USE_GLOG") + nvcc_args.append("-UC10_USE_GLOG") + if isinstance(self, CUDAOpBuilder): + if self.enable_bf16: + cxx_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=cxx_args, + extra_cuda_cflags=nvcc_args, + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + with_cuda=True if (isinstance(self, CUDAOpBuilder)) else None, + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + + # Reset arch list so we are not silently removing it for other possible use cases + if torch_arch_list: + os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + + __class__._loaded_ops[self.name] = op_module + + return op_module + + +class CUDAOpBuilder(OpBuilder): + + def compute_capability_args(self, cross_compile_archs=None): + """ + Returns nvcc compute capability compile flags. + + 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. + 2. If neither is set default compute capabilities will be used + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + + Format: + + - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... + + - `cross_compile_archs` uses ; separator. + + """ + ccs = [] + if self.jit_mode: + # Compile for underlying architectures since we know those at runtime + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + ccs = sorted(ccs) + ccs[-1] += '+PTX' + else: + # Cross-compile mode, compile for various architectures + # env override takes priority + cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + if cross_compile_archs_env is not None: + if cross_compile_archs is not None: + print( + f"{WARNING} env var TORCH_CUDA_ARCH_LIST={cross_compile_archs_env} overrides cross_compile_archs={cross_compile_archs}" + ) + cross_compile_archs = cross_compile_archs_env.replace(' ', ';') + else: + if cross_compile_archs is None: + cross_compile_archs = get_default_compute_capabilities() + ccs = cross_compile_archs.split(';') + + ccs = self.filter_ccs(ccs) + if len(ccs) == 0: + raise RuntimeError( + f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") + + args = [] + self.enable_bf16 = True + for cc in ccs: + num = cc[0] + cc[1].split('+')[0] + args.append(f'-gencode=arch=compute_{num},code=sm_{num}') + if cc[1].endswith('+PTX'): + args.append(f'-gencode=arch=compute_{num},code=compute_{num}') + + if int(cc[0]) <= 7: + self.enable_bf16 = False + + return args + + def filter_ccs(self, ccs: List[str]): + """ + Prune any compute capabilities that are not compatible with the builder. Should log + which CCs have been pruned. + """ + return [cc.split('.') for cc in ccs] + + def version_dependent_macros(self): + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + + def is_compatible(self, verbose=False): + return super().is_compatible(verbose) + + def builder(self): + from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ + {'cxx': self.strip_empty_entries(self.cxx_args()), \ + 'nvcc': self.strip_empty_entries(self.nvcc_args())} + + if self.enable_bf16: + compile_args['cxx'].append("-DBF16_AVAILABLE") + compile_args['nvcc'].append("-DBF16_AVAILABLE") + + cuda_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + return cuda_ext + + def cxx_args(self): + if sys.platform == "win32": + return ['-O2'] + else: + return ['-O3', '-std=c++17', '-g', '-Wno-reorder'] + + def nvcc_args(self): + args = ['-O3'] + try: + nvcc_threads = int(os.getenv("DS_NVCC_THREADS", "")) + if nvcc_threads <= 0: + raise ValueError("") + except ValueError: + nvcc_threads = min(os.cpu_count(), 8) + cuda_major, cuda_minor = installed_cuda_version() + if cuda_major > 10: + if cuda_major == 12 and cuda_minor >= 5: + std_lib = '-std=c++20' + else: + std_lib = '-std=c++17' + else: + std_lib = '-std=c++14' + args += [ + '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', std_lib, + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + f'--threads={nvcc_threads}' + ] + if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': + args.append('--ptxas-options=-v') + args += self.compute_capability_args() + return args + + def libraries_args(self): + if sys.platform == "win32": + return ['cublas', 'curand'] + else: + return [] + diff --git a/arctic_inference/op_builder/swiftkv_ops_builder.py b/arctic_inference/op_builder/swiftkv_ops_builder.py new file mode 100644 index 000000000..77ebc8abe --- /dev/null +++ b/arctic_inference/op_builder/swiftkv_ops_builder.py @@ -0,0 +1,29 @@ +import os +from .builder import CUDAOpBuilder + +class SwiftKVOpsBuilder(CUDAOpBuilder): + def __init__(self): + super().__init__(name="reshape_and_cache_flash_bulk") + + def absolute_name(self): + return f'arctic_inference.swiftkv_ops.{self.name}' + + def get_prefix(self): + # borrowed from moe_op. refactor later + ai_path = self._src_path("arctic_inference") + return "arctic_inference" if os.path.isdir(ai_path) else ".." + + def sources(self): + sources = [ + 'csrc/custom_ops/torch_bindings.cpp', + 'csrc/custom_ops/reshape_and_cache_flash_bulk.cu', + ] + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def include_paths(self): + sources = ['csrc/custom_ops'] + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources diff --git a/arctic_inference/py_custom_ops.py b/arctic_inference/py_custom_ops.py index 89a05d92b..2df4ab397 100644 --- a/arctic_inference/py_custom_ops.py +++ b/arctic_inference/py_custom_ops.py @@ -22,8 +22,8 @@ def try_load_torch_library() -> bool: return False try: - logger.info(f"Attempting to load custom ops from {library_path}...") torch.ops.load_library(library_path) + logger.info(f"Successfully loaded custom ops from {library_path}.") return True except RuntimeError as e: logger.info( @@ -37,6 +37,25 @@ def try_load_torch_library() -> bool: return False +def try_load_jit_library() -> bool: + try: + from arctic_inference.op_builder.swiftkv_ops_builder import SwiftKVOpsBuilder + swiftkv_ops_module = SwiftKVOpsBuilder().load() + + logger.info("Successfully loaded SwiftKVOpsBuilder JIT library.") + return True + except ImportError as e: + logger.info( + f"Unable to import SwiftKVOpsBuilder. ImportError: {e}. Falling back to original implementation." + ) + return False + except Exception as e: + logger.info( + f"Unable to load JIT library. Exception: {e}. Falling back to original implementation." + ) + return False + + def reshape_and_cache_flash_bulk( keys: list[torch.Tensor], values: list[torch.Tensor], @@ -52,3 +71,48 @@ def reshape_and_cache_flash_bulk( torch.ops.arctic_inference.reshape_and_cache_flash_bulk( keys, values, key_caches, value_caches, slot_mapping, kv_cache_dtype, k_scales, v_scales, num_heads, head_size) + + +def reshape_and_cache_flash_fp4( + keys: torch.Tensor, + values: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key_cache_scales: torch.Tensor, + value_cache_scales: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + torch.ops.arctic_inference.reshape_and_cache_flash_fp4( + keys, values, key_cache, value_cache, slot_mapping, kv_cache_dtype, + k_scale, v_scale, key_cache_scales, value_cache_scales) + + +def speculator_ln( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return torch.ops.arctic_inference.speculator_ln_cuda( + input, weight, bias, eps) + + +def sum_lstm( + states_4d: torch.Tensor, + z4_4d: torch.Tensor, + prev_cell_d: torch.Tensor, + w_cell: torch.Tensor | None, + b_cell: torch.Tensor | None, + w_state: torch.Tensor | None, + b_state: torch.Tensor | None, + alpha: float, + eps_cell: float, + eps_state: float, + use_fast_gelu: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.arctic_inference.sum_lstm_cuda( + states_4d, z4_4d, prev_cell_d, w_cell, b_cell, w_state, b_state, alpha, + eps_cell, eps_state, use_fast_gelu) \ No newline at end of file diff --git a/arctic_inference/common/suffix_cache/__init__.py b/arctic_inference/suffix_decoding/__init__.py similarity index 83% rename from arctic_inference/common/suffix_cache/__init__.py rename to arctic_inference/suffix_decoding/__init__.py index 59161c330..2e2ad3384 100644 --- a/arctic_inference/common/suffix_cache/__init__.py +++ b/arctic_inference/suffix_decoding/__init__.py @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .suffix_cache import SuffixCache, SuffixSpecResult +from .cache import SuffixDecodingCache, SuffixDecodingDraft -__all__ = ["SuffixCache", "SuffixSpecResult"] +__all__ = ["SuffixDecodingCache", "SuffixDecodingDraft"] diff --git a/arctic_inference/common/suffix_cache/suffix_cache.py b/arctic_inference/suffix_decoding/cache.py similarity index 61% rename from arctic_inference/common/suffix_cache/suffix_cache.py rename to arctic_inference/suffix_decoding/cache.py index c06cf19e5..2c0cd32a0 100644 --- a/arctic_inference/common/suffix_cache/suffix_cache.py +++ b/arctic_inference/suffix_decoding/cache.py @@ -16,13 +16,15 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Hashable, KeysView, List, Optional, Sequence, Union +from typing import Hashable, KeysView, List, Optional, Sequence -from arctic_inference.common.suffix_cache._C import SuffixTree, Candidate +import numpy as np + +from arctic_inference.suffix_decoding._C import SuffixTree, Draft @dataclass -class SuffixSpecResult: +class SuffixDecodingDraft: """ A dataclass representing the result of a speculation using SuffixDecoding. @@ -34,7 +36,7 @@ class SuffixSpecResult: probs (List[float]): List of estimated probabilities for each token. score (float): The overall score of the suffix match computed as the sum of the estimated probabilities of each speculated token. - match_len (int): The length of the pattern match that yielded this + match_len (int): The length of the context match that yielded this speculation result. """ token_ids: List[int] = field(default_factory=list) @@ -44,30 +46,33 @@ class SuffixSpecResult: match_len: int = 0 @staticmethod - def from_candidate(candidate: Candidate) -> SuffixSpecResult: - return SuffixSpecResult( - token_ids=candidate.token_ids, - parents=candidate.parents, - probs=candidate.probs, - score=candidate.score, - match_len=candidate.match_len, + def from_native(draft: Draft) -> SuffixDecodingDraft: + return SuffixDecodingDraft( + token_ids=draft.token_ids, + parents=draft.parents, + probs=draft.probs, + score=draft.score, + match_len=draft.match_len, ) -class SuffixCache: +class SuffixDecodingCache: def __init__(self, max_tree_depth: int = 64, - max_cached_requests: Optional[int] = None): + max_cached_requests: int = -1): """ - Initialize the SuffixCache. + Initialize the SuffixDecodingCache. Args: max_tree_depth (int): The maximum depth of the suffix trees. max_cached_requests (int, optional): The maximum number of cached - requests. Cache eviction is used when the limit is reached. If - `None`, there is no limit on the number of cached requests. + requests. Eviction is triggered when the limit is reached. `-1` + means no limit on the number of cached requests. """ + if max_cached_requests > 0x7FFFFFFF: + raise ValueError("max_cached_requests must be at most 2^31") + self._max_tree_depth = max_tree_depth self._max_cached_requests = max_cached_requests @@ -78,7 +83,7 @@ def __init__(self, self._local_trees = {} # Maps between Python request ID and int32_t sequence ID. Tracks all - # request IDs that are in the global tree or one of the local trees. + # request IDs that are in the global tree. self._req_to_seq_id = {} self._seq_to_req_id = {} @@ -112,33 +117,54 @@ def cached_requests(self) -> KeysView: """ return self._req_to_seq_id.keys() - def start_request(self, req_id: Hashable, prompt_token_ids: Sequence[int]): + def start_request( + self, + req_id: Hashable, + prompt_token_ids: np.ndarray | Sequence[int], + ): """ This method should be called when starting to process a new request. It will store the prompt for the request, allowing future speculations for the same request to use the prompt context. The prompt will be stored - until `stop_request` is called. + until `stop_request` is called. If `max_cached_requests != 0`, then a + new slot is allocated in the global cache for the response, triggering + cache eviction (FIFO order) if needed. Args: req_id (Hashable): The request identifier. Must be a hashable value that uniquely identifies the request. - prompt_token_ids (Sequence[int]): A sequence of token IDs - representing the prompt of the request. + prompt_token_ids (np.ndarray | Sequence[int]): A sequence of token + IDs representing the prompt of the request. Raises: ValueError: If a request with the same `req_id` is already active or cached. """ - if req_id in self._req_to_seq_id: - raise ValueError(f"Request '{req_id}' is already active or cached") - seq_id = self._generate_seq_id(req_id) + if req_id in self._local_trees: + raise ValueError(f"Request '{req_id}' is already active") + + if isinstance(prompt_token_ids, np.ndarray): + # If input is a numpy array, use the zero-copy ndarray overload. + self._validate_ndarray(prompt_token_ids) + extend_func = SuffixTree.extend_ndarray + else: + extend_func = SuffixTree.extend + self._local_trees[req_id] = SuffixTree(self._max_tree_depth) - self._local_trees[req_id].extend(seq_id, prompt_token_ids) + extend_func(self._local_trees[req_id], 0, prompt_token_ids) + if self._max_cached_requests != 0: + # Global cache is enabled. + if req_id in self._req_to_seq_id: + # Evict existing cached response for the request if present. + self.evict_cached_response(req_id) + # Allocate a new seq_id for the request. + self._generate_seq_id(req_id) def stop_request(self, req_id: Hashable): """ This method should be called when a request is completed. It will evict - the prompt for the request, freeing up memory. + the prompt for the request, freeing up memory. The request's response + may still be cached in the global cache until it is evicted. Args: req_id (Hashable): The request identifier. Must be a hashable value @@ -154,7 +180,7 @@ def stop_request(self, req_id: Hashable): def add_active_response( self, req_id: Hashable, - token_ids: Union[int, Sequence[int]], + token_ids: np.ndarray | Sequence[int], ): """ Update the cached response for a given request by appending token(s) to @@ -163,51 +189,35 @@ def add_active_response( Args: req_id (Hashable): The unique identifier for the request. - token_ids (Union[int, Sequence[int]]): Either a single token ID - (int) or a sequence of token IDs to be appended to the response - for the given request. + token_ids (np.ndarray | Sequence[int]): A sequence of token IDs to + be appended to the response for the given request. Raises: ValueError: If the request with the given `req_id` is not active. """ if req_id not in self._local_trees: raise ValueError(f"Request '{req_id}' is not active") - seq_id = self._req_to_seq_id[req_id] - if isinstance(token_ids, int): - self._global_tree.append(seq_id, token_ids) - self._local_trees[req_id].append(seq_id, token_ids) - else: - self._global_tree.extend(seq_id, token_ids) - self._local_trees[req_id].extend(seq_id, token_ids) - def insert_new_response( - self, - req_id: Hashable, - token_ids: Union[int, Sequence[int]], - ): - """ - Insert a complete response to the global cache for a request that is - not active and is not already cached. + if isinstance(token_ids, np.ndarray): + # If input is a numpy array, use the zero-copy ndarray overload. + self._validate_ndarray(token_ids) + extend_func = SuffixTree.extend_ndarray + else: + extend_func = SuffixTree.extend - Args: - req_id (Hashable): The unique identifier for the request. - token_ids (Sequence[int]): A sequence of token IDs to be inserted - as the response for the given request. + # Update the local tree for the active request. + extend_func(self._local_trees[req_id], 0, token_ids) - Raises: - ValueError: If a request with the same `req_id` is already active - or cached. - """ + # Also update the response if the request is in the global cache (it + # may be evicted from the global cache before the request is stopped). if req_id in self._req_to_seq_id: - raise ValueError(f"Request '{req_id}' is already active or cached") - seq_id = self._generate_seq_id(req_id) - self._global_tree.extend(seq_id, token_ids) + seq_id = self._req_to_seq_id[req_id] + extend_func(self._global_tree, seq_id, token_ids) - def evict_request(self, req_id: Hashable): + def evict_cached_response(self, req_id: Hashable): """ - Evicts the given request's prompt and response from the cache. If the - request is active, it becomes inactive. The `req_id` can then be reused - after eviction. + Evicts the given request's response from the global cache. `req_id` can + be safely reused for a new request after eviction. Args: req_id (Hashable): The unique identifier for the request that @@ -217,9 +227,7 @@ def evict_request(self, req_id: Hashable): ValueError: If no response exists for the given request identifier. """ if req_id not in self._req_to_seq_id: - raise ValueError(f"Request '{req_id}' is not active or cached") - if req_id in self._local_trees: - del self._local_trees[req_id] + raise ValueError(f"Request '{req_id}' is not cached") seq_id = self._req_to_seq_id.pop(req_id) self._seq_to_req_id.pop(seq_id) self._global_tree.remove(seq_id) @@ -227,29 +235,33 @@ def evict_request(self, req_id: Hashable): def speculate( self, req_id: Hashable, - pattern: Sequence[int], + context: np.ndarray | Sequence[int], max_spec_tokens: Optional[int] = None, max_spec_factor: float = 1.0, max_spec_offset: float = 0.0, min_token_prob: float = 0.1, use_tree_spec: bool = False, - ) -> SuffixSpecResult: + ) -> SuffixDecodingDraft: """ Speculates and returns the most likely continuation of a given token - pattern using the request's prompt and the global cache of previous + context using the request's prompt and the global cache of previous responses. This method can only be called for active requests (i.e. after calling `start_request` and before calling `stop_request`). Args: req_id (Hashable): The unique identifier for the request. - pattern (Sequence[int]): The sequence of token IDs to match and - continue from. + context (np.ndarray | Sequence[int]): A sequence of token IDs to + match and speculate subsequent tokens from. max_spec_tokens (int): Maximum number of tokens to speculate. If 0, uses the cache's max_depth. max_spec_factor (float): Factor that limits speculation based on - matched pattern length. + matched context length. The number of speculated tokens is + limited by `max_spec_factor * match_length + max_spec_offset`. + max_spec_offset (float): Offset that limits speculation based on + matched context length. The number of speculated tokens is + limited by `max_spec_factor * match_length + max_spec_offset`. min_token_prob (float): Minimum estimated probability threshold for - candidate tokens. + draft tokens. use_tree_spec (bool): If True, uses tree-based speculation. Returns: @@ -262,32 +274,40 @@ def speculate( if req_id not in self._local_trees: raise ValueError(f"Request '{req_id}' is not active") + if isinstance(context, np.ndarray): + # If input is a numpy array, use the zero-copy ndarray overload. + self._validate_ndarray(context) + spec_func = SuffixTree.speculate_ndarray + else: + spec_func = SuffixTree.speculate + if max_spec_tokens is None: max_spec_tokens = self.max_depth - if len(pattern) > self._max_tree_depth: - pattern = pattern[-self._max_tree_depth :] + if len(context) > self._max_tree_depth: + context = context[-self._max_tree_depth :] - candidate = self._local_trees[req_id].speculate( - pattern, + draft1 = spec_func( + self._local_trees[req_id], + context, max_spec_tokens, max_spec_factor, max_spec_offset, min_token_prob, use_tree_spec) - result = SuffixSpecResult.from_candidate(candidate) - candidate = self._global_tree.speculate( - pattern, + draft2 = spec_func( + self._global_tree, + context, max_spec_tokens, max_spec_factor, max_spec_offset, min_token_prob, use_tree_spec) - if candidate.score > result.score: - result = SuffixSpecResult.from_candidate(candidate) - return result + draft = draft1 if draft1.score >= draft2.score else draft2 + + return SuffixDecodingDraft.from_native(draft) def _generate_seq_id(self, req_id: Hashable) -> int: # Find the next available seq_id not used by an active request. @@ -313,16 +333,25 @@ def _generate_seq_id(self, req_id: Hashable) -> int: return seq_id def _maybe_evict_requests(self, new_seq_id: int): - if self._max_cached_requests is None: + if self._max_cached_requests < 0: + # Negative value means no global cache size limit. return + assert self._max_cached_requests != 0 # Global cache must be enabled. while len(self._req_to_seq_id) > self._max_cached_requests: - # Evict the first request that is not active. Should be FIFO order - # in python 3.7+ as dict preserves insertion order. We also want to - # avoid evicting the request that was just added (new_seq_id). + # Evict the first eligible request. Should be FIFO order in Python + # 3.7+ since dict preserves insertion order. Avoid evicting the + # request that was just added (new_seq_id). for req_id, seq_id in self._req_to_seq_id.items(): - if seq_id != new_seq_id and req_id not in self._local_trees: - self.evict_request(req_id) + if seq_id != new_seq_id: + self.evict_cached_response(req_id) break - else: - # All previously cached requests are active, cannot evict any. - break + + def _validate_ndarray(self, arr: np.ndarray): + if arr.ndim != 1: + raise ValueError(f"ndarray input must have ndim=1, " + f"got ndim={arr.ndim}") + if arr.dtype != np.int32: + raise ValueError(f"ndarray input must have dtype=int32, " + f"got dtype={arr.dtype.name}") + if not arr.flags["CONTIGUOUS"]: + raise ValueError(f"ndarray input must be contiguous") diff --git a/arctic_inference/common/suffix_cache/simulator.py b/arctic_inference/suffix_decoding/simulator.py similarity index 93% rename from arctic_inference/common/suffix_cache/simulator.py rename to arctic_inference/suffix_decoding/simulator.py index 19a158c9f..0e70ef630 100644 --- a/arctic_inference/common/suffix_cache/simulator.py +++ b/arctic_inference/suffix_decoding/simulator.py @@ -22,20 +22,21 @@ from collections import OrderedDict from typing import Dict, List, Optional, Tuple +import numpy as np import pandas as pd from tqdm import tqdm from transformers import AutoTokenizer -from arctic_inference.common.suffix_cache import SuffixCache +from arctic_inference.suffix_decoding import SuffixDecodingCache os.environ["TOKENIZERS_PARALLELISM"] = "false" def suffix_decode( - suffix_cache: SuffixCache, + suffix_cache: SuffixDecodingCache, request_id: int, - prompt: List[int], - ground_truth_response: List[int], + prompt: np.ndarray, + ground_truth_response: np.ndarray, max_spec_tokens: int, max_spec_factor: float, min_token_prob: float, @@ -47,17 +48,21 @@ def suffix_decode( suffix_cache.start_request(request_id, prompt if use_cached_prompt else []) - assert isinstance(prompt, list) and isinstance(ground_truth_response, list) + assert isinstance(prompt, np.ndarray) + assert isinstance(ground_truth_response, np.ndarray) results = [] response = [] while len(response) < len(ground_truth_response): - text = prompt + response + if response: + sequence = np.concatenate([prompt, response], dtype=np.int32) + else: + sequence = prompt start_time = time.perf_counter() result = suffix_cache.speculate( request_id, - text, + sequence, max_spec_tokens=max_spec_tokens, max_spec_factor=max_spec_factor, min_token_prob=min_token_prob, @@ -105,7 +110,7 @@ def suffix_decode( "update_ms": update_time * 1000, }) - assert response == ground_truth_response + assert np.array_equal(response, ground_truth_response) suffix_cache.stop_request(request_id) @@ -159,7 +164,7 @@ def process_task( use_cached_prompt: bool, evict_fraction: float, evict_strategy: str, - max_cached_requests: Optional[int], + max_cached_requests: int, ) -> List[Dict]: eval_subset, train_subset = sample_data( dataset, @@ -168,8 +173,8 @@ def process_task( num_train, seed, ) - suffix_cache = SuffixCache(max_tree_depth=max_depth, - max_cached_requests=max_cached_requests) + suffix_cache = SuffixDecodingCache(max_tree_depth=max_depth, + max_cached_requests=max_cached_requests) train_request_ids = [] num_cached_tokens = {} # request_id -> num tokens for request_id, example in tqdm(train_subset.iterrows(), @@ -178,7 +183,9 @@ def process_task( # Use negative request_id to indicate training examples and avoid # conflicts with eval request_ids numbered 0, .., num_eval - 1. train_request_id = -1 - request_id - suffix_cache.insert_new_response(train_request_id, example["response"]) + suffix_cache.start_request(train_request_id, example["prompt"]) + suffix_cache.add_active_response(train_request_id, example["response"]) + suffix_cache.stop_request(train_request_id) train_request_ids.append(train_request_id) num_cached_tokens[train_request_id] = len(example["response"]) @@ -194,7 +201,7 @@ def process_task( rng = random.Random(seed) evict_ids = rng.sample(cached_request_ids, num_evict) for request_id in tqdm(evict_ids, desc="Evicting cached responses"): - suffix_cache.evict_request(request_id) + suffix_cache.evict_cached_response(request_id) print("Checking cache integrity...", end=" ", flush=True) if ret := suffix_cache._global_tree.check_integrity(): @@ -321,8 +328,10 @@ def tokenize_data(dataset: pd.DataFrame, tokenizer_name: str) -> pd.DataFrame: responses = [] for _, row in tqdm(dataset.iterrows(), total=len(dataset), desc="Tokenizing dataset"): - prompts.append(tokenizer.encode(row["prompt"])) - responses.append(tokenizer.encode(row["response"])) + prompt = tokenizer.encode(row["prompt"], return_tensors="np") + prompts.append(prompt.astype(np.int32).flatten()) + response = tokenizer.encode(row["response"], return_tensors="np") + responses.append(response.astype(np.int32).flatten()) return pd.DataFrame({ "prompt": prompts, "response": responses, @@ -550,8 +559,8 @@ def get_parser(): "--max-cached-requests", type=int, nargs="+", - default=[None], - help="Max number of cached requests (if None, unlimited)", + default=[-1], + help="Max number of cached requests (if -1, unlimited)", ) parser.add_argument( "--evict-fraction", diff --git a/arctic_inference/vllm/args.py b/arctic_inference/vllm/args.py index 8407e4bec..9cba50abd 100644 --- a/arctic_inference/vllm/args.py +++ b/arctic_inference/vllm/args.py @@ -20,7 +20,7 @@ from vllm.config import ParallelConfig from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser from arctic_inference.patching import ArcticPatch from arctic_inference.vllm.config import ArcticParallelConfig @@ -50,7 +50,6 @@ class EngineArgsPatch(ArcticPatch[EngineArgs]): _orig_add_cli_args = EngineArgs.add_cli_args _orig_from_cli_args = EngineArgs.__dict__["from_cli_args"].__wrapped__ _orig_create_engine_config = EngineArgs.create_engine_config - _orig_is_v1_supported_oracle = EngineArgs._is_v1_supported_oracle def __new__(cls, *args, **kwargs): # Override __new__ to return an ArcticEngineArgs instead of an @@ -109,6 +108,11 @@ def create_engine_config(self, *args, **kwargs): if (self.ulysses_sequence_parallel_size > 1 and self.distributed_executor_backend is None): self.distributed_executor_backend = "mp" + + # Store ulysses_sequence_parallel_size for access during config initialization + from arctic_inference.vllm import ulysses + ulysses._ulysses_sp_size = self.ulysses_sequence_parallel_size + vllm_config = self._orig_create_engine_config(*args, **kwargs) # Recreate the parallel config with Arctic parameters since they might # not be passed to the parallel config __init__ when first initialized. @@ -121,21 +125,6 @@ def create_engine_config(self, *args, **kwargs): vllm_config.parallel_config = ArcticParallelConfig(**kwargs) return vllm_config - def _is_v1_supported_oracle(self, *args, **kwargs): - orig_speculative_config = self.speculative_config - - # Since Arctic Inference is only compatible with v1 and we already - # check it earlier, we can just disable this check altogether. - if (self.speculative_config is not None and - self.speculative_config.get("method") in ("arctic", "suffix")): - self.speculative_config = None - - res = self._orig_is_v1_supported_oracle(*args, **kwargs) - - self.speculative_config = orig_speculative_config - - return res - class AsyncEngineArgsPatch(ArcticPatch[AsyncEngineArgs]): diff --git a/arctic_inference/vllm/config.py b/arctic_inference/vllm/config.py index 4cda01436..b310e1f4f 100644 --- a/arctic_inference/vllm/config.py +++ b/arctic_inference/vllm/config.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from pydantic.dataclasses import dataclass import logging +import vllm from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig @@ -55,8 +56,10 @@ def world_size(self, value: int) -> None: @dataclass class ArcticSpeculativeConfig(SpeculativeConfig): + method: str | None = None enable_suffix_decoding: bool = False suffix_cache_max_depth: int = 64 + suffix_speculative_tokens: int = 0 suffix_cache_max_requests: int = 100000 suffix_max_spec_factor: float = 1.0 suffix_max_spec_offset: float = 0.0 @@ -76,46 +79,66 @@ def __new__(cls, *args, **kwargs): class SpeculativeConfigPatch(ArcticPatch[SpeculativeConfig]): - _orig_from_dict = SpeculativeConfig.__dict__["from_dict"].__wrapped__ _orig_post_init = SpeculativeConfig.__post_init__ def __new__(cls, *args, **kwargs): - # Override __new__ to return an ArcticSpeculativeConfig instead of a - # SpeculativeConfig when creating a new instance of the class. if cls is SpeculativeConfig: return ArcticSpeculativeConfig.__new__(ArcticSpeculativeConfig, *args, **kwargs) return super(SpeculativeConfig, cls).__new__(cls) def __post_init__(self): - use_suffix = (self.method - == "suffix") or (self.method is None - and self.enable_suffix_decoding) - if (use_suffix or self.method == "arctic") and \ - self.disable_by_batch_size is None: + is_arctic_method = self.method in ("arctic", "mlp_speculator") + use_suffix = (self.method == "suffix") or (self.method is None + and self.enable_suffix_decoding) + use_hybrid = (self.method == "arctic" and self.enable_suffix_decoding) + + if (use_suffix or is_arctic_method) and self.disable_by_batch_size is None: logger.info("Defaulting disable_by_batch_size to 64") self.disable_by_batch_size = 64 + + if use_hybrid: + self.suffix_speculative_tokens = self.suffix_cache_max_depth if use_suffix: self.method = "suffix" self.enable_suffix_decoding = True - self.num_speculative_tokens = self.suffix_cache_max_depth + # Use suffix_speculative_tokens if explicitly set, otherwise + # default to 16 (not suffix_cache_max_depth which can be very + # large and makes every step process 1+N tokens even when the + # suffix cache has no matches). + # NOTE: num_speculative_tokens defaults to None (not 0). + if self.suffix_speculative_tokens > 0: + self.num_speculative_tokens = self.suffix_speculative_tokens + elif self.num_speculative_tokens is None: + self.num_speculative_tokens = 16 self._verify_args() + return + + if is_arctic_method: + actual_draft_model = getattr(self, "draft_model", None) + + self.draft_model = None + + try: + self._orig_post_init() + finally: + self.draft_model = actual_draft_model + + if self.num_speculative_tokens == 0: + self.num_speculative_tokens = getattr(self, "num_lookahead_slots", 1) else: self._orig_post_init() - @classmethod - def from_dict(cls, dict_value: dict) -> SpeculativeConfig: - """Parse the CLI value for the speculative config.""" - if cls is SpeculativeConfig: - return SpeculativeConfigPatch._orig_from_dict( - ArcticSpeculativeConfig, dict_value) - return SpeculativeConfigPatch._orig_from_dict(cls, dict_value) - class VllmConfigPatch(ArcticPatch[VllmConfig]): _orig_str = VllmConfig.__str__ + _orig_post_init = VllmConfig.__post_init__ + + from typing import Literal + OldEagleModelTypes = vllm.config.speculative.EagleModelTypes + NewEagleModelTypes = Literal["arctic", "suffix", OldEagleModelTypes] def __str__(self, *args, **kwargs): string = self._orig_str(*args, **kwargs) @@ -124,6 +147,24 @@ def __str__(self, *args, **kwargs): string += f", shift_parallel_threshold={self.parallel_config.shift_parallel_threshold}" return string + def __post_init__(self, *args, **kwargs): + # if self.speculative_config is not None: + # if self.speculative_config.method not in get_args(EagleModelTypes): + # raise ValueError( + # "Currently, async scheduling is only supported " + # "with EAGLE/MTP kind of speculative decoding" + # ) + import sys + from typing import Literal + target_module = sys.modules[VllmConfig.__module__] + original_types = getattr(target_module, "EagleModelTypes") + NewEagleModelTypes = Literal["mlp_speculator", "suffix", original_types] + setattr(target_module, "EagleModelTypes", NewEagleModelTypes) + try: + self._orig_post_init(*args, **kwargs) + finally: + setattr(target_module, "EagleModelTypes", original_types) + class MLPSpeculatorConfigPatch(ArcticPatch[MLPSpeculatorConfig]): @@ -132,3 +173,16 @@ class MLPSpeculatorConfigPatch(ArcticPatch[MLPSpeculatorConfig]): def __init__(self, *args, **kwargs): self.base_model_arch = kwargs.pop("base_model_arch", "") self._orig_init(*args, **kwargs) + + # Inject dummy attributes required by vLLM's ModelArchConfigConvertor + # The convertor tries to calculate head_size = hidden_size // num_attention_heads + if not hasattr(self, "num_attention_heads"): + self.num_attention_heads = 1 + + if not hasattr(self, "hidden_size"): + # Fallback to n_embd if present, otherwise default to a safe dummy value + self.hidden_size = getattr(self, "n_embd", 1024) + + # Ensure hidden_size is an integer to prevent TypeError during division + if hasattr(self, "hidden_size"): + self.hidden_size = int(self.hidden_size) diff --git a/arctic_inference/vllm/model_runner.py b/arctic_inference/vllm/model_runner.py index b5530ca1d..2cff488ef 100644 --- a/arctic_inference/vllm/model_runner.py +++ b/arctic_inference/vllm/model_runner.py @@ -5,7 +5,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,47 +15,58 @@ import contextlib import copy +import gc import time -from typing import Any, Union, Optional, TYPE_CHECKING -from itertools import tee +from typing import Any, Optional, TYPE_CHECKING, Union import numpy as np import torch +from tqdm import tqdm + import vllm.distributed.parallel_state as parallel_state import vllm.envs as envs -from tqdm import tqdm -from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.config import CompilationLevel +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import (get_pp_group, get_tp_group, is_global_first_rank) -from vllm.forward_context import set_forward_context -from vllm.config import VllmConfig +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.model_executor.model_loader import get_model from vllm.sequence import IntermediateTensors -from vllm.utils import round_up -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.utils.math_utils import round_up, cdiv +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput, + SamplerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import MAX_SPEC_LEN, RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.worker.gpu_model_runner import GPUModelRunner, logger +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_model_runner import ( + GPUModelRunner, + logger, + AsyncGPUModelRunnerOutput, +) +from vllm.v1.structured_output.utils import apply_grammar_bitmask if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput -from arctic_inference.common.suffix_cache import SuffixCache from arctic_inference.patching import ArcticPatch -from arctic_inference.vllm.spec_dec.arctic_proposer import ArcticProposer -from arctic_inference.common.suffix_cache import SuffixSpecResult +from arctic_inference.suffix_decoding import (SuffixDecodingCache, + SuffixDecodingDraft) +from arctic_inference.vllm.spec_dec.arctic_proposer import (ArcticProposer, + SuffixProposer) SP_TP_MODE = None @contextlib.contextmanager def set_shift_parallel_mode(mode: Optional[bool]): + """ + Swap the tensor-parallel group to an SP-compatible variant when 'mode' is True. + """ if mode is None: yield return @@ -76,7 +87,6 @@ def set_shift_parallel_mode(mode: Optional[bool]): try: yield finally: - # restore the original state SP_TP_MODE = old_mode parallel_state._TP = old_tp_group @@ -88,113 +98,299 @@ def is_shift_parallel_mode() -> bool: class GPUModelRunnerPatch(ArcticPatch[GPUModelRunner]): + """ + Rebased GPUModelRunnerPatch for vLLM v14. + """ - _orig_initialize_kv_cache = GPUModelRunner.initialize_kv_cache - _orig_prepare_inputs = GPUModelRunner._prepare_inputs + _orig_capture_cudagraphs = GPUModelRunner._capture_cudagraphs _orig_profile_run = GPUModelRunner.profile_run _orig_load_model = GPUModelRunner.load_model _orig_propose_draft_token_ids = GPUModelRunner.propose_draft_token_ids + _orig_dummy_run = GPUModelRunner._dummy_run _orig_init = GPUModelRunner.__init__ + _orig_build_attention_metadata = GPUModelRunner._build_attention_metadata + _orig_execute_model = GPUModelRunner.execute_model + _orig_bookkeeping_sync = GPUModelRunner._bookkeeping_sync + _orig_sample_tokens = GPUModelRunner.sample_tokens + _orig_initialize_kv_cache = GPUModelRunner.initialize_kv_cache + # _orig_pad_for_sequence_parallelism = GPUModelRunner._pad_for_sequence_parallelism def __init__( self, vllm_config: VllmConfig, device: torch.device, ): - # Ulysses sequence parallelism if vllm_config.parallel_config.ulysses_sequence_parallel_size > 1: self.use_ulysses = True pass_config = vllm_config.compilation_config.pass_config - if pass_config.enable_sequence_parallelism: + if pass_config.enable_sp: raise ValueError( "Ulysses sequence parallelism is incompatible with native " "sequence parallelism. Set enable_sequence_parallelism " - "to False in the pass config to use Ulysses.") + "to False in the pass config to use Ulysses." + ) else: self.use_ulysses = False - # Speculative decoding - # TODO: Use "arctic" as an umbrella method that also covers the Arctic - # Inverence version of "mlp_speculator". - if (vllm_config.speculative_config is not None and \ - vllm_config.speculative_config.method in ( - "arctic", "suffix", "mlp_speculator")): - # Delay the creation of the drafter until - # after the child class has been initialized. + arctic_methods = ("arctic", "suffix", "mlp_speculator") + is_arctic_spec = (vllm_config.speculative_config is not None and + vllm_config.speculative_config.method in arctic_methods) + + arctic_speculative_config = None + if is_arctic_spec: arctic_speculative_config = vllm_config.speculative_config vllm_config.speculative_config = None - else: - arctic_speculative_config = None self._orig_init(vllm_config, device) - # Set up speculative decoding. - self._suffix_cache = None - if arctic_speculative_config is not None: - # Restore the speculative config. + self._suffix_cache: Optional[SuffixDecodingCache] = None + + if is_arctic_spec: self.vllm_config.speculative_config = arctic_speculative_config self.speculative_config = arctic_speculative_config + self.num_spec_tokens = getattr(self.speculative_config, + "num_speculative_tokens", 0) + self.uniform_decode_query_len = 1 + self.num_spec_tokens + + if not hasattr(self, "draft_token_ids_cpu") or self.draft_token_ids_cpu is None: + self.draft_token_ids_event = torch.Event() + self.draft_token_ids_copy_stream = torch.cuda.Stream() + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + + if (self.use_async_scheduling + and self.speculative_config.method in ("arctic", "mlp_speculator", "suffix")): + if not hasattr(self, "valid_sampled_token_count_cpu") or self.valid_sampled_token_count_cpu is None: + self.valid_sampled_token_count_event = torch.Event() + self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) if get_pp_group().is_last_rank: - if (self.speculative_config.method == "arctic" - or self.speculative_config.method == "mlp_speculator"): + if self.speculative_config.method in ("arctic", "mlp_speculator"): self.drafter = ArcticProposer(self.vllm_config) - elif self.speculative_config.method != "suffix": - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") - - self.rejection_sampler = RejectionSampler() - - if (self.speculative_config is not None - and self.speculative_config.enable_suffix_decoding): - if self.speculative_config.method not in ("arctic", "suffix", - "mlp_speculator"): + elif self.speculative_config.method == "suffix": + self.drafter = SuffixProposer() + else: + raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") + + self.rejection_sampler = RejectionSampler(self.sampler) + + if (self.speculative_config is not None and + getattr(self.speculative_config, "enable_suffix_decoding", False)): + + if self.speculative_config.method not in arctic_methods: raise ValueError( "Suffix decoding is only supported with the 'arctic', " - "'mlp_speculator' or 'suffix' spec decoding methods.") + "'mlp_speculator' or 'suffix' spec decoding methods." + ) spec_cfg = self.speculative_config - self._suffix_cache = SuffixCache( + self._suffix_cache = SuffixDecodingCache( max_tree_depth=spec_cfg.suffix_cache_max_depth, - max_cached_requests=spec_cfg.suffix_cache_max_requests) + max_cached_requests=spec_cfg.suffix_cache_max_requests + ) + + # Async suffix decoding infrastructure: a dedicated CUDA stream and + # pinned buffer for copying sampled token IDs to CPU *without* + # serialising behind Arctic GPU drafting work on the default stream. + if self._suffix_cache is not None and self.use_async_scheduling: + self.suffix_copy_stream = torch.cuda.Stream() + self.suffix_copy_done_event = torch.Event() + max_gen_len = 1 + self.num_spec_tokens + self.suffix_sampled_ids_pinned = torch.empty( + (self.max_num_reqs, max_gen_len), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + # Pinned buffer for suffix merge results. Using pinned memory + # for H2C copies avoids the implicit default-stream + # synchronisation that cudaMemcpyAsync performs with pageable + # (non-pinned) source memory. Without this, the merge step + # blocks the CPU until ALL pending GPU work (including Arctic + # drafting) completes, destroying the async overlap. + self._suffix_merge_pinned = torch.zeros( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + + # Pre-allocated GPU buffer for the merged draft tensor. + # Avoids per-step F.pad / torch.zeros allocations in the async + # Arctic drafting path (propose_draft_token_ids + suffix merge). + # Shape: [max_num_reqs, num_spec_tokens], int64 (matches + # draft_token_ids_cpu for zero-cost _copy_draft_token_ids_to_cpu). + if (self.speculative_config is not None + and self.use_async_scheduling + and self.speculative_config.method + in ("arctic", "mlp_speculator")): + self._draft_merged_gpu = torch.zeros( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, device=self.device, + ) + + # Pre-allocated pinned index buffer for suffix merge overlay. + # Avoids per-step torch.tensor(...).pin_memory() allocations. + if self._suffix_cache is not None and self.use_async_scheduling: + self._suffix_index_pinned = torch.empty( + self.max_num_reqs, dtype=torch.long, + device="cpu", pin_memory=self.pin_memory, + ) + + # Per-request response tokens for suffix pattern building in async + # mode. In async scheduling, _bookkeeping_sync writes -1 placeholders + # to token_ids_cpu instead of real values, corrupting the pattern that + # propose_suffix_draft_token_ids reads. We keep a clean copy here. + self._suffix_response_tokens: dict[str, list[int]] = {} + + # Actual draft lengths per request from the previous step. Used + # by execute_model to trim the scheduler's spec token allocation + # down to the real draft width, and communicated back to the + # scheduler (via scheduler_output._actual_draft_lens) so + # _update_after_schedule can set dynamic placeholder counts. + self._prev_actual_draft_lens: dict[str, int] = {} + + # Backup-token buffer used by suffix-only async rejection sampling. + # The arctic proposer has its own buffer; this one covers the case + # where no arctic drafter is present. + self._suffix_backup_tokens_gpu: Optional[torch.Tensor] = None + if (self._suffix_cache is not None + and self.use_async_scheduling + and self.speculative_config.method not in ("arctic", + "mlp_speculator")): + self._suffix_backup_tokens_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=self.device, + ) + + def _suffix_only_rejection_sample( + self, + sampled_token_ids: torch.Tensor, + common_attn_metadata: "CommonAttentionMetadata", + ) -> None: + """Rejection-sample accepted tokens for suffix-only async scheduling. + + EAGLE / arctic do this inside propose_draft_token_ids via + prepare_next_token_ids_padded. For suffix-only there is no + drafter with that method, so we call the same Triton kernel + directly and feed the results into _copy_valid_sampled_token_count. + """ + from vllm.triton_utils import triton + from vllm.v1.spec_decode.utils import ( + eagle_prepare_next_token_padded_kernel, + ) + + num_reqs = self.input_batch.num_reqs + batch_size, num_tokens = sampled_token_ids.shape + device = sampled_token_ids.device + + # Compute backup tokens (last accepted token per request) on CPU, + # then copy to GPU in one shot to avoid per-element synchronisation. + backup = self._suffix_backup_tokens_gpu + assert backup is not None + backup_np = np.empty(num_reqs, dtype=np.int32) + for i in range(num_reqs): + req_id = self.input_batch.req_ids[i] + seq_len = int(common_attn_metadata.seq_lens_cpu[i].item()) + backup_np[i] = self.requests[req_id].get_token_id(seq_len) + # Copy directly from CPU numpy-backed tensor to GPU; avoids + # creating an intermediate GPU tensor via .to(device). + backup[:num_reqs].copy_( + torch.from_numpy(backup_np), non_blocking=True, + ) + + next_token_ids = torch.empty(batch_size, dtype=torch.int32, + device=device) + valid_counts = torch.empty(batch_size, dtype=torch.int32, + device=device) + + BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens) + eagle_prepare_next_token_padded_kernel[(batch_size,)]( + sampled_token_ids, + self.discard_request_mask.gpu, + backup, + next_token_ids, + valid_counts, + self.model_config.get_vocab_size(), + num_tokens, + batch_size, + sampled_token_ids.stride(0), + BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS, + ) + + self._copy_valid_sampled_token_count(next_token_ids, valid_counts) + + def _build_attention_metadata(self, *args, **kwargs): + attn_metadata, spec_decode_common_attn_metadata = \ + self._orig_build_attention_metadata(*args, **kwargs) + + logits_indices = kwargs.get("logits_indices", None) + if logits_indices is not None: + if isinstance(attn_metadata, list): + for ub in attn_metadata: + for meta in ub.values(): + meta.swiftkv_logits_indices = logits_indices + else: + for meta in attn_metadata.values(): + meta.swiftkv_logits_indices = logits_indices + + return attn_metadata, spec_decode_common_attn_metadata + + # set padding for SP here + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: + + sp_size = self.parallel_config.ulysses_sequence_parallel_size + num_input_tokens = round_up(num_scheduled_tokens, sp_size) + + #if torch.distributed.get_rank() == 0: + # print(f"padding num_scheduled_tokens {num_scheduled_tokens} -> num_input_tokens {num_input_tokens}") + + return num_input_tokens + def profile_run(self) -> None: self._orig_profile_run() - if self.shift_model is not None: - # Run the shift model to trigger compilation. + if getattr(self, "shift_model", None) is not None: orig_model, self.model = self.model, self.shift_model + cc = self.vllm_config.compilation_config + base_ctx = cc.static_forward_context + shift_ctx = getattr(self, 'shift_forward_context', None) try: + if shift_ctx is not None: + cc.static_forward_context = shift_ctx with set_shift_parallel_mode(True): self._dummy_run(self.max_num_tokens, is_profile=True) finally: self.model = orig_model + cc.static_forward_context = base_ctx - def _prepare_inputs(self, *args, **kwargs): - attn_metadata, attention_cuda_graphs, logits_indices, *rest = ( - self._orig_prepare_inputs(*args, **kwargs)) - # SwiftKV requires knowing the logits indices from inside the model - # definition in order to early-stop the prefill tokens. - for meta in attn_metadata.values(): - meta.swiftkv_logits_indices = logits_indices - return attn_metadata, attention_cuda_graphs, logits_indices, *rest def monkeypatch_forward(self: GPUModelRunner): + """ + Slice the batch across Ulysses SP ranks for forward, then all-gather. + """ sp_size = parallel_state._SP.world_size sp_rank = parallel_state._SP.rank_in_group device_group = parallel_state._SP.device_group model_forward = self.model.forward - input_key = 'inputs_embeds' if self.is_multimodal_model else 'input_ids' + input_key = 'inputs_embeds' if self.supports_mm_inputs else 'input_ids' def ulysses_forward(*args, **kwargs): - # update inputs input_tensor = kwargs[input_key] positions = kwargs['positions'] - # Ulysses parameters - N = input_tensor.shape[0] + N = input_tensor.shape[0] N_ulysses = N // sp_size N_offset = N_ulysses * sp_rank - # narrow the input kwargs[input_key] = input_tensor[N_offset:N_offset + N_ulysses] kwargs['positions'] = positions[N_offset:N_offset + N_ulysses] @@ -202,459 +398,629 @@ def ulysses_forward(*args, **kwargs): output = model_forward(*args, **kwargs) if output.size(0) == N_ulysses: - # all-gather model_output - model_output = torch.empty((N, self.hidden_size), + model_output = torch.empty((N, output.shape[1]), dtype=output.dtype, device=output.device) torch.distributed.all_gather_into_tensor(model_output, output, group=device_group) else: - # SwiftKV models will already have all-gathered the output. assert output.size(0) == N model_output = output return model_output - self.model.forward = ulysses_forward + self.get_model().forward = ulysses_forward @torch.inference_mode() - def execute_model( + def _dummy_run( self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output) - - # Prepare the decoder inputs. - (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - use_shift_model = (self.use_ulysses and self.shift_model is not None - and num_scheduled_tokens - <= self.shift_parallel_threshold) - if self.use_ulysses and not use_shift_model: - # add padding to the batch size to make it a multiple of SP - sp_size = self.parallel_config.ulysses_sequence_parallel_size - num_input_tokens = round_up(num_scheduled_tokens, sp_size) - if (self.use_cuda_graph and num_input_tokens // sp_size - <= self.cudagraph_batch_sizes[-1]): - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_input_tokens // sp_size) * sp_size - elif (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) + num_tokens: int, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + + from vllm.v1.worker.gpu_model_runner import supports_mm_encoder_only + if supports_mm_encoder_only(self.model): + return torch.tensor([]), torch.tensor([]) + + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + + if create_mixed_batch: + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + max_query_len = num_prefill_tokens + elif uniform_decode: + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens - - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad - - # _prepare_inputs may reorder the batch, so we must gather multi - # modal outputs after that to ensure the correct order - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - - if self.is_multimodal_model and get_pp_group().is_first_rank: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_scheduled_tokens] - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - input_ids, mm_embeds) + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + if torch.distributed.get_rank() == 0: + print(f"num_tokens_unpadded: {num_tokens_unpadded}, num_reqs: {num_reqs}") + + _cg_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + force_uniform_decode=uniform_decode, + force_has_lora=activate_lora, + ) + ) + + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cg_mode + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + + from vllm.v1.worker.gpu_model_runner import maybe_create_ubatch_slices + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + + logits_indices_cpu = np.cumsum(num_scheduled_tokens) - 1 + logits_indices = torch.from_numpy(logits_indices_cpu).to(self.device) + + attn_metadata = None + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if create_mixed_batch: + seq_lens_list = [1] * num_decode_tokens + [num_prefill_tokens + 1] else: - inputs_embeds = self.model.get_input_embeddings(input_ids) - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) - inputs_embeds = self.inputs_embeds[:num_input_tokens] - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - else: - positions = self.positions[:num_input_tokens] + seq_lens_list = [max_query_len] * num_reqs # simplified + + self.seq_lens.np[:num_reqs] = seq_lens_list + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + pad_attn = (cudagraph_runtime_mode == CUDAGraphMode.FULL) + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, + for_cudagraph_capture=is_graph_capturing, + ) - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) - - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, but we - # compiled with full CUDA graphs, we have to skip them entirely. - skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - - # Run the model. - # Use persistent buffers for CUDA graphs. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - ): - self.maybe_setup_kv_connector(scheduler_output) - - model = self.shift_model if use_shift_model else self.model - with set_shift_parallel_mode(use_shift_model): - model_output = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + if attn_metadata is not None: + if isinstance(attn_metadata, list): + for ub_meta in attn_metadata: + for meta in ub_meta.values(): + meta.swiftkv_logits_indices = logits_indices + else: + for meta in attn_metadata.values(): + meta.swiftkv_logits_indices = logits_indices + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, + num_sampled_tokens, activate_lora, remove_lora): + + model_kwargs = self._init_model_kwargs() + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + model_kwargs.update(self._dummy_mm_kwargs(num_reqs)) + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + else: + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = None - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + positions = self.positions.gpu[:num_tokens_padded] + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens_padded] - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None - - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - if not broadcast_pp_output: - return hidden_states - assert isinstance(hidden_states, IntermediateTensors) - get_pp_group().send_tensor_dict(hidden_states.tensors, - all_gather_group=get_tp_group()) - logits = None - else: - if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving) - - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) - - # Sample the next token and get logprobs if needed. + intermediate_tensors = None + if not get_pp_group().is_first_rank: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, dtype=self.model_config.dtype, device=self.device) + intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) + + target_num_tokens = num_tokens_padded + if ubatch_slices_padded is not None: + target_num_tokens = ubatch_slices_padded[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = target_num_tokens + + with self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( + attn_metadata, self.vllm_config, num_tokens=target_num_tokens, + num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded): + + outputs = self.model(input_ids=input_ids, positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, **model_kwargs) + + hidden_states = outputs[0] if self.use_aux_hidden_state_outputs else outputs + + if self.speculative_config and self.speculative_config.use_eagle(): + self.drafter.dummy_run(num_tokens, use_cudagraphs=False, is_graph_capturing=is_graph_capturing) + + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + return hidden_states, hidden_states[logits_indices] + + # ------------------------------------------------------------------ + # _sample: inline the base GPUModelRunner._sample logic here because + # the class is monkey-patched at runtime, making both super() and + # GPUModelRunner._sample(self, ...) resolve back to this method. + # ------------------------------------------------------------------ + def _sample( + self, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> SamplerOutput: sampling_metadata = self.input_batch.sampling_metadata + self.input_batch.update_async_output_token_ids() if spec_decode_metadata is None: - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - - num_nans_in_logits = {} - if envs.VLLM_COMPUTE_NANS_IN_LOGITS: - num_nans_in_logits = self._get_nans_in_logits(logits) - - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) - - # NOTE: GPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:num_scheduled_tokens], - scheduler_output, + return self.sampler( + logits=logits, sampling_metadata=sampling_metadata) + + if (self.use_async_scheduling + and self._draft_token_req_ids is not None): + draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu() + self.input_batch.update_async_spec_token_ids( + draft_token_ids_cpu) + + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, ) + self._update_states_after_model_execute( + sampler_output.sampled_token_ids) + return sampler_output - # Get the valid generated tokens. - sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ + ModelRunnerOutput, AsyncGPUModelRunnerOutput, IntermediateTensors + ]: + num_scheduled_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None) + if num_scheduled_tokens is None: + try: + num_scheduled_tokens = int( + sum(scheduler_output.num_scheduled_tokens.values()) + ) + except Exception: + num_scheduled_tokens = 0 + + use_shift_model = ( + getattr(self, "use_ulysses", False) + and getattr(self, "shift_model", None) is not None + and num_scheduled_tokens <= int(getattr(self, "shift_parallel_threshold", 0)) + ) + + if not use_shift_model: + return self._orig_execute_model(scheduler_output, intermediate_tensors) + + orig_model = self.model + cc = self.vllm_config.compilation_config + base_ctx = cc.static_forward_context + shift_ctx = getattr(self, 'shift_forward_context', None) + try: + self.model = self.shift_model + if shift_ctx is not None: + cc.static_forward_context = shift_ctx + with set_shift_parallel_mode(True), \ + self._use_shift_cudagraph_tables(): + result = self._orig_execute_model(scheduler_output, intermediate_tensors) + finally: + self.model = orig_model + cc.static_forward_context = base_ctx + return result + + @torch.inference_mode + def sample_tokens(self, grammar_output): + """Wrapper around base sample_tokens for arctic async spec decode. + + Saves execute_model_state before the base clears it, then handles + the 'not-fits-in-drafter' case that the base only handles for Eagle. + """ + _arctic_saved_state = None + if (self.execute_model_state is not None + and self.speculative_config is not None + and self.speculative_config.method + in ("arctic", "mlp_speculator", "suffix") + and self.use_async_scheduling): + _arctic_saved_state = ( + self.execute_model_state.scheduler_output, + self.execute_model_state.spec_decode_common_attn_metadata, ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() - # Cache the sampled tokens in the model runner, so that the scheduler - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): - if not sampled_ids: - continue + result = self._orig_sample_tokens(grammar_output) - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) + # If _arctic_async_sampled_tensor was stashed by _bookkeeping_sync + # but never consumed by propose_draft_token_ids, this is the + # not-fits-in-drafter case. Mirror Eagle's handling: call + # _copy_valid_sampled_token_count and set draft tokens to zeros. + stashed = getattr(self, '_arctic_async_sampled_tensor', None) + if stashed is not None: + del self._arctic_async_sampled_tensor + if _arctic_saved_state is not None: + scheduler_output, common_attn_meta = _arctic_saved_state + self._arctic_handle_not_fits( + stashed, scheduler_output, common_attn_meta) - if self._suffix_cache is not None: - self._update_suffix_cache(valid_sampled_token_ids) + return result - if not self.speculative_config: - # Speculative decoding is not enabled. - spec_token_ids = None + def _arctic_handle_not_fits( + self, + sampled_token_ids: torch.Tensor, + scheduler_output: "SchedulerOutput", + common_attn_metadata, + ) -> None: + """Mirror Eagle's not-fits-in-drafter path for arctic async. + + When the input is too long for the drafter but spec decode is + active, Eagle still calls prepare_next_token_ids_padded / + _copy_valid_sampled_token_count and sets draft tokens to zeros. + Without this, _get_valid_sampled_token_count returns stale counts + and _prepare_input_ids scatters -1 placeholders into the + embedding layer. + """ + if (hasattr(self, 'drafter') + and hasattr(self.drafter, 'prepare_next_token_ids_padded') + and common_attn_metadata is not None): + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count) else: - spec_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - sampler_output.sampled_token_ids, - sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - attn_metadata, - ) - - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - - self.eplb_step() - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, - ) + # Fallback for drafters without prepare_next_token_ids_padded + # (e.g. suffix-only). Compute valid counts with PyTorch ops. + mask = sampled_token_ids != -1 + valid_counts = mask.sum(dim=1) + batch_size = sampled_token_ids.shape[0] + col_indices = torch.arange( + sampled_token_ids.shape[1], + device=sampled_token_ids.device, + ).unsqueeze(0).expand_as(sampled_token_ids) + last_valid_col = ( + col_indices.masked_fill(~mask, -1).max(dim=1).values) + last_valid_col = last_valid_col.clamp(min=0) + next_token_ids = sampled_token_ids[ + torch.arange(batch_size, + device=sampled_token_ids.device), + last_valid_col, + ] + self._copy_valid_sampled_token_count( + next_token_ids, valid_counts) + + # Zero draft tokens -- same as Eagle's not-fits path. + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32, + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu( + scheduler_output, zeros_only=True) + + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, + ): + """Wrap base _bookkeeping_sync to handle arctic async spec decode. + + In the base vLLM code, only Eagle-style drafters run *before* + bookkeeping (setting prev_sampled_token_ids via + _copy_valid_sampled_token_count). Arctic/suffix drafting runs + *after* bookkeeping, so prev_sampled_token_ids is still None + when bookkeeping checks ``assert sampled_token_ids.shape[-1] == 1``. + + We fix this by: + 1. Saving the GPU sampled tensor for propose_draft_token_ids. + 2. Setting prev_sampled_token_ids to a placeholder so the + assertion is skipped. The real value will be written by + _copy_valid_sampled_token_count inside propose_draft_token_ids + (fits case) or sample_tokens (not-fits case). + """ + sampled_token_ids = sampler_output.sampled_token_ids + if (self.use_async_scheduling + and self.speculative_config is not None + and self.speculative_config.method + in ("arctic", "mlp_speculator", "suffix") + and spec_decode_metadata is not None + and sampled_token_ids.shape[-1] > 1 + and self.input_batch.prev_sampled_token_ids is None): + # Stash the full GPU tensor so propose_draft_token_ids can + # pick it up later (it normally only receives an empty list + # in the post-bookkeeping path). + self._arctic_async_sampled_tensor = sampled_token_ids + # Placeholder: first column only (bonus token per request). + # Prevents the assertion from firing; the correct value will + # be overwritten by _copy_valid_sampled_token_count shortly. + self.input_batch.prev_sampled_token_ids = ( + sampled_token_ids[:, :1]) + + return self._orig_bookkeeping_sync( + scheduler_output, sampler_output, logits, hidden_states, + num_scheduled_tokens, spec_decode_metadata) def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], - original_sampled_token_ids: np.ndarray, + sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata], - attn_metadata: dict[str, Any], - ) -> list[list[int]]: - disable_spec_decode = (self.speculative_config and - self.speculative_config.disable_by_batch_size - and len(self.input_batch.req_ids) - > self.speculative_config.disable_by_batch_size) - if disable_spec_decode: - # No speculative decoding is enabled. - return [[] for _ in sampled_token_ids] + aux_hidden_states: list[torch.Tensor] | None, + spec_decode_metadata: SpecDecodeMetadata | None, + common_attn_metadata: CommonAttentionMetadata, + ) -> list[list[int]] | torch.Tensor: + # In async mode, the base vLLM dispatches arctic to the + # post-bookkeeping path which passes valid_sampled_token_ids + # (an empty list for async). Recover the stashed GPU tensor + # so the fast async drafting path below can activate. + if (isinstance(sampled_token_ids, list) + and len(sampled_token_ids) == 0 + and hasattr(self, '_arctic_async_sampled_tensor')): + sampled_token_ids = self._arctic_async_sampled_tensor + del self._arctic_async_sampled_tensor + + # Compute the maximum number of requests to draft for. + # When disable_by_batch_size is set and the batch exceeds it, + # we still draft for the first N requests instead of disabling + # entirely. This avoids the stale-data crash that occurs when + # drafting is fully disabled one step and re-enabled the next. + batch_size = len(self.input_batch.req_ids) + draft_limit = batch_size # default: draft for all + if ( + self.speculative_config + and self.speculative_config.disable_by_batch_size + and batch_size > self.speculative_config.disable_by_batch_size + ): + draft_limit = self.speculative_config.disable_by_batch_size + + use_async_path = ( + self.speculative_config.method in ("arctic", "mlp_speculator") + and isinstance(sampled_token_ids, torch.Tensor) + and self.use_async_scheduling + and common_attn_metadata is not None + ) + + if use_async_path: + assert isinstance(sampled_token_ids, torch.Tensor) + + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + target_hidden_states = self.drafter.prepare_hidden_states( + sample_hidden_states=sample_hidden_states, + sampled_token_ids=sampled_token_ids, + spec_decode_metadata=spec_decode_metadata, + ) + + # Only draft for the first draft_limit requests. + raw_draft = self.drafter.propose( + context_token_ids=next_token_ids[:draft_limit], + previous_hidden_states=target_hidden_states[:draft_limit], + num_predict_tokens=self.drafter.model.n_predict, + ) + # Use pre-allocated GPU buffer when available. This avoids + # per-step F.pad + torch.zeros allocations. The buffer is + # [max_num_reqs, num_spec_tokens] so a single zero_() + + # copy_() handles both width and batch padding in one shot. + merged_buf = getattr(self, '_draft_merged_gpu', None) + if merged_buf is not None: + draft = merged_buf[:batch_size] + draft.zero_() + rd_rows, rd_cols = raw_draft.shape + draft[:rd_rows, :rd_cols].copy_(raw_draft) + else: + draft = raw_draft + if draft.shape[1] < self.num_spec_tokens: + draft = torch.nn.functional.pad( + draft, + (0, self.num_spec_tokens - draft.shape[1]), + value=0, + ) + if draft_limit < batch_size: + full_draft = torch.zeros( + batch_size, draft.shape[1], + dtype=draft.dtype, device=draft.device, + ) + full_draft[:draft_limit] = draft + draft = full_draft + return draft + + if isinstance(sampled_token_ids, torch.Tensor): + vocab_size = self.model_config.get_vocab_size() + sampled_token_ids_list = [ + [t for t in seq if t != -1 and t < vocab_size] + for seq in sampled_token_ids.tolist() + ] + sampled_token_ids_tensor = sampled_token_ids + else: + sampled_token_ids_list = sampled_token_ids + sampled_token_ids_tensor = None + + arctic_spec_token_ids = None suffix_spec_token_ids = None - new_sampled_token_ids = sampled_token_ids.copy() + + if self.speculative_config.method in ("arctic", "mlp_speculator"): + if sampled_token_ids_tensor is None: + import numpy as np + sampled_token_ids_tensor = torch.tensor(sampled_token_ids_list, device=self.device) + + previous_hidden_states = self.drafter.prepare_hidden_states( + sample_hidden_states=sample_hidden_states, + sampled_token_ids=sampled_token_ids_tensor, + spec_decode_metadata=spec_decode_metadata, + ) + + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids_list, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) + + # Only draft for the first draft_limit requests. + arctic_output_tensor = self.drafter.propose( + context_token_ids=next_token_ids[:draft_limit], + previous_hidden_states=previous_hidden_states[:draft_limit], + num_predict_tokens=self.drafter.model.n_predict, + ) + + arctic_spec_token_ids = arctic_output_tensor.tolist() + # Pad with empty lists for requests beyond draft_limit. + if draft_limit < batch_size: + arctic_spec_token_ids.extend( + [] for _ in range(batch_size - draft_limit) + ) + if self._suffix_cache is not None: - results = self.propose_suffix_draft_token_ids( - new_sampled_token_ids) + self._update_suffix_cache(sampled_token_ids_list) + results = self.propose_suffix_draft_token_ids(sampled_token_ids_list) + suffix_spec_token_ids = [] - # The score is an estimate of the acceptance length. Thus, the - # heuristic is to use the suffix decoded tokens if the score is - # greater than the # of tokens we would speculate otherwise. - min_score = (self.speculative_config.num_speculative_tokens - if self.speculative_config.method != "suffix" else 0) - min_score = (0 if self.speculative_config.method == "suffix" else - self.speculative_config.num_speculative_tokens) - for i, result in enumerate(results): + min_score = 0 if self.speculative_config.method == "suffix" \ + else self.drafter.model.n_predict + + for result in results: if result.score >= min_score: - # Use suffix decoded tokens, disable other speculation - # methods for this request. - new_sampled_token_ids[i] = [] suffix_spec_token_ids.append(result.token_ids) else: suffix_spec_token_ids.append([]) spec_token_ids = None - if self.speculative_config.method == "suffix": - pass - elif (self.speculative_config.method == "arctic" - or self.speculative_config.method == "mlp_speculator"): - assert isinstance(self.drafter, ArcticProposer) - previous_hidden_states = self.drafter.prepare_hidden_states( - sample_hidden_states=sample_hidden_states, - sampled_token_ids=original_sampled_token_ids, - spec_decode_metadata=spec_decode_metadata, - ) - spec_token_ids = self.propose_arctic_draft_token_ids( - scheduler_output, - new_sampled_token_ids, - previous_hidden_states=previous_hidden_states) + if suffix_spec_token_ids is not None and arctic_spec_token_ids is not None: + spec_token_ids = [ + s_tokens if s_tokens else a_tokens + for s_tokens, a_tokens in zip(suffix_spec_token_ids, arctic_spec_token_ids) + ] + elif suffix_spec_token_ids is not None: + spec_token_ids = suffix_spec_token_ids + elif arctic_spec_token_ids is not None: + spec_token_ids = arctic_spec_token_ids else: spec_token_ids = self._orig_propose_draft_token_ids( scheduler_output, - new_sampled_token_ids, + sampled_token_ids_list, sampling_metadata, hidden_states, sample_hidden_states, aux_hidden_states, spec_decode_metadata, - attn_metadata, + common_attn_metadata, ) if spec_token_ids is None: - spec_token_ids = suffix_spec_token_ids - elif suffix_spec_token_ids is not None: - spec_token_ids = [ - suffix_spec_token_ids[i] or spec_token_ids[i] - for i in range(len(suffix_spec_token_ids)) - ] - - return spec_token_ids - - def propose_arctic_draft_token_ids( - self, - scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], - previous_hidden_states: Optional[torch.Tensor] = None, - ) -> list[list[int]]: - last_tokens: list[int] = [] - max_spec_tokens = self.speculative_config.num_speculative_tokens - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - - if (num_sampled_ids == 0): - if self.speculative_config.enable_suffix_decoding: - return [[]] * len(sampled_token_ids) - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - sampled_ids = [req_state.get_token_id(seq_len)] - - # Add sampled_token_ids to token_ids_cpu. - start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + num_sampled_ids - - max_spec_tokens = min( - max_spec_tokens, - self.max_model_len - end_idx - 1, + spec_token_ids = [[] for _ in range(len(self.input_batch.req_ids))] + + # For async scheduling the base _prepare_input_ids asserts that + # _draft_token_ids is a torch.Tensor and uses it to scatter draft + # tokens into the input. If we reached here (non-async code path) + # while async scheduling is active we must: + # 1. Convert the list-of-lists draft tokens to a padded tensor. + # 2. Call _copy_valid_sampled_token_count so the next step's + # _get_valid_sampled_token_count returns correct counts. + if (self.use_async_scheduling + and isinstance(spec_token_ids, list) + and isinstance(sampled_token_ids, torch.Tensor)): + # --- _copy_valid_sampled_token_count --- + if (hasattr(self, 'drafter') + and hasattr(self.drafter, 'prepare_next_token_ids_padded') + and common_attn_metadata is not None): + next_tok, valid_cnt = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + )) + self._copy_valid_sampled_token_count(next_tok, valid_cnt) + else: + # Manual fallback (suffix-only drafter). + mask = sampled_token_ids != -1 + valid_cnt = mask.sum(dim=1) + _bs = sampled_token_ids.shape[0] + cols = torch.arange( + sampled_token_ids.shape[1], + device=sampled_token_ids.device, + ).unsqueeze(0).expand_as(sampled_token_ids) + last_col = cols.masked_fill(~mask, -1).max(dim=1).values + last_col = last_col.clamp(min=0) + next_tok = sampled_token_ids[ + torch.arange(_bs, device=sampled_token_ids.device), + last_col, + ] + self._copy_valid_sampled_token_count(next_tok, valid_cnt) + + # --- Convert list[list[int]] -> padded tensor --- + padded = torch.zeros( + batch_size, self.num_spec_tokens, + dtype=torch.int32, device=self.device, ) - if max_spec_tokens <= 0: - continue - - self.input_batch.token_ids_cpu[i, - start_idx:end_idx] = sampled_ids[-1] - last_tokens.append(self.input_batch.token_ids_cpu[i, end_idx - 1]) + for i, tokens in enumerate(spec_token_ids): + length = min(len(tokens), self.num_spec_tokens) + if length > 0: + padded[i, :length] = torch.tensor( + tokens[:length], dtype=torch.int32, + device=self.device) + spec_token_ids = padded - if max_spec_tokens <= 0: - return [[] for _ in sampled_token_ids] - - drafter_output = self.drafter.propose( - last_tokens, - previous_hidden_states=previous_hidden_states, - num_predict_tokens=max_spec_tokens, - ) - - draft_token_ids = drafter_output.tolist() - - for i, sampled_ids in enumerate(sampled_token_ids): - if not sampled_ids: - draft_token_ids[i] = [] - - return draft_token_ids + return spec_token_ids def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None: seen_req_ids = set() @@ -666,207 +1032,842 @@ def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None: continue index = self.input_batch.req_id_to_index[req_id] - if req_id not in self._suffix_cache.active_requests: + is_new = req_id not in self._suffix_cache.active_requests + if is_new: if req_id in self._suffix_cache.cached_requests: - # Reset the suffix cache for this request. - self._suffix_cache.evict_request(req_id) + self._suffix_cache.evict_cached_response(req_id) num_prompt_tokens = self.input_batch.num_prompt_tokens[index] - prompt_token_ids = ( - self.input_batch.token_ids_cpu[index, :num_prompt_tokens]) - self._suffix_cache.start_request(req_id, prompt_token_ids) + prompt_token_ids = self.input_batch.token_ids_cpu[index, :num_prompt_tokens] + self._suffix_cache.start_request(req_id, prompt_token_ids.tolist()) + self._suffix_response_tokens[req_id] = [] self._suffix_cache.add_active_response(req_id, sampled_ids) + self._suffix_response_tokens[req_id].extend(sampled_ids) - # Stop requests that are not seen + stopped_ids = [] for req_id in list(self._suffix_cache.active_requests): if req_id not in seen_req_ids: self._suffix_cache.stop_request(req_id) + self._suffix_response_tokens.pop(req_id, None) + stopped_ids.append(req_id) def propose_suffix_draft_token_ids( self, sampled_token_ids: list[list[int]], - spec_token_ids: Optional[list[list[int]]] = None, - ) -> list[list[int]]: + ) -> list[SuffixDecodingDraft]: config = self.speculative_config results = [] for i, sampled_ids in enumerate(sampled_token_ids): - spec_ids = spec_token_ids[i] if spec_token_ids is not None else [] num_sampled_ids = len(sampled_ids) if not num_sampled_ids: - # Skip speculative decoding. - results.append(SuffixSpecResult()) + results.append(SuffixDecodingDraft()) continue req_id = self.input_batch.req_ids[i] + index = self.input_batch.req_id_to_index[req_id] - # Add sampled_token_ids to token_ids_cpu. - start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + len(sampled_ids) - - if end_idx >= self.max_model_len: - results.append(SuffixSpecResult()) - self.input_batch.token_ids_cpu[ - i, start_idx:self. - max_model_len] = sampled_ids[:self.max_model_len - - start_idx] - continue - - self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - - size = min(end_idx, config.suffix_cache_max_depth) - pattern = self.input_batch.token_ids_cpu[i, end_idx - size:end_idx] - pattern = pattern.tolist() + spec_ids - if len(pattern) > config.suffix_cache_max_depth: - pattern = pattern[-config.suffix_cache_max_depth:] - max_spec_tokens = min(MAX_SPEC_LEN - len(spec_ids), - config.suffix_cache_max_depth, - self.max_model_len - end_idx - 1) - # max_spec_offset is modified to mimic the behavior of the original - # max_spec_factor and max_spec_offset as if the speculative tokens - # were generated by suffix decoding. For example, if: - # - max_spec_factor = 2 - # - max_spec_offset = -1 - # - we've already speculated 3 tokens - # - and the suffix match length is 6 - # Then: - # - The match length before the already-speculated tokens is 3 - # - The original config allow up to 5 speculated tokens total - # - Already speculated 3 tokens, so should allow 2 more tokens - # So the new config should map match length 6 to 2 max spec tokens. - max_spec_factor = config.suffix_max_spec_factor - max_spec_offset = (config.suffix_max_spec_offset - len(spec_ids) * - (max_spec_factor + 1)) + # In async mode, token_ids_cpu contains -1 placeholders at + # decoded positions (written by _bookkeeping_sync). Build the + # pattern from the clean _suffix_response_tokens instead. + if (self.use_async_scheduling + and req_id in self._suffix_response_tokens): + response = self._suffix_response_tokens[req_id] + num_prompt = int( + self.input_batch.num_prompt_tokens[index]) + num_tokens = num_prompt + len(response) + if num_tokens >= self.max_model_len: + results.append(SuffixDecodingDraft()) + continue + # Take up to suffix_cache_max_depth tokens from the tail. + depth = config.suffix_cache_max_depth + if len(response) >= depth: + pattern = response[-depth:] + else: + need = depth - len(response) + prompt_start = max(0, num_prompt - need) + prompt_part = self.input_batch.token_ids_cpu[ + index, prompt_start:num_prompt].tolist() + pattern = prompt_part + response + else: + num_tokens = self.input_batch.num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + results.append(SuffixDecodingDraft()) + continue + start = max(0, num_tokens - config.suffix_cache_max_depth) + pattern = self.input_batch.token_ids_cpu[ + i, start:num_tokens].tolist() + + max_spec = min( + MAX_SPEC_LEN, self.max_model_len - num_tokens - 1 + ) result = self._suffix_cache.speculate( req_id, pattern, - max_spec_tokens=max_spec_tokens, - max_spec_factor=max_spec_factor, - max_spec_offset=max_spec_offset, - min_token_prob=config.suffix_min_token_prob) + max_spec_tokens=max_spec, + max_spec_factor=config.suffix_max_spec_factor, + max_spec_offset=config.suffix_max_spec_offset, + min_token_prob=config.suffix_min_token_prob, + ) results.append(result) return results - def load_model(self) -> None: + + def _start_suffix_copy( + self, + sampled_token_ids: torch.Tensor, + ) -> None: + """Initiate an async D2H copy of sampled token IDs for suffix decoding. + + Copies ``sampled_token_ids`` to a pinned CPU buffer on a dedicated + CUDA stream (``suffix_copy_stream``) that only waits for prior work + on the default stream (i.e. sampling). + + **This MUST be called BEFORE launching Arctic GPU work on the default + stream** so that the copy is not ordered behind Arctic kernels. The + resulting timeline is:: + + Default stream: [sample] -> [arctic prepare / propose ...] + Suffix stream : [sample] -> [D2H copy] -> [event] + CPU : wait -> suffix logic + + The companion ``_finish_suffix_copy`` synchronises on the copy event + and returns the materialised CPU list. + """ + n_rows = sampled_token_ids.shape[0] + n_cols = sampled_token_ids.shape[-1] + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.suffix_copy_stream): + self.suffix_copy_stream.wait_stream(default_stream) + self.suffix_sampled_ids_pinned[:n_rows, :n_cols].copy_( + sampled_token_ids, non_blocking=True, + ) + self.suffix_copy_done_event.record() + self._suffix_copy_shape = (n_rows, n_cols) + + def _finish_suffix_copy(self) -> list[list[int]]: + """Wait for the suffix copy and return sampled token IDs as CPU lists. + + Synchronises on the copy event recorded by ``_start_suffix_copy``, + then parses the pinned buffer into ``list[list[int]]``, applying + rejection-sampling for spec-decode batches and masking discarded + (still-in-prefill) requests. + """ + self.suffix_copy_done_event.synchronize() + n_rows, n_cols = self._suffix_copy_shape + pinned = self.suffix_sampled_ids_pinned[:n_rows, :n_cols] + + num_reqs = self.input_batch.num_reqs + discard_indices = np.nonzero( + self.discard_request_mask.np[:num_reqs] + )[0] + + if n_cols == 1: + result = pinned.tolist() + for i in discard_indices: + result[int(i)].clear() + else: + result, _ = RejectionSampler.parse_output( + pinned, + self.input_batch.vocab_size, + discard_indices, + ) + + return result + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> Union[ModelRunnerOutput, AsyncGPUModelRunnerOutput, IntermediateTensors]: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # noqa + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) + + with record_function_or_nullcontext("gpu_model_runner: sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + self._draft_token_ids = None + self._draft_token_req_ids = None + self.input_batch.prev_sampled_token_ids = None + + def propose_draft_token_ids( + sampled_token_ids: torch.Tensor | list[np.ndarray], + ) -> None: + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("gpu_model_runner: draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + self._copy_draft_token_ids_to_cpu(scheduler_output) + + # --- Draft proposal orchestration --- + # + # There are three drafting modes depending on async scheduling and + # which drafters are active: + # + # A) Arctic-only + async: + # Pre-bookkeeping. Arctic runs entirely on GPU tensors + # (the existing EAGLE-like async path). + # + # B) Suffix (with or without arctic) + async: + # Pre-bookkeeping. Uses a *dedicated* suffix_copy_stream so + # that the D2H transfer of sampled token IDs only waits on + # "sampling complete" -- NOT on Arctic GPU work. Timeline: + # + # Default stream: [sample] -> [arctic GPU work ----------->] + # Suffix stream: [sample] -> [D2H copy] -> [event] + # CPU : wait -> suffix + # + # At merge time, if any requests need arctic fallback, we + # sync on the arctic result (.tolist()); by then the GPU + # kernels have had the full suffix-CPU window to finish. + # + # C) Non-async (any combination): + # Post-bookkeeping. propose_draft_token_ids handles everything + # using CPU valid_sampled_token_ids from bookkeeping. + + has_suffix = self._suffix_cache is not None + is_arctic_method = ( + self.speculative_config is not None + and self.speculative_config.method in ("arctic", "mlp_speculator") + ) + input_fits_in_drafter = spec_decode_common_attn_metadata is not None + sampled_token_ids = sampler_output.sampled_token_ids + + # Determine if we should use async pre-bookkeeping drafting. + use_async_spec = ( + self.use_async_scheduling + and self.speculative_config is not None + and (is_arctic_method or has_suffix) + and input_fits_in_drafter + ) + + if use_async_spec: + if is_arctic_method and not has_suffix: + # (A) Arctic-only async: use existing closure which calls + # propose_draft_token_ids (the method) with a GPU tensor + # and triggers the async GPU path internally. + propose_draft_token_ids(sampled_token_ids) + + # Track actual draft lengths for next step's allocation. + _n_predict = self.drafter.model.n_predict + _batch_size = len(self.input_batch.req_ids) + _disable_bs = ( + self.speculative_config.disable_by_batch_size + if self.speculative_config else None + ) + _draft_limit = _batch_size + if _disable_bs and _batch_size > _disable_bs: + _draft_limit = _disable_bs + self._prev_actual_draft_lens = { + req_id: _n_predict if i < _draft_limit else 0 + for i, req_id in enumerate(self.input_batch.req_ids) + } + scheduler_output._actual_draft_lens = ( + self._prev_actual_draft_lens + ) + + elif is_arctic_method: + # (B) Arctic + suffix async. + # D2H copy of sampled tokens runs on a dedicated + # suffix_copy_stream that only waits for sampling, + # NOT for Arctic GPU kernels enqueued afterwards. + # Suffix CPU work then overlaps with Arctic GPU. + + # Step 1: Initiate async D2H copy before Arctic GPU work. + self._start_suffix_copy(sampled_token_ids) + + # Step 2: Launch arctic on the default stream. + with record_function_or_nullcontext( + "gpu_model_runner: draft (arctic)" + ): + arctic_draft = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + # Step 3: Wait for D2H copy, then run suffix on CPU. + # This overlaps with remaining Arctic GPU work. + with record_function_or_nullcontext( + "gpu_model_runner: draft (suffix)" + ): + sampled_cpu = self._finish_suffix_copy() + self._update_suffix_cache(sampled_cpu) + suffix_results = self.propose_suffix_draft_token_ids( + sampled_cpu + ) + + min_score = self.drafter.model.n_predict + suffix_draft = [ + result.token_ids if result.score >= min_score + else [] + for result in suffix_results + ] + + # Step 4: Merge arctic + suffix results. + # Suffix takes priority when available. + # + # The merged result MUST be a torch.Tensor on GPU + # because upstream _prepare_input_ids (next iteration) + # asserts isinstance(self._draft_token_ids, torch.Tensor) + # and uses it for on-device scatter. + + # Collect suffix rows that have results. + suffix_indices = [ + i for i, s in enumerate(suffix_draft) if s + ] + + # arctic_draft is already a [batch, num_spec_tokens] + # GPU tensor (from the pre-allocated _draft_merged_gpu + # buffer when available, or F.pad fallback). + if isinstance(arctic_draft, torch.Tensor): + merged = arctic_draft + else: + # Shouldn't happen in async path, but handle + # gracefully. + k = self.num_spec_tokens + n = len(arctic_draft) + pin = self._suffix_merge_pinned[:n, :k] + pin_np = pin.numpy() + pin_np[:] = 0 + for i, row in enumerate(arctic_draft): + t = row[:k] + if t: + pin_np[i, :len(t)] = t + merged = pin.to( + device=self.device, non_blocking=True) + + if merged.shape[1] < self.num_spec_tokens: + merged = torch.nn.functional.pad( + merged, + (0, self.num_spec_tokens - merged.shape[1]), + value=0, + ) + + # Batch-overwrite rows that have suffix results. + # CRITICAL: use the pre-allocated *pinned* merge + # buffer so that cudaMemcpyAsync does NOT synchronise + # the default stream. With pageable (non-pinned) + # memory CUDA must sync the stream first, blocking the + # CPU until Arctic GPU work finishes -- destroying + # the async overlap. + width = merged.shape[1] + if suffix_indices: + n_sfx = len(suffix_indices) + overlay_pin = \ + self._suffix_merge_pinned[:n_sfx, :width] + overlay_np = overlay_pin.numpy() + overlay_np[:] = 0 + for j, idx in enumerate(suffix_indices): + s = suffix_draft[idx] + slen = min(len(s), width) + overlay_np[j, :slen] = s[:slen] + # Pinned H2C -- truly non-blocking. + overlay_t = overlay_pin.to( + device=self.device, non_blocking=True) + # Use pre-allocated pinned index buffer when + # available to avoid per-step allocation. + idx_pinned = getattr( + self, '_suffix_index_pinned', None) + if idx_pinned is not None: + idx_pin = idx_pinned[:n_sfx] + idx_pin[:] = torch.tensor( + suffix_indices, dtype=torch.long) + else: + idx_pin = torch.tensor( + suffix_indices, dtype=torch.long + ).pin_memory() + idx_t = idx_pin.to( + device=self.device, non_blocking=True) + merged.index_copy_(0, idx_t, overlay_t) + + self._draft_token_ids = merged + self._copy_draft_token_ids_to_cpu(scheduler_output) + + # Track actual draft lengths for next step's allocation. + _n_predict = self.drafter.model.n_predict + _batch_size = len(self.input_batch.req_ids) + _disable_bs = ( + self.speculative_config.disable_by_batch_size + if self.speculative_config else None + ) + _draft_limit = _batch_size + if _disable_bs and _batch_size > _disable_bs: + _draft_limit = _disable_bs + _actual_lens: dict[str, int] = {} + for _i, _req_id in enumerate(self.input_batch.req_ids): + _s = (len(suffix_draft[_i]) + if _i < len(suffix_draft) else 0) + _a = _n_predict if _i < _draft_limit else 0 + # Suffix takes priority when available. + _actual_lens[_req_id] = _s if _s > 0 else _a + self._prev_actual_draft_lens = _actual_lens + scheduler_output._actual_draft_lens = _actual_lens + + else: + # (B2) Suffix-only async. + # No Arctic GPU work to overlap with, so skip the + # copy stream (its overhead exceeds the marginal + # overlap with the lightweight rejection kernel). + # Instead: rejection sample → parse tokens directly + # → suffix CPU work → build GPU tensor via pinned buf. + self._suffix_only_rejection_sample( + sampled_token_ids, + spec_decode_common_attn_metadata, + ) + + # Parse sampled tokens to CPU lists. .cpu() syncs the + # default stream (waits for the rejection kernel, which + # is lightweight), then parse_output performs CPU-side + # rejection to extract accepted token IDs. + with record_function_or_nullcontext( + "gpu_model_runner: draft (suffix)" + ): + num_reqs = self.input_batch.num_reqs + discard_indices = np.nonzero( + self.discard_request_mask.np[:num_reqs] + )[0] + n_cols = sampled_token_ids.shape[-1] + if n_cols == 1: + sampled_cpu = sampled_token_ids.tolist() + for idx in discard_indices: + sampled_cpu[int(idx)].clear() + else: + sampled_cpu, _ = RejectionSampler.parse_output( + sampled_token_ids.cpu(), + self.input_batch.vocab_size, + discard_indices, + ) + + self._update_suffix_cache(sampled_cpu) + suffix_results = self.propose_suffix_draft_token_ids( + sampled_cpu + ) + suffix_draft = [ + result.token_ids if result.score >= 0 + else [] + for result in suffix_results + ] + + # Build GPU tensor from suffix lists via pinned buffer. + k = self.num_spec_tokens + n = len(suffix_draft) + pin = self._suffix_merge_pinned[:n, :k] + pin_np = pin.numpy() + pin_np[:] = 0 + for i, s in enumerate(suffix_draft): + if s: + slen = min(len(s), k) + pin_np[i, :slen] = s[:slen] + self._draft_token_ids = pin.to( + device=self.device, non_blocking=True) + self._copy_draft_token_ids_to_cpu(scheduler_output) + + # Track actual draft lengths for next step's allocation. + _actual_lens_b2: dict[str, int] = {} + for _i, _req_id in enumerate(self.input_batch.req_ids): + _s = (len(suffix_draft[_i]) + if _i < len(suffix_draft) else 0) + _actual_lens_b2[_req_id] = _s + self._prev_actual_draft_lens = _actual_lens_b2 + scheduler_output._actual_draft_lens = _actual_lens_b2 + + # --- Bookkeeping --- + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + + # (C) Non-async drafting: run after bookkeeping. + # Pass the raw GPU sampled_token_ids tensor (not the cleaned + # valid_sampled_token_ids list) so that propose_draft_token_ids + # gets a properly shaped 2-D tensor with -1 markers for rejected + # tokens. The method already handles tensors correctly: it + # filters out -1 to build the CPU list for + # prepare_next_token_ids_cpu and keeps the raw tensor for + # prepare_hidden_states. + if ( + self.speculative_config is not None + and not use_async_spec + and input_fits_in_drafter + ): + propose_draft_token_ids(sampled_token_ids) + + with record_function_or_nullcontext("gpu_model_runner: eplb"): + self.eplb_step() + + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + if self.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.slot_mapping) # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + kv_connector_output=kv_connector_output, + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, + num_nans_in_logits=num_nans_in_logits, + cudagraph_stats=cudagraph_stats, + ) + # Attach actual draft lengths to the ModelRunnerOutput so the + # scheduler can read them reliably in update_from_output. + # This survives the async pipeline (scheduler_output attrs + # may not due to object lifecycle in the batch queue). + output._actual_draft_lens = getattr( + scheduler_output, '_actual_draft_lens', None) + + if not self.use_async_scheduling: + return output + + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + + + def load_model(self, eep_scale_up: bool = False) -> None: load_shift_model = ( - self.vllm_config.parallel_config.enable_shift_parallel) + self.vllm_config.parallel_config.enable_shift_parallel + ) if load_shift_model: - # Make a deep copy of the config before loading the model. shift_config = copy.deepcopy(self.vllm_config) - self._orig_load_model() + self._orig_load_model(eep_scale_up) if self.parallel_config.ulysses_sequence_parallel_size > 1: self.monkeypatch_forward() if load_shift_model: shift_config.parallel_config.tensor_parallel_size *= ( - shift_config.parallel_config.ulysses_sequence_parallel_size) + shift_config.parallel_config.ulysses_sequence_parallel_size + ) shift_config.parallel_config.ulysses_sequence_parallel_size = 1 with set_shift_parallel_mode(True): self.shift_model = get_model(vllm_config=shift_config) self.shift_parallel_threshold = ( - shift_config.parallel_config.shift_parallel_threshold) + shift_config.parallel_config.shift_parallel_threshold + ) + self.shift_forward_context = ( + shift_config.compilation_config.static_forward_context + ) + if "SwiftKV" in self.model.__class__.__name__: - # HACK: Replace the decode-runner since it always runs in full - # TP, but the original model is captured using SP * BATCH_SIZE, - # which does not cover all its cuda graph sizes. The shift-mode - # model should have all its cuda graphs captured correctly. - self.model.model.decode_runner = ( - self.shift_model.model.decode_runner) + if hasattr(self.model, "model") and hasattr(self.model.model, "decode_runner"): + self.model.model.decode_runner = self.shift_model.model.decode_runner + else: + logger.warning("Could not apply SwiftKV HACK: " + "model.model.decode_runner not found.") + + cudagraph_mode = self.compilation_config.cudagraph_mode + if (cudagraph_mode is not None + and cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.use_ubatching): + from vllm.compilation.cuda_graph import CUDAGraphWrapper + self.shift_model = CUDAGraphWrapper( + self.shift_model, self.vllm_config, + runtime_mode=CUDAGraphMode.FULL, + ) else: self.shift_model = None self.shift_parallel_threshold = 0 + self.shift_forward_context = None + - def capture_model(self) -> None: - if not self.use_cuda_graph: - logger.warning( - "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", CompilationLevel.PIECEWISE) + def initialize_kv_cache(self, kv_cache_config) -> None: + self._orig_initialize_kv_cache(kv_cache_config) + shift_ctx = getattr(self, 'shift_forward_context', None) + if shift_ctx is None: return + base_ctx = self.compilation_config.static_forward_context + bound = 0 + for name, shift_attn in shift_ctx.items(): + base_attn = base_ctx.get(name) + if base_attn is not None and hasattr(base_attn, 'kv_cache'): + shift_attn.kv_cache = base_attn.kv_cache + bound += 1 + if is_global_first_rank(): + logger.info("Bound KV cache to %d shift model attention layers", + bound) + + from vllm.forward_context import BatchDescriptor + def _case_bs(self, case) -> int: + # vLLM can pass ints, tuples, or sometimes BatchDescriptor-like objects + if isinstance(case, int): + return case + if isinstance(case, BatchDescriptor): + return int(case.num_tokens) + if isinstance(case, tuple): + return int(case[0]) + # last resort + return int(getattr(case, "num_tokens")) + + def _with_bs(self, case, new_bs: int): + if isinstance(case, tuple): + return (new_bs, *case[1:]) + if isinstance(case, BatchDescriptor): + # Best-effort reconstruction; adjust if your BatchDescriptor signature differs. + return BatchDescriptor( + num_tokens=new_bs, + num_reqs=case.num_reqs, + uniform=case.uniform, + has_lora=case.has_lora, + ) + return new_bs - compilation_counter.num_gpu_runner_capture_triggers += 1 - - start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] - - # Trigger CUDA graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - with parallel_state.graph_capture(device=self.device): - sp_size = self.parallel_config.ulysses_sequence_parallel_size - full_cg = self.full_cuda_graph - # capture original model shapes - compilation_cases = ( - shape for shape in reversed(self.cudagraph_batch_sizes) - if shape * sp_size > self.shift_parallel_threshold and shape * - sp_size <= self.max_num_tokens) - # Only rank 0 should print progress bar during capture - if is_global_first_rank(): - print_cases, compilation_cases = tee(compilation_cases) - logger.info(f"original model shapes {list(print_cases)}") - compilation_cases = tqdm( - list(compilation_cases), - desc="Capturing CUDA graph shapes of original model") - for num_tokens in compilation_cases: - # We skip EPLB here since we don't want to record dummy metrics - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens * sp_size, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - self._dummy_run(num_tokens * sp_size, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - - # Capture shift model shapes - if self.shift_model is not None: - orig_model, self.model = self.model, self.shift_model - # Reset compilation cases - compilation_cases = ( - shape for shape in reversed(self.cudagraph_batch_sizes) - if shape <= self.shift_parallel_threshold - or "SwiftKV" in self.model.__class__.__name__) - # Note: We want to capture all shapes for the SwiftKV shift model. - # This is necessary since SwiftKV always uses full TP for the decode runner. - # For all other models, we only capture necessary shapes for the SP_TP mode, - # yielding less setup time. - if is_global_first_rank(): - print_cases, compilation_cases = tee(compilation_cases) - logger.info(f"shift model shapes {list(print_cases)}") - compilation_cases = tqdm( - list(compilation_cases), - desc="Capturing CUDA graph shapes of shift model") - with set_shift_parallel_mode(True): - for num_tokens in compilation_cases: - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - self.model = orig_model - end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] - elapsed_time = end_time - start_time - cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory - # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + def _register_shift_cudagraph_keys( + self, + compilation_cases, + cudagraph_runtime_mode: CUDAGraphMode, + ): + """Register shift model batch sizes in the cudagraph dispatcher so + that runtime dispatch correctly routes to captured FULL/PIECEWISE + graphs.""" + dispatcher = getattr(self, 'cudagraph_dispatcher', None) + if dispatcher is None: + return - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: - self._orig_initialize_kv_cache(kv_cache_config) + uniform = cudagraph_runtime_mode == CUDAGraphMode.FULL + added = 0 + for case in compilation_cases: + bs = self._case_bs(case) + bd = dispatcher._create_padded_batch_descriptor( + bs, uniform, False, + ) + if not uniform: + bd = bd.relax_for_mixed_batch_cudagraphs() + dispatcher.add_cudagraph_key(cudagraph_runtime_mode, bd) + added += 1 + + @contextlib.contextmanager + def _shift_graph_capture_context(self): + """Enable ca_comm for shift model graph capture.""" + yield - if self.shift_model is not None: - # Bind the KV caches to the shift parallel model. - forward_context = ( - self.vllm_config.compilation_config.static_forward_context) - for mod in self.shift_model.modules(): - if isinstance(mod, Attention): - mod.kv_cache = forward_context[mod.layer_name].kv_cache + @contextlib.contextmanager + def _use_shift_cudagraph_tables(self): + """Temporarily swap compilation_config sizes to the shift (unscaled) + lookup table so that vLLM internals (dispatcher, pad_for_cudagraph, + bounds checks) all see the shift model's sizes.""" + cc = self.compilation_config + saved_sizes = cc.cudagraph_capture_sizes + saved_max = cc.max_cudagraph_capture_size + saved_table = cc.bs_to_padded_graph_size + + shift_sizes = self.vllm_config._shift_cudagraph_capture_sizes + shift_max = self.vllm_config._shift_max_cudagraph_capture_size + shift_table = self.vllm_config._shift_bs_to_padded_graph_size + + cc.cudagraph_capture_sizes = shift_sizes + cc.max_cudagraph_capture_size = shift_max + cc.bs_to_padded_graph_size = shift_table + try: + yield + finally: + cc.cudagraph_capture_sizes = saved_sizes + cc.max_cudagraph_capture_size = saved_max + cc.bs_to_padded_graph_size = saved_table + + def _capture_cudagraphs( + self, + compilation_cases: list[tuple[int, bool]], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + """ + Capture CUDA graphs for both base (SP) and shift (TP) variants, splitting + shapes by threshold so both models have required graphs captured. + """ + assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ + cudagraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE] + + sp_size = parallel_state._SP.world_size + tp_size = parallel_state._TP.world_size + threshold = int(getattr(self, "shift_parallel_threshold", 0)) + has_shift = getattr(self, "shift_model", None) is not None + is_swiftkv = "SwiftKV" in self.model.__class__.__name__ + + # --- Base model (Ulysses SP): uses the scaled lookup table (default) --- + # Exclude sizes at or below the shift threshold -- those batches + # are routed to the shift model at runtime. Capturing them for the + # base model would deadlock because the Ulysses all-to-all collectives + # diverge across ranks at small batch sizes. + if has_shift and not is_swiftkv: + compilation_cases_base = [ + case for case in compilation_cases + if self._case_bs(case) > threshold + ] + else: + compilation_cases_base = list(compilation_cases) + + if is_global_first_rank(): + logger.info( + "base model (SP=%s, TP=%s) cudagraph mode %s shapes %s", + sp_size, tp_size, cudagraph_runtime_mode, + [self._case_bs(c) for c in compilation_cases_base], + ) + + if compilation_cases_base: + self._orig_capture_cudagraphs( + compilation_cases_base, cudagraph_runtime_mode, uniform_decode + ) + + # --- Shift model (SP*TP fused as TP-only): uses the unscaled lookup table --- + # The incoming compilation_cases contain *scaled* base sizes (e.g. + # [4, 8, ..., 2048] with sp_size=4). The shift model needs the + # *unscaled* sizes from its own capture list (e.g. [1, 2, ..., 512]). + # We rebuild the cases from _shift_cudagraph_capture_sizes, copying + # the non-bs fields (like has_lora) from the first matching base case. + if has_shift: + shift_sizes = self.vllm_config._shift_cudagraph_capture_sizes + # Use the first base case as a template for non-bs fields + template = compilation_cases[0] if compilation_cases else None + compilation_cases_shift = [ + self._with_bs(template, bs) if template is not None else bs + for bs in reversed(shift_sizes) + ] + + if is_global_first_rank(): + logger.info( + "shift model (SPxTP=%s) shapes %s", + sp_size * tp_size, + [self._case_bs(c) for c in compilation_cases_shift], + ) + + if compilation_cases_shift: + orig_model, self.model = self.model, self.shift_model + cc = self.vllm_config.compilation_config + base_ctx = cc.static_forward_context + shift_ctx = getattr(self, 'shift_forward_context', None) + try: + if shift_ctx is not None: + cc.static_forward_context = shift_ctx + _CA_MIN_BS = 8 + compilation_cases_shift = [ + c for c in compilation_cases_shift + if self._case_bs(c) >= _CA_MIN_BS + ] + shift_sizes = [ + s for s in shift_sizes if s >= _CA_MIN_BS + ] + self.vllm_config._shift_cudagraph_capture_sizes = ( + shift_sizes) + self.vllm_config._shift_max_cudagraph_capture_size = ( + max(shift_sizes) if shift_sizes else 0) + self.vllm_config._shift_bs_to_padded_graph_size = { + bs: bs for bs in shift_sizes + } + for bs in range(1, _CA_MIN_BS): + self.vllm_config._shift_bs_to_padded_graph_size[ + bs] = _CA_MIN_BS + + if is_global_first_rank(): + logger.info( + "shift model: skipping bs < %d for " + "ca_comm graph capture (will pad to %d)", + _CA_MIN_BS, _CA_MIN_BS, + ) + + with set_shift_parallel_mode(True), \ + self._use_shift_cudagraph_tables(), \ + self._shift_graph_capture_context(): + self._register_shift_cudagraph_keys( + compilation_cases_shift, + cudagraph_runtime_mode, + ) + self._orig_capture_cudagraphs( + compilation_cases_shift, + cudagraph_runtime_mode, + uniform_decode, + ) + finally: + self.model = orig_model + cc.static_forward_context = base_ctx diff --git a/arctic_inference/vllm/patches.py b/arctic_inference/vllm/patches.py new file mode 100644 index 000000000..0d13baed7 --- /dev/null +++ b/arctic_inference/vllm/patches.py @@ -0,0 +1,302 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vllm +from vllm.logger import init_logger +from vllm.v1.core.sched.async_scheduler import AsyncScheduler +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine.core import EngineCoreProc +from vllm.v1.worker.worker_base import WorkerBase + +from arctic_inference.patching import ArcticPatch +from arctic_inference.utils import get_compatible_vllm_version +from arctic_inference.vllm.args import EngineArgsPatch, AsyncEngineArgsPatch +from arctic_inference.vllm.config import (ParallelConfigPatch, + SpeculativeConfigPatch, + VllmConfigPatch, + MLPSpeculatorConfigPatch) +from arctic_inference.vllm.stats import (SpecDecodingStatsPatch, + SpecDecodingLoggingPatch) +from arctic_inference.vllm.structured_output import XgrammarBackendPatch +from arctic_inference.vllm.ulysses import apply_shift_parallel_patches + + +logger = init_logger(__name__) + + +class AsyncSchedulerPatch(ArcticPatch[AsyncScheduler]): + """Patch AsyncScheduler to: + 1. Respect ``disable_by_batch_size`` when allocating spec token + placeholders (the worker only drafts for the first N requests). + 2. Use the previous step's actual draft length for dynamic placeholder + allocation, avoiding wasted verification compute when the real draft + width (e.g. Arctic n_predict=3) is much smaller than + num_speculative_tokens (e.g. 12). + 3. Store ``_scheduled_spec_count`` so that the post-fix in + ``update_from_output`` can compensate for worker-side trimming. + """ + + _orig_update_after_schedule = AsyncScheduler._update_after_schedule + + def _update_after_schedule(self, scheduler_output): + # Call the base Scheduler._update_after_schedule (NOT the + # AsyncScheduler override which we are replacing). + Scheduler._update_after_schedule(self, scheduler_output) + + has_structured_output_requests = False + pending_structured_output_tokens = False + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + + # Respect disable_by_batch_size: only add spec token placeholders + # for the first N decode requests (matching the worker's draft_limit + # in propose_draft_token_ids). + spec_config = getattr(self.vllm_config, 'speculative_config', None) + disable_bs = ( + spec_config.disable_by_batch_size if spec_config else None + ) + decode_with_spec_count = 0 + for req_id in scheduler_output.num_scheduled_tokens: + request = self.requests[req_id] + has_structured_output_requests |= request.use_structured_output + pending_structured_output_tokens |= ( + request.use_structured_output + and request.num_output_placeholders > 0 + ) + cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) + # Store the originally-scheduled spec count so that + # update_from_output can compensate for worker-side trimming. + request._scheduled_spec_count = cur_num_spec_tokens + + if ( + request.num_computed_tokens + == request.num_tokens + + request.num_output_placeholders + + cur_num_spec_tokens + ): + # The request will generate a new token + spec tokens. + request.num_output_placeholders += 1 + cur_num_spec_tokens + + # Check if beyond the disable_by_batch_size limit. + decode_with_spec_count += 1 + if disable_bs and decode_with_spec_count > disable_bs: + # Beyond limit: no spec token placeholders. + request.spec_token_ids = [] + continue + + # Use previous step's actual draft length to size + # placeholders. When suffix had a good match (actual + # > n_predict), allocate the full width so the next + # step can verify all suffix tokens. When suffix + # didn't match (actual = n_predict from arctic), + # allocate only that many to avoid wasting attention + # compute on zero-padded positions. + # Cold start: allocate full width (generous). + prev_actual = getattr( + request, '_prev_actual_draft_len', None) + if prev_actual is not None: + num_placeholders = min( + max(prev_actual, 1), self.num_spec_tokens) + else: + num_placeholders = self.num_spec_tokens + + request.spec_token_ids = [-1] * num_placeholders + + scheduler_output.has_structured_output_requests = ( + has_structured_output_requests) + scheduler_output.pending_structured_output_tokens = ( + pending_structured_output_tokens) + + def update_from_output(self, scheduler_output, model_runner_output): + """Wrap Scheduler.update_from_output to store actual draft counts. + + We infer the drafter's real capability from the acceptance + results so that the next ``_update_after_schedule`` can size + placeholders correctly. This works even when the worker runs + in a separate process (where scheduler_output._actual_draft_lens + set by the worker doesn't survive serialisation back to the + scheduler). + + Strategy: + 1. **Primary path**: read ``_actual_draft_lens`` from the + ``model_runner_output`` object (attached by the model runner + to the ``ModelRunnerOutput`` dataclass, which reliably + survives the async pipeline). + 2. **Legacy path**: read from ``scheduler_output._actual_draft_lens`` + (works in same-process non-async mode). + 3. **Fallback**: infer from acceptance results with exponential + growth — when all drafted tokens are accepted, double the + allocation so suffix decoding reaches full capacity in + O(log n) steps instead of O(n). + """ + sampled_token_ids = model_runner_output.sampled_token_ids + req_id_to_index = model_runner_output.req_id_to_index + + result = Scheduler.update_from_output( + self, scheduler_output, model_runner_output) + + # Primary path: read from model_runner_output (most reliable + # for async scheduling — the ModelRunnerOutput object is + # returned by get_output() and guaranteed to survive). + actual_lens = getattr( + model_runner_output, '_actual_draft_lens', None) + + # Legacy path: read from scheduler_output (works for non-async + # or same-process setups where the attribute is preserved). + if not actual_lens: + actual_lens = getattr( + scheduler_output, '_actual_draft_lens', None) + + if actual_lens: + for req_id, actual_len in actual_lens.items(): + request = self.requests.get(req_id) + if request is not None: + request._prev_actual_draft_len = actual_len + return result + + # Fallback: infer from acceptance results (multi-process case). + # Uses exponential growth when all drafted tokens are accepted + # (indicating the drafter / suffix cache can handle more), so + # the allocation converges to num_spec_tokens in O(log n) + # steps: + # step 0: 1 position → accept 1/1 → prev = 2 + # step 1: 2 positions → accept 2/2 → prev = 4 + # step 2: 4 positions → accept 4/4 → prev = 8 + # ... + # When not all are accepted (normal drafter), linear growth: + # step 0: 1 position → accept 1 → prev = 2 + # step 1: 2 positions → accept 2 → prev = 3 + # step 2: 3 positions → steady state (n_predict = 3) + if not sampled_token_ids: + return result + + for req_id in scheduler_output.num_scheduled_tokens: + scheduled_spec = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if not scheduled_spec: + continue + + req_index = req_id_to_index.get(req_id) + if req_index is None: + continue + + request = self.requests.get(req_id) + if request is None: + continue + + generated = sampled_token_ids[req_index] + num_accepted = (len(generated) - 1) if generated else 0 + num_draft = len(scheduled_spec) + + prev = getattr(request, '_prev_actual_draft_len', None) + + if num_accepted > 0: + if num_accepted >= num_draft and num_draft > 0: + # All drafted tokens accepted — the drafter (or + # suffix cache) could produce more if given room. + # Double the allocation for exponential ramp-up. + new_val = min( + num_draft * 2, self.num_spec_tokens) + else: + # Partial acceptance: grow linearly. + new_val = min( + num_accepted + 1, self.num_spec_tokens) + request._prev_actual_draft_len = max( + prev or 0, new_val) + elif prev is None: + # First step for this request, zero acceptance. + # Seed with 1 so _update_after_schedule doesn't + # fall back to the full num_spec_tokens next time. + request._prev_actual_draft_len = 1 + + return result + + +class EngineCoreProcPatch(ArcticPatch[EngineCoreProc]): + + _orig_run_engine_core = EngineCoreProc.run_engine_core + + @staticmethod + def run_engine_core(*args, **kwargs): + # When starting the API server, it will spawn a new process to run the + # EngineCore. We need to load the plugins in the new process before it + # initializes the Executor. + vllm.plugins.load_general_plugins() + return EngineCoreProcPatch._orig_run_engine_core(*args, **kwargs) + + +class WorkerBasePatch(ArcticPatch[WorkerBase]): + + _orig_init = WorkerBase.__init__ + + def __init__(self, *args, **kwargs): + # Some patches like the GPUModelRunner will import CUDA libraries when + # they are initialized, which will cause process forking to fail. For + # these patches, we need to delay the initialization until after the + # process has been forked (i.e., in the WorkerBase initializer). + from arctic_inference.vllm.model_runner import GPUModelRunnerPatch + + GPUModelRunnerPatch.apply_patch() + + return self._orig_init(*args, **kwargs) + + +def apply_arctic_patches(): + + from transformers import AutoConfig + from arctic_inference.common.swiftkv import LlamaSwiftKVConfig + + # Register SwiftKV model configurations to transformers. + AutoConfig.register("llama_swiftkv", LlamaSwiftKVConfig) + + from vllm import ModelRegistry + #from arctic_inference.vllm.swiftkv import LlamaSwiftKVForCausalLM + + # Register SwiftKV model definitions to vLLM. + ModelRegistry.register_model( + "LlamaSwiftKVForCausalLM", + "arctic_inference.vllm.swiftkv:LlamaSwiftKVForCausalLM") + + # Register ArcticSpeculator models to vLLM. + from arctic_inference.vllm.spec_dec.arctic_speculator import ( + ArcticMLPSpeculator, ArcticLSTMSpeculator) + ModelRegistry.register_model("ArcticMLPSpeculatorPreTrainedModel", + ArcticMLPSpeculator) + ModelRegistry.register_model("ArcticLSTMSpeculatorPreTrainedModel", + ArcticLSTMSpeculator) + # This name is currently used in corvo + ModelRegistry.register_model("MLPVariantSpeculatorPreTrainedModel", + ArcticLSTMSpeculator) + + # Patches that make later patches work properly. + EngineCoreProcPatch.apply_patch() + WorkerBasePatch.apply_patch() + + # Async scheduler patches for spec decode (disable_by_batch_size + # interaction + dynamic draft width allocation). + AsyncSchedulerPatch.apply_patch() + + # Patches to vLLM arguments and configuration objects. + EngineArgsPatch.apply_patch() + AsyncEngineArgsPatch.apply_patch() + ParallelConfigPatch.apply_patch() + SpeculativeConfigPatch.apply_patch() + SpecDecodingStatsPatch.apply_patch() + SpecDecodingLoggingPatch.apply_patch() + VllmConfigPatch.apply_patch() + XgrammarBackendPatch.apply_patch() + MLPSpeculatorConfigPatch.apply_patch() + + # Main optimization patches. + apply_shift_parallel_patches() diff --git a/arctic_inference/vllm/plugin.py b/arctic_inference/vllm/plugin.py new file mode 100644 index 000000000..afe100ce4 --- /dev/null +++ b/arctic_inference/vllm/plugin.py @@ -0,0 +1,45 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import vllm + +import arctic_inference.envs as envs +from arctic_inference.utils import get_compatible_vllm_version + + +def arctic_inference_plugin(): + if not envs.ARCTIC_INFERENCE_ENABLED: + return + + if not envs.ARCTIC_INFERENCE_SKIP_VERSION_CHECK: + compatible_version = get_compatible_vllm_version() + if vllm.__version__ != compatible_version: + raise RuntimeError( + f"Arctic Inference plugin requires vllm=={compatible_version} " + f"but found vllm=={vllm.__version__}!") + + if not envs.ARCTIC_INFERENCE_SKIP_PLATFORM_CHECK: + if not vllm.platforms.current_platform.is_cuda(): + raise RuntimeError( + f"Arctic Inference plugin requires the cuda platform!") + + print("\x1b[36;1mArctic Inference plugin is enabled!\x1b[0m", + file=sys.stderr) + + # Lazy import to avoid potential errors when the plugin is disabled. + from .patches import apply_arctic_patches + apply_arctic_patches() diff --git a/arctic_inference/vllm/plugins.py b/arctic_inference/vllm/plugins.py deleted file mode 100644 index 0f34a3999..000000000 --- a/arctic_inference/vllm/plugins.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2025 Snowflake Inc. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import vllm -from vllm.logger import init_logger -from vllm.v1.engine.core import EngineCoreProc -from vllm.v1.worker.worker_base import WorkerBase - -from arctic_inference.patching import ArcticPatch -from arctic_inference.utils import get_compatible_vllm_version -from arctic_inference.vllm.args import EngineArgsPatch, AsyncEngineArgsPatch -from arctic_inference.vllm.config import (ParallelConfigPatch, - SpeculativeConfigPatch, - VllmConfigPatch, - MLPSpeculatorConfigPatch) -from arctic_inference.vllm.stats import (SpecDecodingStatsPatch, - SpecDecodingLoggingPatch) -from arctic_inference.vllm.ulysses import apply_shift_parallel_patches - - -logger = init_logger(__name__) - - -class EngineCoreProcPatch(ArcticPatch[EngineCoreProc]): - - _orig_run_engine_core = EngineCoreProc.run_engine_core - - @staticmethod - def run_engine_core(*args, **kwargs): - # When starting the API server, it will spawn a new process to run the - # EngineCore. We need to load the plugins in the new process before it - # initializes the Executor. - vllm.plugins.load_general_plugins() - return EngineCoreProcPatch._orig_run_engine_core(*args, **kwargs) - - -class WorkerBasePatch(ArcticPatch[WorkerBase]): - - _orig_init = WorkerBase.__init__ - - def __init__(self, *args, **kwargs): - # Some patches like the GPUModelRunner will import CUDA libraries when - # they are initialized, which will cause process forking to fail. For - # these patches, we need to delay the initialization until after the - # process has been forked (i.e., in the WorkerBase initializer). - from arctic_inference.vllm.model_runner import GPUModelRunnerPatch - - GPUModelRunnerPatch.apply_patch() - - return self._orig_init(*args, **kwargs) - - -def arctic_inference_plugin(): - if (vllm.__version__ != get_compatible_vllm_version() and not - vllm.__version__.startswith("0.1.dev")): # Make it work with dev - logger.warning( - f"ArcticInference requires vllm=={get_compatible_vllm_version()} " - f"but found vllm=={vllm.__version__}. Ignoring plugin!") - return - - if not vllm.platforms.current_platform.is_cuda(): - logger.warning( - f"ArcticInference requires the cuda platform. Ignoring plugin!") - return - - if os.getenv("VLLM_USE_V1") == "0": - logger.warning("ArcticInference only supports vLLM V1, but detected V0 engine. " - "Ignoring plugin!\n" - "Hint: To strictly enforce the V1 vLLM engine, please set " - "VLLM_USE_V1=1.") - return - - from transformers import AutoConfig - from arctic_inference.common.swiftkv import LlamaSwiftKVConfig - - # Register SwiftKV model configurations to transformers. - AutoConfig.register("llama_swiftkv", LlamaSwiftKVConfig) - - from vllm import ModelRegistry - #from arctic_inference.vllm.swiftkv import LlamaSwiftKVForCausalLM - - # Register SwiftKV model definitions to vLLM. - ModelRegistry.register_model( - "LlamaSwiftKVForCausalLM", - "arctic_inference.vllm.swiftkv:LlamaSwiftKVForCausalLM") - - # Register ArcticSpeculator models to vLLM. - from arctic_inference.vllm.spec_dec.arctic_speculator import ( - ArcticMLPSpeculator, ArcticLSTMSpeculator) - ModelRegistry.register_model("ArcticMLPSpeculatorPreTrainedModel", - ArcticMLPSpeculator) - ModelRegistry.register_model("ArcticLSTMSpeculatorPreTrainedModel", - ArcticLSTMSpeculator) - # This name is currently used in corvo - ModelRegistry.register_model("MLPVariantSpeculatorPreTrainedModel", - ArcticLSTMSpeculator) - - # Patches that make later patches work properly. - EngineCoreProcPatch.apply_patch() - WorkerBasePatch.apply_patch() - - # Patches to vLLM arguments and configuration objects. - EngineArgsPatch.apply_patch() - AsyncEngineArgsPatch.apply_patch() - ParallelConfigPatch.apply_patch() - SpeculativeConfigPatch.apply_patch() - SpecDecodingStatsPatch.apply_patch() - SpecDecodingLoggingPatch.apply_patch() - VllmConfigPatch.apply_patch() - MLPSpeculatorConfigPatch.apply_patch() - - # Main optimization patches. - apply_shift_parallel_patches() diff --git a/arctic_inference/vllm/spec_dec/arctic_proposer.py b/arctic_inference/vllm/spec_dec/arctic_proposer.py index 5d7612a33..4f471f66d 100644 --- a/arctic_inference/vllm/spec_dec/arctic_proposer.py +++ b/arctic_inference/vllm/spec_dec/arctic_proposer.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Union, List from vllm.config import VllmConfig from vllm.model_executor.model_loader import get_model +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_model_runner import logger +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.utils import CpuGpuBuffer +from vllm.utils.platform_utils import is_pin_memory_available import numpy as np import torch @@ -39,6 +43,9 @@ def __init__( self.model = None self.device = None + self.max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.backup_next_token_ids = None # type: Optional[CpuGpuBuffer] + def load_model( self, model: Union[ArcticMLPSpeculator, ArcticLSTMSpeculator], @@ -100,67 +107,195 @@ def load_model( model_config=draft_config_model_config, quant_config=draft_config_quant_config, parallel_config=draft_config_parallel_config, + scheduler_config=self.vllm_config.scheduler_config, + speculative_config=self.speculative_config, load_config=self.vllm_config.load_config, device_config=self.vllm_config.device_config, ) self.model = get_model(vllm_config=draft_worker_config) - self.device = next(model.parameters()).device + self.device = next(self.model.parameters()).device self.input_hidden_dim = self.model.input_hidden_dim if isinstance( self.model, ArcticLSTMSpeculator) else self.model.emb_dim + self.backup_next_token_ids = CpuGpuBuffer( + self.max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=self.device, + with_numpy=True, + ) + def prepare_hidden_states( self, sample_hidden_states: torch.Tensor, - sampled_token_ids: Union[np.ndarray, list[list[int]]], + sampled_token_ids: Union[torch.Tensor, np.ndarray, List[List[int]]], spec_decode_metadata: SpecDecodeMetadata, ) -> torch.Tensor: - if sample_hidden_states is not None: - assert sample_hidden_states.shape[-1] == self.input_hidden_dim, \ - f"hidden_states shape mismatch: {sample_hidden_states.shape[-1]} != {self.input_hidden_dim}. \ - Please make sure spec model is trained using the same base model." - - # if isinstance(sampled_token_ids, list): - # # Pad the list of lists to create a uniform tensor - # max_len = max(len(x) for x in sampled_token_ids) if sampled_token_ids else 0 - # if max_len == 0: - # return sample_hidden_states - # padded_ids = [l + [-1] * (max_len - len(l)) for l in sampled_token_ids] - # sampled_token_ids = torch.tensor(padded_ids, - # device=sample_hidden_states.device) + assert sample_hidden_states is not None, "sample_hidden_states must be provided" + + if isinstance(sampled_token_ids, np.ndarray): + sampled_token_ids = torch.as_tensor( + sampled_token_ids, device=sample_hidden_states.device, dtype=torch.long + ) + elif isinstance(sampled_token_ids, list): + sampled_token_ids = torch.as_tensor( + sampled_token_ids, device=sample_hidden_states.device, dtype=torch.long + ) + elif sampled_token_ids.device != sample_hidden_states.device: + sampled_token_ids = sampled_token_ids.to(sample_hidden_states.device, non_blocking=True) max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: + num_requests = sampled_token_ids.shape[0] + if max_gen_len == 1 and sample_hidden_states.shape[0] == num_requests: + # Fast path: one row per request, no index-select needed. return sample_hidden_states assert spec_decode_metadata is not None - valid_mask = sampled_token_ids != -1 - gen_lens = valid_mask.sum(dim=1) - num_sampled_tokens = np.array(spec_decode_metadata.num_draft_tokens) - num_sampled_tokens = torch.tensor(num_sampled_tokens, - device=gen_lens.device) + 1 - hidden_states_idx = (gen_lens - 1) + torch.cumsum( - num_sampled_tokens, 0) - num_sampled_tokens - previous_hidden_states = sample_hidden_states[hidden_states_idx] + if hasattr(spec_decode_metadata, "cu_num_draft_tokens") and spec_decode_metadata.cu_num_draft_tokens is not None: + cu = spec_decode_metadata.cu_num_draft_tokens + num_draft_tokens_gpu = torch.cat([cu[0:1], cu[1:] - cu[:-1]]) + else: + num_draft_tokens_gpu = torch.as_tensor( + spec_decode_metadata.num_draft_tokens, + device=sample_hidden_states.device, + dtype=torch.int64 + ) + + num_processed_tokens_per_req = num_draft_tokens_gpu + 1 + + offsets = torch.cumsum(num_processed_tokens_per_req, dim=0) - num_processed_tokens_per_req + + vocab_size = self.vllm_config.model_config.get_vocab_size() + valid_mask = (sampled_token_ids != -1) & (sampled_token_ids < vocab_size) + gen_lens = valid_mask.sum(dim=1).to(dtype=torch.int64) + + last_valid = torch.clamp(gen_lens - 1, min=0) + hidden_states_idx = offsets + last_valid + + previous_hidden_states = sample_hidden_states.index_select( + dim=0, index=hidden_states_idx + ) + + assert previous_hidden_states.size(-1) == self.input_hidden_dim, ( + f"hidden_states dim {previous_hidden_states.size(-1)} != speculator expected {self.input_hidden_dim}. " + "Make sure the spec model is trained with the same base model." + ) + return previous_hidden_states def propose( self, - context_token_ids: np.ndarray, - previous_hidden_states: torch.Tensor, + context_token_ids: Union[torch.Tensor, np.ndarray, List[int]], + previous_hidden_states: Optional[torch.Tensor], num_predict_tokens: int, - ) -> Optional[np.ndarray]: - assert num_predict_tokens > 0, \ - f"num_predict_tokens must be greater than 0, got {num_predict_tokens}." - - input_ids = torch.tensor(context_token_ids, device=self.device) + ) -> Optional[torch.Tensor]: + assert num_predict_tokens > 0 + if isinstance(context_token_ids, torch.Tensor): + if context_token_ids.device != self.device: + input_ids = context_token_ids.to(self.device, non_blocking=True) + else: + input_ids = context_token_ids + else: + input_ids = torch.as_tensor(context_token_ids, device=self.device, dtype=torch.long) next_tokens = self.model.generate_proposals( input_ids=input_ids, previous_hidden_states=previous_hidden_states, num_predict_tokens=num_predict_tokens, ) + return next_tokens - return next_tokens.cpu().numpy() + # Borrow from eagle + def prepare_next_token_ids_cpu( + self, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor( + next_token_ids, dtype=torch.int32, device=self.device + ) + return next_token_ids + + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.triton_utils import triton + from vllm.v1.spec_decode.utils import eagle_prepare_next_token_padded_kernel + + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item() + ) + for i in range(num_reqs) + ], + dtype=np.int32, + ) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + backup_tokens_gpu = self.backup_next_token_ids.gpu + + batch_size, num_tokens = sampled_token_ids.shape + device = sampled_token_ids.device + + assert discard_request_mask.dtype == torch.bool + assert backup_tokens_gpu.dtype == torch.int32 + + next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device) + valid_sampled_tokens_count = next_token_ids.new_empty(batch_size) + + # Kernel grid: one program per request (row) + grid = (batch_size,) + + # Find the next power of 2 for block sizes + BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens) + eagle_prepare_next_token_padded_kernel[grid]( + sampled_token_ids, + discard_request_mask, + backup_tokens_gpu, + next_token_ids, + valid_sampled_tokens_count, + gpu_input_batch.vocab_size, + num_tokens, + batch_size, + sampled_token_ids.stride(0), + BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS, + ) + + return next_token_ids, valid_sampled_tokens_count + + +class SuffixProposer: + def __init__(self): + pass + + def load_model( + self, + model: None, + ): + pass diff --git a/arctic_inference/vllm/spec_dec/arctic_speculator.py b/arctic_inference/vllm/spec_dec/arctic_speculator.py index ecd857e53..22c4a09d2 100644 --- a/arctic_inference/vllm/spec_dec/arctic_speculator.py +++ b/arctic_inference/vllm/spec_dec/arctic_speculator.py @@ -22,7 +22,7 @@ from vllm.config import VllmConfig from arctic_inference.vllm.spec_dec.logits_processor_opt import LogitsProcessorOpt -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.v1.outputs import SamplerOutput from arctic_inference.vllm.spec_dec.fp8 import (Fp8ConfigWithEmbedding, OriginalFp8LinearMethod) @@ -33,7 +33,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from arctic_inference.py_custom_ops import (try_load_torch_library, + speculator_ln, sum_lstm) + SQRT2 = 2**0.5 +USE_CUSTOM_OP = try_load_torch_library() def padding_size(size: int) -> int: @@ -85,7 +89,7 @@ def __init__( self.bias = nn.Parameter(torch.empty(normalized_shape)) self.eps = eps - def forward(self, x): + def forward_fallback(self, x): xf = x xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) x = xf.type_as(x) @@ -94,6 +98,22 @@ def forward(self, x): x = x + self.bias return x + def forward_opt(self, x): + return speculator_ln( + x, + self.weight if self.elementwise_scale_and_shift else None, + self.bias if self.elementwise_scale_and_shift else None, + float(self.eps), + ) + + def forward(self, x): + if USE_CUSTOM_OP: + return self.forward_opt(x) + else: + return self.forward_fallback(x) + + + def _generate_cg_key(padding_size: int, head_index: int): return (padding_size << 16) + head_index @@ -222,13 +242,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: scale=1.0, skip_last_gather=True, ) - self.sampler = get_sampler() self.cuda_graph_max_batch_size = 0 self.cuda_graph_mode = False if not vllm_config.model_config.enforce_eager: self.cuda_graph_mode = True self.cuda_graphs = {} + spec_config = vllm_config.speculative_config + disable_by_batch_size = ( + spec_config.disable_by_batch_size + if spec_config is not None and spec_config.disable_by_batch_size is not None + else vllm_config.scheduler_config.max_num_seqs + ) + self.cuda_graph_max_batch_size = padding_size(disable_by_batch_size) self.cuda_graph_max_batch_size = padding_size( vllm_config.scheduler_config.max_num_seqs) self.static_cuda_buffers = { @@ -315,6 +341,7 @@ def generate_token_ids( argidx = torch.argmax(vals, -1).reshape(batch_size, -1) last_tokens = torch.gather(indices, -1, argidx) + last_tokens.clamp_(0, self.vocab_size - 1) if next_tokens_tensors[head_index] == None: next_tokens_tensors[head_index] = last_tokens else: @@ -421,11 +448,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.input_hidden_dim = config.input_hidden_dim - config.inner_dim = [int(i) for i in config.inner_dim.split(".")] + + def _parse_dim(value): + """Helper to normalize dimension config into a list of ints.""" + if isinstance(value, str): + return [int(i) for i in value.split(".")] + elif isinstance(value, int): + return [value] + elif isinstance(value, list): + return [int(i) for i in value] + return value + + config.inner_dim = _parse_dim(config.inner_dim) self.inner_dim = config.inner_dim - config.emb_dim = [int(i) for i in config.emb_dim.split(".")] - self.emb_dim = config.emb_dim - config.proj_dim = [int(i) for i in config.proj_dim.split(".")] + + config.emb_dim = _parse_dim(config.emb_dim) + self.emb_dim = config.emb_dim + + config.proj_dim = _parse_dim(config.proj_dim) self.proj_dim = config.proj_dim self.max_speculative_tokens = config.num_lookahead_tokens @@ -582,13 +622,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: scale=1.0, skip_last_gather=True, ) - self.sampler = get_sampler() self.cuda_graph_max_batch_size = 0 self.cuda_graph_mode = False - self.cuda_graph_max_batch_size = padding_size( - vllm_config.scheduler_config.max_num_seqs) + spec_config = vllm_config.speculative_config + disable_by_batch_size = ( + spec_config.disable_by_batch_size + if spec_config is not None and spec_config.disable_by_batch_size is not None + else vllm_config.scheduler_config.max_num_seqs + ) + self.cuda_graph_max_batch_size = padding_size(disable_by_batch_size) + self.static_cuda_buffers = { "last_tokens": torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long), @@ -622,6 +667,7 @@ def _prepare_cuda_graph_ios( cell_states=None, use_lstm=False, ): + # TODO: optimize this self.static_cuda_buffers["last_tokens"][:size] = last_tokens if cell_states is not None: self.static_cuda_buffers["cell_states"][:size] = cell_states @@ -664,32 +710,69 @@ def generate_states( prev_state = previous_hidden_states - z = self.forget_emb[actual_i](last_tokens).repeat(1, 1, 4) # b n d + z4 = self.forget_emb[actual_i](last_tokens).repeat(1, 1, 4) # b n d states = self.projs[actual_proj_i](prev_state) - added_states = torch.add(states, - z, - alpha=self.emb_weight / self.state_weight) - forget_input_output, cell_candidate = added_states.split( - [self.proj_dim[0] * 3, self.proj_dim[0]], dim=-1) - forget_gate, input_gate, output_gate = torch.sigmoid( - forget_input_output).split( - [self.proj_dim[0], self.proj_dim[0], self.proj_dim[0]], - dim=-1) + if USE_CUSTOM_OP: + # Shapes: + # prev_state: [B, 1, D_eff] (e.g., 2880 in the first round and 4096 later) + # states: [B, 1, 4*D_gate] (e.g., 4*4096) + # z4: [B, 1, 4*D_gate] + states_4d = states.flatten(0, 1).contiguous() # [B, 4*D_gate] + z4_4d = z4.flatten(0, 1).contiguous() # [B, 4*D_gate] + + orig_cell_shape = cell_states.shape # [B, 1, D_gate] + pc_d = cell_states.flatten(0, 1).contiguous() # [B, D_gate] + + # Optional precondition checks that mirror the kernel's TORCH_CHECKs: + assert states_4d.size(-1) % 4 == 0 + assert z4_4d.size(-1) == states_4d.size(-1) + assert pc_d.size(-1) == states_4d.size(-1) // 4 + + w_cell = self.cell_ln[actual_i].weight + b_cell = self.cell_ln[actual_i].bias + w_state = self.state_ln[actual_i].weight + b_state = self.state_ln[actual_i].bias + + alpha = float(self.emb_weight / self.state_weight) + eps_cell = float(self.cell_ln[actual_i].eps) + eps_state = float(self.state_ln[actual_i].eps) + use_fast_gelu = False + + state_d, cell_d = sum_lstm( + states_4d, z4_4d, pc_d, + w_cell, b_cell, w_state, b_state, + alpha, eps_cell, eps_state, use_fast_gelu + ) + + state = state_d.reshape(orig_cell_shape) # [B, 1, D_gate] + cell_states = cell_d.reshape(orig_cell_shape) # [B, 1, D_gate] + + return state, cell_states + else: + added_states = torch.add(states, + z4, + alpha=self.emb_weight / self.state_weight) - cell_candidate = self.activation( - self.cell_ln[actual_i](cell_candidate)) # b n d - cell_candidate = cell_candidate * input_gate + forget_input_output, cell_candidate = added_states.split( + [self.proj_dim[0] * 3, self.proj_dim[0]], dim=-1) + forget_gate, input_gate, output_gate = torch.sigmoid( + forget_input_output).split( + [self.proj_dim[0], self.proj_dim[0], self.proj_dim[0]], + dim=-1) - cell_states = cell_states * forget_gate - cell_states = cell_states + cell_candidate + cell_candidate = self.activation( + self.cell_ln[actual_i](cell_candidate)) # b n d + cell_candidate = cell_candidate * input_gate - state_candidate = self.activation( - self.state_ln[actual_i](cell_states)) - state = state_candidate * output_gate + cell_states = cell_states * forget_gate + cell_states = cell_states + cell_candidate - return state, cell_states + state_candidate = self.activation( + self.state_ln[actual_i](cell_states)) + state = state_candidate * output_gate + return state, cell_states else: # Project and predict z = self.emb[actual_i](last_tokens) # b k d @@ -712,6 +795,7 @@ def generate_token_ids( next_tokens_tensors: List[torch.Tensor], cell_states: torch.Tensor = None, ) -> torch.Tensor: + last_tokens.clamp_(0, self.vocab_size - 1) for head_index in range(num_predict_tokens): if self.method == "sum_lstm": states, cell_states = self.generate_states( @@ -731,6 +815,7 @@ def generate_token_ids( last_tokens = torch.argmax(logits, dim=-1).reshape(batch_size, -1) else: + # TODO: fuse topk + all_gather vals, indices = torch.topk(logits, 1, dim=-1) indices = indices + self.tp_rank * logits.shape[-1] @@ -743,6 +828,7 @@ def generate_token_ids( argidx = torch.argmax(vals, -1).reshape(batch_size, -1) last_tokens = torch.gather(indices, -1, argidx) + last_tokens.clamp_(0, self.vocab_size - 1) if next_tokens_tensors[head_index] == None: next_tokens_tensors[head_index] = last_tokens else: @@ -817,6 +903,9 @@ def generate_proposals( if g is None: device = torch.cuda.current_device() + for i in range(num_predict_tokens): + self.static_cuda_buffers["next_tokens"][i][:padded_size] = torch.zeros( + (padded_size, 1), dtype=torch.long, device=device) with graph_capture(device=device) as capture_context: g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=capture_context.stream): @@ -874,9 +963,15 @@ def maybe_load_weight(self, param, loaded_weight): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = collections.OrderedDict(weights) if self.method == "sum_lstm" and self.tie_lstm_embs: - weights.pop("input_emb.0.weight") - weights.pop("cell_emb.0.weight") - weights.pop("output_emb.0.weight") + try: + weights.pop("input_emb.0.weight") + weights.pop("cell_emb.0.weight") + weights.pop("output_emb.0.weight") + except KeyError: + # If the weights are not present, it means they are not tied + # and we should not try to pop them. + print("No tied LSTM embeddings found, skipping.") + pass for name, param in self.named_parameters(): if "projs." in name: print(f"REPLACING {name}") diff --git a/arctic_inference/vllm/spec_dec/fp8.py b/arctic_inference/vllm/spec_dec/fp8.py index 7366d9127..b2a749694 100644 --- a/arctic_inference/vllm/spec_dec/fp8.py +++ b/arctic_inference/vllm/spec_dec/fp8.py @@ -66,9 +66,7 @@ def __init__(self, quant_config: Fp8Config): # Marlin doesn't support block-wise fp8 self.use_marlin = False - self.fp8_linear = Fp8LinearOp( - # Default to using per_token quantization if cutlass is supported - use_per_token_if_dynamic=cutlass_fp8_supported()) + self.fp8_linear = Fp8LinearOp(act_quant_static=False) def create_weights( self, diff --git a/arctic_inference/vllm/spec_dec/logits_processor_opt.py b/arctic_inference/vllm/spec_dec/logits_processor_opt.py index 4516a8656..b90c9ba86 100644 --- a/arctic_inference/vllm/spec_dec/logits_processor_opt.py +++ b/arctic_inference/vllm/spec_dec/logits_processor_opt.py @@ -44,8 +44,7 @@ def __init__(self, self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_gather = not current_platform.is_tpu( - ) and not envs.VLLM_USE_V1 + self.use_gather = False self.skip_last_gather = skip_last_gather diff --git a/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py b/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py index 149f40a3d..0b766229d 100644 --- a/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py +++ b/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py @@ -408,19 +408,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Copy the data. Select chunk corresponding to current shard. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - if current_platform.is_hpu(): - # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here, - # so we're using a workaround. Remove this when fixed in - # HPU PT bridge. - padded_weight = torch.cat([ - loaded_weight, - torch.zeros(param.shape[0] - loaded_weight.shape[0], - *loaded_weight.shape[1:]) - ]) - param.data.copy_(padded_weight) - else: - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) - param[loaded_weight.shape[0]:].data.fill_(0) + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) def forward(self, input_): if self.tp_size > 1: @@ -432,7 +421,12 @@ def forward(self, input_): self.shard_indices.added_vocab_start_index, self.shard_indices.added_vocab_end_index) else: - masked_input = input_ + # For tp_size==1 there is no masking, so clamp to the + # valid embedding range. Async scheduling + spec decode + # can leave -1 sentinel tokens in input_ids that would + # otherwise crash F.embedding. + masked_input = input_.clamp( + 0, self.num_embeddings_per_partition - 1) # Get the embeddings. output_parallel = self.quant_method.embedding(self, masked_input.long()) diff --git a/arctic_inference/vllm/structured_output.py b/arctic_inference/vllm/structured_output.py new file mode 100644 index 000000000..9409ea784 --- /dev/null +++ b/arctic_inference/vllm/structured_output.py @@ -0,0 +1,38 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.logger import init_logger +from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend + +from arctic_inference.patching import ArcticPatch + +logger = init_logger(__name__) + + +class XgrammarBackendPatch(ArcticPatch[XgrammarBackend]): + """Patch for XgrammarBackend to handle additional structured output.""" + + _orig_post_init = XgrammarBackend.__post_init__ + + def __post_init__(self): + self._orig_post_init() + + if self.vllm_config.speculative_config is not None: + self.num_speculative_tokens = \ + max(self.vllm_config.speculative_config.num_speculative_tokens, + self.vllm_config.speculative_config.suffix_speculative_tokens) + + logger.info(f"XgrammarBackendPatch: num_speculative_tokens=" + f"{self.num_speculative_tokens}") diff --git a/arctic_inference/vllm/swiftkv/llama_swiftkv.py b/arctic_inference/vllm/swiftkv/llama_swiftkv.py index 93f003579..675dc4aa8 100644 --- a/arctic_inference/vllm/swiftkv/llama_swiftkv.py +++ b/arctic_inference/vllm/swiftkv/llama_swiftkv.py @@ -20,10 +20,12 @@ from torch import nn import vllm.distributed.parallel_state as parallel_state -from vllm.attention.backends.abstract import AttentionType +from vllm.v1.attention.backend import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.forward_context import ForwardContext, get_forward_context +from vllm.config.compilation import CUDAGraphMode +from vllm.forward_context import (BatchDescriptor, ForwardContext, + get_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -40,19 +42,20 @@ LlamaMLP) from vllm.model_executor.models.utils import (AutoWeightsLoader, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.v1.sample.metadata import SamplingMetadata from vllm.sequence import IntermediateTensors # Add FlashInfer backend detection try: from vllm.v1.attention.backends.flashinfer import FlashInferMetadata FLASHINFER_AVAILABLE = True -except ImportError: +except (ImportError, RuntimeError): FLASHINFER_AVAILABLE = False FlashInferMetadata = None import arctic_inference.vllm.model_runner as model_runner from arctic_inference.common.swiftkv.configs import LlamaSwiftKVConfig +import arctic_inference.envs as envs logger = init_logger(__name__) @@ -75,8 +78,6 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, @@ -90,8 +91,6 @@ def __init__( hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=bias, @@ -147,16 +146,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): - rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - # Support abacusai/Smaug-72B-v0.1 with attention_bias - # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = LlamaSwiftKVAttention( @@ -165,8 +156,6 @@ def __init__( num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), - rope_theta=rope_theta, - rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, @@ -343,10 +332,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=self.quant_config, ) self.layers = torch.nn.ModuleList([ - LlamaDecoderLayer(config=config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.layers.{idx}") + LlamaDecoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers.{idx}", + config=config,) for idx in range(config.num_key_value_layers) ]) with model_runner.set_shift_parallel_mode(True): @@ -367,12 +355,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._init_prefill_runner(vllm_config) self._init_decode_runner(vllm_config) - from arctic_inference.py_custom_ops import try_load_torch_library - self.use_custom_ops = True if try_load_torch_library() else False + from arctic_inference.py_custom_ops import (try_load_torch_library, + try_load_jit_library) + + self.use_custom_ops = try_load_torch_library() or try_load_jit_library() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def _init_prefill_runner(self, vllm_config: VllmConfig): vllm_config.compilation_config = copy.copy( vllm_config.compilation_config) @@ -609,7 +603,7 @@ def swiftkv_select( kv_cache = attn.kv_cache[forward_context.virtual_engine] if kv_cache.numel(): # different cache layouts - if isinstance(attn_metadata, FlashInferMetadata): + if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): # FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size] key_caches.append(kv_cache[:, 0]) value_caches.append(kv_cache[:, 1]) @@ -637,7 +631,7 @@ def swiftkv_select( attn = layer.self_attn.attn kv_cache = attn.kv_cache[forward_context.virtual_engine] if kv_cache.numel(): - if isinstance(attn_metadata, FlashInferMetadata): + if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): # FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size] k_cache, v_cache = kv_cache.unbind(1) else: @@ -658,7 +652,7 @@ def swiftkv_select( logits_indices = attn_metadata.swiftkv_logits_indices num_surviving_tokens = logits_indices.numel() - if isinstance(attn_metadata, FlashInferMetadata): + if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): # Handle FlashInfer metadata final_logits_indices = self._fix_flashinfer_metadata(attn_metadata, logits_indices, num_surviving_tokens) else: @@ -702,6 +696,28 @@ def forward( k_states, v_states)) + # When swiftkv_select filters tokens (mixed prefill-decode batches), + # the decode runner processes fewer tokens than the original batch. + # Piecewise CUDA graphs captured for the original batch size cannot + # be replayed with modified attention metadata (stale FA3 scheduler + # metadata, changed query_start_loc, etc.), so we fall back to eager + # compiled execution for the decode runner on these batches. + # For decode-only batches all tokens survive, so CUDA graphs are + # used normally -- preserving decode throughput. + fwd_ctx = get_forward_context() + saved_batch_descriptor = fwd_ctx.batch_descriptor + saved_cudagraph_mode = fwd_ctx.cudagraph_runtime_mode + decode_num_tokens = hidden_states.shape[0] + if (saved_batch_descriptor is not None + and saved_batch_descriptor.num_tokens != decode_num_tokens): + fwd_ctx.batch_descriptor = BatchDescriptor( + num_tokens=decode_num_tokens, + num_reqs=saved_batch_descriptor.num_reqs, + uniform=saved_batch_descriptor.uniform, + has_lora=saved_batch_descriptor.has_lora, + ) + fwd_ctx.cudagraph_runtime_mode = CUDAGraphMode.NONE + with model_runner.set_shift_parallel_mode(True): hidden_states = self.decode_runner( hidden_states, @@ -711,12 +727,15 @@ def forward( v_states, ) + fwd_ctx.batch_descriptor = saved_batch_descriptor + fwd_ctx.cudagraph_runtime_mode = saved_cudagraph_mode + attn_metadata = get_attn_metadata_for_swiftkv() if attn_metadata is not None: logits_indices = attn_metadata.swiftkv_logits_indices batch_size = logits_indices.numel() - if isinstance(attn_metadata, FlashInferMetadata): + if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): inverse_sort_indices = attn_metadata.swiftkv_inverse_sort_indices orig_hidden_states[logits_indices] = hidden_states[inverse_sort_indices][:batch_size] else: @@ -839,6 +858,9 @@ def _init_model(self, def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -853,10 +875,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/arctic_inference/vllm/ulysses.py b/arctic_inference/vllm/ulysses.py index 23210a015..7de12536a 100644 --- a/arctic_inference/vllm/ulysses.py +++ b/arctic_inference/vllm/ulysses.py @@ -16,44 +16,54 @@ import threading import weakref from contextlib import contextmanager -from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Any +from concurrent.futures import Future +from collections import deque +from collections.abc import Callable +from typing import Optional, cast +import time import torch import vllm.distributed.parallel_state as parallel_state import vllm.envs as envs from vllm.attention.layer import Attention -from vllm.config import ModelConfig, ParallelConfig +from vllm.config import ModelConfig, ParallelConfig, CUDAGraphMode, VllmConfig from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.parallel_state import (init_model_parallel_group, get_world_group, destroy_model_parallel, destroy_distributed_environment) -from vllm.executor.multiproc_worker_utils import ( +from vllm.v1.executor.multiproc_executor import ( set_multiprocessing_worker_envs) -from vllm.utils import get_distributed_init_method, get_open_port +from vllm.utils.network_utils import get_distributed_init_method, get_open_port, get_loopback_ip +from vllm.utils.system_utils import get_mp_context from vllm.v1.executor.abstract import FailureCallback from vllm.v1.executor.multiproc_executor import (MultiprocExecutor, WorkerProc, - UnreadyWorkerProcHandle) -from vllm.platforms import current_platform -from vllm.utils import resolve_obj_by_qualname -from vllm.compilation.backends import PiecewiseCompileInterpreter -from vllm.model_executor.layers.fused_moe import FusedMoE + UnreadyWorkerProcHandle, + FutureWrapper) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.config.compilation import CompilationConfig +from vllm.v1.engine.core import EngineCore, EngineCoreOutputs +from vllm.v1.outputs import ModelRunnerOutput + from arctic_inference.patching import ArcticPatch +# global variable to hack compilation config +_ulysses_sp_size = 1 def apply_shift_parallel_patches(): - UlyssesModelConfigPatch.apply_patch() - UlyssesParallelStatePatch.apply_patch() - UlyssesWorkerProcPatch.apply_patch() - UlyssesMultiprocExecutorPatch.apply_patch() - UlyssesAttentionPatch.apply_patch() - PiecewiseCompileInterpreterPatch.apply_patch() - UlyssesFusedMoEPatch.apply_patch() + UlyssesModelConfig.apply_patch() + UlyssesParallelState.apply_patch() + UlyssesWorkerProc.apply_patch() + UlyssesMultiprocExecutor.apply_patch() + UlyssesAttention.apply_patch() + UlyssesCudagraphDispatcher.apply_patch() + UlyssesCompilationConfig.apply_patch() + UlyssesVllmConfig.apply_patch() + UlyssesEngineCore.apply_patch() -class UlyssesModelConfigPatch(ArcticPatch[ModelConfig]): +class UlyssesModelConfig(ArcticPatch[ModelConfig]): _orig_get_num_kv_heads = ModelConfig.get_num_kv_heads _orig_get_num_attention_heads = ModelConfig.get_num_attention_heads @@ -74,7 +84,8 @@ def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp"): + or self.hf_config.model_type == "mimo_mtp" + or self.hf_config.model_type == "glm4_moe_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -90,54 +101,27 @@ def get_layers_start_end_indices( return start, end -class UlyssesParallelStatePatch(ArcticPatch[parallel_state]): +class UlyssesParallelState(ArcticPatch[parallel_state]): _SP = None _SP_TP = None - _SP_AA = None - _SP_AG = None - # Rationale for SP_AA and SP_AG groups: - # When num_kv_heads > SP, the kv heads are distributed and replicated as in TP. - # To implement the logic, the distributed kv heads are exchanged with a local - # all-to-all within SP_AA group followed by an local all-gather within SP_AG - # group. The SP_AA and SP_AG groups partitions the SP group into two orthogonal - # sub-groups and will not be initialized if max(1, num_kv_heads / TP) < SP. - # See the figure in PR #126 https://github.com/snowflakedb/ArcticInference/pull/126 def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + prefill_context_model_parallel_size: int = 1, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: - """ - Initialize model parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used for tensor model - parallelism. - pipeline_model_parallel_size: number of GPUs used for pipeline model - parallelism. - - Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: - 4 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7] - 2 pipeline model-parallel groups: - [g0, g2, g4, g6], [g1, g3, g5, g7] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP - # Get world size and rank. Ensure some consistencies. + + from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP, _DCP, _PCP + assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() rank = torch.distributed.get_rank() backend = backend or torch.distributed.get_backend( - get_world_group().device_group) + get_world_group().device_group + ) data_parallel_size = 1 from vllm.config import get_current_vllm_config @@ -145,154 +129,180 @@ def initialize_model_parallel( if config is not None: data_parallel_size = config.parallel_config.data_parallel_size - sequence_parallel_size = \ - config.parallel_config.ulysses_sequence_parallel_size - - # the layout order is: ExternalDP x DP x PP x SP x TP - # ExternalDP is the data parallel group that is not part of the model, - # every dp rank can generate independently (in verl integration). - # DP is the data parallel group that is part of the model, - # all the ranks in the same DP group should generate simultaneously, - # i.e. the `generate` call in the same DP group should be called together, - # otherwise it will cause deadlock. - # to get group_ranks for each dimension, transpose that dimension to the - # last dimension, then reshape to 2D, then unbind the last dimension - all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, - sequence_parallel_size, tensor_model_parallel_size) # noqa + sequence_parallel_size = config.parallel_config.ulysses_sequence_parallel_size + + # vLLM types allow None, but group building needs an int + if decode_context_model_parallel_size is None: + # treat "no DCP" as DCP==TP (common interpretation) + decode_context_model_parallel_size = tensor_model_parallel_size - # Build the tensor model-parallel groups. - assert _TP is None, ("tensor model parallel group is already initialized") + # Layout order (extended from vLLM's): ExternalDP x DP x PP x PCP x SP x TP + all_ranks = torch.arange(world_size).reshape( + -1, + data_parallel_size, + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + sequence_parallel_size, + tensor_model_parallel_size, + ) + + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] TP_group_ranks = group_ranks - # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="tp") - - # Build the pipeline model-parallel groups. - assert _PP is None, ( - "pipeline model parallel group is already initialized") - group_ranks = all_ranks.transpose(2, 4).reshape( - -1, pipeline_model_parallel_size).unbind(0) + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + ) + + assert _DCP is None, "decode context model parallel group is already initialized" + group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + DCP_group_ranks = group_ranks + _DCP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp", + ) + + assert _PCP is None, "prefill context parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(3, 5) + .reshape(-1, prefill_context_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] + PCP_group_ranks = group_ranks + _PCP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="pcp", + ) + + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(2, 5) + .reshape(-1, pipeline_model_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] PP_group_ranks = group_ranks - _PP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="pp") - - assert _DP is None, ("data parallel group is already initialized") - group_ranks = all_ranks.transpose(1, - 4).reshape(-1, - data_parallel_size).unbind(0) + _PP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="pp", + ) + + assert _DP is None, "data parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(1, 5) + .reshape(-1, data_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] DP_group_ranks = group_ranks - _DP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="dp") - - assert _EP is None, ("expert parallel group is already initialized") - group_ranks = all_ranks.transpose(1, 3).reshape( - -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + _DP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="dp", + ) + + assert _EP is None, "expert parallel group is already initialized" + group_ranks = ( + all_ranks.permute(0, 4, 2, 1, 3, 5) # ExternalDP, SP, PP, DP, PCP, TP + .reshape(-1, data_parallel_size * prefill_context_model_parallel_size * tensor_model_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] EP_group_ranks = group_ranks - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") - - # Build the sequence parallel groups. - assert parallel_state._SP is None, ( - "sequence parallel group is already initialized") - group_ranks = all_ranks.transpose(3, 4).reshape( - -1, sequence_parallel_size).unbind(0) + _EP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="ep", + ) + + assert parallel_state._SP is None, "sequence parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(4, 5) + .reshape(-1, sequence_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] SP_group_ranks = group_ranks - _SP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="sp") - - # Build full-TP groups for ShiftParallel - shift_parallel_size = (tensor_model_parallel_size * - sequence_parallel_size) - assert parallel_state._SP_TP is None, ( - "full-TP group is already initialized") - # transpose(3, 4) for obtaining the correct attn head order - group_ranks = all_ranks.transpose(3, 4).reshape( - -1, shift_parallel_size).unbind(0) + _SP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="sp", + ) + + shift_parallel_size = tensor_model_parallel_size * sequence_parallel_size + assert parallel_state._SP_TP is None, "full-TP group is already initialized" + group_ranks = ( + all_ranks.transpose(4, 5) # keep same head-order trick as your old transpose(3,4) + .reshape(-1, shift_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] SP_TP_group_ranks = group_ranks - _SP_TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="sp_tp") + _SP_TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="sp_tp", + ) parallel_state.logger.info( "rank %s in world size %s is assigned as DP rank %s, PP rank %s, " - "TP rank %s, EP rank %s, SP rank %s, SP_TP rank %s", rank, - world_size, _DP.rank_in_group, _PP.rank_in_group, - _TP.rank_in_group, _EP.rank_in_group, _SP.rank_in_group, - _SP_TP.rank_in_group) + "PCP rank %s, TP rank %s, DCP rank %s, EP rank %s, SP rank %s, SP_TP rank %s", + rank, + world_size, + _DP.rank_in_group, + _PP.rank_in_group, + _PCP.rank_in_group, + _TP.rank_in_group, + _DCP.rank_in_group, + _EP.rank_in_group, + _SP.rank_in_group, + _SP_TP.rank_in_group, + ) parallel_state._TP = _TP + parallel_state._DCP = _DCP + parallel_state._PCP = _PCP parallel_state._PP = _PP + parallel_state._DP = _DP + parallel_state._EP = _EP parallel_state._SP = _SP parallel_state._SP_TP = _SP_TP - parallel_state._DP = _DP - - # check if SP requires kv replication - num_kv_heads = config.model_config._orig_get_num_kv_heads(config.parallel_config) - if num_kv_heads < sequence_parallel_size: - - # divide SP group into two orthogonal sub-groups: - sp_aa_size = num_kv_heads - sp_ag_size = sequence_parallel_size // num_kv_heads - all_ranks_ = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, - sp_aa_size, sp_ag_size, tensor_model_parallel_size) - - group_ranks = all_ranks_.transpose(3, 5).reshape( - -1, sp_aa_size).unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - SP_AA_group_ranks = group_ranks - # SP_AA group is used for all-to-all communication of kv heads - _SP_AA = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="sp_aa") - - group_ranks = all_ranks_.transpose(4, 5).reshape( - -1, sp_ag_size).unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - SP_AG_group_ranks = group_ranks - # SP_AG group is used for all-gather communication of kv heads - _SP_AG = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="sp_ag") - - parallel_state._SP_AA = _SP_AA - parallel_state._SP_AG = _SP_AG if get_world_group().local_rank == 0: parallel_state.logger.info( - f"UlyssesParallelStatePatch initialized:\n" - f" PP {_PP.world_size} ranks {PP_group_ranks}\n" - f" TP {_TP.world_size} ranks {TP_group_ranks}\n" - f" SP {_SP.world_size} ranks {SP_group_ranks}\n" - f" DP {_DP.world_size} ranks {DP_group_ranks}\n" - f" EP {_EP.world_size} ranks {EP_group_ranks}\n" - f" SP_TP {_SP_TP.world_size} ranks {SP_TP_group_ranks}") - if num_kv_heads < sequence_parallel_size: - parallel_state.logger.info( - f" SP_AA {parallel_state._SP_AA.world_size} ranks {SP_AA_group_ranks}\n" - f" SP_AG {parallel_state._SP_AG.world_size} ranks {SP_AG_group_ranks}\n") + "UlyssesParallelState initialized:\n" + f" PP {_PP.world_size} ranks {PP_group_ranks}\n" + f" TP {_TP.world_size} ranks {TP_group_ranks}\n" + f" DCP {_DCP.world_size} ranks {DCP_group_ranks}\n" + f" PCP {_PCP.world_size} ranks {PCP_group_ranks}\n" + f" SP {_SP.world_size} ranks {SP_group_ranks}\n" + f" DP {_DP.world_size} ranks {DP_group_ranks}\n" + f" EP {_EP.world_size} ranks {EP_group_ranks}\n" + f" SP_TP {_SP_TP.world_size} ranks {SP_TP_group_ranks}" + ) + + num_kv_heads = config.model_config._orig_get_num_kv_heads(config.parallel_config) + if get_world_group().local_rank == 0 and num_kv_heads < sequence_parallel_size: + parallel_state.logger.info( + f"KV cache is replicated by factor {sequence_parallel_size // num_kv_heads}" + ) @contextmanager def graph_capture(device: torch.device): @@ -316,22 +326,16 @@ def graph_capture(device: torch.device): yield context -class UlyssesWorkerProcPatch(ArcticPatch[WorkerProc]): +class UlyssesWorkerProc(ArcticPatch[WorkerProc]): def destroy_model_parallel(self): - from vllm.distributed.parallel_state import _SP, _SP_TP, _SP_AA, _SP_AG + from vllm.distributed.parallel_state import _SP, _SP_TP if _SP: _SP.destroy() _SP = None if _SP_TP: _SP_TP.destroy() _SP_TP = None - if _SP_AA: - _SP_AA.destroy() - _SP_AA = None - if _SP_AG: - _SP_AG.destroy() - _SP_AG = None def shutdown(self): self.rpc_broadcast_mq = None @@ -342,7 +346,7 @@ def shutdown(self): destroy_distributed_environment() -class UlyssesMultiprocExecutorPatch(ArcticPatch[MultiprocExecutor]): +class UlyssesMultiprocExecutor(ArcticPatch[MultiprocExecutor]): def _init_executor(self) -> None: # Call self.shutdown at exit to clean up @@ -350,81 +354,125 @@ def _init_executor(self) -> None: self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False self.shutdown_event = threading.Event() - self.failure_callback: Optional[FailureCallback] = None - self.io_thread_pool: Optional[ThreadPoolExecutor] = None + self.failure_callback: FailureCallback | None = None self.world_size = self.parallel_config.world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - pp_parallel_size = self.parallel_config.pipeline_parallel_size - sp_parallel_size = self.parallel_config.ulysses_sequence_parallel_size - assert (self.world_size == - tensor_parallel_size * pp_parallel_size * sp_parallel_size), ( + assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( + f"global world_size ({self.parallel_config.world_size}) must be " + f"divisible by nnodes_within_dp " + f"({self.parallel_config.nnodes_within_dp}). " + ) + self.local_world_size = self.parallel_config.local_world_size + tp_size = self.parallel_config.tensor_parallel_size + pp_size = self.parallel_config.pipeline_parallel_size + pcp_size = self.parallel_config.prefill_context_parallel_size + sp_size = self.parallel_config.ulysses_sequence_parallel_size + + assert self.world_size == tp_size * pp_size * pcp_size * sp_size, ( f"world_size ({self.world_size}) must be equal to the " - f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}) x ulysses_sequence_parallel" - f"_size ({sp_parallel_size}).") + f"tensor_parallel_size ({tp_size}) x pipeline" + f"_parallel_size ({pp_size}) x prefill_context" + f"_parallel_size ({pcp_size}) x ulysses_sequence_parallel" + f"_size ({sp_size})." + ) - # Set multiprocessing envs that are common to V0 and V1 - set_multiprocessing_worker_envs(self.parallel_config) + # Set multiprocessing envs + set_multiprocessing_worker_envs() - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. + # use the loopback address get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - + get_loopback_ip(), get_open_port() + ) + self.rpc_broadcast_mq: MessageQueue | None = None + scheduler_output_handle: Handle | None = None # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs - max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue(self.world_size, - self.world_size, - max_chunk_bytes=max_chunk_bytes) - scheduler_output_handle = self.rpc_broadcast_mq.export_handle() - + if self.parallel_config.node_rank_within_dp == 0: + # For leader node within each dp rank, + # each dp will have its own leader multiproc executor. + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + self.rpc_broadcast_mq = MessageQueue( + self.world_size, + self.local_world_size, + max_chunk_bytes=max_chunk_bytes, + connect_ip=self.parallel_config.master_addr, + ) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + # Create workers + # FIX: Removed duplicate initialization and local import that caused UnboundLocalError + context = get_mp_context() + shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] + success = False try: - for rank in range(self.world_size): + global_start_rank = ( + self.local_world_size * self.parallel_config.node_rank_within_dp + ) + for local_rank in range(self.local_world_size): + global_rank = global_start_rank + local_rank unready_workers.append( WorkerProc.make_worker_process( vllm_config=self.vllm_config, - local_rank=rank, - rank=rank, + local_rank=local_rank, + rank=global_rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, - )) + shared_worker_lock=shared_worker_lock, + ) + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. + + # Wait for all local workers to be ready. self.workers = WorkerProc.wait_for_ready(unready_workers) + # Start background thread to monitor worker health if not in headless mode. + if self.monitor_workers: + self.start_worker_monitor() + + self.response_mqs = [] + # Only leader node have remote response mqs + if self.parallel_config.node_rank_within_dp == 0: + for rank in range(self.world_size): + if rank < self.local_world_size: + local_message_queue = self.workers[rank].worker_response_mq + assert local_message_queue is not None + self.response_mqs.append(local_message_queue) + else: + remote_message_queue = self.workers[0].peer_worker_response_mqs[ + rank + ] + assert remote_message_queue is not None + self.response_mqs.append(remote_message_queue) + # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc. - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() - self.start_worker_monitor() + # Wait for all input mqs to be ready. + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.wait_until_ready() + # Wait for all remote response mqs to be ready. + for response_mq in self.response_mqs: + response_mq.wait_until_ready() success = True finally: if not success: # Clean up the worker procs if there was a failure. - self._ensure_worker_termination( - [w.proc for w in unready_workers]) + # Close death_writers first to signal workers to exit + for uw in unready_workers: + if uw.death_writer is not None: + uw.death_writer.close() + self._ensure_worker_termination([uw.proc for uw in unready_workers]) - # For pipeline parallel, we use a thread pool for asynchronous - # execute_model. - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") + self.futures_queue = deque[tuple[FutureWrapper, Callable]]() self.output_rank = self._get_output_rank() -class UlyssesAttentionPatch(ArcticPatch[Attention]): +class UlyssesAttention(ArcticPatch[Attention]): _orig_init = Attention.__init__ _orig_forward = Attention.forward @@ -438,17 +486,8 @@ def __init__(self, num_heads, *args, **kwargs): num_kv_heads = kwargs["num_kv_heads"] self.is_kv_replicated = True if num_kv_heads < self.sp_size else False if self.is_kv_replicated: + self.replication_factor = self.sp_size // num_kv_heads num_kv_heads = 1 - assert parallel_state._SP_AA is not None and parallel_state._SP_AG is not None, ( - "UlyssesAttentionPatch requires SP_AA and SP_AG groups to be initialized.") - self.sp_aa_device_group = parallel_state._SP_AA.device_group - self.sp_ag_device_group = parallel_state._SP_AG.device_group - self.sp_aa_size = parallel_state._SP_AA.world_size - self.sp_ag_size = parallel_state._SP_AG.world_size - # this reorders the all-gathered sequence - self.order = [j * self.sp_aa_size + i - for i in range(self.sp_aa_size) - for j in range(self.sp_ag_size)] else: num_kv_heads //= self.sp_size kwargs["num_kv_heads"] = num_kv_heads @@ -459,52 +498,29 @@ def forward(self, query, key, value, **kwargs): if self.sp_size == 1 or is_shift_parallel_mode(): return self._orig_forward(query, key, value, **kwargs) + # prepare + q = query.view(-1, self.sp_size, self.num_heads * self.head_size) if self.is_kv_replicated: - # Ulysses all-to-all 1/2 (query) - q = query.view(-1, - self.sp_size, self.num_heads * self.head_size).transpose( - 0, 1).reshape(-1, - self.num_heads * self.head_size) - q_ = torch.empty_like(q) - torch.distributed.all_to_all_single(q_, q, group=self.sp_device_group) - # Ulysses pack (key, value) - kv = torch.cat((key.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size), - value.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size)), - dim=-1).transpose(0, 1).reshape( - -1, 2 * self.num_kv_heads * self.head_size) - # Ulysses all-to-all (key, value) - kv_part = torch.empty_like(kv) - torch.distributed.all_to_all_single(kv_part, kv, group=self.sp_aa_device_group) - # Ulysses all-gather (key, value) - kv_ = torch.empty(q_.shape[0], - 2 * self.num_kv_heads * self.head_size, - dtype=query.dtype, - device=query.device) - torch.distributed.all_gather_into_tensor(kv_, - kv_part, - group=self.sp_ag_device_group) - # reorder - kv_chunk = kv_.chunk(self.sp_size) - kv_ordered = torch.cat([kv_chunk[i] for i in self.order]) - # unpack (key, value) - k_, v_ = kv_ordered.split([self.num_kv_heads * self.head_size] * 2, dim=-1) + k = key.view(-1, self.sp_size // self.replication_factor, self.head_size).repeat_interleave(self.replication_factor, dim=1) + v = value.view(-1, self.sp_size // self.replication_factor, self.head_size).repeat_interleave(self.replication_factor, dim=1) else: - # pack - qkv = (torch.cat( - (query.view(-1, self.sp_size, self.num_heads * self.head_size), - key.view(-1, self.sp_size, self.num_kv_heads * self.head_size), - value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)), - dim=-1) - .transpose(0, 1) - .reshape(-1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size)) - # Ulysses all-to-all 1/2 - qkv_ = torch.empty_like(qkv) - torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group) - # unpack - q_, k_, v_ = qkv_.split([ - self.num_heads * self.head_size, self.num_kv_heads * - self.head_size, self.num_kv_heads * self.head_size - ], dim=-1) + k = key.view(-1, self.sp_size, self.num_kv_heads * self.head_size) + v = value.view(-1, self.sp_size, self.num_kv_heads * self.head_size) + + # pack + qkv = torch.cat((q, k, v), dim=-1).transpose(0, 1).reshape( + -1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size) + + # Ulysses all-to-all 1/2 + qkv_ = torch.empty_like(qkv) + torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group) + + # unpack + q_, k_, v_ = qkv_.split([ + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + self.num_kv_heads * self.head_size + ], dim=-1) # original attention c_ = self._orig_forward(q_, k_, v_, **kwargs) @@ -519,81 +535,288 @@ def forward(self, query, key, value, **kwargs): return output -class PiecewiseCompileInterpreterPatch(ArcticPatch[PiecewiseCompileInterpreter]): - - # find the symbolic shape of the subgraph - def find_symbolic_shape(self, args: tuple[torch.fx.node.Argument, - ...]) -> torch.SymInt: - symbols = set() - for x in args: - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor): - for dim in x.shape: - if isinstance(dim, torch.SymInt): - symbols.update(dim.node.expr.free_symbols) - assert len(symbols) == 1, ( - f"Expected exactly one symbolic shape, but found {len(symbols)}: {symbols}") - return list(symbols)[0] - - def call_module(self, target: torch.fx.node.Target, - args: tuple[torch.fx.node.Argument, - ...], kwargs: dict[str, Any]) -> Any: - assert isinstance(target, str) - # [Arctic Inference] - # Since monkeypatching inherits the original class - # through ArcticPatch class, we lose the access to the original class' - # super() function. Instead of using super(), we directly invoke call_module - # from the super class torch.fx.Interpreter of PiecewiseCompileInterpreter. - # see - v0.9.0.1/compilation/backends.py#L241 - output = torch.fx.Interpreter.call_module(self, target, args, kwargs) - - if target in self.compile_submod_names: - index = self.compile_submod_names.index(target) - submod = self.fetch_attr(target) - # [Arctic Inference] - # Compiler may create subgraphs with certain symbolic - # integer values that violates vllm's assumption here: - # - v0.9.0.1/compilation/base_piecewise_backend.py#L64 - # The index of the significant symbol determines the runtime shape here: - # - v0.9.0.1/compilation/cuda_piecewise_backend.py#L112 - # The fix is relaxing vllm's original assumption that there is only a - # single symbolic that determines the shape.We then find the matching - # symbol indices. - sym_shape = self.find_symbolic_shape(args) - sym_shape_indices = [] - for i, x in enumerate(args): - if isinstance(x, torch.SymInt): - if sym_shape == x: - sym_shape_indices.append(i) - - global compilation_start_time - compiled_graph_for_general_shape = self.vllm_backend.\ - compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None) - - piecewise_backend = resolve_obj_by_qualname( - current_platform.get_piecewise_backend_cls()) - self.module.__dict__[target] = piecewise_backend( - submod, self.vllm_config, self.graph_pool, index, - len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_general_shape, self.vllm_backend) - - from vllm.compilation.counter import compilation_counter - compilation_counter.num_piecewise_capturable_graphs_seen += 1 +class UlyssesCudagraphDispatcher(ArcticPatch[CudagraphDispatcher]): + + _orig_initialize_cudagraph_keys = CudagraphDispatcher.initialize_cudagraph_keys + + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, + uniform_decode_query_len: int): + self._orig_initialize_cudagraph_keys(cudagraph_mode, uniform_decode_query_len) + + # sp_group = getattr(parallel_state, "_SP", None) + # sp_size = sp_group.world_size if sp_group is not None else 1 + # if sp_size <= 1: + # return - return output + # if self.vllm_config.lora_config: + # if self.compilation_config.cudagraph_specialize_lora: + # lora_cases = [True, False] + # else: + # lora_cases = [True] + # else: + # lora_cases = [False] + + # if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + # for bs, has_lora in product( + # self.compilation_config.cudagraph_capture_sizes, lora_cases + # ): + # bd = self._create_padded_batch_descriptor( + # num_tokens=bs, # * sp_size, + # uniform_decode=False, + # has_lora=has_lora, + # ).relax_for_mixed_batch_cudagraphs() + + # self.add_cudagraph_key(cudagraph_mode.mixed_mode(), bd) + + # if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + # and cudagraph_mode.separate_routine()): + # max_num_tokens = ( + # uniform_decode_query_len + # * self.vllm_config.scheduler_config.max_num_seqs + # ) + # cudagraph_capture_sizes_for_decode = [ + # x for x in self.compilation_config.cudagraph_capture_sizes + # if uniform_decode_query_len <= x <= max_num_tokens + # ] + # for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): + # bd = self._create_padded_batch_descriptor( + # num_tokens=bs, # * sp_size, + # uniform_decode=True, + # has_lora=has_lora, + # ) + # self.add_cudagraph_key(CUDAGraphMode.FULL, bd) + + +class UlyssesCompilationConfig(ArcticPatch[CompilationConfig]): + + _orig_post_init_cudagraph_sizes = CompilationConfig.post_init_cudagraph_sizes + + def post_init_cudagraph_sizes(self) -> None: + +# # print(f"Before post_init_cudagraph_sizes: max_cudagraph_capture_size={self.max_cudagraph_capture_size}, cudagraph_capture_sizes={self.cudagraph_capture_sizes}") + +# # Access the module-level variable set during engine config creation +# sp_size = _ulysses_sp_size + +# # scale sizes by Ulysses sequence parallel size +# self.max_cudagraph_capture_size *= sp_size +# self.cudagraph_capture_sizes = [size * sp_size for size in self.cudagraph_capture_sizes] + +# # print(f"After scaling for SP size {sp_size}: max_cudagraph_capture_size={self.max_cudagraph_capture_size}, cudagraph_capture_sizes={self.cudagraph_capture_sizes}") + + self._orig_post_init_cudagraph_sizes() + +# # revert back to original shapes +# self.max_cudagraph_capture_size //= sp_size +# self.cudagraph_capture_sizes = [size // sp_size for size in self.cudagraph_capture_sizes] + +# # print(f"self.bs_to_padded_graph_size {self.bs_to_padded_graph_size}") + +# # import traceback +# # traceback.print_stack() + +class UlyssesVllmConfig(ArcticPatch[VllmConfig]): + + _orig_set_cudagraph_sizes = VllmConfig._set_cudagraph_sizes + + @staticmethod + def _generate_capture_sizes(max_size: int) -> list[int]: + sizes = [i for i in [1, 2, 4] if i <= max_size] + if max_size >= 8: + sizes += list(range(8, min(max_size + 1, 256), 8)) + if max_size >= 256: + sizes += list(range(256, min(max_size + 1, 512), 16)) + if max_size >= 512: + sizes += list(range(512, max_size + 1, 32)) + return sizes + + @staticmethod + def _build_bs_to_padded(capture_sizes: list[int], + max_capture_size: int) -> list[int]: + table = [0] * (max_capture_size + 1) + for end, start in zip( + capture_sizes + [max_capture_size + 1], + [0] + capture_sizes, + ): + for bs in range(start, end): + table[bs] = start if bs == start else end + return table + + def _set_cudagraph_sizes(self): + sp_size = _ulysses_sp_size + + max_cudagraph_capture_size = self.compilation_config.max_cudagraph_capture_size + cudagraph_capture_sizes = self.compilation_config.cudagraph_capture_sizes + + if cudagraph_capture_sizes is None: + if max_cudagraph_capture_size is None: + max_cudagraph_capture_size = 512 + # Canonical (unscaled) sizes: [1, 2, 4, 8, ..., 512] + canonical_sizes = self._generate_capture_sizes( + max_cudagraph_capture_size) + + # Base model (Ulysses): scale by sp_size + self.compilation_config.cudagraph_capture_sizes = [ + s * sp_size for s in canonical_sizes + ] + self.compilation_config.max_cudagraph_capture_size = ( + max_cudagraph_capture_size * sp_size + ) + + # Shift model: scale by 1 (use canonical sizes as-is) + shift_sizes = list(canonical_sizes) + shift_max = max_cudagraph_capture_size + else: + shift_sizes = list(cudagraph_capture_sizes) + shift_max = max(shift_sizes) if shift_sizes else 0 + self._shift_cudagraph_capture_sizes = shift_sizes + self._shift_max_cudagraph_capture_size = shift_max + self._shift_bs_to_padded_graph_size = self._build_bs_to_padded( + shift_sizes, shift_max) if shift_sizes else [] -class UlyssesFusedMoEPatch(ArcticPatch[FusedMoE]): + print( + f"UlyssesVllmConfig: base max_cudagraph_capture_size=" + f"{self.compilation_config.max_cudagraph_capture_size}, " + f"base sizes={self.compilation_config.cudagraph_capture_sizes}, " + f"shift max={shift_max}, shift sizes={shift_sizes}" + ) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - # directly call forward_impl to bypass custom opt - # custom opt prevents using the shift model - # we will expand this function to fuse SP with EP - return self.forward_impl(hidden_states, router_logits) + self._orig_set_cudagraph_sizes() + + def pad_for_cudagraph(self, batch_size: int) -> int: + from .model_runner import is_shift_parallel_mode + if is_shift_parallel_mode() and self._shift_bs_to_padded_graph_size: + return self._shift_bs_to_padded_graph_size[batch_size] + return self.compilation_config.bs_to_padded_graph_size[batch_size] + + +class UlyssesEngineCore(ArcticPatch[EngineCore]): + + iteration = 0 + + def step_with_batch_queue( + self, + ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]: + """Schedule and execute batches with the batch queue. + Note that if nothing to output in this step, None is returned. + + The execution flow is as follows: + 1. Try to schedule a new batch if the batch queue is not full. + If a new batch is scheduled, directly return an empty engine core + output. In other words, fulfilling the batch queue has a higher priority + than getting model outputs. + 2. If there is no new scheduled batch, meaning that the batch queue + is full or no other requests can be scheduled, we block until the first + batch in the job queue is finished. + 3. Update the scheduler from the output. + """ + batch_queue = self.batch_queue + assert batch_queue is not None + + # Try to schedule a new batch if the batch queue is not full, but + # the scheduler may return an empty batch if all requests are scheduled. + # Note that this is not blocking. + assert len(batch_queue) < self.batch_queue_size + + step_start_time = time.monotonic() + + model_executed = False + deferred_scheduler_output = None + if self.scheduler.has_requests(): + scheduler_output = self.scheduler.schedule() + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) + if not self.is_ec_producer: + model_executed = scheduler_output.total_num_scheduled_tokens > 0 + + if self.is_pooling_model or not model_executed: + # No sampling required (no requests scheduled). + future = cast(Future[ModelRunnerOutput], exec_future) + else: + if not scheduler_output.pending_structured_output_tokens: + # We aren't waiting for any tokens, get any grammar output + # and sample immediately. + grammar_output = self.scheduler.get_grammar_bitmask( + scheduler_output + ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) + else: + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + + if not deferred_scheduler_output: + # Add this step's future to the queue. + batch_queue.appendleft((future, scheduler_output, exec_future)) + if ( + model_executed + and len(batch_queue) < self.batch_queue_size + and not batch_queue[-1][0].done() + ): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True + + elif not batch_queue: + # Queue is empty. We should not reach here since this method should + # only be called when the scheduler contains requests or the queue + # is non-empty. + return None, False + + # Block until the next result is available. + future, scheduler_output, exec_model_fut = batch_queue.pop() + with ( + self.log_error_detail(scheduler_output), + self.log_iteration_details(scheduler_output), + ): + model_output = future.result() + if model_output is None: + # None from sample_tokens() implies that the original execute_model() + # call failed - raise that exception. + exec_model_fut.result() + raise RuntimeError("unexpected error") + + # Before processing the model output, process any aborts that happened + # during the model execution. + self._process_aborts_queue() + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) + + # NOTE(nick): We can either handle the deferred tasks here or save + # in a field and do it immediately once step_with_batch_queue is + # re-called. The latter slightly favors TTFT over TPOT/throughput. + if deferred_scheduler_output: + # If we are doing speculative decoding with structured output, + # we need to get the draft token ids from the prior step before + # we can compute the grammar bitmask for the deferred request. + if self.use_spec_decode: + draft_token_ids = self.model_executor.take_draft_token_ids() + assert draft_token_ids is not None + # Update the draft token ids in the scheduler output to + # filter out the invalid spec tokens, which will be padded + # with -1 and skipped by the grammar bitmask computation. + self.scheduler.update_draft_token_ids_in_output( + draft_token_ids, deferred_scheduler_output + ) + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens(grammar_output, non_block=True) + batch_queue.appendleft((future, deferred_scheduler_output, exec_future)) + + total_time_ms = (time.monotonic() - step_start_time) * 1000 + + running, waiting = self.scheduler.get_request_counts() + scheduled_tokens = scheduler_output.total_num_scheduled_tokens + concurrency = len(scheduler_output.num_scheduled_tokens.keys()) + # print(f"iteration {self.iteration}, running: {running}, waiting: {waiting}, scheduled tokens: {scheduled_tokens}, concurrency: {concurrency}, total_time_ms: {total_time_ms:.2f}") + self.iteration += 1 + + return engine_core_outputs, model_executed \ No newline at end of file diff --git a/benchmark/reproducibility/README.md b/benchmark/reproducibility/README.md new file mode 100644 index 000000000..eaa2bcc80 --- /dev/null +++ b/benchmark/reproducibility/README.md @@ -0,0 +1,147 @@ + +## Reproducibility + +This page is on reproducibility of the [Shift Parallelism](https://arxiv.org/pdf/2509.16495) paper. Please see the Artifact Appendix. + +## Instructions + +### Step 1: Making vLLM work + +1. Create a clean environment + +- If python 3.10 is already installed, create a clean virtual environment +```console +python3.10 -m venv myvenv +source myvenv/bin/activate +``` + +- If not, use conda to create an environment with python 3.10 +```console +conda create -n myenv python=3.10 +conda activate myenv +``` + +2. Install vLLM +```console +pip install vllm==v0.10.1 +``` + +3. Download the models +```console +huggingface-cli download RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic --local-dir RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic +huggingface-cli download Qwen/Qwen3-32B-FP8 --local-dir Qwen/Qwen3-32B-FP8 +``` + +### Step 2: Install ArcticInference +Checkout to the right commit for v0.10.1. +```console +git clone https://github.com/snowflakedb/ArcticInference.git +cd ArcticInference +git checkout d096fdf +pip install . +``` + +Once installed, Arctic Inference automatically patches vLLM to use Arctic Inference with Shift Parallelism, and users can continue to use their familiar vLLM APIs and CLI. + +### Step 3: Vibe test +`vibe_test.py` +```python +import vllm +from vllm import LLM, SamplingParams + +vllm.plugins.load_general_plugins() + +llm = LLM( + model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic", + ulysses_sequence_parallel_size=8, + enable_shift_parallel=True, + shift_parallel_threshold=8, +) + +conversation = [ + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] + +sampling_params = SamplingParams(temperature=0.0, max_tokens=800) + +outputs = llm.chat(conversation, sampling_params=sampling_params) + +print(outputs[0].outputs[0].text) +``` + +```console +ARCTIC_INFERENCE_ENABLED=1 VLLM_DISABLE_COMPILE_CACHE=1 python vibe_test.py +``` +We turn off compile cache for now to prevent complications. + +### Step 4: Patch vLLM bench + +This patch is necessary to run traces for running traces. + +Please replace your `vllm/bencmarks/serve.py` with `patches/serve.py` and `vllm/benchmarks/datasets.py` with `patches/datasets.py`. The vLLM path can be found by `pip show vllm`. + +### Step 5: Run Traces + +The traces are available: https://doi.org/10.5281/zenodo.18240909 + +```console +wget https://zenodo.org/records/18240909/files/AzureLLMInferenceTrace_code_15mins.jsonl +wget https://zenodo.org/records/18240909/files/conversation_trace_15mins.jsonl +``` + +There are two traces and three parallelisms in full reproduction. For each case, the server and the client is run in separate terminals. For example, + +1. Azure LLM code (Figure 9) + +server: +```console +ARCTIC_INFERENCE_ENABLED=1 VLLM_DISABLE_COMPILE_CACHE=1 vllm serve RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic \ + --disable-log-requests \ + --no-enable-prefix-caching \ + --ulysses-sequence-parallel-size 8 \ + --enable-shift-parallel \ + --max-num-batched-tokens 131072 +``` + +client: +```console +vllm bench serve --model RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic --trace-dataset-path AzureLLMInferenceTrace_code_15mins.jsonl --ignore-eos --trace-output-path code_output.csv +``` + +2. Mooncake conversation (Figure 10) + +For the Mooncake dataset, we first increase the context length with YaRN: https://huggingface.co/Qwen/Qwen3-32B-FP8#processing-long-texts + +server: +```console +ARCTIC_INFERENCE_ENABLED=1 VLLM_DISABLE_COMPILE_CACHE=1 vllm serve Qwen/Qwen3-32B-FP8 \ + --disable-log-requests \ + --no-enable-prefix-caching \ + --ulysses-sequence-parallel-size 8 \ + --enable-shift-parallel \ + --max-num-batched-tokens 131072 +``` + +client: +```console +vllm bench serve --model Qwen/Qwen3-32B-FP8 --trace-dataset-path conversation_trace_15mins.jsonl --ignore-eos --trace-output-path conversation_output.csv +``` + +Ready-to-use `server.sh` and `client.sh` are included. + +The key results involves two figures (Figure 9—10). The breakdown of reproduction times are given below: + +Figure 9: DP (15 mins), TP (15 mins), SP (15 mins), Shift Parallel (15 mins) +Figure 10: DP (~2.5 hrs), TP (~1 hr), SP (15 mins), Shift Parallel (15 mins) + +### Step 6: Plotting + +Install plotting library. +```console +pip install matplotlib +``` + +run `plot.py` with appropriate trace output path. diff --git a/benchmark/reproducibility/client.sh b/benchmark/reproducibility/client.sh new file mode 100644 index 000000000..7e54661b4 --- /dev/null +++ b/benchmark/reproducibility/client.sh @@ -0,0 +1,6 @@ + +model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic" +vllm bench serve --model $model --trace-dataset-path AzureLLMInferenceTrace_code_15mins.jsonl --ignore-eos --trace-output-path code_output.csv + +# model="Qwen/Qwen3-32B-FP8" +# vllm bench serve --model $model --trace-dataset-path conversation_trace_15mins.jsonl --ignore-eos --trace-output-path conversation_output.csv diff --git a/benchmark/reproducibility/patches_v0.10.1/datasets.py b/benchmark/reproducibility/patches_v0.10.1/datasets.py new file mode 100644 index 000000000..f5812f064 --- /dev/null +++ b/benchmark/reproducibility/patches_v0.10.1/datasets.py @@ -0,0 +1,1675 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena +""" +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from io import BytesIO +from typing import Any, Callable, Optional, Union + +import numpy as np +from PIL import Image +from transformers import PreTrainedTokenizerBase +from typing_extensions import deprecated + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer +from vllm.utils import PlaceholderModule + +try: + from datasets import load_dataset +except ImportError: + datasets = PlaceholderModule("datasets") + load_dataset = datasets.placeholder_attr("load_dataset") + +try: + import pandas as pd +except ImportError: + pd = PlaceholderModule("pandas") + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None + lora_request: Optional[LoRARequest] = None + trace_timestamp: int = 0 + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + IS_MULTIMODAL = False + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. + + Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. + max_loras (Optional[int]): The maximum number of LoRAs available. + If `None`, LoRA is not used. + lora_path (Optional[str]): Path to the LoRA parameters on disk. + If `None`, LoRA is not used. + + Returns: + A tuple with the following elements: + - A new [LoRARequest][] (or `None` if not applicable). + - The tokenizer associated with the LoRA request + (or the base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. + num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = convert_image_mode(image, "RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + trace_dataset_path = None, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + assert range_ratio < 1.0, ( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + num_special_tokens = tokenizer.num_special_tokens_to_add() + real_input_len = input_len - num_special_tokens + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + print(f"RandomDataset sample {trace_dataset_path}") + + events = [] + with open(trace_dataset_path, "r") as f: + for line in f: + if not line.strip(): + continue + obj = json.loads(line) + timestamp = obj["timestamp"] + input_length = obj["input_length"] + output_length = obj["output_length"] + print(f"read trace timestamp {timestamp} input_length {input_length} output_length {output_length}") + events.append((timestamp, input_length, output_length)) + # Ensure chronological order + events.sort(key=lambda x: x[0]) + + num_requests = len(events) + timestamps = [i[0] for i in events] + input_lens = [i[1] for i in events] + output_lens = [i[2] for i in events] + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + # After decoding the prompt we have to encode and decode it again. + # This is done because in some cases N consecutive tokens + # give a string tokenized into != N number of tokens. + # For example for GPT2Tokenizer: + # [6880, 6881] -> ['Ġcalls', 'here'] -> + # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + # To avoid uncontrolled change of the prompt length, + # the encoded sequence is truncated before being decode again. + total_input_len = prefix_len + int(input_lens[i]) + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + trace_timestamp=timestamps[i] + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + # TODO: Also support ShareGPT4Video. + if image_path := entry.get("image"): + mm_content = process_image(image_path) + else: + mm_content = None + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +def add_dataset_parser(parser: FlexibleArgumentParser): + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="random", + choices=[ + "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", + "prefix_repetition" + ], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Do not load the dataset in streaming mode.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) + + # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help= + "Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help= + "Skip applying chat template to prompt, used only for custom dataset.", + ) + + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.", + ) + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=("Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]."), + ) + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + + +def get_samples(args, tokenizer) -> list[SampleRequest]: + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + ) + + elif args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.endpoint_type == "openai-chat": + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) + + elif args.dataset_name == "hf": + # all following datasets are implemented from the + # HuggingFaceDataset base class + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_class = VisionArenaDataset + args.hf_split = "train" + args.hf_subset = None + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_class = InstructCoderDataset + args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ConversationDataset + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_class = AIMODataset + args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" + elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MLPerfDataset + args.hf_split = "train" + else: + supported_datasets = set([ + dataset_name for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ]) + raise ValueError( + f"Unsupported dataset path: {args.dataset_path}. " + "Huggingface dataset only supports dataset_path" + f" from one of following: {supported_datasets}. " + "Please consider contributing if you would " + "like to add support for additional dataset formats.") + + if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [ + "openai-chat", + "openai-audio", + ]: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend.") + input_requests = dataset_class( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + random_seed=args.seed, + no_stream=args.no_stream, + ).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.hf_output_len, + ) + + else: + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + trace_dataset_path=args.trace_dataset_path, + ), + "prefix_repetition": + lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.prefix_repetition_prefix_len, + suffix_len=args.prefix_repetition_suffix_len, + num_prefixes=args.prefix_repetition_num_prefixes, + output_len=args.prefix_repetition_output_len, + ), + } + + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + + return input_requests + + +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset.") + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + +@deprecated( + "SonnetDataset is deprecated and will be removed in a future version.", +) +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + while len(samples) < num_requests: + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Base Implementation +# ----------------------------------------------------------------------------- +class HuggingFaceDataset(BenchmarkDataset): + """Base class for datasets hosted on HuggingFace.""" + + SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + + def __init__( + self, + dataset_path: str, + dataset_split: str, + no_stream: bool = False, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(dataset_path=dataset_path, **kwargs) + + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + self.load_stream = not no_stream + self.load_data() + + def load_data(self) -> None: + """Load data from HuggingFace datasets.""" + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=self.load_stream, + ) + self.data = self.data.shuffle(seed=self.random_seed) + + +# ----------------------------------------------------------------------------- +# Conversation Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ConversationDataset(HuggingFaceDataset): + """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { + 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + } + IS_MULTIMODAL = True + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "lmarena-ai/VisionArena-Chat": + lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": + lambda x: x["turns"][0][0]["content"] + } + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError( + f"Unsupported dataset path: {self.dataset_path}") + prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Instruct Coder Dataset Implementation +# ----------------------------------------------------------------------------- + + +class InstructCoderDataset(HuggingFaceDataset): + """ + InstructCoder Dataset. + https://huggingface.co/datasets/likaixin/InstructCoder + + InstructCoder is the dataset designed for general code editing. It consists + of 114,239 instruction-input-output triplets, and covers multiple distinct + code editing scenario. + """ + + DEFAULT_OUTPUT_LEN = 200 # this is the average default output length + SUPPORTED_DATASET_PATHS = { + "likaixin/InstructCoder", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = f"{item['input']}\n\n{item['instruction']} Just output \ + the code, do not include any explanation." + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0] + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>") + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# MLPerf Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MLPerfDataset(HuggingFaceDataset): + """ + MLPerf Inference Dataset. + + Dataset on HF: + https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data + https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data + + Each record contains: + - "system_prompt": system role instruction. + - "question": user question. + - "output": reference answer. + + We combine the system prompt and question into a chat-formatted prompt + (using the tokenizer's chat template) and set the expected output length to + the tokenized length of the provided reference answer. + """ + + SUPPORTED_DATASET_PATHS = { + "mgoin/mlperf-inference-llama2-data", + "mgoin/mlperf-inference-llama3.1-data", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list[SampleRequest]: + # Force dynamic output length based on reference completion. + dynamic_output = output_len is None + sampled_requests: list[SampleRequest] = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + + system_prompt = item["system_prompt"] + question = item["question"] + reference_answer = item["output"] + + # Build chat-style prompt using tokenizer template, if available. + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + prompt_formatted = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + + # Determine output length from reference answer tokens. + ref_out_len = len( + tokenizer(reference_answer, add_special_tokens=False).input_ids + ) + expected_output_len = ref_out_len if dynamic_output else output_len + + # Validate sequence lengths. + if not is_valid_sequence(prompt_len, expected_output_len): + continue + + sampled_requests.append( + SampleRequest( + prompt=prompt_formatted, + prompt_len=prompt_len, + expected_output_len=expected_output_len, + ) + ) + + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Prefix Repetition Dataset Implementation +# ----------------------------------------------------------------------------- + + +class PrefixRepetitionRandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the repeated prefix + # dataset. + DEFAULT_PREFIX_LEN = 256 + DEFAULT_SUFFIX_LEN = 256 + DEFAULT_NUM_PREFIXES = 10 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + suffix_len: int = DEFAULT_SUFFIX_LEN, + num_prefixes: int = DEFAULT_NUM_PREFIXES, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + prompts_per_prefix = num_requests // num_prefixes + if prompts_per_prefix == 0: + raise ValueError( + f"num_requests ({num_requests}) must be greater than or equal " + f"to num_prefixes ({num_prefixes})" + ) + + def _generate_exact_length_tokens(target_length: int) -> list[int]: + """Generate tokens that decode and re-encode to exactly + target_length.""" + # Generate random tokens + tokens = np.random.randint( + 0, vocab_size, size=target_length).tolist() + text = tokenizer.decode(tokens) + re_encoded = tokenizer.encode(text, add_special_tokens=False) + + if len(re_encoded) == target_length: + return re_encoded + elif len(re_encoded) < target_length: + # Recursively generate additional consistent tokens + needed = target_length - len(re_encoded) + extra_tokens = _generate_exact_length_tokens(needed) + return re_encoded + extra_tokens + else: + # Truncate to target length + return re_encoded[:target_length] + + requests = [] + for _ in range(num_prefixes): + prefix_tokens = _generate_exact_length_tokens(prefix_len) + + for _ in range(prompts_per_prefix): + suffix_tokens = _generate_exact_length_tokens(suffix_len) + + combined_tokens = prefix_tokens + suffix_tokens + prompt = tokenizer.decode(combined_tokens) + prompt_len = len(combined_tokens) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + + random.shuffle(requests) + return requests diff --git a/benchmark/reproducibility/patches_v0.10.1/serve.py b/benchmark/reproducibility/patches_v0.10.1/serve.py new file mode 100644 index 000000000..928dd27b0 --- /dev/null +++ b/benchmark/reproducibility/patches_v0.10.1/serve.py @@ -0,0 +1,1199 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +r"""Benchmark online serving throughput. + +On the server side, run one of the following commands +to launch the vLLM OpenAI API server: + vllm serve + +On the client side, run: + vllm bench serve \ + --endpoint-type \ + --label \ + --model \ + --dataset-name \ + --request-rate \ + --num-prompts +""" +import argparse +import asyncio +import gc +import json +import os +import random +import time +import warnings +from collections.abc import AsyncGenerator, Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Literal, Optional + +import aiohttp +import numpy as np +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser, + get_samples) +from vllm.benchmarks.lib.endpoint_request_func import ( + ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, + RequestFuncOutput) +from vllm.benchmarks.lib.ready_checker import wait_for_endpoint +from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.transformers_utils.tokenizer import get_tokenizer + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +def _get_current_request_rate( + ramp_up_strategy: Optional[Literal["linear", "exponential"]], + ramp_up_start_rps: Optional[int], + ramp_up_end_rps: Optional[int], + request_index: int, + total_requests: int, + request_rate: float, +) -> float: + if (ramp_up_strategy and ramp_up_start_rps is not None + and ramp_up_end_rps is not None): + progress = request_index / max(total_requests - 1, 1) + if ramp_up_strategy == "linear": + increase = (ramp_up_end_rps - ramp_up_start_rps) * progress + return ramp_up_start_rps + increase + elif ramp_up_strategy == "exponential": + ratio = ramp_up_end_rps / ramp_up_start_rps + return ramp_up_start_rps * (ratio**progress) + else: + raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") + return request_rate + + +async def get_request_trace(input_requests: list[SampleRequest]) -> AsyncGenerator[tuple[SampleRequest, float], None]: + + print(f"get_request_trace") + + prev_ts = 0 + for i in range(len(input_requests)): + request = input_requests[i] + curr_ts = request.trace_timestamp + delay = curr_ts - prev_ts + print(f"request {i} prompt_len {request.prompt_len} expected_output_len {request.expected_output_len} timestamp {request.trace_timestamp} delay {delay} ms") + if delay > 0: + await asyncio.sleep(delay/1000) + prev_ts = curr_ts + yield request, 1 + + +async def get_request( + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float = 1.0, + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, +) -> AsyncGenerator[tuple[SampleRequest, float], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness and OPTIONAL ramp-up strategy. + + Args: + input_requests: + A list of input requests, each represented as a SampleRequest. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + ramp_up_strategy (optional): + The ramp-up strategy. Can be "linear" or "exponential". + If None, uses constant request rate (specified by request_rate). + ramp_up_start_rps (optional): + The starting request rate for ramp-up. + ramp_up_end_rps (optional): + The ending request rate for ramp-up. + """ + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + # Convert to list to get length for ramp-up calculations + if isinstance(input_requests, Iterable) and not isinstance( + input_requests, list): + input_requests = list(input_requests) + + total_requests = len(input_requests) + assert total_requests > 0, "No requests provided." + + # Precompute delays among requests to minimize request send laggings + request_rates = [] + delay_ts = [] + for request_index, request in enumerate(input_requests): + current_request_rate = _get_current_request_rate(ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) + request_rates.append(current_request_rate) + if current_request_rate == float("inf"): + delay_ts.append(0) + else: + theta = 1.0 / (current_request_rate * burstiness) + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) + + # Calculate the cumulative delay time from the first sent out requests. + for i in range(1, len(delay_ts)): + delay_ts[i] += delay_ts[i - 1] + if ramp_up_strategy is None and delay_ts[-1] != 0: + # When ramp_up_strategy is not set, we assume the request rate is fixed + # and all requests should be sent in target_total_delay_s, the following + # logic would re-scale delay time to ensure the final delay_ts + # align with target_total_delay_s. + # + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap + # from target_total_delay_s. The purpose of the following logic is to + # close the gap for stablizing the throughput data + # from different random seeds. + target_total_delay_s = total_requests / request_rate + normalize_factor = target_total_delay_s / delay_ts[-1] + delay_ts = [delay * normalize_factor for delay in delay_ts] + + start_ts = time.time() + for request_index, request in enumerate(input_requests): + if delay_ts[request_index] > 0: + current_ts = time.time() + sleep_interval_s = start_ts + delay_ts[request_index] - current_ts + if sleep_interval_s > 0: + await asyncio.sleep(sleep_interval_s) + yield request, request_rates[request_index] + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], + trace_output_path: str | None = None, +) -> tuple[BenchmarkMetrics, list[int]]: + """Calculate the metrics for the benchmark. + + Args: + input_requests: The input requests. + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + tokenizer: The tokenizer to use. + selected_percentiles: The percentiles to select. + goodput_config_dict: The goodput configuration. + + Returns: + A tuple of the benchmark metrics and the actual output lengths. + """ + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if not output_len: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if trace_output_path is not None: + import csv + header = ["request", "timestamp", "input_len", "output_len", "ttft_ms", "tpot_ms", "e2el_ms"] + with open(trace_output_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(header) + for i in range(len(outputs)): + row = [ + i, + input_requests[i].trace_timestamp, + input_requests[i].prompt_len, + actual_output_lens[i], + round(ttfts[i] * 1000, 2), + round(all_tpots[i] * 1000, 2), + round(e2els[i] * 1000, 2) + ] + writer.writerow(row) + print(f"Metrics saved to {trace_output_path}") + else: + print(f"requsest, timestamp, input_len, output_len, ttft_ms, tpot_ms, e2el_ms") + for i in range(len(outputs)): + print(f"{i} {input_requests[i].trace_timestamp} {input_requests[i].prompt_len} {actual_output_lens[i]} {ttfts[i]*1000:.2f} {all_tpots[i]*1000:.2f} {e2els[i]*1000:.2f}") + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by the endpoint + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + endpoint_type: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + logprobs: Optional[int], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[Iterable[str]], + extra_body: Optional[dict], + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, + ready_check_timeout_sec: int = 600, + trace_dataset_path: str | None = None, + trace_output_path: str | None = None, +): + if endpoint_type in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + else: + raise ValueError(f"Unknown endpoint_type: {endpoint_type}") + + # Reuses connections across requests to reduce TLS handshake overhead. + connector = aiohttp.TCPConnector( + limit=max_concurrency or 0, + limit_per_host=max_concurrency or 0, + ttl_dns_cache=300, + use_dns_cache=True, + keepalive_timeout=60, + enable_cleanup_closed=True, + force_close=False, + ssl=("https://" in api_url), + ) + + session = aiohttp.ClientSession( + connector=connector, + trust_env=True, + timeout=aiohttp.ClientTimeout(total=6 * 60 * 60), + ) + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) + + assert ( + test_mm_content is None + or isinstance(test_mm_content, dict) + or ( + isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content) + ) + ), "multi_modal_data must be a dict or list[dict]" + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + + test_output = await wait_for_endpoint( + request_func, + test_input, + session, + timeout_seconds=ready_check_timeout_sec, + ) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body) + profile_output = await request_func( + request_func_input=profile_input, session=session) + if profile_output.success: + print("Profiler started") + + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") + + if ramp_up_strategy is not None: + print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") + print(f"Will increase RPS from {ramp_up_start_rps} to " + f"{ramp_up_end_rps} RPS over the duration of the benchmark.") + else: + print(f"Traffic request rate: {request_rate}") + + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, session, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + session=session, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + session=session, + pbar=pbar) + + if trace_dataset_path is not None: + iterator = get_request_trace(input_requests) + else: + iterator = get_request( + input_requests, + request_rate, + burstiness, + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + ) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append({ + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + }) + + async for request, current_request_rate in iterator: + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({ + "rps": rps_val, + "timestamp": timestamp + }) + last_int_rps = current_int_rps + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + session=session, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + trace_output_path=trace_output_path, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + if max_concurrency is not None: + print("{:<40} {:<10}".format("Maximum request concurrency:", + max_concurrency)) + if request_rate != float('inf'): + print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", + request_rate )) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + + if rps_change_events: + result["rps_change_events"] = rps_change_events + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func( + request_func_input=profile_input, session=session) + if profile_output.success: + print("Profiler stopped") + + await session.close() + return result + + +def check_goodput_args(args): + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any], + file_name: str) -> None: + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics if k in results}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + add_dataset_parser(parser) + parser.add_argument( + "--endpoint-type", + type=str, + default="openai", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--label", + type=str, + default=None, + help="The label (prefix) of the benchmark results. If not specified, " + "the endpoint type will be used as the label.", + ) + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" # noqa + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-separated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-separated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\"." + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).", + ) + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + + parser.add_argument( + "--ramp-up-strategy", + type=str, + default=None, + choices=["linear", "exponential"], + help="The ramp-up strategy. This would be used to " + "ramp up the request rate from initial RPS to final " + "RPS rate (specified by --ramp-up-start-rps and " + "--ramp-up-end-rps.) over the duration of the benchmark." + ) + parser.add_argument( + "--ramp-up-start-rps", + type=int, + default=None, + help="The starting request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ramp-up-end-rps", + type=int, + default=None, + help="The ending request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ready-check-timeout-sec", + type=int, + default=600, + help="Maximum time to wait for the endpoint to become ready " + "in seconds (default: 600 seconds / 10 minutes).", + ) + parser.add_argument("--trace-dataset-path", type=str, default=None) + parser.add_argument("--trace-output-path", type=str, default=None) + + +def main(args: argparse.Namespace) -> dict[str, Any]: + return asyncio.run(main_async(args)) + +async def main_async(args: argparse.Namespace) -> dict[str, Any]: + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + # Validate ramp-up arguments + if args.ramp_up_strategy is not None: + if args.request_rate != float("inf"): + raise ValueError( + "When using ramp-up, do not specify --request-rate. " + "The request rate will be controlled by ramp-up parameters. " + "Please remove the --request-rate argument." + ) + if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: + raise ValueError( + "When using --ramp-up-strategy, both --ramp-up-start-rps and " + "--ramp-up-end-rps must be specified" + ) + if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: + raise ValueError("Ramp-up start and end RPS must be non-negative") + if args.ramp_up_start_rps > args.ramp_up_end_rps: + raise ValueError("Ramp-up start RPS must be less than end RPS") + if (args.ramp_up_strategy == "exponential" + and args.ramp_up_start_rps == 0): + raise ValueError( + "For exponential ramp-up, the start RPS cannot be 0.") + + endpoint_type = args.endpoint_type + label = args.label + model_id = args.model + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code) + + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") + + # Load the dataset. + input_requests = get_samples(args, tokenizer) + goodput_config_dict = check_goodput_args(args) + + # Collect the sampling parameters. + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature, + }.items() if v is not None + } + + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError("Sampling parameters are only supported by " + "openai-compatible backends.") + + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + + benchmark_result = await benchmark( + endpoint_type=args.endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + trace_dataset_path=args.trace_dataset_path, + trace_output_path=args.trace_output_path, + ) + + # Save config and results to json + result_json: dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = args.endpoint_type + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts + + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) + + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + if field in benchmark_result: + del benchmark_result[field] + + # Save to file + if args.save_result or args.append_result: + base_model_id = model_id.split("/")[-1] + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + label = label or endpoint_type + if args.ramp_up_strategy is not None: + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) + file_name = os.path.join(args.result_dir, file_name) + with open(file_name, + mode="a+" if args.append_result else "w", + encoding="utf-8") as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") + json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) + + return result_json diff --git a/benchmark/reproducibility/plot.py b/benchmark/reproducibility/plot.py new file mode 100644 index 000000000..8e286b71b --- /dev/null +++ b/benchmark/reproducibility/plot.py @@ -0,0 +1,39 @@ +import pandas as pd +import matplotlib.pyplot as plt + +# modify the path to your CSV file as needed +df = pd.read_csv("code_output.csv") + +request = df["request"] +timestamp = df["timestamp"] +input_length = df["input_len"] +output_length = df["output_len"] +ttft = df["ttft_ms"] +tpot = df["tpot_ms"] +latency = df["e2el_ms"] + +# TTFT +plt.figure() +plt.plot(ttft) +plt.xlabel("Request Id") +plt.ylabel("Milliseconds") +plt.title("TTFT") +plt.savefig("ttft.png") + +# TPOT +plt.figure() +plt.plot(tpot, 'o') +plt.yscale('log') +plt.xlabel("Request Id") +plt.ylabel("Milliseconds") +plt.title("TPOT") +plt.savefig("tpot.png") + +# Latency +plt.figure() +plt.plot(latency, 'o') +plt.xlabel("Request Id") +plt.ylabel("Milliseconds") +plt.title("Completion Time") +plt.savefig("latency.png") + diff --git a/benchmark/reproducibility/read_trace.py b/benchmark/reproducibility/read_trace.py new file mode 100644 index 000000000..134809b59 --- /dev/null +++ b/benchmark/reproducibility/read_trace.py @@ -0,0 +1,185 @@ +import json +import csv +from datetime import datetime +from collections import defaultdict + +def parse_timestamp(timestamp_str): + """Parse timestamp string and return datetime object.""" + try: + # Truncate microseconds to 6 digits if longer + if '.' in timestamp_str: + parts = timestamp_str.split('.') + if len(parts[1]) > 6: + timestamp_str = parts[0] + '.' + parts[1][:6] + return datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S.%f") + except ValueError: + # Try without microseconds if the format doesn't match + return datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S") + +def convert_to_relative_ms(timestamps): + """Convert list of timestamp strings to milliseconds relative to first entry.""" + if not timestamps: + return [] + + first_time = parse_timestamp(timestamps[0]) + relative_ms = [] + + for ts in timestamps: + current_time = parse_timestamp(ts) + delta = (current_time - first_time).total_seconds() * 1000 + relative_ms.append(delta) + + return relative_ms + +def read_csv(file_path): + data = [] + timestamps = [] + with open(file_path, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + timestamps.append(row.get("TIMESTAMP")) + entry = { + "timestamp": row.get("TIMESTAMP"), + "input_length": int(row.get("ContextTokens")), + "output_length": int(row.get("GeneratedTokens")) + } + data.append(entry) + + # Convert timestamps to relative milliseconds + relative_ms = convert_to_relative_ms(timestamps) + for i, entry in enumerate(data): + entry["timestamp"] = relative_ms[i] + + return data + +def read_jsonl(file_path): + data = [] + with open(file_path, 'r') as f: + for line in f: + obj = json.loads(line) + # Extract required columns + entry = { + "timestamp": obj.get("timestamp"), + "input_length": obj.get("input_length"), + "output_length": obj.get("output_length") + } + data.append(entry) + return data + +# Example usage: +# file_path = "synthetic_trace.jsonl" +# file_path = "toolagent_trace.jsonl" +file_path = "conversation_trace.jsonl" +records = read_jsonl(file_path) + +# file_path = "AzureLLMInferenceTrace_conv.csv" +# file_path = "AzureLLMInferenceTrace_code.csv" +# records = read_csv(file_path) + +# file_path = "AzureLLMInferenceTrace_code_1min_section.jsonl" +# records = read_jsonl(file_path) + + +# threshold_low = 2 * 60 * 1000 +# threshold_high = 17 * 60 * 1000 +threshold_low = 0 * 60 * 1000 +threshold_high = 15 * 60 * 1000 +records = [record for record in records if threshold_low <= record["timestamp"] <= threshold_high] + +# print(f"Filtered records between {threshold_low} ms and {threshold_high} ms: {len(records)} entries.") + +num_prompts = len(records) +timestamp = [int(record["timestamp"]) for record in records] +timestamp = [ts - timestamp[0] for ts in timestamp] # Normalize to start from 0 ms +input_length = [record["input_length"] for record in records] +output_length = [record["output_length"] for record in records] + + +# Calculate sum of input_length per batch (prompts with same timestamp) + +batch_sums = defaultdict(int) +batch_counts = defaultdict(int) + +for req in range(num_prompts): + ts = timestamp[req] + batch_sums[ts] += input_length[req] + batch_counts[ts] += 1 + +# Print batch statistics +print(f"\nBatch statistics (prompts with same timestamp):") +print(f"Number of unique timestamps (batches): {len(batch_sums)}") +for ts in sorted(batch_sums.keys()): + print(f"Timestamp {ts} ms: {batch_counts[ts]} prompts, total input_length = {batch_sums[ts]}") +print(f"max input_length in a batch: {max(batch_sums.values())}\n") + +# # Write to jsonl file +# output_file = file_path.replace('.jsonl', '_processed_15mins.jsonl').replace('.csv', '_processed_15mins.jsonl') +# with open(output_file, 'w') as f: +# for req in range(num_prompts): +# entry = { +# "timestamp": timestamp[req], +# "input_length": input_length[req], +# "output_length": output_length[req] +# } +# f.write(json.dumps(entry) + '\n') +print(f"timestamp input_length output_length") +for req in range(num_prompts): + # print(f"Record {req}: Timestamp={timestamp[req]}, Input Length={input_length[req]}, Output Length={output_length[req]}") + print(f"{timestamp[req]} {input_length[req]} {output_length[req]}") + +import matplotlib.pyplot as plt + +window_size = 80000 # in ms +request_rate = [] +time_axis = [start_time/1000 for start_time in range(0, max(timestamp)+window_size, window_size)] +for start_time in range(0, max(timestamp)+window_size, window_size): + new_requests = [sample for sample in timestamp if start_time <= sample < start_time+window_size] + request_rate.append(len(new_requests)/(window_size/1000)) + +plt.figure(figsize=(10, 6)) +plt.plot(time_axis, request_rate, linestyle='-') +plt.xlabel('Time (s)') +plt.ylabel('Request Rate (requests/second)') +plt.title(f'Request Rate over Time {file_path}') +plt.grid(True) +plt.tight_layout() +plt.show() + +# plt.figure(figsize=(10, 6)) +# plt.plot(range(num_prompts), input_length, linestyle='-') +# plt.xlabel('Request Index') +# plt.ylabel('Input Length') +# plt.title(f'Input Length over Requests {file_path}') +# plt.grid(True) +# plt.tight_layout() +# plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(range(num_prompts), output_length, linestyle='-') +plt.xlabel('Request Index') +plt.ylabel('Output Length') +plt.title(f'Output Length over Requests {file_path}') +plt.grid(True) +plt.tight_layout() +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(range(num_prompts), timestamp, linestyle='-') +plt.xlabel('Request Index') +plt.ylabel('Timestamp') +plt.title(f'Request Timestamps {file_path}') +plt.grid(True) +plt.tight_layout() +plt.show() + +plt.figure(figsize=(10, 6)) +plt.scatter(output_length, input_length, alpha=0.5) +plt.xlabel('Output Length') +plt.ylabel('Input Length') +plt.yscale('log') +plt.title(f'Input Length vs Output Length {file_path}') +plt.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +print(f"Total number of prompts: {num_prompts} in {file_path}") \ No newline at end of file diff --git a/benchmark/reproducibility/server.sh b/benchmark/reproducibility/server.sh new file mode 100644 index 000000000..ab5be3048 --- /dev/null +++ b/benchmark/reproducibility/server.sh @@ -0,0 +1,14 @@ + +export VLLM_DISABLE_COMPILE_CACHE=1 + +# two models, three parallelisms, six experiments + +model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic" +ARCTIC_INFERENCE_ENABLED=1 vllm serve $model --disable-log-requests --no-enable-prefix-caching --ulysses-sequence-parallel-size 8 --enable-shift-parallel --max-num-batched-tokens 131072 +# vllm serve $model --disable-log-requests --no-enable-prefix-caching --tensor-parallel-size 8 --max-num-batched-tokens 131072 +# vllm serve $model --disable-log-requests --no-enable-prefix-caching --data-parallel-size 8 --max-num-batched-tokens 131072 + +# model="Qwen/Qwen3-32B-FP8" +# ARCTIC_INFERENCE_ENABLED=1 vllm serve $model --disable-log-requests --no-enable-prefix-caching --ulysses-sequence-parallel-size 8 --enable-shift-parallel --max-num-batched-tokens 131072 --kv-cache-dtype fp8 +# vllm serve $model --disable-log-requests --no-enable-prefix-caching --tensor-parallel-size 8 --max-num-batched-tokens 131072 --kv-cache-dtype fp8 +# vllm serve $model --disable-log-requests --no-enable-prefix-caching --data-parallel-size 8 --max-num-batched-tokens 131072 --kv-cache-dtype fp8 diff --git a/benchmark/reproducibility/test.jsonl b/benchmark/reproducibility/test.jsonl new file mode 100644 index 000000000..e5566daf6 --- /dev/null +++ b/benchmark/reproducibility/test.jsonl @@ -0,0 +1,26 @@ +{"timestamp": 0, "input_length": 6758, "output_length": 500} +{"timestamp": 0, "input_length": 7322, "output_length": 490} +{"timestamp": 0, "input_length": 7236, "output_length": 794} +{"timestamp": 0, "input_length": 2290, "output_length": 316} +{"timestamp": 0, "input_length": 6760, "output_length": 3} +{"timestamp": 0, "input_length": 4834, "output_length": 173} +{"timestamp": 0, "input_length": 23141, "output_length": 453} +{"timestamp": 0, "input_length": 26888, "output_length": 458} +{"timestamp": 0, "input_length": 10498, "output_length": 402} +{"timestamp": 0, "input_length": 17450, "output_length": 610} +{"timestamp": 3000, "input_length": 13544, "output_length": 71} +{"timestamp": 3000, "input_length": 87169, "output_length": 402} +{"timestamp": 3000, "input_length": 6324, "output_length": 548} +{"timestamp": 3000, "input_length": 2012, "output_length": 354} +{"timestamp": 3000, "input_length": 7324, "output_length": 14} +{"timestamp": 3000, "input_length": 9418, "output_length": 145} +{"timestamp": 3000, "input_length": 915, "output_length": 355} +{"timestamp": 3000, "input_length": 12846, "output_length": 466} +{"timestamp": 3000, "input_length": 20506, "output_length": 929} +{"timestamp": 3000, "input_length": 16609, "output_length": 349} +{"timestamp": 3000, "input_length": 26353, "output_length": 370} +{"timestamp": 3000, "input_length": 6059, "output_length": 475} +{"timestamp": 3000, "input_length": 5954, "output_length": 420} +{"timestamp": 3000, "input_length": 11339, "output_length": 848} +{"timestamp": 3000, "input_length": 15172, "output_length": 80} +{"timestamp": 3000, "input_length": 45922, "output_length": 265} diff --git a/benchmark/reproducibility/vibe_test.py b/benchmark/reproducibility/vibe_test.py new file mode 100644 index 000000000..46a15d459 --- /dev/null +++ b/benchmark/reproducibility/vibe_test.py @@ -0,0 +1,25 @@ +import vllm +from vllm import LLM, SamplingParams + +vllm.plugins.load_general_plugins() + +llm = LLM( + model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic", + ulysses_sequence_parallel_size=8, + enable_shift_parallel=True, + shift_parallel_threshold=8, +) + +conversation = [ + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] + +sampling_params = SamplingParams(temperature=0.0, max_tokens=800) + +outputs = llm.chat(conversation, sampling_params=sampling_params) + +print(outputs[0].outputs[0].text) + diff --git a/benchmark/rollout/README.md b/benchmark/rollout/README.md new file mode 100644 index 000000000..96fc91d47 --- /dev/null +++ b/benchmark/rollout/README.md @@ -0,0 +1,51 @@ +## Rollout Replay Patch for v0.14.1 + +This patch extencds SamplingParams to specify the length of each sequence when n > 1. + +The patch is applied as `source patch_sampling.sh`. + +As a result, you can specify `max_tokens_n` as a list in sampling params and set `ignore_eos` so that each sequence generates exactly specified number of tokens. + +``` +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + # "The future of AI is", +] + +sampling_params = [SamplingParams(n=2, + temperature=0.8, + top_p=1.0, + max_tokens_n=[25, 50], + ignore_eos=True, + ), + SamplingParams(n=3, + temperature=0.8, + top_p=1.0, + max_tokens_n=[5, 10, 15], + ignore_eos=True, + ), + SamplingParams(n=1, + temperature=0.8, + top_p=1.0, + max_tokens=100, + ignore_eos=True, + # max_tokens_n=[100], this will be ineffective since n = 1 + ), + ] + +outputs = llm.generate(prompts, sampling_params=sampling_params) +``` + +The number of resulting input and output tokens per sequence: +``` +prompt 0 seq 0: input 5 output 25 +prompt 0 seq 1: input 5 output 50 +prompt 1 seq 0: input 7 output 5 +prompt 1 seq 1: input 7 output 10 +prompt 1 seq 2: input 7 output 15 +prompt 2 seq 0: input 5 output 100 +``` + diff --git a/benchmark/rollout/parallel_sampling.patch b/benchmark/rollout/parallel_sampling.patch new file mode 100644 index 000000000..8c7bbb7a5 --- /dev/null +++ b/benchmark/rollout/parallel_sampling.patch @@ -0,0 +1,18 @@ +--- /home/yak/myenv_v14/lib/python3.12/site-packages/vllm/v1/engine/parallel_sampling.py 2026-01-31 22:11:45.340329657 +0000 ++++ parallel_sampling.py 2026-01-31 22:08:27.000000000 +0000 +@@ -66,11 +66,13 @@ + Child `sampling_params` instance. + """ + seed = self.sampling_params.seed +- if self.cached_child_sampling_params: ++ # if self.cached_child_sampling_params: + # Reuse child sampling_params data structure +- return self.cached_child_sampling_params ++ # return self.cached_child_sampling_params + # Build child sampling_params + child_sampling_params = copy(self.sampling_params) ++ print(f"child index: {index}, max_tokens_n: {child_sampling_params.max_tokens_n}") ++ child_sampling_params.max_tokens = child_sampling_params.max_tokens_n[index] + child_sampling_params.n = 1 + if seed is None: + # Cache child sampling_params for later reuse diff --git a/benchmark/rollout/patch_sampling.sh b/benchmark/rollout/patch_sampling.sh new file mode 100755 index 000000000..531bed562 --- /dev/null +++ b/benchmark/rollout/patch_sampling.sh @@ -0,0 +1,12 @@ + +VLLM_PATH="$(pip show vllm | awk '/^Location: /{print $2}')" + +if [ -z "$VLLM_PATH" ]; then + echo "Error: could not find VLLM in current env" + exit 1 +else + echo "VLLM path is: $VLLM_PATH" +fi + +patch $VLLM_PATH/vllm/sampling_params.py < sampling_params.patch +patch $VLLM_PATH/vllm/v1/engine/parallel_sampling.py < parallel_sampling.patch diff --git a/benchmark/rollout/sampling_params.patch b/benchmark/rollout/sampling_params.patch new file mode 100644 index 000000000..7647b8b6d --- /dev/null +++ b/benchmark/rollout/sampling_params.patch @@ -0,0 +1,26 @@ +--- /home/yak/myenv_v14/lib/python3.12/site-packages/vllm/sampling_params.py 2026-01-31 22:11:45.180329435 +0000 ++++ sampling_params.py 2026-01-31 22:16:23.000000000 +0000 +@@ -124,6 +124,7 @@ + """ + + n: int = 1 ++ max_tokens_n: list[int] | None = None + """Number of outputs to return for the given prompt request. + + NOTE: +@@ -250,6 +251,7 @@ + @staticmethod + def from_optional( + n: int | None = 1, ++ max_tokens_n: list[int] | None = None, + presence_penalty: float | None = 0.0, + frequency_penalty: float | None = 0.0, + repetition_penalty: float | None = 1.0, +@@ -289,6 +291,7 @@ + + return SamplingParams( + n=1 if n is None else n, ++ max_tokens_n=max_tokens_n, + presence_penalty=0.0 if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty, + repetition_penalty=1.0 diff --git a/benchmark/trace/README.md b/benchmark/trace/README.md new file mode 100644 index 000000000..29b0b9c73 --- /dev/null +++ b/benchmark/trace/README.md @@ -0,0 +1,18 @@ +The patch allows running datasets with the following json format: +```jsonl +{"timestamp": 15, "input_length": 1000, "output_length": 128} +... +``` +where each line corresponds to a request with random prompt to be sent at the timestamp (in ms) from t = 0. + +The patch is applied as +``` +bash apply_vllm_bench_patch_v10p1.sh +``` +This will modify a couple of files of `vllm bench serve`. Please make sure vllm versions match. + +Then use serve as +``` +vllm bench serve --model $model --trace-dataset-path example_trace.jsonl --ignore-eos +``` + diff --git a/benchmark/trace/apply_vllm_bench_patch_v10p1.sh b/benchmark/trace/apply_vllm_bench_patch_v10p1.sh new file mode 100644 index 000000000..397135015 --- /dev/null +++ b/benchmark/trace/apply_vllm_bench_patch_v10p1.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash +set -euo pipefail + +SERVE_PY="/home/yak/myvenv/lib/python3.10/site-packages/vllm/benchmarks/serve.py" +DATASETS_PY="/home/yak/myvenv/lib/python3.10/site-packages/vllm/benchmarks/datasets.py" + +patch "$DATASETS_PY" << 'EOF' +79d78 +< trace_timestamp: int = 0 +338d336 +< trace_dataset_path = None, +353,373c351,367 +< print(f"RandomDataset sample {trace_dataset_path}") +< +< events = [] +< with open(trace_dataset_path, "r") as f: +< for line in f: +< if not line.strip(): +< continue +< obj = json.loads(line) +< timestamp = obj["timestamp"] +< input_length = obj["input_length"] +< output_length = obj["output_length"] +< print(f"read trace timestamp {timestamp} input_length {input_length} output_length {output_length}") +< events.append((timestamp, input_length, output_length)) +< # Ensure chronological order +< events.sort(key=lambda x: x[0]) +< print(f"events {events}") +< +< num_requests = len(events) +< timestamps = [i[0] for i in events] +< input_lens = [i[1] for i in events] +< output_lens = [i[2] for i in events] +--- +> # New sampling logic: [X * (1 - b), X * (1 + b)] +> input_low = int(real_input_len * (1 - range_ratio)) +> input_high = int(real_input_len * (1 + range_ratio)) +> output_low = int(output_len * (1 - range_ratio)) +> output_high = int(output_len * (1 + range_ratio)) +> +> # Add logging for debugging +> logger.info( +> "Sampling input_len from [%s, %s] and output_len from [%s, %s]", +> input_low, input_high, output_low, output_high) +> +> input_lens = np.random.randint(input_low, +> input_high + 1, +> size=num_requests) +> output_lens = np.random.randint(output_low, +> output_high + 1, +> size=num_requests) +400d393 +< trace_timestamp=timestamps[i] +765d757 +< trace_dataset_path=args.trace_dataset_path, +EOF + +patch "$SERVE_PY" << 'EOF' +101,116d100 +< async def get_request_trace(input_requests: list[SampleRequest]) -> AsyncGenerator[tuple[SampleRequest, float], None]: +< +< print(f"get_request_trace") +< +< prev_ts = 0 +< for i in range(len(input_requests)): +< request = input_requests[i] +< curr_ts = request.trace_timestamp +< delay = curr_ts - prev_ts +< print(f"request {i} prompt_len {request.prompt_len} expected_output_len {request.expected_output_len} timestamp {request.trace_timestamp} delay {delay} ms") +< if delay > 0: +< await asyncio.sleep(delay/1000) +< prev_ts = curr_ts +< yield request, 1 +< +< +267,270d250 +< print(f"requsest, timestamp, input_len, output_len, ttft_ms, tpot_ms, e2el_ms") +< for i in range(len(outputs)): +< print(f"{i} {input_requests[i].trace_timestamp} {input_requests[i].prompt_len} {actual_output_lens[i]} {ttfts[i]*1000:.2f} {all_tpots[i]*1000:.2f} {e2els[i]*1000:.2f}") +< +356d335 +< trace_dataset_path: str | None = None, +477,488d455 +< if trace_dataset_path is not None: +< iterator = get_request_trace(input_requests) +< else: +< iterator = get_request( +< input_requests, +< request_rate, +< burstiness, +< ramp_up_strategy, +< ramp_up_start_rps, +< ramp_up_end_rps, +< ) +< +501c468,470 +< async for request, current_request_rate in iterator: +--- +> async for request, current_request_rate in get_request( +> input_requests, request_rate, burstiness, ramp_up_strategy, +> ramp_up_start_rps, ramp_up_end_rps): +987d955 +< parser.add_argument("--trace-dataset-path", type=str, default=None) +1096d1063 +< trace_dataset_path=args.trace_dataset_path, +EOF + +echo "Patch applied successfully!" + +echo "You can now use:" +echo "" +echo " vllm bench serve \\" +echo " --model RedHatAI/Meta-Llama-3.1-70B-Instruct-FP8 \\" +echo " --trace-dataset-path example_trace.jsonl" +echo "" diff --git a/benchmark/trace/example_trace.jsonl b/benchmark/trace/example_trace.jsonl new file mode 100644 index 000000000..7ffb0b673 --- /dev/null +++ b/benchmark/trace/example_trace.jsonl @@ -0,0 +1,6 @@ +{"timestamp": 0, "input_length": 6758, "output_length": 500} +{"timestamp": 0, "input_length": 7322, "output_length": 490} +{"timestamp": 0, "input_length": 7236, "output_length": 794} +{"timestamp": 0, "input_length": 2290, "output_length": 316} +{"timestamp": 0, "input_length": 6760, "output_length": 3} +{"timestamp": 15000, "input_length": 15, "output_length": 3} diff --git a/benchmark/trace/run_client.sh b/benchmark/trace/run_client.sh new file mode 100644 index 000000000..6512b735a --- /dev/null +++ b/benchmark/trace/run_client.sh @@ -0,0 +1,6 @@ + +model="/data-fast/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic" +dataset=example_trace.jsonl + +vllm bench serve --model $model --trace-dataset-path $dataset --ignore-eos + diff --git a/csrc/custom_ops/CMakeLists.txt b/csrc/custom_ops/CMakeLists.txt index 353c895bb..fa101e1e5 100644 --- a/csrc/custom_ops/CMakeLists.txt +++ b/csrc/custom_ops/CMakeLists.txt @@ -16,7 +16,10 @@ find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) pybind11_add_module(custom_ops - kernels.cu + reshape_and_cache_flash_fp4.cu + reshape_and_cache_flash_bulk.cu + speculator_ln.cu + sum_lstm.cu torch_bindings.cpp ) diff --git a/csrc/custom_ops/attention_generic.cuh b/csrc/custom_ops/attention_generic.cuh index 607a062fb..6605da3b6 100644 --- a/csrc/custom_ops/attention_generic.cuh +++ b/csrc/custom_ops/attention_generic.cuh @@ -23,32 +23,26 @@ namespace vllm { // A vector type to store Q, K, V elements. -template -struct Vec {}; +template struct Vec {}; // A vector type to store FP32 accumulators. -template -struct FloatVec {}; +template struct FloatVec {}; // Template vector operations. template inline __device__ Acc mul(A a, B b); -template -inline __device__ float sum(T v); +template inline __device__ float sum(T v); -template -inline __device__ float dot(T a, T b) { +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template -inline __device__ float dot(T a, T b) { +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template -inline __device__ void zero(T& dst) { +template inline __device__ void zero(T &dst) { constexpr int WORDS = sizeof(T) / 4; union { T raw; @@ -62,5 +56,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm - +} // namespace vllm diff --git a/csrc/custom_ops/custom_ops.h b/csrc/custom_ops/custom_ops.h index 3edf44ece..336f59370 100644 --- a/csrc/custom_ops/custom_ops.h +++ b/csrc/custom_ops/custom_ops.h @@ -7,14 +7,35 @@ #include void reshape_and_cache_flash_bulk( - torch::Tensor& keys, - torch::Tensor& values, - std::vector const& key_caches, - std::vector const& value_caches, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - std::vector const& k_scales, - std::vector const& v_scales, - int64_t num_heads, - int64_t head_size -); \ No newline at end of file + torch::Tensor &keys, torch::Tensor &values, + std::vector const &key_caches, + std::vector const &value_caches, torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype, + std::vector const &k_scales, + std::vector const &v_scales, int64_t num_heads, + int64_t head_size); + +void reshape_and_cache_flash_fp4(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, + torch::Tensor &value_cache, + torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype, + torch::Tensor &k_scale, torch::Tensor &v_scale, + torch::Tensor &key_scale_cache, + torch::Tensor &value_scale_cache); + +torch::Tensor speculator_ln_cuda(const torch::Tensor &input, + const c10::optional &weight, + const c10::optional &bias, + double eps); + +std::tuple sum_lstm_cuda( + const torch::Tensor& states_4d, // [..., 4D] + const torch::Tensor& z4_4d, // [..., 4D] (repeat along last dim) + const torch::Tensor& prev_cell_d, // [..., D] + const c10::optional& w_cell, + const c10::optional& b_cell, + const c10::optional& w_state, + const c10::optional& b_state, + double alpha, double eps_cell, double eps_state, + bool use_fast_gelu); \ No newline at end of file diff --git a/csrc/custom_ops/dispatch_utils.h b/csrc/custom_ops/dispatch_utils.h index 3361922f0..1a778131b 100644 --- a/csrc/custom_ops/dispatch_utils.h +++ b/csrc/custom_ops/dispatch_utils.h @@ -8,63 +8,62 @@ // Need a special dispatch case macro since we will nest the FP8 dispatch. // Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'. -#define AT_DISPATCH_FP8_CASE(enum_type, ...) \ +#define AT_DISPATCH_FP8_CASE(enum_type, ...) \ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__) -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. #ifdef USE_ROCM - #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ - AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ - AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) - #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #else - #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ - AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) - #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #endif // When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'. // See AT_DISPATCH_FP8_CASE above. -#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) - diff --git a/csrc/custom_ops/dtype_common.cuh b/csrc/custom_ops/dtype_common.cuh new file mode 100644 index 000000000..0ea648f56 --- /dev/null +++ b/csrc/custom_ops/dtype_common.cuh @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include +#include + +template struct IsSupported : std::false_type {}; +template <> struct IsSupported : std::true_type {}; +template <> struct IsSupported : std::true_type {}; + +__device__ __forceinline__ float to_float_device(__half h) { + return __half2float(h); +} + +__device__ __forceinline__ __half from_float_device(float f) { + return __float2half(f); +} + +__device__ __forceinline__ float to_float_device(__nv_bfloat16 h) { +#if !defined(__CUDA_NO_BF16__) + return __bfloat162float(h); +#else + uint16_t raw = *reinterpret_cast(&h); + uint32_t u32 = (uint32_t)raw << 16; + float out = __uint_as_float(u32); + return out; +#endif +} + +__device__ __forceinline__ __nv_bfloat16 from_float_device_bf16(float f) { +#if !defined(__CUDA_NO_BF16__) + return __float2bfloat16(f); +#else + uint32_t u = __float_as_uint(f); + uint16_t hi = (uint16_t)(u >> 16); + __nv_bfloat16 h; + *reinterpret_cast(&h) = hi; + return h; +#endif +} + +template struct DevHalf; + +template <> struct DevHalf { + using type = __half; + static __device__ __forceinline__ float to_float(__half h) { + return to_float_device(h); + } + static __device__ __forceinline__ __half from_float(float f) { + return from_float_device(f); + } +}; + +template <> struct DevHalf { + using type = __nv_bfloat16; + static __device__ __forceinline__ float to_float(__nv_bfloat16 h) { + return to_float_device(h); + } + static __device__ __forceinline__ __nv_bfloat16 from_float(float f) { + return from_float_device_bf16(f); + } +}; + +template struct alignas(sizeof(T) * N) Pack { + T v[N]; +}; + +static inline bool is_aligned(const void *p, size_t bytes) { + return (reinterpret_cast(p) % bytes) == 0; +} diff --git a/csrc/custom_ops/dtype_fp8.cuh b/csrc/custom_ops/dtype_fp8.cuh index 3bb195524..86f15adb5 100644 --- a/csrc/custom_ops/dtype_fp8.cuh +++ b/csrc/custom_ops/dtype_fp8.cuh @@ -4,10 +4,10 @@ #include #ifdef ENABLE_FP8 - #ifndef USE_ROCM - #include - #endif // USE_ROCM -#endif // ENABLE_FP8 +#ifndef USE_ROCM +#include +#endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { @@ -18,25 +18,20 @@ enum class Fp8KVCacheDataType { }; // fp8 vector types for quantization of kv cache -template <> -struct Vec { +template <> struct Vec { using Type = uint8_t; }; -template <> -struct Vec { +template <> struct Vec { using Type = uint16_t; }; -template <> -struct Vec { +template <> struct Vec { using Type = uint32_t; }; -template <> -struct Vec { +template <> struct Vec { using Type = uint2; }; -} // namespace vllm - +} // namespace vllm diff --git a/csrc/custom_ops/quant_utils.cuh b/csrc/custom_ops/quant_utils.cuh index 1f2b13690..e23bbed64 100644 --- a/csrc/custom_ops/quant_utils.cuh +++ b/csrc/custom_ops/quant_utils.cuh @@ -10,9 +10,9 @@ namespace vllm { #ifndef USE_ROCM namespace fp8 { - #ifdef ENABLE_FP8 +#ifdef ENABLE_FP8 - #if 0 // Disable the following code to reduce the binary size. +#if 0 // Disable the following code to reduce the binary size. template __inline__ __device__ Tout vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { @@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion( template <> __inline__ __device__ uint8_t vec_conversion( const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); - #else +#else __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); return (uint8_t)res; - #endif +#endif } // float -> fp8 @@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion( from_float(b, a); return b; } - #endif +#endif /* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = @@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion( template __inline__ __device__ Tout scaled_vec_conversion( - const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { return x; } // fp8 -> half template <> __inline__ __device__ uint16_t scaled_vec_conversion( - const uint8_t& a, const float scale, + const uint8_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); return float_to_half(half_to_float(tmp.x) * scale); @@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion( // fp8x2 -> half2 template <> __inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t& a, const float scale, + const uint16_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint16_t u16[2]; @@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion( // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 scaled_vec_conversion( - const uint32_t& a, const float scale, + const uint32_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint2 u32x2; @@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion( // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 -scaled_vec_conversion(const uint2& a, const float scale, +scaled_vec_conversion(const uint2 &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint4 u64x2; @@ -348,7 +348,7 @@ scaled_vec_conversion(const uint2& a, const float scale, template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>( - const uint8_t& a, const float scale, + const uint8_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // Note there is no direct convert function from fp8 to bf16. // fp8 -> half @@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>( template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>( - const uint16_t& a, const float scale, + const uint16_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, @@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>( // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t& a, const float scale, + const uint32_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, @@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion( // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t scaled_vec_conversion( - const uint2& a, const float scale, + const uint2 &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -404,7 +404,7 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion( // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( - const uint8_t& a, const float scale, + const uint8_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); @@ -417,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion( // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion( - const uint16_t& a, const float scale, + const uint16_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8x2 -> half2 uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); @@ -428,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ Float4_ scaled_vec_conversion( - const uint32_t& a, const float scale, + const uint32_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ res; res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); @@ -440,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion( // fp8x8 -> float8 template <> __inline__ __device__ Float8_ scaled_vec_conversion( - const uint2& a, const float scale, + const uint2 &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -456,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion( // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const uint16_t& a, const float scale, + const uint16_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); @@ -466,22 +466,22 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16& a, const float scale, + const __nv_bfloat16 &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); - #else +#else __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; - #endif - __builtin_unreachable(); // Suppress missing return statement warning +#endif + __builtin_unreachable(); // Suppress missing return statement warning } // float -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const float& a, const float scale, + const float &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); @@ -491,84 +491,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ float4 scaled_vec_conversion( - const uint32_t& a, const float scale, + const uint32_t &a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } - #endif // ENABLE_FP8 +#endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin& x) { - #if 0 // Disable the following code to reduce the binary size. +__inline__ __device__ Tout convert(const Tin &x) { +#if 0 // Disable the following code to reduce the binary size. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return vec_conversion(x, __NV_E5M2); } - #endif +#endif assert(false); - __builtin_unreachable(); // Suppress missing return statement warning + __builtin_unreachable(); // Suppress missing return statement warning } template -__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { - #ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, __NV_E5M2); } - #endif +#endif assert(false); - __builtin_unreachable(); // Suppress missing return statement warning -} - - // The following macro is used to dispatch the conversion function based on - // the data type of the key and value cache. The FN is a macro that calls a - // function with template. - #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ + __builtin_unreachable(); // Suppress missing return statement warning +} + +// The following macro is used to dispatch the conversion function based on +// the data type of the key and value cache. The FN is a macro that calls a +// function with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ - } else { \ - TORCH_CHECK(false, \ - "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else if (KV_DTYPE == "fp8_e5m2") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else { \ - TORCH_CHECK(false, \ - "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ - } - -} // namespace fp8 -#endif // not USE_ROCM -} // namespace vllm + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/custom_ops/kernels.cu b/csrc/custom_ops/reshape_and_cache_flash_bulk.cu similarity index 73% rename from csrc/custom_ops/kernels.cu rename to csrc/custom_ops/reshape_and_cache_flash_bulk.cu index fe6ee9ba1..cf9dce1ed 100644 --- a/csrc/custom_ops/kernels.cu +++ b/csrc/custom_ops/reshape_and_cache_flash_bulk.cu @@ -2,30 +2,21 @@ #include "dispatch_utils.h" #include "quant_utils.cuh" -#include #include +#include #include namespace vllm { -template +template __global__ void reshape_and_cache_flash_bulk_kernel( - const scalar_t* __restrict__ keys, - const scalar_t* __restrict__ values, - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ slot_mapping, - const int block_stride, - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - int64_t* k_scale_ptrs, - int64_t* v_scale_ptrs) { + const scalar_t *__restrict__ keys, const scalar_t *__restrict__ values, + int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ slot_mapping, const int block_stride, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, int64_t *k_scale_ptrs, + int64_t *v_scale_ptrs) { const int64_t layer_idx = blockIdx.x; const int64_t token_idx = blockIdx.y; const int64_t slot_idx = slot_mapping[token_idx]; @@ -37,14 +28,14 @@ __global__ void reshape_and_cache_flash_bulk_kernel( const int64_t block_offset = slot_idx % block_size; const int n = num_heads * head_size; - cache_t* __restrict__ key_cache = - reinterpret_cast(key_cache_ptrs[layer_idx]); - cache_t* __restrict__ value_cache = - reinterpret_cast(value_cache_ptrs[layer_idx]); - const float* __restrict__ k_scale = - reinterpret_cast(k_scale_ptrs[layer_idx]); - const float* __restrict__ v_scale = - reinterpret_cast(v_scale_ptrs[layer_idx]); + cache_t *__restrict__ key_cache = + reinterpret_cast(key_cache_ptrs[layer_idx]); + cache_t *__restrict__ value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); + const float *__restrict__ k_scale = + reinterpret_cast(k_scale_ptrs[layer_idx]); + const float *__restrict__ v_scale = + reinterpret_cast(v_scale_ptrs[layer_idx]); for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + layer_idx * n + i; @@ -70,29 +61,26 @@ __global__ void reshape_and_cache_flash_bulk_kernel( } // namespace vllm -#define CALL_RESHAPE_AND_CACHE_FLASH_BULK(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_flash_bulk_kernel \ - <<>>( \ - reinterpret_cast(keys.data_ptr()), \ - reinterpret_cast(values.data_ptr()), \ - key_cache_ptrs_tensor.data_ptr(), \ - value_cache_ptrs_tensor.data_ptr(), \ - slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, static_cast(num_heads), \ - static_cast(head_size), block_size, \ - k_scale_ptrs_tensor.data_ptr(), \ +#define CALL_RESHAPE_AND_CACHE_FLASH_BULK(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_bulk_kernel \ + <<>>( \ + reinterpret_cast(keys.data_ptr()), \ + reinterpret_cast(values.data_ptr()), \ + key_cache_ptrs_tensor.data_ptr(), \ + value_cache_ptrs_tensor.data_ptr(), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, static_cast(num_heads), \ + static_cast(head_size), block_size, \ + k_scale_ptrs_tensor.data_ptr(), \ v_scale_ptrs_tensor.data_ptr()); void reshape_and_cache_flash_bulk( - torch::Tensor& keys, - torch::Tensor& values, - std::vector const& key_caches, - std::vector const& value_caches, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - std::vector const& k_scales, - std::vector const& v_scales, - int64_t num_heads, + torch::Tensor &keys, torch::Tensor &values, + std::vector const &key_caches, + std::vector const &value_caches, torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype, + std::vector const &k_scales, + std::vector const &v_scales, int64_t num_heads, int64_t head_size) { int num_layers = key_caches.size(); @@ -146,7 +134,8 @@ void reshape_and_cache_flash_bulk( .to(device_of_key); dim3 grid(num_layers, num_tokens); - dim3 block(std::min(static_cast(num_heads) * static_cast(head_size), 512)); + dim3 block( + std::min(static_cast(num_heads) * static_cast(head_size), 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/custom_ops/reshape_and_cache_flash_fp4.cu b/csrc/custom_ops/reshape_and_cache_flash_fp4.cu new file mode 100644 index 000000000..ce0ffffc5 --- /dev/null +++ b/csrc/custom_ops/reshape_and_cache_flash_fp4.cu @@ -0,0 +1,338 @@ +#include "custom_ops.h" +#include "dispatch_utils.h" +#include "quant_utils.cuh" +#include "vectorization_utils.cuh" + +#include +#include +#include + +#include + +namespace vllm { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +template struct TypeConverter { + using Type = half2; +}; +template <> struct TypeConverter { + using Type = half; +}; +template <> struct TypeConverter { + using Type = half2; +}; +template <> struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; +template <> struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +template struct PackedVec8 { + typename TypeConverter::Type elts[4]; +}; + +#if __CUDA_ARCH__ < 1000 +__device__ uint8_t float_to_e2m1_rn(float val) { + if (isnan(val)) { + return 0x0; + } + if (isinf(val)) { + val = val < 0.f ? -6.f : 6.f; + } + uint32_t sign_bit = (reinterpret_cast(val) & 0x80000000) >> 28; + float x = fabsf(val); + uint8_t magnitude_bits; + if (x > 5.0f) + magnitude_bits = 0x7; // 6.0 + else if (x > 3.5f) + magnitude_bits = 0x6; // 4.0 + else if (x > 2.5f) + magnitude_bits = 0x5; // 3.0 + else if (x > 1.75f) + magnitude_bits = 0x4; // 2.0 + else if (x > 1.25f) + magnitude_bits = 0x3; // 1.5 + else if (x > 0.75f) + magnitude_bits = 0x2; // 1.0 + else if (x > 0.25f) + magnitude_bits = 0x1; // 0.5 + else + magnitude_bits = 0x0; // 0.0 + return sign_bit | magnitude_bits; +} +#endif + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile("{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), + "f"(array[1].y), "f"(array[2].x), "f"(array[2].y), + "f"(array[3].x), "f"(array[3].y)); + return val; +#else + uint32_t result = 0; + uint8_t *result_bytes = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 4; ++i) { + uint8_t val1 = float_to_e2m1_rn(array[i].x); + uint8_t val2 = float_to_e2m1_rn(array[i].y); + result_bytes[i] = (val2 << 4) | (val1 & 0x0F); + } + return result; +#endif +} + +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ __forceinline__ void +quantize_16_to_fp4(const scalar_t *__restrict__ src_base, + uint32_t *__restrict__ dst_base, + uint8_t *__restrict__ scale_dst, const int tid_in_block) { + + const int lane_in_pair = tid_in_block % 2; + + using PackedVec = PackedVec8; + PackedVec vec = + *reinterpret_cast(src_base + lane_in_pair * 8); + + auto localMax = __habs2(vec.elts[0]); +#pragma unroll + for (int i = 1; i < 4; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + localMax = __hmax2(__shfl_xor_sync(0xffffffff, localMax, 1), localMax); + float vecMax = float(__hmax(localMax.x, localMax.y)); + float SFValue = vecMax * reciprocal_approximate_ftz(6.0f); + + if (lane_in_pair == 0) { + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + *scale_dst = reinterpret_cast(tmp); + SFValue = float(tmp); + } + + SFValue = __shfl_sync(0xffffffff, SFValue, tid_in_block & ~1); + float outputScale = + (SFValue != 0.0f) ? reciprocal_approximate_ftz(SFValue) : 0.0f; + + float2 fp2Vals[4]; +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + dst_base[lane_in_pair] = fp32_vec_to_e2m1(fp2Vals); +} +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + +template +__global__ void reshape_and_cache_flash_kernel_fp4( + const scalar_t *__restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t *__restrict__ value, // [num_tokens, num_heads, head_size] + cache_t *__restrict__ key_cache, // NHD or HND + cache_t *__restrict__ value_cache, + const int64_t *__restrict__ slot_mapping, // [num_tokens] + const int64_t block_stride, const int64_t page_stride, + const int64_t head_stride, const int64_t key_stride, + const int64_t value_stride, const int num_heads, const int head_size, + const int block_size, const float *k_scale, const float *v_scale, + uint8_t *__restrict__ key_scale_cache, + uint8_t *__restrict__ value_scale_cache, const bool is_nhd) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n_elems = num_heads * head_size; + + const scalar_t *__restrict__ key_src = key + token_idx * key_stride; + const scalar_t *__restrict__ value_src = value + token_idx * value_stride; + + cache_t *__restrict__ key_dst = + key_cache + block_idx * block_stride + block_offset * page_stride; + cache_t *__restrict__ value_dst = + value_cache + block_idx * block_stride + block_offset * page_stride; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint32_t *__restrict__ key_dst_fp4 = reinterpret_cast(key_dst); + uint32_t *__restrict__ value_dst_fp4 = + reinterpret_cast(value_dst); + + const int64_t scale_block_stride = block_stride / 8; + const int64_t scale_page_stride = page_stride / 8; + const int64_t scale_head_stride = head_stride / 8; + + uint8_t *__restrict__ key_scale_dst = key_scale_cache + + block_idx * scale_block_stride + + block_offset * scale_page_stride; + uint8_t *__restrict__ value_scale_dst = value_scale_cache + + block_idx * scale_block_stride + + block_offset * scale_page_stride; + + constexpr int CHUNK_SIZE = 16; + constexpr int THREADS_PER_CHUNK = 2; + + const int lane = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + const int warps_in_block = blockDim.x / 32; + const int chunks_per_warp = 32 / THREADS_PER_CHUNK; + + if (is_nhd) { + // NHD layout + const int num_chunks = n_elems / CHUNK_SIZE; + for (int chunk_base_idx = warp_id * chunks_per_warp; + chunk_base_idx < num_chunks; + chunk_base_idx += warps_in_block * chunks_per_warp) { + const int chunk_idx = chunk_base_idx + (lane / THREADS_PER_CHUNK); + if (chunk_idx < num_chunks) { + quantize_16_to_fp4(key_src + chunk_idx * CHUNK_SIZE, + key_dst_fp4 + chunk_idx * (CHUNK_SIZE / 8), + key_scale_dst + chunk_idx, threadIdx.x); + quantize_16_to_fp4(value_src + chunk_idx * CHUNK_SIZE, + value_dst_fp4 + chunk_idx * (CHUNK_SIZE / 8), + value_scale_dst + chunk_idx, threadIdx.x); + } + } + } else { + // HND layout + const int num_chunks_per_head = head_size / CHUNK_SIZE; + for (int head = warp_id; head < num_heads; head += warps_in_block) { + const scalar_t *__restrict__ k_src_h = key_src + head * head_size; + const scalar_t *__restrict__ v_src_h = value_src + head * head_size; + + cache_t *__restrict__ k_dst_head_u8 = + key_dst + static_cast(head) * head_stride; + cache_t *__restrict__ v_dst_head_u8 = + value_dst + static_cast(head) * head_stride; + + uint32_t *__restrict__ k_dst_h = + reinterpret_cast(k_dst_head_u8); + uint32_t *__restrict__ v_dst_h = + reinterpret_cast(v_dst_head_u8); + + uint8_t *__restrict__ k_scale_dst_h = + key_scale_dst + static_cast(head) * scale_head_stride; + uint8_t *__restrict__ v_scale_dst_h = + value_scale_dst + static_cast(head) * scale_head_stride; + + for (int chunk_idx = lane / THREADS_PER_CHUNK; + chunk_idx < num_chunks_per_head; chunk_idx += chunks_per_warp) { + quantize_16_to_fp4(k_src_h + chunk_idx * CHUNK_SIZE, + k_dst_h + chunk_idx * (CHUNK_SIZE / 8), + k_scale_dst_h + chunk_idx, threadIdx.x); + quantize_16_to_fp4(v_src_h + chunk_idx * CHUNK_SIZE, + v_dst_h + chunk_idx * (CHUNK_SIZE / 8), + v_scale_dst_h + chunk_idx, threadIdx.x); + } + } + } +#endif // __CUDA_ARCH__ >= 900 +} + +} // namespace vllm + +void reshape_and_cache_flash_fp4( + torch::Tensor &key, // [num_tokens, num_heads, head_size] + torch::Tensor &value, // [num_tokens, num_heads, head_size] + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping, const std::string &kv_cache_dtype, + torch::Tensor &k_scale, torch::Tensor &v_scale, + torch::Tensor &key_scale_cache, torch::Tensor &value_scale_cache) { + const int64_t num_tokens = slot_mapping.size(0); + const int64_t num_heads = key.size(1); + const int64_t head_size = key.size(2); + + TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4, + "KV cache must be rank-4"); + + const bool is_nhd = (key_cache.size(2) == num_heads); + + const int64_t block_stride = key_cache.stride(0); + const int64_t page_stride = + is_nhd ? key_cache.stride(1) : key_cache.stride(2); + const int64_t head_stride = + is_nhd ? key_cache.stride(2) : key_cache.stride(1); + const int block_size = is_nhd ? key_cache.size(1) : key_cache.size(2); + + TORCH_CHECK(value_cache.stride(0) == block_stride, "block_stride mismatch"); + TORCH_CHECK((is_nhd ? value_cache.stride(1) : value_cache.stride(2)) == + page_stride, + "page_stride mismatch"); + TORCH_CHECK((is_nhd ? value_cache.stride(2) : value_cache.stride(1)) == + head_stride, + "head_stride mismatch"); + TORCH_CHECK((is_nhd ? value_cache.size(1) : value_cache.size(2)) == + block_size, + "block_size mismatch between key_cache and value_cache"); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(num_tokens); + int threads = std::min(num_heads * head_size, 512); + threads = std::max(32, ((threads + 31) / 32) * 32); + dim3 block(threads); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (key.scalar_type()) { + case torch::kHalf: { + vllm::reshape_and_cache_flash_kernel_fp4 + <<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(key_cache.data_ptr()), + reinterpret_cast(value_cache.data_ptr()), + slot_mapping.data_ptr(), block_stride, page_stride, + head_stride, key_stride, value_stride, num_heads, head_size, + block_size, k_scale.data_ptr(), v_scale.data_ptr(), + key_scale_cache.data_ptr(), + value_scale_cache.data_ptr(), is_nhd); + break; + } + case torch::kBFloat16: { + vllm::reshape_and_cache_flash_kernel_fp4<__nv_bfloat16, uint8_t> + <<>>( + reinterpret_cast<__nv_bfloat16 *>(key.data_ptr()), + reinterpret_cast<__nv_bfloat16 *>(value.data_ptr()), + reinterpret_cast(key_cache.data_ptr()), + reinterpret_cast(value_cache.data_ptr()), + slot_mapping.data_ptr(), block_stride, page_stride, + head_stride, key_stride, value_stride, num_heads, head_size, + block_size, k_scale.data_ptr(), v_scale.data_ptr(), + key_scale_cache.data_ptr(), + value_scale_cache.data_ptr(), is_nhd); + break; + } + default: { + TORCH_CHECK(false, "Unsupported input dtype for reshape_and_cache_fp4. " + "Must be half or bfloat16."); + } + } +} \ No newline at end of file diff --git a/csrc/custom_ops/speculator_ln.cu b/csrc/custom_ops/speculator_ln.cu new file mode 100644 index 000000000..e3c376747 --- /dev/null +++ b/csrc/custom_ops/speculator_ln.cu @@ -0,0 +1,252 @@ +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include "dtype_common.cuh" + +struct SumOp { + __device__ __forceinline__ float operator()(const float &a, + const float &b) const { + return a + b; + } +}; + +template +__global__ void spec_ln_vec_kernel( + typename DevHalf::type *__restrict__ out, + const typename DevHalf::type *__restrict__ in, + const typename DevHalf::type *__restrict__ weight, + const typename DevHalf::type *__restrict__ bias, + int64_t row_stride, int hidden, float eps) { + + using DH = DevHalf; + using H = typename DH::type; + using VPack = Pack; + + extern __shared__ char smem[]; + __shared__ float s_inv_rms; + + const int row = blockIdx.x; + const H *row_in = in + row * row_stride; + H *row_out = out + row * hidden; + + const int vec_len = hidden / VEC; + const int tail = hidden - vec_len * VEC; + + float local_ss = 0.f; + + const VPack *in_vec = reinterpret_cast(row_in); + + for (int i = threadIdx.x; i < vec_len; i += blockDim.x) { + VPack p = in_vec[i]; +#pragma unroll + for (int k = 0; k < VEC; ++k) { + float x = DH::to_float(p.v[k]); + local_ss += x * x; + } + } + + for (int j = threadIdx.x + vec_len * VEC; j < hidden; j += blockDim.x) { + float x = DH::to_float(row_in[j]); + local_ss += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float sumsq = BlockReduce(temp_storage).Reduce(local_ss, SumOp{}, blockDim.x); + + if (threadIdx.x == 0) { + float mean = sumsq / static_cast(hidden); + s_inv_rms = rsqrtf(mean + eps); + } + __syncthreads(); + + const VPack *w_vec = reinterpret_cast(weight); + const VPack *b_vec = reinterpret_cast(bias); + VPack *out_vec = reinterpret_cast(row_out); + + for (int i = threadIdx.x; i < vec_len; i += blockDim.x) { + VPack px = reinterpret_cast(row_in)[i]; + VPack py; + VPack pw; + VPack pb; + bool use_w = weight != nullptr; + bool use_b = bias != nullptr; + if (use_w) + pw = w_vec[i]; + if (use_b) + pb = b_vec[i]; +#pragma unroll + for (int k = 0; k < VEC; ++k) { + float y = DH::to_float(px.v[k]); + y = y * s_inv_rms; + if (use_w) + y *= DH::to_float(pw.v[k]); + if (use_b) + y += DH::to_float(pb.v[k]); + py.v[k] = DH::from_float(y); + } + out_vec[i] = py; + } + + for (int j = threadIdx.x + vec_len * VEC; j < hidden; j += blockDim.x) { + float y = DH::to_float(row_in[j]) * s_inv_rms; + if (weight) + y *= DH::to_float(weight[j]); + if (bias) + y += DH::to_float(bias[j]); + row_out[j] = DH::from_float(y); + } +} + +template +__global__ void spec_ln_scalar_kernel( + typename DevHalf::type *__restrict__ out, + const typename DevHalf::type *__restrict__ in, + const typename DevHalf::type *__restrict__ weight, + const typename DevHalf::type *__restrict__ bias, + int64_t row_stride, int hidden, float eps) { + + using DH = DevHalf; + using H = typename DH::type; + + __shared__ float s_inv_rms; + + const int row = blockIdx.x; + const H *row_in = in + row * row_stride; + H *row_out = out + row * hidden; + + float local_ss = 0.f; + for (int j = threadIdx.x; j < hidden; j += blockDim.x) { + float x = DH::to_float(row_in[j]); + local_ss += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float sumsq = BlockReduce(temp_storage).Reduce(local_ss, SumOp{}, blockDim.x); + + if (threadIdx.x == 0) { + float mean = sumsq / static_cast(hidden); + s_inv_rms = rsqrtf(mean + eps); + } + __syncthreads(); + + for (int j = threadIdx.x; j < hidden; j += blockDim.x) { + float y = DH::to_float(row_in[j]) * s_inv_rms; + if (weight) + y *= DH::to_float(weight[j]); + if (bias) + y += DH::to_float(bias[j]); + row_out[j] = DH::from_float(y); + } +} + +template +torch::Tensor speculator_ln_cuda_impl( + const torch::Tensor &input, // [..., hidden] + const c10::optional &w, // [hidden] or None + const c10::optional &b, // [hidden] or None + double eps_d) { + + TORCH_CHECK(input.is_cuda(), "speculator_ln: input must be CUDA"); + TORCH_CHECK(IsSupported::value, + "speculator_ln: dtype must be fp16 or bf16"); + + const auto dtype = input.scalar_type(); + const auto device = input.device(); + const float eps = static_cast(eps_d); + + const int64_t hidden = input.size(-1); + TORCH_CHECK(hidden > 0, "speculator_ln: hidden size must be > 0"); + TORCH_CHECK( + input.stride(-1) == 1, + "speculator_ln: last dimension must be contiguous (stride -1 == 1)"); + + const typename DevHalf::type *w_ptr = nullptr; + const typename DevHalf::type *b_ptr = nullptr; + + if (w.has_value() && w->defined()) { + TORCH_CHECK(w->is_cuda(), "weight must be CUDA"); + TORCH_CHECK(w->scalar_type() == dtype, + "weight dtype must match input dtype"); + TORCH_CHECK(w->dim() == 1 && w->numel() == hidden, + "weight must be 1D of size hidden"); + TORCH_CHECK(w->is_contiguous(), "weight must be contiguous"); + w_ptr = reinterpret_cast::type *>( + w->data_ptr()); + } + if (b.has_value() && b->defined()) { + TORCH_CHECK(b->is_cuda(), "bias must be CUDA"); + TORCH_CHECK(b->scalar_type() == dtype, "bias dtype must match input dtype"); + TORCH_CHECK(b->dim() == 1 && b->numel() == hidden, + "bias must be 1D of size hidden"); + TORCH_CHECK(b->is_contiguous(), "bias must be contiguous"); + b_ptr = reinterpret_cast::type *>( + b->data_ptr()); + } + + auto in_2d = input.view({-1, hidden}); + const int64_t num_rows = in_2d.size(0); + const int64_t row_stride = in_2d.stride(0); + + auto out = at::empty_like(input); + auto out_2d = out.view({-1, hidden}); + + const auto stream = at::cuda::getCurrentCUDAStream(); + const int BLOCK = (num_rows < 256) ? 1024 : 256; + dim3 grid(num_rows); + dim3 block(std::min(hidden, BLOCK)); + + using H = typename DevHalf::type; + const H *in_ptr = reinterpret_cast(in_2d.data_ptr()); + H *out_ptr = reinterpret_cast(out_2d.data_ptr()); + + const bool can_vec8 = (hidden % 8 == 0) && (row_stride % 8 == 0) && + is_aligned(in_ptr, 16) && is_aligned(out_ptr, 16) && + (!w_ptr || is_aligned(w_ptr, 16)) && + (!b_ptr || is_aligned(b_ptr, 16)); + + const bool can_vec4 = (hidden % 4 == 0) && (row_stride % 4 == 0) && + is_aligned(in_ptr, 8) && is_aligned(out_ptr, 8) && + (!w_ptr || is_aligned(w_ptr, 8)) && + (!b_ptr || is_aligned(b_ptr, 8)); + + if (can_vec8) { + spec_ln_vec_kernel<<>>( + out_ptr, in_ptr, w_ptr, b_ptr, row_stride, (int)hidden, eps); + } else if (can_vec4) { + spec_ln_vec_kernel<<>>( + out_ptr, in_ptr, w_ptr, b_ptr, row_stride, (int)hidden, eps); + } else { + spec_ln_scalar_kernel<<>>( + out_ptr, in_ptr, w_ptr, b_ptr, row_stride, (int)hidden, eps); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return out; +} + +torch::Tensor speculator_ln_cuda(const torch::Tensor &input, + const c10::optional &weight, + const c10::optional &bias, + double eps) { + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + switch (input.scalar_type()) { + case at::kHalf: + return speculator_ln_cuda_impl(input, weight, bias, eps); + case at::kBFloat16: + return speculator_ln_cuda_impl(input, weight, bias, eps); + default: + TORCH_CHECK(false, "speculator_ln: only fp16 and bf16 are supported."); + } +} diff --git a/csrc/custom_ops/sum_lstm.cu b/csrc/custom_ops/sum_lstm.cu new file mode 100644 index 000000000..afe6f1478 --- /dev/null +++ b/csrc/custom_ops/sum_lstm.cu @@ -0,0 +1,473 @@ +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include "dtype_common.cuh" + +__device__ __forceinline__ float sigmoid_f(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float gelu_erf(float x) { + const float kInvSqrt2 = 0.70710678118654752440f; + return 0.5f * x * (1.0f + erff(x * kInvSqrt2)); +} + +__device__ __forceinline__ float gelu_tanh(float x) { + const float kSqrt2OverPi = 0.79788456080286535588f; + return 0.5f * x * (1.0f + tanhf(kSqrt2OverPi * (x + 0.044715f * x * x * x))); +} + +template +__global__ void sum_lstm_vec_kernel( + typename DevHalf::type *__restrict__ out_state, + typename DevHalf::type *__restrict__ out_cell, + + const typename DevHalf::type *__restrict__ states_4d, + const typename DevHalf::type *__restrict__ z4_4d, + const typename DevHalf::type *__restrict__ prev_cell, + + const typename DevHalf::type *__restrict__ w_cell, + const typename DevHalf::type *__restrict__ b_cell, + const typename DevHalf::type *__restrict__ w_state, + const typename DevHalf::type *__restrict__ b_state, + + int64_t states_row_stride, int64_t z4_row_stride, int64_t cell_row_stride, + int64_t out_row_stride, int D_eff, int D_gate, + + float alpha, float eps_cell, float eps_state, int use_fast_gelu) { + using DH = DevHalf; + using H = typename DH::type; + using VPack = Pack; + using BlockReduce = cub::BlockReduce; + + const int row = blockIdx.x; + const H *row_s = states_4d + row * states_row_stride; + const H *row_z4 = z4_4d + row * z4_row_stride; + const H *row_pc = prev_cell + row * cell_row_stride; + + H *row_out_cell = out_cell + row * out_row_stride; + H *row_out_state = out_state + row * out_row_stride; + + const int vec_len = D_eff / VEC; + + float local_ss_cpre = 0.f; + const VPack *s3_vec = reinterpret_cast(row_s + 3 * D_gate); + const VPack *z3_vec = reinterpret_cast(row_z4 + 3 * D_gate); + + for (int i = threadIdx.x; i < vec_len; i += BLOCK_THREADS) { + VPack s3 = s3_vec[i]; + VPack z3 = z3_vec[i]; +#pragma unroll + for (int k = 0; k < VEC; ++k) { + float c = DH::to_float(s3.v[k]) + alpha * DH::to_float(z3.v[k]); + local_ss_cpre += c * c; + } + } + + __shared__ typename BlockReduce::TempStorage temp0; + float sumsq_cpre = BlockReduce(temp0).Sum(local_ss_cpre); + + __shared__ float inv_rms_cpre; + if (threadIdx.x == 0) { + float mean = sumsq_cpre / static_cast(D_eff); + inv_rms_cpre = rsqrtf(mean + eps_cell); + } + __syncthreads(); + + float local_ss_cnew = 0.f; + + const VPack *s0_vec = reinterpret_cast(row_s + 0 * D_gate); + const VPack *s1_vec = reinterpret_cast(row_s + 1 * D_gate); + const VPack *s2_vec = reinterpret_cast(row_s + 2 * D_gate); + + const VPack *z0_vec = reinterpret_cast(row_z4 + 0 * D_gate); + const VPack *z1_vec = reinterpret_cast(row_z4 + 1 * D_gate); + const VPack *z2_vec = reinterpret_cast(row_z4 + 2 * D_gate); + + const VPack *pc_vec = reinterpret_cast(row_pc); + VPack *outc_vec = reinterpret_cast(row_out_cell); + + const VPack *wcell_vec = reinterpret_cast(w_cell); + const VPack *bcell_vec = reinterpret_cast(b_cell); + + const bool use_w_cell = (w_cell != nullptr); + const bool use_b_cell = (b_cell != nullptr); + + for (int i = threadIdx.x; i < vec_len; i += BLOCK_THREADS) { + VPack s0 = s0_vec[i], s1 = s1_vec[i], s2 = s2_vec[i], s3 = s3_vec[i]; + VPack z0 = z0_vec[i], z1 = z1_vec[i], z2 = z2_vec[i], z3 = z3_vec[i]; + VPack pc = pc_vec[i]; + + VPack oc; + VPack wcell, bcell; + if (use_w_cell) + wcell = wcell_vec[i]; + if (use_b_cell) + bcell = bcell_vec[i]; + +#pragma unroll + for (int k = 0; k < VEC; ++k) { + float pre_f = DH::to_float(s0.v[k]) + alpha * DH::to_float(z0.v[k]); + float pre_i = DH::to_float(s1.v[k]) + alpha * DH::to_float(z1.v[k]); + float cpre = DH::to_float(s3.v[k]) + alpha * DH::to_float(z3.v[k]); + + float fgate = sigmoid_f(pre_f); + float igate = sigmoid_f(pre_i); + + float cn = cpre * inv_rms_cpre; + if (use_w_cell) + cn *= DH::to_float(wcell.v[k]); + if (use_b_cell) + cn += DH::to_float(bcell.v[k]); + + float cact = (use_fast_gelu ? gelu_tanh(cn) : gelu_erf(cn)); + float pcv = DH::to_float(pc.v[k]); + + float cnew = pcv * fgate + cact * igate; + + local_ss_cnew += cnew * cnew; + oc.v[k] = DH::from_float(cnew); + } + outc_vec[i] = oc; + } + + __shared__ typename BlockReduce::TempStorage temp1; + float sumsq_cnew = BlockReduce(temp1).Sum(local_ss_cnew); + + __shared__ float inv_rms_cnew; + if (threadIdx.x == 0) { + float mean = sumsq_cnew / static_cast(D_eff); + inv_rms_cnew = rsqrtf(mean + eps_state); + } + __syncthreads(); + + const VPack *outc_read = reinterpret_cast(row_out_cell); + VPack *outs_vec = reinterpret_cast(row_out_state); + + const VPack *wstate_vec = reinterpret_cast(w_state); + const VPack *bstate_vec = reinterpret_cast(b_state); + const bool use_w_st = (w_state != nullptr); + const bool use_b_st = (b_state != nullptr); + + for (int i = threadIdx.x; i < vec_len; i += BLOCK_THREADS) { + VPack s2 = s2_vec[i]; + VPack z2 = z2_vec[i]; + VPack oc = outc_read[i]; + + VPack wst, bst; + if (use_w_st) + wst = wstate_vec[i]; + if (use_b_st) + bst = bstate_vec[i]; + + VPack os; +#pragma unroll + for (int k = 0; k < VEC; ++k) { + float cnew = DH::to_float(oc.v[k]); + float cn = cnew * inv_rms_cnew; + if (use_w_st) + cn *= DH::to_float(wst.v[k]); + if (use_b_st) + cn += DH::to_float(bst.v[k]); + + float sact = (use_fast_gelu ? gelu_tanh(cn) : gelu_erf(cn)); + + float pre_o = DH::to_float(s2.v[k]) + alpha * DH::to_float(z2.v[k]); + float ogate = sigmoid_f(pre_o); + + float st = sact * ogate; + os.v[k] = DH::from_float(st); + } + outs_vec[i] = os; + } +} + +template +__global__ void sum_lstm_scalar_kernel( + + typename DevHalf::type *__restrict__ out_state, + typename DevHalf::type *__restrict__ out_cell, + + const typename DevHalf::type *__restrict__ states_4d, + const typename DevHalf::type *__restrict__ z4_4d, + const typename DevHalf::type *__restrict__ prev_cell, + + const typename DevHalf::type *__restrict__ w_cell, + const typename DevHalf::type *__restrict__ b_cell, + const typename DevHalf::type *__restrict__ w_state, + const typename DevHalf::type *__restrict__ b_state, + + int64_t states_row_stride, int64_t z4_row_stride, int64_t cell_row_stride, + int64_t out_row_stride, int D_eff, int D_gate, + + float alpha, float eps_cell, float eps_state, int use_fast_gelu) { + using DH = DevHalf; + using H = typename DH::type; + using BlockReduce = cub::BlockReduce; + + const int row = blockIdx.x; + const H *s = states_4d + row * states_row_stride; + const H *z4 = z4_4d + row * z4_row_stride; + const H *pc = prev_cell + row * cell_row_stride; + + H *oc = out_cell + row * out_row_stride; + H *os = out_state + row * out_row_stride; + + float local_ss_cpre = 0.f; + for (int j = threadIdx.x; j < D_eff; j += BLOCK_THREADS) { + float c = DH::to_float(s[3 * D_gate + j]) + + alpha * DH::to_float(z4[3 * D_gate + j]); + local_ss_cpre += c * c; + } + + __shared__ typename BlockReduce::TempStorage temp0; + float sumsq_cpre = BlockReduce(temp0).Sum(local_ss_cpre); + + __shared__ float inv_rms_cpre; + if (threadIdx.x == 0) { + float mean = sumsq_cpre / static_cast(D_eff); + inv_rms_cpre = rsqrtf(mean + eps_cell); + } + __syncthreads(); + + float local_ss_cnew = 0.f; + for (int j = threadIdx.x; j < D_eff; j += BLOCK_THREADS) { + float pre_f = DH::to_float(s[0 * D_gate + j]) + + alpha * DH::to_float(z4[0 * D_gate + j]); + float pre_i = DH::to_float(s[1 * D_gate + j]) + + alpha * DH::to_float(z4[1 * D_gate + j]); + float cpre = DH::to_float(s[3 * D_gate + j]) + + alpha * DH::to_float(z4[3 * D_gate + j]); + + float fgate = sigmoid_f(pre_f); + float igate = sigmoid_f(pre_i); + + float cn = cpre * inv_rms_cpre; + if (w_cell) + cn *= DH::to_float(w_cell[j]); + if (b_cell) + cn += DH::to_float(b_cell[j]); + + float cact = (use_fast_gelu ? gelu_tanh(cn) : gelu_erf(cn)); + float cnew = DH::to_float(pc[j]) * fgate + cact * igate; + + oc[j] = DH::from_float(cnew); + local_ss_cnew += cnew * cnew; + } + + __shared__ typename BlockReduce::TempStorage temp1; + float sumsq_cnew = BlockReduce(temp1).Sum(local_ss_cnew); + + __shared__ float inv_rms_cnew; + if (threadIdx.x == 0) { + float mean = sumsq_cnew / static_cast(D_eff); + inv_rms_cnew = rsqrtf(mean + eps_state); + } + __syncthreads(); + + for (int j = threadIdx.x; j < D_eff; j += BLOCK_THREADS) { + float cn = DH::to_float(oc[j]) * inv_rms_cnew; + if (w_state) + cn *= DH::to_float(w_state[j]); + if (b_state) + cn += DH::to_float(b_state[j]); + + float sact = (use_fast_gelu ? gelu_tanh(cn) : gelu_erf(cn)); + float pre_o = DH::to_float(s[2 * D_gate + j]) + + alpha * DH::to_float(z4[2 * D_gate + j]); + float ogate = sigmoid_f(pre_o); + + float st = sact * ogate; + os[j] = DH::from_float(st); + } +} + +template +static std::tuple +sum_lstm_cuda_impl(const torch::Tensor &states_4d, const torch::Tensor &z4_4d, + const torch::Tensor &prev_cell_d, + const c10::optional &w_cell, + const c10::optional &b_cell, + const c10::optional &w_state, + const c10::optional &b_state, double alpha_d, + double eps_cell_d, double eps_state_d, bool use_fast_gelu) { + TORCH_CHECK(states_4d.is_cuda() && z4_4d.is_cuda() && prev_cell_d.is_cuda(), + "sum_lstm: inputs must be CUDA tensors"); + TORCH_CHECK(IsSupported::value, + "sum_lstm: dtype must be fp16 or bf16"); + + const auto dtype = states_4d.scalar_type(); + TORCH_CHECK(z4_4d.scalar_type() == dtype && + prev_cell_d.scalar_type() == dtype, + "sum_lstm: all input dtypes must match"); + + const int64_t hidden4 = states_4d.size(-1); + TORCH_CHECK(hidden4 > 0 && hidden4 % 4 == 0, + "sum_lstm: last dim of states must be 4*D_gate"); + const int64_t D_gate = hidden4 / 4; + + TORCH_CHECK(z4_4d.size(-1) == hidden4, + "sum_lstm: z4 must have last dim 4*D_gate"); + + const int64_t D_cell = prev_cell_d.size(-1); + TORCH_CHECK(D_cell == D_gate, + "sum_lstm: prev_cell last dim must equal D_gate. Got ", D_cell, + " vs expected ", D_gate, "."); + const int64_t D_eff = D_gate; + + TORCH_CHECK(states_4d.stride(-1) == 1 && z4_4d.stride(-1) == 1 && + prev_cell_d.stride(-1) == 1, + "sum_lstm: last dimension must be contiguous (stride -1 == 1)"); + + auto check_opt_len = [&](const c10::optional &t, + const char *name) { + if (t.has_value()) { + TORCH_CHECK(t->is_cuda(), "sum_lstm: ", name, " must be CUDA"); + TORCH_CHECK(t->scalar_type() == dtype, "sum_lstm: ", name, + " dtype mismatch"); + TORCH_CHECK(t->numel() >= D_eff, "sum_lstm: ", name, + " must have length >= D_eff"); + TORCH_CHECK(t->is_contiguous(), "sum_lstm: ", name, + " must be contiguous"); + } + }; + check_opt_len(w_cell, "w_cell"); + check_opt_len(b_cell, "b_cell"); + check_opt_len(w_state, "w_state"); + check_opt_len(b_state, "b_state"); + + auto s2 = states_4d.view({-1, hidden4}); + auto z2 = z4_4d.view({-1, hidden4}); + auto p2 = prev_cell_d.view({-1, D_cell}); + + const int64_t rows = s2.size(0); + const int64_t s_stride = s2.stride(0); + const int64_t z_stride = z2.stride(0); + const int64_t p_stride = p2.stride(0); + + auto out_cell = at::empty_strided(p2.sizes(), p2.strides(), p2.options()); + auto out_state = at::empty_strided(p2.sizes(), p2.strides(), p2.options()); + + const int64_t out_stride_cell = out_cell.stride(0); + const int64_t out_stride_state = out_state.stride(0); + TORCH_CHECK(out_stride_cell == out_stride_state, + "sum_lstm: internal - output strides mismatch"); + + auto out_cell_orig = out_cell.view(prev_cell_d.sizes()); + auto out_state_orig = out_state.view(prev_cell_d.sizes()); + + using H = typename DevHalf::type; + const auto stream = at::cuda::getCurrentCUDAStream(); + + H *out_state_ptr = reinterpret_cast(out_state.data_ptr()); + H *out_cell_ptr = reinterpret_cast(out_cell.data_ptr()); + const H *s_ptr = reinterpret_cast(s2.data_ptr()); + const H *z_ptr = reinterpret_cast(z2.data_ptr()); + const H *p_ptr = reinterpret_cast(p2.data_ptr()); + + const bool row_aligned_8 = + (s_stride % 8 == 0) && (z_stride % 8 == 0) && (out_stride_cell % 8 == 0); + const bool row_aligned_4 = + (s_stride % 4 == 0) && (z_stride % 4 == 0) && (out_stride_cell % 4 == 0); + + const bool base_aligned_16 = + is_aligned(s_ptr, 16) && is_aligned(z_ptr, 16) && is_aligned(p_ptr, 16) && + is_aligned(out_state_ptr, 16) && is_aligned(out_cell_ptr, 16); + + const bool base_aligned_8 = + is_aligned(s_ptr, 8) && is_aligned(z_ptr, 8) && is_aligned(p_ptr, 8) && + is_aligned(out_state_ptr, 8) && is_aligned(out_cell_ptr, 8); + + const bool gates_div_8 = (D_gate % 8 == 0); + const bool gates_div_4 = (D_gate % 4 == 0); + + const bool can_vec8 = + (D_eff % 8 == 0) && gates_div_8 && row_aligned_8 && base_aligned_16; + + const bool can_vec4 = + (D_eff % 4 == 0) && gates_div_4 && row_aligned_4 && base_aligned_8; + + float alpha = static_cast(alpha_d); + float eps_cell = static_cast(eps_cell_d); + float eps_state = static_cast(eps_state_d); + int fast = use_fast_gelu ? 1 : 0; + + constexpr int BLK_128 = 128; + constexpr int BLK_256 = 256; + + dim3 grid(rows); + + auto wcell_ptr = + w_cell ? reinterpret_cast(w_cell->data_ptr()) : nullptr; + auto bcell_ptr = + b_cell ? reinterpret_cast(b_cell->data_ptr()) : nullptr; + auto wstate_ptr = + w_state ? reinterpret_cast(w_state->data_ptr()) : nullptr; + auto bstate_ptr = + b_state ? reinterpret_cast(b_state->data_ptr()) : nullptr; + +#define LAUNCH_VEC(V, BLK) \ + do { \ + dim3 block((BLK)); \ + sum_lstm_vec_kernel<<>>( \ + out_state_ptr, out_cell_ptr, s_ptr, z_ptr, p_ptr, wcell_ptr, \ + bcell_ptr, wstate_ptr, bstate_ptr, s_stride, z_stride, p_stride, \ + out_stride_cell, static_cast(D_eff), static_cast(D_gate), \ + alpha, eps_cell, eps_state, fast); \ + } while (0) + +#define LAUNCH_SCALAR(BLK) \ + do { \ + dim3 block((BLK)); \ + sum_lstm_scalar_kernel<<>>( \ + out_state_ptr, out_cell_ptr, s_ptr, z_ptr, p_ptr, wcell_ptr, \ + bcell_ptr, wstate_ptr, bstate_ptr, s_stride, z_stride, p_stride, \ + out_stride_cell, static_cast(D_eff), static_cast(D_gate), \ + alpha, eps_cell, eps_state, fast); \ + } while (0) + + if (can_vec8) { + LAUNCH_VEC(8, BLK_128); + } else if (can_vec4) { + LAUNCH_VEC(4, BLK_256); + } else { + LAUNCH_SCALAR(BLK_256); + } + +#undef LAUNCH_VEC +#undef LAUNCH_SCALAR + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return {out_state_orig, out_cell_orig}; +} + +std::tuple +sum_lstm_cuda(const torch::Tensor &states_4d, const torch::Tensor &z4_4d, + const torch::Tensor &prev_cell_d, + const c10::optional &w_cell, + const c10::optional &b_cell, + const c10::optional &w_state, + const c10::optional &b_state, double alpha, + double eps_cell, double eps_state, bool use_fast_gelu) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(states_4d)); + switch (states_4d.scalar_type()) { + case at::kHalf: + return sum_lstm_cuda_impl(states_4d, z4_4d, prev_cell_d, w_cell, + b_cell, w_state, b_state, alpha, + eps_cell, eps_state, use_fast_gelu); + case at::kBFloat16: + return sum_lstm_cuda_impl( + states_4d, z4_4d, prev_cell_d, w_cell, b_cell, w_state, b_state, alpha, + eps_cell, eps_state, use_fast_gelu); + default: + TORCH_CHECK(false, "sum_lstm: only fp16 and bf16 are supported."); + } +} \ No newline at end of file diff --git a/csrc/custom_ops/torch_bindings.cpp b/csrc/custom_ops/torch_bindings.cpp index bfba28a4c..a283e6c18 100644 --- a/csrc/custom_ops/torch_bindings.cpp +++ b/csrc/custom_ops/torch_bindings.cpp @@ -1,19 +1,53 @@ #include "custom_ops.h" +#include #include TORCH_LIBRARY(arctic_inference, ops) { - ops.def( - "reshape_and_cache_flash_bulk(Tensor keys," - " Tensor values," - " Tensor(c!)[] key_caches," - " Tensor(d!)[] value_caches," - " Tensor slot_mapping," - " str kv_cache_dtype," - " Tensor(e)[] k_scales," - " Tensor(f)[] v_scales," - " int num_heads," - " int head_size) -> ()"); + ops.def("reshape_and_cache_flash_bulk(Tensor keys," + " Tensor values," + " Tensor(c!)[] key_caches," + " Tensor(d!)[] value_caches," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor(e)[] k_scales," + " Tensor(f)[] v_scales," + " int num_heads," + " int head_size) -> ()"); ops.impl("reshape_and_cache_flash_bulk", torch::kCUDA, &reshape_and_cache_flash_bulk); + + ops.def("reshape_and_cache_flash_fp4(Tensor key," + " Tensor value," + " Tensor(c!) key_cache," + " Tensor(d!) value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor k_scale," + " Tensor v_scale," + " Tensor(e!) key_scale_cache," + " Tensor(f!) value_scale_cache) -> ()"); + ops.impl("reshape_and_cache_flash_fp4", torch::kCUDA, + &reshape_and_cache_flash_fp4); + + ops.def("speculator_ln_cuda(Tensor input," + " Tensor? weight," + " Tensor? bias," + " float eps) -> Tensor"); + ops.impl("speculator_ln_cuda", torch::kCUDA, &speculator_ln_cuda); + + ops.def("sum_lstm_cuda(Tensor states_4d," + " Tensor z4_4d," + " Tensor prev_cell_d," + " Tensor? w_cell," + " Tensor? b_cell," + " Tensor? w_state," + " Tensor? b_state," + " float alpha," + " float eps_cell," + " float eps_state," + " bool use_fast_gelu) -> (Tensor, Tensor)"); + ops.impl("sum_lstm_cuda", torch::kCUDA, &sum_lstm_cuda); } + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/csrc/custom_ops/vectorization.cuh b/csrc/custom_ops/vectorization.cuh new file mode 100644 index 000000000..dd5cd063e --- /dev/null +++ b/csrc/custom_ops/vectorization.cuh @@ -0,0 +1,29 @@ +#pragma once +/** + * __device__ datatypes vectorized by 4 + */ + +// Include both AMD and NVIDIA fp8 types to avoid circular import +#include +#include + +namespace vllm { + +// Vectorization containers +template +struct __align__(vec_size * sizeof(scalar_t)) vec_n_t { + scalar_t val[vec_size]; +}; + +template +struct __align__(vec_size * sizeof(quant_type_t)) q8_n_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + quant_type_t val[vec_size]; +}; + +template using vec4_t = vec_n_t; +template using q8x4_t = q8_n_t; + +} // namespace vllm diff --git a/csrc/custom_ops/vectorization_utils.cuh b/csrc/custom_ops/vectorization_utils.cuh new file mode 100644 index 000000000..be42dad34 --- /dev/null +++ b/csrc/custom_ops/vectorization_utils.cuh @@ -0,0 +1,176 @@ +#pragma once +#include "vectorization.cuh" + +namespace vllm { + +template +struct DefaultVecOp { + ScaOp scalar_op; + + __device__ __forceinline__ void + operator()(vec_n_t &dst, + const vec_n_t &src) const { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + scalar_op(dst.val[i], src.val[i]); + } + } +}; + +template +__device__ inline void +vectorize_with_alignment(const InT *in, OutT *out, int len, int tid, int stride, + VecOp &&vec_op, // vec_n_t -> vec_n_t + ScaOp &&scalar_op) { // InT -> OutT + static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, + "VEC_SIZE must be a positive power-of-two"); + constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B + uintptr_t addr = reinterpret_cast(in); + + // fast path when the whole region is already aligned + // Note: currently the output is guaranteed to be same as the input, so we + // don't check it here, comments here just for future reference. + bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0); + if (can_vec) { + int num_vec = len / VEC_SIZE; + + using vin_t = vec_n_t; + using vout_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + auto *v_out = reinterpret_cast(out); + + for (int i = tid; i < num_vec; i += stride) { + vout_t tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st + } + return; + } + + int misalignment_offset = addr & (WIDTH - 1); // addr % 64 + int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64) + int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64 + prefix_elems /= sizeof(InT); + prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16 + + // 1. prefill the when it is unsafe to vectorize + for (int i = tid; i < prefix_elems; i += stride) { + scalar_op(out[i], in[i]); + } + + in += prefix_elems; + out += prefix_elems; + len -= prefix_elems; + + int num_vec = len / VEC_SIZE; + using vin_t = vec_n_t; + using vout_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + auto *v_out = reinterpret_cast(out); + + // 2. vectorize the main part + for (int i = tid; i < num_vec; i += stride) { + vout_t tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st + } + + // 3. handle the tail + int tail_start = num_vec * VEC_SIZE; + for (int i = tid + tail_start; i < len; i += stride) { + scalar_op(out[i], in[i]); + } +} + +template +__device__ __forceinline__ void +vectorize_with_alignment(const InT *in, OutT *out, int len, int tid, int stride, + ScaOp &&scalar_op) { + using Vec = DefaultVecOp>; + vectorize_with_alignment(in, out, len, tid, stride, Vec{scalar_op}, + std::forward(scalar_op)); +} + +template struct DefaultReadVecOp { + ScaOp scalar_op; + + __device__ __forceinline__ void + operator()(const vec_n_t &src) const { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + scalar_op(src.val[i]); + } + } +}; + +// read-only version: iterate over the input with alignment guarantees +template +__device__ inline void +vectorize_read_with_alignment(const InT *in, int len, int tid, int stride, + VecOp &&vec_op, ScaOp &&scalar_op) { + static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, + "VEC_SIZE must be a positive power-of-two"); + constexpr int WIDTH = VEC_SIZE * sizeof(InT); + uintptr_t addr = reinterpret_cast(in); + + // fast path when the whole region is already aligned + bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0); + if (can_vec) { + int num_vec = len / VEC_SIZE; + + using vin_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + + for (int i = tid; i < num_vec; i += stride) { + vin_t tmp = v_in[i]; + vec_op(tmp); + } + return; + } + + int misalignment_offset = addr & (WIDTH - 1); + int alignment_bytes = WIDTH - misalignment_offset; + int prefix_elems = alignment_bytes & (WIDTH - 1); + prefix_elems /= sizeof(InT); + prefix_elems = min(prefix_elems, len); + + // 1. handle the possibly unaligned prefix with scalar access. + for (int i = tid; i < prefix_elems; i += stride) { + scalar_op(in[i]); + } + + in += prefix_elems; + len -= prefix_elems; + + int num_vec = len / VEC_SIZE; + using vin_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + + // 2. vectorized traversal of the main aligned region. + for (int i = tid; i < num_vec; i += stride) { + vec_op(v_in[i]); + } + + // 3. handle remaining tail elements. + int tail_start = num_vec * VEC_SIZE; + for (int i = tid + tail_start; i < len; i += stride) { + scalar_op(in[i]); + } +} + +// overload that requires only a scalar_op +template +__device__ __forceinline__ void +vectorize_read_with_alignment(const InT *in, int len, int tid, int stride, + ScaOp &&scalar_op) { + using Vec = DefaultReadVecOp>; + vectorize_read_with_alignment(in, len, tid, stride, Vec{scalar_op}, + std::forward(scalar_op)); +} + +} // namespace vllm diff --git a/csrc/suffix_cache/pybind.cc b/csrc/suffix_cache/pybind.cc deleted file mode 100644 index c2a13805e..000000000 --- a/csrc/suffix_cache/pybind.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Snowflake Inc. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "suffix_tree.h" - -namespace py = pybind11; - - -PYBIND11_MODULE(_C, m) { - py::class_(m, "Candidate") - .def_readwrite("token_ids", &Candidate::token_ids) - .def_readwrite("parents", &Candidate::parents) - .def_readwrite("probs", &Candidate::probs) - .def_readwrite("score", &Candidate::score) - .def_readwrite("match_len", &Candidate::match_len); - - py::class_(m, "SuffixTree") - .def(py::init()) - .def("num_seqs", &SuffixTree::num_seqs) - .def("append", &SuffixTree::append) - .def("extend", &SuffixTree::extend) - .def("remove", &SuffixTree::remove) - .def("speculate", &SuffixTree::speculate) - .def("check_integrity", &SuffixTree::check_integrity) - .def("estimate_memory", &SuffixTree::estimate_memory); -} diff --git a/csrc/suffix_cache/suffix_tree.cc b/csrc/suffix_cache/suffix_tree.cc deleted file mode 100644 index ca11a08fb..000000000 --- a/csrc/suffix_cache/suffix_tree.cc +++ /dev/null @@ -1,529 +0,0 @@ -// Copyright 2025 Snowflake Inc. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include "suffix_tree.h" - -#define CHECK_OR_RETURN(cond, msg) if (!(cond)) return msg; - -SuffixTree::SuffixTree(int max_depth) - : _max_depth(max_depth), _root(new Node()) { -} - -// Append a new element to a new or existing sequence. -void SuffixTree::append(int seq_id, int token) { - // Initialize the sequence if it doesn't exist. - _seqs.try_emplace(seq_id); - _active_nodes.try_emplace(seq_id); - - // Insert a new active node at the root. - _active_nodes[seq_id].push_back(_root.get()); - _root->endpoints[seq_id] = static_cast(_seqs[seq_id].size()); - _root->count += 1; - - // Ensure the number of active nodes doesn't exceed max_depth. - if (_active_nodes[seq_id].size() > static_cast(_max_depth)) { - _active_nodes[seq_id].pop_front(); - } - _seqs[seq_id].push_back(token); - - // Iterate over all active nodes for this sequence. - for (size_t i = 0; i < _active_nodes[seq_id].size(); ++i) { - Node* node = _active_nodes[seq_id][i]; - Node* child = nullptr; - if (node->children.contains(token)) { - child = node->children[token].get(); - } - - assert(node->endpoints.contains(seq_id)); - assert(node->endpoints[seq_id] == _seqs[seq_id].size() - 1); - - if (child == nullptr) { - // No existing child node for the new token. - if (node->count == 1 && node != _root.get()) { - // The active node has count = 1, which means the only suffix that ends here is the - // one that's being extended right now. Then this node should be a leaf node, and - // we can simply extend the length of this node. - assert(node->children.empty()); - assert(node->ref_seq == seq_id); - node->length += 1; - node->endpoints[seq_id] += 1; - } else { - // Either this is the root node, or the current suffix is not the only one that - // ends here. Either case, we need to extend the current suffix into a new child. - Node* new_child = new Node(); - new_child->token = token; - new_child->parent = node; - new_child->count = 1; - new_child->endpoints[seq_id] = static_cast(_seqs[seq_id].size()); - new_child->ref_seq = seq_id; - new_child->ref_idx = static_cast(_seqs[seq_id].size()) - 1; - new_child->length = 1; - node->children.emplace(token, new_child); - node->endpoints.erase(seq_id); - _active_nodes[seq_id][i] = new_child; - } - } - else if (node->count == child->count + 1 && node != _root.get()) { - // The active node has a child for the new token, and the child's count is exactly one - // fewer than the active node's count. Since the suffix for the active node ends here, - // that means all other suffixes that pass through this node must go to that child. - assert(node->children.size() == 1); // The active node should have only one child. - assert(node->endpoints.size() == 1); // Only the current suffix should end here. - if (child->length == 1) { - // The child only has length 1. If we append the new token to the current suffix, - // then it will perfectly overlap with the child. In this case, we should just fuse - // the current suffix into the child and eliminate the current node. - Node* parent = node->parent; - // Update child to take the place of the current node. - child->token = node->token; - child->count += 1; // Current suffix extends into the child - child->length = node->length + 1; - child->endpoints[seq_id] = static_cast(_seqs[seq_id].size()); - child->ref_seq = seq_id; - child->ref_idx = static_cast(_seqs[seq_id].size()) - child->length; - child->parent = parent; - // Give ownership of child pointer to parent and should also free the current node. - assert(parent->children.contains(child->token)); - assert(parent->children[child->token].get() == node); - Node* tmp = node->children[token].release(); - parent->children[child->token].reset(tmp); - // Replace active node with child node. - _active_nodes[seq_id][i] = child; - } else { - // The child has length > 1. If we append the new token to the current suffix, then - // it still does not reach the child node. In this case, we keep both nodes but - // extend the length of the current node by 1 into the child node. - node->length += 1; - node->endpoints[seq_id] += 1; - node->ref_seq = seq_id; - node->ref_idx = static_cast(_seqs[seq_id].size()) - node->length; - child->length -= 1; - child->ref_idx += 1; - // The child node's first token should be updated to its second token. - child->token = _seqs[child->ref_seq][child->ref_idx]; - if (child->token != token) { - Node* tmp = node->children[token].release(); - node->children.emplace(child->token, tmp); - node->children.erase(token); - } - } - } - else { - // There is a child for the new token, and should move the active node into that child. - if (child->length == 1) { - // The child node has length 1, just update the active node pointer to it. - node->endpoints.erase(seq_id); - child->count += 1; - child->endpoints[seq_id] = static_cast(_seqs[seq_id].size()); - child->ref_seq = seq_id; - child->ref_idx = static_cast(_seqs[seq_id].size()) - 1; - _active_nodes[seq_id][i] = child; - } else { - // The child node has length > 1. If we extend the current suffix into it, then it - // must be split into a segment of length 1 and another segment with the remainder. - Node* new_node = new Node(); - new_node->token = token; - new_node->count = child->count + 1; - new_node->parent = node; - new_node->length = 1; - new_node->endpoints[seq_id] = static_cast(_seqs[seq_id].size()); - new_node->ref_seq = seq_id; - new_node->ref_idx = static_cast(_seqs[seq_id].size()) - new_node->length; - // The child node's first token should be updated to its second token. - child->token = _seqs[child->ref_seq][child->ref_idx + 1]; - Node* tmp = node->children[token].release(); - new_node->children.emplace(child->token, tmp); - node->children[token].reset(new_node); - node->endpoints.erase(seq_id); - child->parent = new_node; - child->length -= 1; - child->ref_idx += 1; - _active_nodes[seq_id][i] = new_node; - } - } - } -} - -// Extend a new or existing sequence. -void SuffixTree::extend(int seq_id, const std::vector& tokens) { - for (int token : tokens) { - append(seq_id, token); - } -} - -// Remove an existing sequence. -void SuffixTree::remove(int seq_id) { - const std::vector& seq = _seqs[seq_id]; - std::vector path; // Declare here to avoid repeated allocations. - // Loop through all suffix starting indices. - for (int start = 0; start < seq.size(); start++) { - Node *node = _root.get(); - node->count--; - int idx = start; - path.clear(); - // Loop through the nodes for this suffix. - while (idx < seq.size()) { - int token = seq[idx]; - if (!node->children.contains(token)) { - break; - } - Node* child = node->children[token].get(); - assert(child->count > 0); - child->count--; - if (child->count == 0) { - node->children.erase(token); - break; - } - if (child->endpoints.contains(seq_id)) { - child->endpoints.erase(seq_id); - } - idx += child->length; - node = child; - path.push_back(node); - } - // The last visited node may be mergeable with its child. - if (node != _root.get() && node->children.size() == 1) { - const auto& it = *node->children.begin(); - std::unique_ptr& child_uptr = node->children[it.first]; - if (node->count == child_uptr->count) { - // Merge node into child. - child_uptr->token = node->token; - child_uptr->length += node->length; - child_uptr->ref_idx -= node->length; - child_uptr->parent = node->parent; - path.back() = node = child_uptr.release(); - node->parent->children[node->token].reset(node); - } - } - // ref_seq and ref_idx of all nodes in the path may need to be updated. - // 1. Go to an arbitrary leaf to get its endpoints. - Node* leaf = node; - int distance = 0; // Distance from node to leaf. - while (!leaf->children.empty()) { - leaf = (*leaf->children.begin()).second.get(); - distance += leaf->length; - } - // 2. Pick an arbitrary endpoint for the reference sequence and index. - if (leaf->endpoints.empty() || leaf->endpoints.contains(seq_id)) { - // Still need to visit this leaf later when removing this sequence. - // We can skip updating the refs until the next time it's visited. - continue; - } - const auto& ref = *leaf->endpoints.begin(); - // 3. Go back up the path to update all nodes' refs. - int32_t ref_seq = ref.first; - int32_t ref_idx = ref.second - distance; - while (!path.empty()) { - Node* n = path.back(); - path.pop_back(); - ref_idx -= n->length; - if (n->ref_seq == seq_id) { - n->ref_seq = ref_seq; - n->ref_idx = ref_idx; - } - } - } - _seqs.erase(seq_id); - _active_nodes.erase(seq_id); -} - -Candidate SuffixTree::speculate(const std::vector& pattern, - int max_spec_tokens, - float max_spec_factor, - float max_spec_offset, - float min_token_prob, - bool use_tree_spec) { - Candidate result; - int start_idx = std::max(static_cast(pattern.size()) - _max_depth, 0); - for ( ; start_idx < pattern.size(); start_idx++) { - auto[node, idx] = _match_pattern(pattern, start_idx); - if (node == nullptr) { - continue; - } - int match_len = static_cast(pattern.size()) - start_idx; - int max_tokens = std::min(max_spec_tokens, - static_cast(match_len * max_spec_factor - + max_spec_offset + 1e-6)); - max_tokens = std::max(max_tokens, 0); - Candidate candidate; - if (use_tree_spec) { - candidate = _speculate_tree(node, idx, max_tokens, min_token_prob); - } else { - candidate = _speculate_path(node, idx, max_tokens, min_token_prob); - } - if (candidate.score > result.score) { - result = std::move(candidate); - result.match_len = match_len; - } - } - return result; -} - -std::string SuffixTree::check_integrity() { - // 1. Check structural integrity of all nodes. - std::queue queue; - queue.push(_root.get()); - while (!queue.empty()) { - Node* node = queue.front(); - queue.pop(); - std::string ret = _check_node_integrity(node); - if (!ret.empty()) { - return ret; - } - for (const auto& [token, child] : node->children) { - queue.push(child.get()); - } - } - // 2. Check all sequences are represented in the tree. - std::unordered_map visit_count; - for (int seq_id = 0; seq_id < _seqs.size(); seq_id++) { - const std::vector& seq = _seqs[seq_id]; - // Loop through all suffix starting indices. - for (int start = 0; start < seq.size(); start++) { - int idx = start; - // Traverse the tree along this suffix. - Node* node = _root.get(); - visit_count[node]++; - while (idx < seq.size() && idx - start < _max_depth) { - CHECK_OR_RETURN(node->children.contains(seq[idx]), - "missing child node for sequence"); - node = node->children[seq[idx]].get(); - visit_count[node]++; - CHECK_OR_RETURN(idx + node->length <= seq.size(), - "path exceeds sequence length"); - for (int i = 0; i < node->length; ++i) { - int ref_seq = node->ref_seq; - int ref_idx = node->ref_idx + i; - CHECK_OR_RETURN(seq[idx + i] == _seqs[ref_seq][ref_idx], - "path does not match sequence tokens"); - } - idx += node->length; - } - // The last node on this path should have an endpoint. - CHECK_OR_RETURN(node->endpoints.contains(seq_id), - "missing endpoint for sequence"); - } - } - // 3. Check all nodes were visited the correct number of times. - assert(queue.empty()); - queue.push(_root.get()); - while (!queue.empty()) { - Node* node = queue.front(); - queue.pop(); - CHECK_OR_RETURN(node->count == visit_count[node], - "node count does not match visit count"); - for (const auto& [token, child] : node->children) { - queue.push(child.get()); - } - } - return ""; -} - -std::string SuffixTree::_check_node_integrity(Node* node) { - int64_t children_count = 0; - for (const auto& [token, child] : node->children) { - // Do all my children have me as their parent? - CHECK_OR_RETURN(child->parent == node, "child node has incorrect parent pointer"); - children_count++; - } - // Is my counter at least the sum of my childrens' counters? - CHECK_OR_RETURN(children_count <= node->count, "node count is less than sum children counts"); - if (node == _root.get()) { - // Root node can stop here after some simple checks. - CHECK_OR_RETURN(node->count >= 0, "root node has negative count"); - CHECK_OR_RETURN(node->parent == nullptr, "root node has non-null parent pointer"); - CHECK_OR_RETURN(node->length == 0, "root node has non-zero length"); - CHECK_OR_RETURN(node->endpoints.empty(), "root node has non-empty endpoints"); - CHECK_OR_RETURN(node->ref_idx == -1, "root node has invalid ref_idx"); - return ""; - } - // Is my length positive? Otherwise, I shouldn't exist. - CHECK_OR_RETURN(node->length > 0, "internal node has non-positive length"); - // Is my count positive? Otherwise, I shouldn't exist. - CHECK_OR_RETURN(node->count > 0, "internal node has non-positive count"); - // Are all my children's counts less than mine? If equal, then we should have been merged. - for (const auto& [token, child] : node->children) { - CHECK_OR_RETURN( - child->count < node->count, "internal node count is not greater than child count"); - } - // Check my reference sequence and index. - CHECK_OR_RETURN(_seqs.count(node->ref_seq), "internal node has invalid ref_seq"); - CHECK_OR_RETURN(node->ref_idx >= 0, "internal node has invalid ref_idx"); - CHECK_OR_RETURN(node->ref_idx + node->length <= _seqs[node->ref_seq].size(), - "internal node has invalid token range"); - // Check my first token is correct. - CHECK_OR_RETURN(node->token == _seqs[node->ref_seq][node->ref_idx], - "internal node has incorrect first token"); - // Check I am my parent's child. - CHECK_OR_RETURN(node->parent->children.contains(node->token), - "internal node is not a child of parent node"); - CHECK_OR_RETURN(node->parent->children[node->token].get() == node, - "parent node has incorrect child pointer"); - // Check all my endpoint references are correct. - for (auto [seq_id, end_idx] : node->endpoints) { - CHECK_OR_RETURN(_seqs.count(seq_id), "node endpoint refers to nonexistent sequence"); - CHECK_OR_RETURN(end_idx > 0 && end_idx <= _seqs[seq_id].size(), "invalid endpoint index"); - // Check all tokens from the start of the suffix to the endpoint. - Node* n = node; - int idx = end_idx; - do { - CHECK_OR_RETURN(n->length <= idx, "invalid endpoint length"); - idx -= n->length; - for (int i = 0; i < n->length; ++i) { - int tok = _seqs[n->ref_seq][n->ref_idx + i]; - CHECK_OR_RETURN(_seqs[seq_id][idx + i] == tok, "invalid endpoint token"); - } - n = n->parent; - } while (n != nullptr); - } - return ""; -} - -std::pair SuffixTree::_match_pattern( - const std::vector& pattern, int start_idx) { - Node* node = _root.get(); - int idx = 0; - for (int i = start_idx; i < pattern.size(); i++) { - int c = pattern[i]; - if (idx >= node->length) { - if (!node->children.contains(c)) { - return {nullptr, -1}; - } - node = node->children[c].get(); - idx = 0; - } - assert(idx < node->length); - if (_seqs[node->ref_seq][node->ref_idx + idx] != c) { - return {nullptr, -1}; - } - idx++; - } - return {node, idx}; -} - -Candidate SuffixTree::_speculate_path(Node* node, int idx, - int max_spec_tokens, - float min_token_prob) { - Candidate ret; - float prob = 1.0f; - while (ret.token_ids.size() < max_spec_tokens && prob >= min_token_prob) { - if (idx < node->length) { - // Use previous token index as parent; if none, mark as -1. - ret.parents.push_back(static_cast(ret.token_ids.size()) - 1); - int token = _seqs[node->ref_seq][node->ref_idx + idx]; - ret.token_ids.push_back(token); - ret.probs.push_back(prob); - ret.score += prob; - idx++; - } else { - Node* child = nullptr; - int64_t count = 0; - // Choose the child with the maximum count. - for (const auto& kv : node->children) { - Node* ch = kv.second.get(); - if (ch->count > count) { - child = ch; - count = ch->count; - } - } - if (child == nullptr) { - break; - } - prob *= static_cast(count) / node->count; - node = child; - idx = 0; - } - } - return ret; -} - -struct HeapItem { - float prob; - Node* node; - int idx; - int parent; // index in the candidate token list; -1 if none. - - HeapItem(float p, Node* n, int i, int par) - : prob(p), node(n), idx(i), parent(par) {} -}; - -struct HeapItemCompare { - bool operator()(const HeapItem& a, const HeapItem& b) const { - // In C++ priority_queue by default returns the largest element. - // Thus, we compare probabilities so that the highest prob is returned. - return a.prob < b.prob; - } -}; - -// Get a candidate token tree using a priority queue. -Candidate SuffixTree::_speculate_tree(Node* node, int idx, - int max_spec_tokens, - float min_token_prob) { - Candidate ret; - std::priority_queue, HeapItemCompare> queue; - queue.emplace(1.0, node, idx, -1); - while (ret.token_ids.size() < max_spec_tokens && !queue.empty()) { - HeapItem item = queue.top(); - queue.pop(); - if (item.idx < item.node->length) { - int token = _seqs[item.node->ref_seq][item.node->ref_idx + item.idx]; - ret.token_ids.push_back(token); - ret.parents.push_back(item.parent); - ret.probs.push_back(item.prob); - ret.score += item.prob; - queue.emplace(item.prob, item.node, item.idx + 1, - static_cast(ret.token_ids.size()) - 1); - } else { - for (const auto& kv : item.node->children) { - Node* child = kv.second.get(); - float prob = item.prob * child->count / - static_cast(item.node->count); - if (prob >= min_token_prob) { - queue.emplace(prob, child, 0, item.parent); - } - } - } - } - return ret; -} - -size_t SuffixTree::estimate_memory() const { - size_t total = sizeof(*this); - std::vector stack; - stack.push_back(_root.get()); - while (!stack.empty()) { - Node* node = stack.back(); - stack.pop_back(); - total += node->memory_usage(); - for (const auto& [token, child] : node->children) { - stack.push_back(child.get()); - } - } - for (const auto& [seq_id, seq] : _seqs) { - total += sizeof(decltype(seq)::value_type) * seq.capacity(); - } - for (const auto& [seq_id, active_nodes] : _active_nodes) { - total += sizeof(decltype(active_nodes)::value_type) * active_nodes.size(); - } - return total; -} diff --git a/csrc/suffix_cache/CMakeLists.txt b/csrc/suffix_decoding/CMakeLists.txt similarity index 64% rename from csrc/suffix_cache/CMakeLists.txt rename to csrc/suffix_decoding/CMakeLists.txt index 92ea3c150..1e0fcd03a 100644 --- a/csrc/suffix_cache/CMakeLists.txt +++ b/csrc/suffix_decoding/CMakeLists.txt @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.14) -project(SuffixCache CXX) +cmake_minimum_required(VERSION 3.18) -set(CMAKE_CXX_STANDARD 17) +project(SuffixDecoding CXX) + +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build." FORCE) @@ -24,10 +25,14 @@ if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) endif() set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g3 -ggdb") -find_package(pybind11 REQUIRED) +# Python dependencies for nanobind +find_package(Python 3.10 COMPONENTS Interpreter Development.Module REQUIRED) + +# Detect the installed nanobind package and import it into CMake +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT) -pybind11_add_module(_C pybind.cc suffix_tree.cc) +find_package(nanobind CONFIG REQUIRED) -# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a -# define (VERSION_INFO) here. -target_compile_definitions(_C PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO}) +nanobind_add_module(_C NOMINSIZE bindings.cc suffix_tree.cc) diff --git a/csrc/suffix_decoding/bindings.cc b/csrc/suffix_decoding/bindings.cc new file mode 100644 index 000000000..b1d058cf0 --- /dev/null +++ b/csrc/suffix_decoding/bindings.cc @@ -0,0 +1,103 @@ +// Copyright 2025 Snowflake Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "suffix_tree.h" + +namespace nb = nanobind; + +using Int32Array1D = nb::ndarray, + nb::device::cpu, nb::any_contig>; + + +void extend_ndarray(SuffixTree& tree, + int seq_id, + const Int32Array1D& tokens) { + tree.extend( + seq_id, + std::span(tokens.data(), tokens.size())); +} + + +void extend_vector(SuffixTree& tree, + int seq_id, + const std::vector& tokens) { + tree.extend(seq_id, std::span(tokens)); +} + + +Draft speculate_ndarray(SuffixTree& tree, + const Int32Array1D& context, + int max_spec_tokens, + float max_spec_factor, + float max_spec_offset, + float min_token_prob, + bool use_tree_spec) { + return tree.speculate( + std::span(context.data(), context.size()), + max_spec_tokens, + max_spec_factor, + max_spec_offset, + min_token_prob, + use_tree_spec); +} + + +Draft speculate_vector(SuffixTree& tree, + const std::vector& context, + int max_spec_tokens, + float max_spec_factor, + float max_spec_offset, + float min_token_prob, + bool use_tree_spec) { + return tree.speculate( + std::span(context), + max_spec_tokens, + max_spec_factor, + max_spec_offset, + min_token_prob, + use_tree_spec); +} + + +NB_MODULE(_C, m) { + nb::set_leak_warnings(false); + + nb::class_(m, "Draft") + .def_rw("token_ids", &Draft::token_ids) + .def_rw("parents", &Draft::parents) + .def_rw("probs", &Draft::probs) + .def_rw("score", &Draft::score) + .def_rw("match_len", &Draft::match_len); + + nb::class_(m, "SuffixTree") + .def(nb::init()) + .def("num_seqs", &SuffixTree::num_seqs) + .def("remove", &SuffixTree::remove) + // Overloads for extend method. Use different names to avoid overload + // resolution overhead at run-time. + .def("extend", &extend_vector) + .def("extend_ndarray", &extend_ndarray) + // Overloads for speculate method. + .def("speculate", &speculate_vector) + .def("speculate_ndarray", &speculate_ndarray) + // Debugging methods, not meant to be used in critical loop. + .def("check_integrity", &SuffixTree::check_integrity) + .def("estimate_memory", &SuffixTree::estimate_memory); +} diff --git a/csrc/suffix_cache/int32_map.h b/csrc/suffix_decoding/int32_map.h similarity index 80% rename from csrc/suffix_cache/int32_map.h rename to csrc/suffix_decoding/int32_map.h index 84fe6b1fd..f1081ad76 100644 --- a/csrc/suffix_cache/int32_map.h +++ b/csrc/suffix_decoding/int32_map.h @@ -24,6 +24,9 @@ #include #include +template +class Int32MapIterator; + /* * A simple hash map with int32_t keys that's designed to be fast and compact: * - Open addressing with triangular probing allows high load factors. @@ -33,7 +36,6 @@ template class Int32Map { public: - using const_iterator_value = std::pair; Int32Map() = default; @@ -168,47 +170,21 @@ class Int32Map { return sizeof(*this) + sizeof(Slot) * cap_; } - class const_iterator { - public: - using value_type = const_iterator_value; - using difference_type = std::ptrdiff_t; - using iterator_category = std::forward_iterator_tag; - - const_iterator() : m_(nullptr), i_(0) {} - - const_iterator(const Int32Map* m, uint32_t i) : m_(m), i_(i) { - advance_(); - } - - value_type operator*() const { - const Slot& s = m_->slots_[i_]; - return { s.key, *m_->value_ptr_(s) }; - } + /* Iterators */ - const_iterator& operator++() { - ++i_; - advance_(); - return *this; - } + friend class Int32MapIterator; + friend class Int32MapIterator; - bool operator==(const const_iterator& o) const { - return m_ == o.m_ && i_ == o.i_; - } + using iterator = Int32MapIterator; + using const_iterator = Int32MapIterator; - bool operator!=(const const_iterator& o) const { - return !(*this == o); - } + iterator begin() { + return iterator(this, 0); + } - private: - void advance_() { - const uint32_t c = m_ ? m_->cap_ : 0u; - while (m_ && i_ < c && !m_->is_filled_(m_->slots_[i_].key)) { - ++i_; - } - } - const Int32Map* m_; - uint32_t i_; - }; + iterator end() { + return iterator(this, cap_); + } const_iterator begin() const { return const_iterator(this, 0); @@ -219,11 +195,39 @@ class Int32Map { } const_iterator cbegin() const { - return begin(); + return const_iterator(this, 0); } const_iterator cend() const { - return end(); + return const_iterator(this, cap_); + } + + iterator find(int32_t key) { + if (key == KEY_EMPTY || key == KEY_TOMBSTONE) { + throw std::invalid_argument("invalid key"); + } + if (!slots_) { + return end(); + } + uint32_t idx; + if (!probe_insert_or_find_(key, idx)) { + return end(); + } + return iterator(this, idx); + } + + const_iterator find(int32_t key) const { + if (key == KEY_EMPTY || key == KEY_TOMBSTONE) { + throw std::invalid_argument("invalid key"); + } + if (!slots_) { + return end(); + } + uint32_t idx; + if (!probe_insert_or_find_(key, idx)) { + return end(); + } + return const_iterator(this, idx); } private: @@ -386,3 +390,60 @@ class Int32Map { // size_ unchanged } }; + + +template +class Int32MapIterator { +public: + using map_type = std::conditional_t, Int32Map>; + using ref_type = std::conditional_t; + using value_type = std::pair; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; + + struct arrow_proxy { + value_type value; + const value_type* operator->() const { + return &value; + } + }; + + Int32MapIterator() : m_(nullptr), i_(0) {} + Int32MapIterator(map_type* m, uint32_t i) : m_(m), i_(i) { advance_(); } + + value_type operator*() const { + auto& s = m_->slots_[i_]; + return { s.key, *m_->value_ptr_(s) }; + } + + arrow_proxy operator->() const { + auto& s = m_->slots_[i_]; + return arrow_proxy{ { s.key, *m_->value_ptr_(s) } }; + } + + Int32MapIterator& operator++() { + ++i_; + advance_(); + return *this; + } + + template + bool operator==(const Int32MapIterator& o) const { + return m_ == o.m_ && i_ == o.i_; + } + + template + bool operator!=(const Int32MapIterator& o) const { + return !(*this == o); + } + +private: + void advance_() { + uint32_t c = m_ ? m_->cap_ : 0u; + while (m_ && i_ < c && !m_->is_filled_(m_->slots_[i_].key)) { + ++i_; + } + } + map_type* m_; + uint32_t i_; +}; diff --git a/csrc/suffix_decoding/suffix_tree.cc b/csrc/suffix_decoding/suffix_tree.cc new file mode 100644 index 000000000..3023983b3 --- /dev/null +++ b/csrc/suffix_decoding/suffix_tree.cc @@ -0,0 +1,933 @@ +// Copyright 2025 Snowflake Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "suffix_tree.h" + +#define CHECK_OR_RETURN(cond) \ + if (!(cond)) return "Integrity check failed at line " + \ + std::to_string(__LINE__) + ": " + #cond; + +SuffixTree::SuffixTree(int max_depth) + : _max_depth(max_depth), _root(new Node()) { +} + +void _remove_from_siblings(Node* node) { + // Remove a node from the siblings and groups linked lists. + assert(node->parent); // Should only be called on non-root nodes. + // Take care of the groups linked list. + Group* group = node->group.get(); + if (group->head == node) { + if (node->next_sibling && node->next_sibling->count == node->count) { + // There are other nodes in the same group, update its head and + // remove the node from the group. + group->head = node->next_sibling; + node->group.reset(); + } else { + // Otherwise, the node is the only member of its group. Remove the + // group together with the node. + if (group->prev) { + group->prev->next = group->next; + } + if (group->next) { + group->next->prev = group->prev; + } + group->prev = group->next = nullptr; + } + } else { + // The node is not the head of its group, just remove it. + node->group.reset(); + } + // Take care of the siblings linked list. + if (node->next_sibling) { + node->next_sibling->prev_sibling = node->prev_sibling; + } else { + node->parent->tail_child = node->prev_sibling; + } + if (node->prev_sibling) { + node->prev_sibling->next_sibling = node->next_sibling; + } else { + node->parent->head_child = node->next_sibling; + } + node->prev_sibling = node->next_sibling = nullptr; +} + +void _insert_into_siblings_before(Node* node, Node* other) { + // Insert a node before another in the siblings and groups linked lists. + assert(node->parent); // Should only be called on non-root nodes. + assert(node->parent == other->parent); // Should be siblings. + // Take care of the siblings linked list. + if (other->prev_sibling) { + other->prev_sibling->next_sibling = node; + } else { + node->parent->head_child = node; + } + node->next_sibling = other; + node->prev_sibling = other->prev_sibling; + other->prev_sibling = node; + // Take care of the groups linked list. + Node* prev_sibling = node->prev_sibling; + if (prev_sibling && node->count == prev_sibling->count) { + // If the previous sibling has the same count, just join its group. + node->group = prev_sibling->group; // std::shared_ptr assignment + } else if (node->count == other->count) { + // Previous sibling has different count, but next sibling has the same + // count. Join as the head of the next sibling's group. + node->group = other->group; // std::shared_ptr assignment + node->group->head = node; + } else { + // Previous and next siblings both have different counts. The node + // belongs in a group by itself. + Group* group = node->group.get(); + if (!group) { + // The node does not come with a group, create a new one. + group = new Group(node); + node->group.reset(group); // std::shared_ptr assignment + } + assert(group->head == node && !group->next && !group->prev); + // Insert the node's group into the linked list. + if (prev_sibling) { + group->prev = prev_sibling->group.get(); + group->prev->next = group; + } + group->next = other->group.get(); + group->next->prev = group; + } +} + +void _insert_into_siblings_after(Node* node, Node* other) { + // Insert a node after another in the siblings and groups linked lists. + assert(node->parent); // Should only be called on non-root nodes. + assert(node->parent == other->parent); // Should be siblings. + // Take care of the siblings linked list. + if (other->next_sibling) { + other->next_sibling->prev_sibling = node; + } else { + node->parent->tail_child = node; + } + node->prev_sibling = other; + node->next_sibling = other->next_sibling; + other->next_sibling = node; + // Take care of the groups linked list. + Node* next_sibling = node->next_sibling; + if (next_sibling && node->count == next_sibling->count) { + // If the next sibling has the same count, join its group and maybe + // update the head of the group. + node->group = next_sibling->group; // std::shared_ptr assignment + if (node->group->head == next_sibling) { + node->group->head = node; + } + } else if (node->count == other->count) { + // Next sibling has different count, but previous sibling has the same + // count. Join as the tail of the previous sibling's group. + node->group = other->group; // std::shared_ptr assignment + } else { + // Previous and next siblings both have different counts. The node + // belongs in a group by itself. + Group* group = node->group.get(); + if (!group) { + // The node does not come with a group, create a new one. + group = new Group(node); + node->group.reset(group); // std::shared_ptr assignment + } + assert(group->head == node && !group->next && !group->prev); + // Insert the node's group into the linked list. + if (next_sibling) { + group->next = next_sibling->group.get(); + group->next->prev = group; + } + group->prev = other->group.get(); + group->prev->next = group; + } +} + +void _replace_in_siblings(Node* old_node, Node* new_node) { + // Replace a node with another in the siblings and groups linked lists. + assert(old_node->count == new_node->count); // Should have the same count. + assert(old_node->parent); // Should only be called on non-root nodes. + // Take care of the siblings linked list. + if (old_node->next_sibling) { + old_node->next_sibling->prev_sibling = new_node; + } else { + old_node->parent->tail_child = new_node; + } + if (old_node->prev_sibling) { + old_node->prev_sibling->next_sibling = new_node; + } else { + old_node->parent->head_child = new_node; + } + new_node->prev_sibling = old_node->prev_sibling; + new_node->next_sibling = old_node->next_sibling; + old_node->prev_sibling = old_node->next_sibling = nullptr; + // Take care of the groups linked list. + Group* group = old_node->group.get(); + if (group->head == old_node) { + group->head = new_node; + } + new_node->group = old_node->group; // std::shared_ptr assignment + old_node->group.reset(); +} + +void _increment_count(Node* node) { + // Increment the count of a node by 1, and update its position in the + // sibling and group linked lists if necessary. + if (!node->parent) { + // Root node has no siblings, update its count and return. + node->count += 1; + return; + } + if (!node->prev_sibling || node->prev_sibling->count > node->count + 1) { + // The node does not need to move, and will not join the previous group + // after its count is incremented. + assert(node->group->head == node); + if (!node->next_sibling || node->next_sibling->count < node->count) { + // The node should be the only member of its group and will not + // join the previous group, so just update its count. + assert(node->group.use_count() == 1); + node->count += 1; + } else { + // The node will split off from its current group to a new group. + assert(node->next_sibling->count == node->count); + Group* orig_group = node->group.get(); + orig_group->head = node->next_sibling; + Group* new_group = new Group(node); + new_group->next = orig_group; + if (orig_group->prev) { + new_group->prev = orig_group->prev; + new_group->prev->next = new_group; + } + orig_group->prev = new_group; + node->group.reset(new_group); // std::shared_ptr assignment + node->count += 1; + } + } else { + // The node needs to be moved. + assert(node->prev_sibling->count >= node->count); + Node* other_node = node->prev_sibling->group->head; + _remove_from_siblings(node); + node->count += 1; + _insert_into_siblings_before(node, other_node); + } +} + +void _decrement_count(Node* node) { + // Decrement the count of a node by 1, and update its position in the + // sibling and group linked lists if necessary. + assert(node->count > 0); + if (!node->parent) { + // Root node has no siblings, update its count and return. + node->count -= 1; + return; + } + if (!node->next_sibling || node->next_sibling->count < node->count - 1) { + // The node does not need to move, and will not join the next group + // after its count is decremented. + if (!node->prev_sibling || node->prev_sibling->count > node->count) { + // The node should be the only member of its group and will not + // join the next group, so just update its count. + assert(node->group.use_count() == 1); + node->count -= 1; + } else { + // The node will split off from its current group to a new group. + assert(node->prev_sibling->count == node->count); + Group* orig_group = node->group.get(); + Group* new_group = new Group(node); + new_group->prev = orig_group; + if (orig_group->next) { + new_group->next = orig_group->next; + new_group->next->prev = new_group; + } + orig_group->next = new_group; + node->group.reset(new_group); // std::shared_ptr assignment + node->count -= 1; + } + } else if (node->next_sibling->count == node->count - 1) { + // The node does not need to move, and will join the same group as its + // next sibling. + assert(node->next_sibling->group->head == node->next_sibling); + node->next_sibling->group->head = node; + if (node->group->head == node) { + // The node is the head of its group, so the group will be removed. + assert(node->group.use_count() == 1); + Group* group = node->group.get(); + if (group->prev) { + group->prev->next = group->next; + } + group->next->prev = group->prev; + } + node->group = node->next_sibling->group; // std::shared_ptr assignment + node->count -= 1; + } else { + // The node needs to be moved to the next group. + assert(node->next_sibling->count == node->count); + Group* other_group = node->group->next; + _remove_from_siblings(node); + node->count -= 1; + if (!other_group) { + // No next group, insert at the end of the siblings list. + _insert_into_siblings_after(node, node->parent->tail_child); + } else { + // Insert as the head of the next group. + _insert_into_siblings_before(node, other_group->head); + } + } +} + +// Append a new element to a new or existing sequence. +void SuffixTree::append(int seq_id, int token) { + // Initialize the sequence if it doesn't exist. + if (!_seqs.contains(seq_id)) { + assert(!_active_nodes.contains(seq_id)); + _seqs.emplace(seq_id); + _active_nodes.emplace(seq_id); + } + + // Keep references to the seq and active nodes for efficiency. + std::vector& seq = _seqs[seq_id]; + std::deque& active_nodes = _active_nodes[seq_id]; + + // Insert a new active node at the root. + active_nodes.push_back(_root.get()); + _root->endpoints[seq_id] = static_cast(seq.size()); + _root->count += 1; + + // Ensure the number of active nodes doesn't exceed max_depth. + if (active_nodes.size() > static_cast(_max_depth)) { + active_nodes.pop_front(); + } + seq.push_back(token); + int32_t seq_len = static_cast(seq.size()); + + // Iterate over all active nodes for this sequence. + for (Node*& active_node : active_nodes) { + Node* node = active_node; + Node* child = nullptr; + auto it = node->children.find(token); + if (it != node->children.end()) { + child = it->second.get(); + } + + assert(node->endpoints.contains(seq_id)); + assert(node->endpoints[seq_id] == seq.size() - 1); + + if (child == nullptr) { + // Case 1: No existing child node for the new token. + if (node->count == 1 && node != _root.get()) { + // Case 1a: The active node has count = 1, which means the only + // suffix that ends here is the one that's being extended right + // now. Then this node should be a leaf node, and we can simply + // extend the length of this node. + assert(node->children.empty()); + assert(node->ref_seq == seq_id); + node->length += 1; + node->endpoints[seq_id] += 1; + } else { + // Case 1b: Either this is the root node, or the current suffix + // is not the only one that passes through this node. Need to + // extend the current suffix into a new child. + + // Create the new child node. + Node* new_child = new Node( + 1, // count + token, + 1, // length + seq_id, + seq_len - 1); + new_child->parent = node; + new_child->endpoints[seq_id] = seq_len; + + // Add new child to active node. + node->children.emplace(token, new_child); + + // Move the endpoint for the sequence from the active node to + // the new child node. + node->endpoints.erase(seq_id); + + // Link the new child node into the siblings list. + if (node->children.size() == 1) { + // This should be the first child being added. + assert(!node->head_child && !node->tail_child); + node->head_child = node->tail_child = new_child; + new_child->group.reset(new Group(new_child)); + } else { + assert(node->tail_child); + _insert_into_siblings_after(new_child, node->tail_child); + } + + // Update the active node to the new child node. + active_node = new_child; + } + } else if (node->count == child->count + 1 && node != _root.get()) { + // Case 2: The active node has a child for the new token, and that + // child's count is exactly one fewer than the active node's count. + // Since the suffix for the active node ends here, then all other + // suffixes that pass through this node must go to that child. + assert(node->children.size() == 1); // Should have only one child. + assert(node->endpoints.size() == 1); // The current seq ends here. + if (child->length == 1) { + // Case 2a: The child has length 1. If we append the new token + // to the current suffix, then it will perfectly overlap with + // that child. Fuse the current active node with that child. + + // Update child to take the place of the current node. + child->count += 1; // Active node extends into the child node. + child->token = node->token; + child->length = node->length + 1; + child->ref_seq = seq_id; + child->ref_idx = seq_len - child->length; + child->endpoints[seq_id] = seq_len; + child->parent = node->parent; + + // Replace the current node with the child in the sibling list. + // Must be done before changing any of the node's pointers. + _replace_in_siblings(node, child); + + // Remove the current node from the suffix tree. + Node* parent = node->parent; + assert(parent->children.contains(node->token)); + assert(parent->children[node->token].get() == node); + // Do it in two steps to avoid undefined evaluation order. + Node* tmp = node->children[token].release(); + parent->children[child->token].reset(tmp); + + // Replace active node with child node. + active_node = child; + } else { + // Case 2b: The child has length > 1. If we append the new + // token to the current suffix, then it still does not reach + // the child node. In this case, we keep both nodes but extend + // the length of the current node by 1 into the child node. + + // Extend the length of the current node by 1. + node->length += 1; + node->endpoints[seq_id] += 1; // Advance endpoint for the seq. + node->ref_seq = seq_id; // Need to update the ref sequence. + node->ref_idx = seq_len - node->length; + + // Child should shrink by 1 at the beginning. + child->length -= 1; + child->ref_idx += 1; + + // The child's first token must be updated to its second token. + child->token = _seqs[child->ref_seq][child->ref_idx]; + if (child->token != token) { + // Need to update the key in the parent's children map. + Node* tmp = node->children[token].release(); + node->children.emplace(child->token, tmp); + node->children.erase(token); + } + + // Active node stays the same. + } + } else { + // Case 3: There exists a child node for the new token, and the + // active node should move into that child. + if (child->length == 1) { + // Case 3a: The child node has length 1, just update the active + // node pointer to it. + + // Move the endpoint for the sequence to the child. + node->endpoints.erase(seq_id); + child->endpoints[seq_id] = seq_len; + + // Increment the child count and update siblings list. + _increment_count(child); + + // Replace active node with child node. + active_node = child; + } else { + // Case 3b: The child node has length > 1. If the suffix is + // extended into it, then it must split into a segment of + // length 1 and another segment with the remainder. + + // Create the new intermediate node. + Node* new_node = new Node( + child->count, + token, + 1, // length + seq_id, + seq_len - 1); + new_node->parent = node; + + // Replace the child with the new node in the siblings list. + // Must be done before changing any of the child's pointers. + _replace_in_siblings(child, new_node); + + // Replace child with new node in the children map. + node->children[token].release(); // Should be child. + node->children[token].reset(new_node); + + // Child should shrink by 1 at the beginning. + child->length -= 1; + child->ref_idx += 1; + + // Child's first token must be updated to its second token. + child->token = _seqs[child->ref_seq][child->ref_idx]; + + // Insert the child into the new node's children map. + new_node->children.emplace(child->token, child); + child->parent = new_node; + + // Move the endpoint for the sequence to the new node. + node->endpoints.erase(seq_id); + new_node->endpoints[seq_id] = seq_len; + + // Create a new group for the child node. + new_node->head_child = new_node->tail_child = child; + child->group.reset(new Group(child)); + + // Increment the new node count and update siblings lists. + _increment_count(new_node); + + // Update active node to the new intermediate node. + active_node = new_node; + } + } + } +} + +// Extend a new or existing sequence. +void SuffixTree::extend(int seq_id, std::span tokens) { + for (int token : tokens) { + append(seq_id, token); + } +} + +// Remove an existing sequence. +void SuffixTree::remove(int seq_id) { + const std::vector& seq = _seqs[seq_id]; + std::vector path; // Declare here to avoid repeated allocations. + // Loop through all suffix starting indices. + for (int start = 0; start < seq.size(); start++) { + Node *node = _root.get(); + node->count--; + int idx = start; + path.clear(); + // Loop through the nodes for this suffix. + while (idx < seq.size()) { + int token = seq[idx]; + if (!node->children.contains(token)) { + break; + } + Node* child = node->children[token].get(); + if (child->count > 1) { + _decrement_count(child); + } else { + assert(child->count == 1); + // Remove the child along with its entire subtree. + _remove_from_siblings(child); + node->children.erase(token); + break; + } + if (child->endpoints.contains(seq_id)) { + child->endpoints.erase(seq_id); + } + idx += child->length; + node = child; + path.push_back(node); + } + // The last visited node may be mergeable with its child. + if (node != _root.get() && node->children.size() == 1) { + const auto& it = *node->children.begin(); + std::unique_ptr& child_uptr = node->children[it.first]; + if (node->count == child_uptr->count) { + // Merge node into child and eliminate node. + child_uptr->token = node->token; + child_uptr->length += node->length; + child_uptr->ref_idx -= node->length; + child_uptr->parent = node->parent; + _replace_in_siblings(node, child_uptr.get()); + path.back() = node = child_uptr.release(); + node->parent->children[node->token].reset(node); + } + } + // ref_seq and ref_idx of all nodes in the path may need to be updated. + // 1. Go to an arbitrary leaf to get its endpoints. + Node* leaf = node; + int distance = 0; // Distance from node to leaf. + while (!leaf->children.empty()) { + leaf = (*leaf->children.begin()).second.get(); + distance += leaf->length; + } + // 2. Pick an arbitrary endpoint for the reference sequence and index. + if (leaf->endpoints.empty() || leaf->endpoints.contains(seq_id)) { + // Still need to visit this leaf later when removing this sequence. + // We can skip updating the refs until the next time it's visited. + continue; + } + const auto& ref = *leaf->endpoints.begin(); + // 3. Go back up the path to update all nodes' refs. + int32_t ref_seq = ref.first; + int32_t ref_idx = ref.second - distance; + while (!path.empty()) { + Node* n = path.back(); + path.pop_back(); + ref_idx -= n->length; + if (n->ref_seq == seq_id) { + n->ref_seq = ref_seq; + n->ref_idx = ref_idx; + } + } + } + _seqs.erase(seq_id); + _active_nodes.erase(seq_id); +} + +Draft SuffixTree::speculate(std::span context, + int max_spec_tokens, + float max_spec_factor, + float max_spec_offset, + float min_token_prob, + bool use_tree_spec) { + Draft best_draft; + for (int match_len = 1; match_len < context.size(); match_len++) { + auto[node, idx] = _match_context( + context.subspan(context.size() - match_len, match_len)); + if (node == nullptr) { + break; + } + int max_tokens = std::min(max_spec_tokens, + static_cast(match_len * max_spec_factor + + max_spec_offset + 1e-6)); + max_tokens = std::max(max_tokens, 0); + Draft draft; + if (use_tree_spec) { + draft = _speculate_tree(node, idx, max_tokens, min_token_prob); + } else { + draft = _speculate_path(node, idx, max_tokens, min_token_prob); + } + if (draft.score >= best_draft.score) { + best_draft = std::move(draft); + best_draft.match_len = match_len; + } + } + return best_draft; +} + +std::string SuffixTree::check_integrity() { + // 1. Check structural integrity of all nodes. + std::queue queue; + queue.push(_root.get()); + while (!queue.empty()) { + Node* node = queue.front(); + queue.pop(); + std::string ret = _check_node_integrity(node); + if (!ret.empty()) { + return ret; + } + for (const auto& [token, child] : node->children) { + queue.push(child.get()); + } + } + // 2. Check all sequences are represented in the tree. + std::unordered_map visit_count; + for (int seq_id = 0; seq_id < _seqs.size(); seq_id++) { + const std::vector& seq = _seqs[seq_id]; + // Loop through all suffix starting indices. + for (int start = 0; start < seq.size(); start++) { + int idx = start; + // Traverse the tree along this suffix. + Node* node = _root.get(); + visit_count[node]++; + while (idx < seq.size() && idx - start < _max_depth) { + // There should be a child for the next token. + CHECK_OR_RETURN(node->children.contains(seq[idx])); + node = node->children[seq[idx]].get(); + visit_count[node]++; + // Sequence should not end in the middle of a node. + CHECK_OR_RETURN(idx + node->length <= seq.size()); + for (int i = 0; i < node->length; ++i) { + int ref_seq = node->ref_seq; + int ref_idx = node->ref_idx + i; + // Reference tokens should match sequence tokens. + CHECK_OR_RETURN(seq[idx + i] == _seqs[ref_seq][ref_idx]); + } + idx += node->length; + } + // The last node on this path should have an endpoint. + CHECK_OR_RETURN(node->endpoints.contains(seq_id)); + } + } + // 3. Check all nodes were visited the correct number of times. + assert(queue.empty()); + queue.push(_root.get()); + while (!queue.empty()) { + Node* node = queue.front(); + queue.pop(); + // The visit count should match the node count. + CHECK_OR_RETURN(node->count == visit_count[node]); + for (const auto& [token, child] : node->children) { + queue.push(child.get()); + } + } + return ""; +} + +std::string SuffixTree::_check_node_integrity(Node* node) { + int64_t children_count = 0; + for (const auto& [token, child] : node->children) { + // All children should have the correct parent pointer. + CHECK_OR_RETURN(child->parent == node); + children_count++; + } + // Node count should be at least the sum of all children counts. + CHECK_OR_RETURN(children_count <= node->count); + if (node == _root.get()) { + // Root node should not contain any tokens, do some basic checks. + CHECK_OR_RETURN(node->count >= 0); + CHECK_OR_RETURN(node->parent == nullptr); + CHECK_OR_RETURN(node->length == 0); + CHECK_OR_RETURN(node->endpoints.empty()); + CHECK_OR_RETURN(node->ref_idx == -1); + } else { + // Node length should be positive. + CHECK_OR_RETURN(node->length > 0); + // Node count should be positive. + CHECK_OR_RETURN(node->count > 0); + // Each child count should be strictly less than the node count. + // Otherwise, the node and the child should have been merged into a + // single node. + for (const auto& [token, child] : node->children) { + CHECK_OR_RETURN(child->count < node->count); + } + // Internal nodes must have a valid reference sequence and index. + CHECK_OR_RETURN(_seqs.contains(node->ref_seq)); + CHECK_OR_RETURN(node->ref_idx >= 0); + CHECK_OR_RETURN( + node->ref_idx + node->length <= _seqs[node->ref_seq].size()); + // Check the first token of the node is correct. + CHECK_OR_RETURN(node->token == _seqs[node->ref_seq][node->ref_idx]); + // Check the node is in its parent's children map. + CHECK_OR_RETURN(node->parent->children.contains(node->token)); + CHECK_OR_RETURN(node->parent->children[node->token].get() == node); + // Check all endpoint references are correct. + for (auto [seq_id, end_idx] : node->endpoints) { + // Endpoint should refer to a sequence id that exists. + CHECK_OR_RETURN(_seqs.contains(seq_id)); + // Endpoint index should be within the sequence length. + CHECK_OR_RETURN(end_idx > 0 && end_idx <= _seqs[seq_id].size()); + // Check all tokens from the start of the suffix to the endpoint. + Node* n = node; + int idx = end_idx; + // Walk up the tree and check all tokens agree with the suffix + // ending at this endpoint. + do { + // Check the index in the sequence is not underflowed. + CHECK_OR_RETURN(n->length <= idx); + idx -= n->length; + for (int i = 0; i < n->length; ++i) { + int tok = _seqs[n->ref_seq][n->ref_idx + i]; + // Check each token in this node agrees with the sequence. + CHECK_OR_RETURN(_seqs[seq_id][idx + i] == tok); + } + n = n->parent; + } while (n != nullptr); + } + } + // Check siblings list integrity. + if (!node->head_child && !node->tail_child) { + CHECK_OR_RETURN(node->children.empty()); + } else { + // If there is a child then there must be both a head and a tail child. + CHECK_OR_RETURN(node->head_child && node->tail_child); + // Check head and tail child pointers are correct. + CHECK_OR_RETURN(node->head_child->prev_sibling == nullptr); + CHECK_OR_RETURN(node->tail_child->next_sibling == nullptr); + // Check all children are in the siblings linked list. + int count = 0; + Node* child = node->head_child; + Node* prev_child = nullptr; + while (child != nullptr) { + count++; + // Check the child is in the children map. + CHECK_OR_RETURN(node->children.contains(child->token)); + // Check the group pointer is valid. + CHECK_OR_RETURN(child->group != nullptr); + if (prev_child) { + // Check the siblings are ordered in nonincreasing count. + CHECK_OR_RETURN(child->count <= prev_child->count); + // Check the sibling pointers are correct. + CHECK_OR_RETURN(child->prev_sibling == prev_child); + CHECK_OR_RETURN(prev_child->next_sibling == child); + // Check the group pointers are correct. + if (child->count == prev_child->count) { + // If the next sibling has the same count, they should be + // in the same group. + CHECK_OR_RETURN(child->group == prev_child->group); + } else { + // Otherwise, they should be in different groups. + CHECK_OR_RETURN(child->group != prev_child->group); + // The child should be the head of its group. + CHECK_OR_RETURN(child->group->head == child); + // Check group pointers are correct. + CHECK_OR_RETURN( + child->group->prev == prev_child->group.get()); + CHECK_OR_RETURN( + prev_child->group->next == child->group.get()); + + } + } else { + CHECK_OR_RETURN(child == node->head_child); + } + prev_child = child; + child = child->next_sibling; + } + // Check the last child reached is the tail child. + CHECK_OR_RETURN(prev_child == node->tail_child); + // Check the number of children matches the size of the children map. + CHECK_OR_RETURN(count == node->children.size()); + } + return ""; +} + +std::pair SuffixTree::_match_context( + std::span context) { + Node* node = _root.get(); + int idx = 0; + const int32_t* ref_data = nullptr; + for (int32_t token : context) { + if (idx >= node->length) { + auto it = node->children.find(token); + if (it == node->children.end()) { + return {nullptr, -1}; + } + node = it->second.get(); + // Keep a pointer directly to the reference data for efficiency. + ref_data = _seqs[node->ref_seq].data() + node->ref_idx; + idx = 0; + } + assert(idx < node->length); + if (ref_data[idx] != token) { + return {nullptr, -1}; + } + idx++; + } + return {node, idx}; +} + +Draft SuffixTree::_speculate_path(Node* node, int idx, + int max_spec_tokens, + float min_token_prob) { + Draft ret; + float prob = 1.0f; + const int32_t* ref_data = _seqs[node->ref_seq].data() + node->ref_idx; + while (ret.token_ids.size() < max_spec_tokens && prob >= min_token_prob) { + if (idx < node->length) { + // Use previous token index as parent; if none, mark as -1. + ret.parents.push_back(static_cast(ret.token_ids.size()) - 1); + ret.token_ids.push_back(ref_data[idx]); + ret.probs.push_back(prob); + ret.score += prob; + idx++; + } else { + Node* child = node->head_child; + if (child == nullptr) { + break; + } + int64_t count = child->count; + prob *= static_cast(count) / node->count; + node = child; + // Keep a pointer directly to the reference data for efficiency. + ref_data = _seqs[node->ref_seq].data() + node->ref_idx; + idx = 0; + } + } + return ret; +} + +struct HeapItem { + float prob; + Node* node; + int idx; + int parent; // index in the draft token list; -1 if none. + + HeapItem(float p, Node* n, int i, int par) + : prob(p), node(n), idx(i), parent(par) {} +}; + +struct HeapItemCmp { + bool operator()(const HeapItem& a, const HeapItem& b) const { + // In C++ priority_queue by default returns the largest element. + // Thus, we compare probabilities so that the highest prob is returned. + return a.prob < b.prob; + } +}; + +// Get a draft token tree using a priority queue. +Draft SuffixTree::_speculate_tree(Node* node, int idx, + int max_spec_tokens, + float min_token_prob) { + Draft ret; + std::priority_queue, HeapItemCmp> queue; + queue.emplace(1.0, node, idx, -1); + while (ret.token_ids.size() < max_spec_tokens && !queue.empty()) { + HeapItem it = queue.top(); + queue.pop(); + if (it.idx < it.node->length) { + int32_t token = _seqs[it.node->ref_seq][it.node->ref_idx + it.idx]; + ret.token_ids.push_back(token); + ret.parents.push_back(it.parent); + ret.probs.push_back(it.prob); + ret.score += it.prob; + queue.emplace(it.prob, it.node, it.idx + 1, + static_cast(ret.token_ids.size()) - 1); + } else { + Node* child = it.node->head_child; + while (child) { + float prob = it.prob * child->count / + static_cast(it.node->count); + if (prob < min_token_prob) { + break; + } + queue.emplace(prob, child, 0, it.parent); + child = child->next_sibling; + } + } + } + return ret; +} + +size_t SuffixTree::estimate_memory() const { + size_t total = sizeof(*this); + std::vector stack; + stack.push_back(_root.get()); + while (!stack.empty()) { + Node* node = stack.back(); + stack.pop_back(); + total += node->memory_usage(); + if (node->head_child) { + Group* group = node->head_child->group.get(); + while (group) { + total += sizeof(*group); + group = group->next; + } + } + for (const auto& [token, child] : node->children) { + stack.push_back(child.get()); + } + } + for (const auto& [seq_id, seq] : _seqs) { + total += sizeof(seq) * seq.capacity(); + } + for (const auto& [seq_id, nodes] : _active_nodes) { + total += sizeof(nodes) * nodes.size(); + } + return total; +} diff --git a/csrc/suffix_cache/suffix_tree.h b/csrc/suffix_decoding/suffix_tree.h similarity index 52% rename from csrc/suffix_cache/suffix_tree.h rename to csrc/suffix_decoding/suffix_tree.h index 57c287a6b..6d642fa6b 100644 --- a/csrc/suffix_cache/suffix_tree.h +++ b/csrc/suffix_decoding/suffix_tree.h @@ -18,35 +18,64 @@ #include #include #include -#include +#include #include #include #include "int32_map.h" +struct Group; + struct Node { + // Number of suffixes from the root that end at or pass through this node. + int64_t count = 0; + // Token referenced by this node. Node can refer to a sequence of tokens, // this is just the ID of the first token. int token = 0; - // Number of suffixes from the root that end at or pass through this node. - int64_t count = 0; + // Number of tokens in this node. + int length = 0; + + // Reference sequence ID and starting index for the tokens in this node. + // Implements the path compression optimization to achieve O(N) memory. + int ref_seq = 0; + int ref_idx = -1; + + // This map tracks all the suffixes that end at this node. Maps seq_id to + // the end index of that suffix (may be truncated due to tree depth). Used + // to find a new ref_seq and ref_idx if the reference sequence is deleted. + Int32Map endpoints; - // Parent node. + // Pointer to parent node. Node* parent = nullptr; // Children nodes, the key should always be the first token of the child. Int32Map> children; - // Maps sequence ID -> index of the end of the suffix in that sequence. - Int32Map endpoints; + // All the children of each node are kept in order of decreasing count in + // a doubly linked list for efficient speculation. head_child points to the + // first child (highest count) and tail_child points to the last child. + Node* head_child = nullptr; + Node* tail_child = nullptr; - // Reference sequence ID and starting index for the tokens in this node. - int ref_seq = 0; - int ref_idx = -1; + // Pointers to the next and previous siblings in the doubly linked list. + Node* next_sibling = nullptr; + Node* prev_sibling = nullptr; - // Number of tokens in this node. - int length = 0; + // To enable efficient reordering of the siblings list when counts change, + // nodes with the same count are grouped together. Each node holds a shared + // pointer to its group, and the groups also form a doubly linked list. + std::shared_ptr group = nullptr; + + Node() = default; + + Node(int64_t count, int token, int length, int ref_seq, int ref_idx) + : count(count), + token(token), + length(length), + ref_seq(ref_seq), + ref_idx(ref_idx) {} // Memory usage of this node. size_t memory_usage() const { @@ -57,17 +86,30 @@ struct Node { } }; -struct Candidate { - // The token ids of the speculation candidate. - std::vector token_ids; +struct Group { + // Pointer to the head node of this group. All nodes before the head node + // have a strictly higher count, and all nodes after the head node have a + // lower or equal count. + Node* head = nullptr; + + // Pointers to the next and previous groups in the doubly linked list. + Group* next = nullptr; + Group* prev = nullptr; + + Group(Node* head) : head(head) {} +}; + +struct Draft { + // The token ids of the speculation draft. + std::vector token_ids; // For each token, the index of its parent token (-1 if no parent). - std::vector parents; + std::vector parents; // For each token, the estimated probability of the token. std::vector probs; - // Floating point score of the candidate (sum of all probs). + // Floating point score of the draft (sum of all probs). float score = 0.0; // Length of the prefix match for the speculated tokens. @@ -87,18 +129,18 @@ class SuffixTree { void append(int seq_id, int token); // Append multiple new elements to the sequence with id seq_id. - void extend(int seq_id, const std::vector& tokens); + void extend(int seq_id, std::span tokens); // Remove the sequence with id seq_id. void remove(int seq_id); - // Given a pattern, speculate the next tokens using the suffix tree. - Candidate speculate(const std::vector& pattern, - int max_spec_tokens, - float max_spec_factor = 1.0f, - float max_spec_offset = 0.0f, - float min_token_prob = 0.1f, - bool use_tree_spec = false); + // Given a context, speculate the next tokens using the suffix tree. + Draft speculate(std::span context, + int max_spec_tokens, + float max_spec_factor, + float max_spec_offset, + float min_token_prob, + bool use_tree_spec); // Check the integrity of the suffix tree, return empty string if ok, // otherwise return an error message. @@ -116,23 +158,22 @@ class SuffixTree { // The root node of the suffix tree. std::unique_ptr _root; - // Mapping from seq id to its sequence (vector of ints). - std::unordered_map> _seqs; + // Mapping from seq id to its sequence of tokens (vectors of int32_t). + Int32Map> _seqs; // For each sequence, a sliding window of active nodes. Maintains at most // _max_depth active nodes for each sequence. Queue is shifted when a new // token is added to the sequence. Each active node is in the queue for at // most _max_depth iterations before being removed. - std::unordered_map> _active_nodes; + Int32Map> _active_nodes; - std::pair _match_pattern(const std::vector& pattern, - int start_idx = 0); + std::pair _match_context(std::span context); - Candidate _speculate_path(Node* node, int idx, int max_spec_tokens, - float min_token_prob); + Draft _speculate_path(Node* node, int idx, int max_spec_tokens, + float min_token_prob); - Candidate _speculate_tree(Node* node, int idx, int max_spec_tokens, - float min_token_prob); + Draft _speculate_tree(Node* node, int idx, int max_spec_tokens, + float min_token_prob); std::string _check_node_integrity(Node* node); }; diff --git a/docs/arctic-speculator.rst b/docs/arctic-speculator.rst index 3066c22a0..fbf4c2031 100644 --- a/docs/arctic-speculator.rst +++ b/docs/arctic-speculator.rst @@ -145,6 +145,8 @@ Minimal configuration with Arctic Speculator (LSTM-based): .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + vllm serve meta-llama/Llama-3.1-8B-Instruct \ --speculative-config '{ "method": "arctic", @@ -156,6 +158,8 @@ Combined Arctic Speculator with Suffix Decoding: .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + vllm serve meta-llama/Llama-3.1-8B-Instruct \ --speculative-config '{ "method": "arctic", diff --git a/docs/arctic-ulysses.rst b/docs/arctic-ulysses.rst index 91aac8775..c767a8463 100644 --- a/docs/arctic-ulysses.rst +++ b/docs/arctic-ulysses.rst @@ -36,6 +36,8 @@ tensor and sequence parallelism across 8 GPUs (4 TP, 2 SP) with Arctic Inference .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + python -m vllm.entrypoints.openai.api_server \ --model meta-llama/Llama-3.3-70B-Instruct \ --tensor-parallel-size 4 \ diff --git a/docs/shift-parallel.rst b/docs/shift-parallel.rst index 95d53cf71..0d8025204 100644 --- a/docs/shift-parallel.rst +++ b/docs/shift-parallel.rst @@ -56,6 +56,8 @@ Here is an example of how run Shift Parallelism with the .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + python -m vllm.entrypoints.openai.api_server \ --model meta-llama/Llama-3.3-70B-Instruct \ --enable-shift-parallel \ diff --git a/docs/suffix-decoding.rst b/docs/suffix-decoding.rst index 63d01c75e..078fc4bc8 100644 --- a/docs/suffix-decoding.rst +++ b/docs/suffix-decoding.rst @@ -97,6 +97,8 @@ Minimal configuration for suffix-only decoding (for Llama-3.1-8B-Instruct): .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + vllm serve meta-llama/Llama-3.1-8B-Instruct \ --speculative-config '{ "method": "suffix" @@ -106,6 +108,8 @@ Configuration combining suffix decoding with Arctic Speculator: .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + vllm serve meta-llama/Llama-3.1-8B-Instruct \ --speculative-config '{ "method": "arctic", diff --git a/docs/swiftkv.rst b/docs/swiftkv.rst index f5b4e2ce2..1b54d0b4f 100644 --- a/docs/swiftkv.rst +++ b/docs/swiftkv.rst @@ -39,6 +39,8 @@ you would select the `Snowflake/Llama-3.3-SwiftKV-70B-Instruct .. code-block:: bash + export ARCTIC_INFERENCE_ENABLED=1 + python -m vllm.entrypoints.openai.api_server \ --model Snowflake/Llama-3.3-SwiftKV-70B-Instruct \ --tensor-parallel-size 8 diff --git a/projects/spec_dec/offline_inference_spec_dec.py b/projects/spec_dec/offline_inference_spec_dec.py index 53604410a..06ed3e397 100644 --- a/projects/spec_dec/offline_inference_spec_dec.py +++ b/projects/spec_dec/offline_inference_spec_dec.py @@ -17,7 +17,8 @@ from vllm import LLM, SamplingParams import os -os.environ["VLLM_USE_V1"] = "1" +os.environ["ARCTIC_INFERENCE_ENABLED"] = "1" +os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" vllm.plugins.load_general_plugins() @@ -29,9 +30,11 @@ "method": "arctic", "model": "Snowflake/Arctic-LSTM-Speculator-Llama-3.1-70B-Instruct", "num_speculative_tokens": 3, - "enable_suffix_decoding": True, + "enable_suffix_decoding": False, "disable_by_batch_size": 64, }, + enforce_eager=True, + async_scheduling=True, seed=0, ) diff --git a/projects/swiftkv/offline_inference_swiftkv.py b/projects/swiftkv/offline_inference_swiftkv.py index b3f86de34..e79853cd3 100644 --- a/projects/swiftkv/offline_inference_swiftkv.py +++ b/projects/swiftkv/offline_inference_swiftkv.py @@ -18,7 +18,7 @@ vllm.plugins.load_general_plugins() -llm = LLM(model="Snowflake/Llama-3.1-SwiftKV-8B-Instruct") +llm = LLM(model="Snowflake/Llama-3.1-SwiftKV-8B-Instruct", tensor_parallel_size=2) print("=" * 80) diff --git a/pyproject.toml b/pyproject.toml index 1cc98ac3f..cb9329395 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,42 +4,46 @@ requires = [ "wheel", "ninja", "cmake>=3.12", - "pybind11", + "nanobind==2.9.2", # uv pip install vllm==0.9.1 --dry-run 2>&1 | grep protobuf | sed 's/^ + //' "protobuf==5.29.5", - "pybind11_global", "grpcio-tools", - "torch == 2.7.0", + "torch==2.9.1", # need to be same as user env. ] build-backend = "setuptools.build_meta" [project] name = "arctic_inference" -version = "0.0.9" +version = "0.1.3" description = "Snowflake LLM inference library" [project.entry-points."vllm.general_plugins"] -arctic_inference = "arctic_inference.vllm.plugins:arctic_inference_plugin" +arctic_inference = "arctic_inference.vllm.plugin:arctic_inference_plugin" [tool.setuptools] packages = [ "arctic_inference", "arctic_inference.common", - "arctic_inference.common.suffix_cache", "arctic_inference.common.swiftkv", "arctic_inference.dynasor", + "arctic_inference.op_builder", + "arctic_inference.suffix_decoding", "arctic_inference.vllm", "arctic_inference.vllm.swiftkv", "arctic_inference.vllm.spec_dec", "arctic_inference.embedding", + "arctic_inference.csrc", ] +[tool.setuptools.package-dir] +"arctic_inference.csrc" = "csrc" + [tool.setuptools.package-data] "arctic_inference.embedding" = ["arctic_inference/embedding/proto/*.proto"] [project.optional-dependencies] vllm = [ - 'vllm==0.9.2', + 'vllm==0.14.1', ] docs = [ diff --git a/setup.py b/setup.py index 89bbfb234..9cdbc27fd 100644 --- a/setup.py +++ b/setup.py @@ -85,11 +85,6 @@ def build_extension(self, ext: CMakeExtension) -> None: item for item in os.environ["CMAKE_ARGS"].split(" ") if item ] - # In this example, we pass in the version to C++. You might not need to. - cmake_args += [ - f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}" - ] - if self.compiler.compiler_type != "msvc": # Using Ninja-build since it a) is available as a wheel and b) # multithreads automatically. MSVC would require all variables be @@ -171,12 +166,20 @@ def run(self): _build_py.run(self) -setup( - ext_modules=[ - CMakeExtension("arctic_inference.common.suffix_cache._C", - "csrc/suffix_cache"), +ext_modules=[] + +if os.environ.get("ARCTIC_INFERENCE_PRECOMPILED_OPS", "").lower() in ("1", "true", "on"): + ext_modules.append( CMakeExtension("arctic_inference.custom_ops", "csrc/custom_ops"), - ], + ) + +ext_modules.append( + CMakeExtension("arctic_inference.suffix_decoding._C", + "csrc/suffix_decoding"), +) + +setup( + ext_modules=ext_modules, cmdclass={ "build_ext": CMakeBuild, 'build_py': CompileGrpc diff --git a/tests/benchmarks/benchmark_utils.py b/tests/benchmarks/benchmark_utils.py index a78382802..141fc4723 100644 --- a/tests/benchmarks/benchmark_utils.py +++ b/tests/benchmarks/benchmark_utils.py @@ -3,6 +3,11 @@ import pandas as pd +from vllm.config.compilation import CUDAGraphMode, CompilationConfig +compilation_config = CompilationConfig( + cudagraph_mode=CUDAGraphMode.PIECEWISE +) + @dataclass class BenchmarkTask: @@ -32,6 +37,7 @@ class BenchmarkTask: "model": "Snowflake/Llama-3.1-SwiftKV-8B-Instruct-FP8", "tensor_parallel_size": 4, "enable_prefix_caching": False, + "compilation_config": compilation_config, }, "llama_8b_suffix": { "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic", @@ -185,4 +191,4 @@ def get_benchmark_summary(): return _SUMMARY.round(3).dropna(how='all').dropna(axis=1, how='all') -_SUMMARY = init_benchmark_summary() +_SUMMARY = init_benchmark_summary() \ No newline at end of file diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index dd7d2bc34..32ac98fa1 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -1,24 +1,7 @@ -import argparse import json -import multiprocessing -import os import pathlib -import subprocess -import sys -import time -from typing import Any, Dict, List -import pytest -import requests -import torch -import vllm - -from .benchmark_utils import (ACCURACY_TASKS, JSON_MODE_TASKS, - PERFORMANCE_TASKS, VLLM_CONFIGS, - get_benchmark_summary) - -MAX_GPUS = torch.cuda.device_count() -BASE_PORT = 8080 +from .benchmark_utils import get_benchmark_summary def pytest_addoption(parser): @@ -26,11 +9,20 @@ def pytest_addoption(parser): def pytest_terminal_summary(terminalreporter, exitstatus, config): + """ + Add benchmark summary to pytest's terminal summary, and save it to a file + if a benchmark result directory is specified. + """ summary = get_benchmark_summary() + if summary.empty: return + + # Print the summary to the terminal terminalreporter.write_sep("=", "Final Benchmark Summary") terminalreporter.write_line(summary.to_string()) + + # Save the summary to a file if a benchmark result directory is specified benchmark_result_dir = config.option.benchmark_result_dir if benchmark_result_dir is not None: benchmark_result_dir.mkdir(parents=True, exist_ok=True) @@ -41,212 +33,4 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): for config_name, config_value in value.items(): summary_dict[task][metric][config_name] = config_value with open(benchmark_result_dir / "summary.json", "w") as f: - json.dump(summary_dict, f, indent=4) - - -def _schedule_configs() -> List[List[str]]: - sorted_configs = sorted( - VLLM_CONFIGS.items(), - key=lambda item: item[1].get("tensor_parallel_size", 1) * item[1].get( - "ulysses_sequence_parallel_size", 1), - reverse=True) - batches: List[List[str]] = [] - current_batch: List[str] = [] - gpus_used_in_batch = 0 - for name, config in sorted_configs: - gpus_needed = config.get("tensor_parallel_size", 1) * config.get( - "ulysses_sequence_parallel_size", 1) - if gpus_used_in_batch + gpus_needed <= MAX_GPUS: - current_batch.append(name) - gpus_used_in_batch += gpus_needed - else: - if current_batch: - batches.append(current_batch) - current_batch = [name] - gpus_used_in_batch = gpus_needed - if current_batch: - batches.append(current_batch) - return batches - - -class BatchServerManager: - - def __init__(self): - self.current_batch_idx = -1 - self.processes: Dict[str, subprocess.Popen] = {} - self.port_map: Dict[str, int] = {} - - def start_batch(self, batch_idx: int, batch_configs: List[str]): - if self.current_batch_idx == batch_idx: - return - self.teardown_current_batch() - self.current_batch_idx = batch_idx - print(f"\nStarting Batch {batch_idx}: {batch_configs} ---") - gpu_pool = list(range(MAX_GPUS)) - gpus_assigned = 0 - for i, config_name in enumerate(batch_configs): - port = BASE_PORT + i - gpus_needed = VLLM_CONFIGS[config_name].get( - "tensor_parallel_size", 1) * VLLM_CONFIGS[config_name].get( - "ulysses_sequence_parallel_size", 1) - gpu_ids = gpu_pool[gpus_assigned:gpus_assigned + gpus_needed] - gpus_assigned += gpus_needed - self.port_map[config_name] = port - command = self._build_server_command(config_name, port) - env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids)) - print( - f" -> Launching '{config_name}' on port {port} with GPUs {gpu_ids}..." - ) - p = subprocess.Popen(command, env=env) - self.processes[config_name] = p - for config_name, process in self.processes.items(): - self._wait_for_server_ready(process, self.port_map[config_name]) - - def teardown_current_batch(self): - if not self.processes: - return - print( - f"\n---Terminating servers for Batch {self.current_batch_idx} ---") - for name, p in self.processes.items(): - if p.poll() is None: - print( - f" -> Terminating '{name}' on port {self.port_map.get(name)}" - ) - p.terminate() - try: - p.wait(timeout=10) - except subprocess.TimeoutExpired: - p.kill() - self.processes.clear() - self.port_map.clear() - - @staticmethod - def _build_server_command(config_name: str, port: int) -> List[str]: - command = [sys.executable, "-m", "vllm.entrypoints.openai.api_server"] - config = VLLM_CONFIGS[config_name] - for key, value in config.items(): - arg_name = f"--{key.replace('_', '-')}" - if isinstance(value, bool): - if value: - command.append(arg_name) - elif isinstance(value, dict): - command.extend([arg_name, json.dumps(value)]) - else: - command.extend([arg_name, str(value)]) - command.extend(["--port", str(port), "--disable-log-requests"]) - return command - - @staticmethod - def _wait_for_server_ready(process: subprocess.Popen, port: int): - url = f"http://localhost:{port}/health" - start_time = time.time() - while True: - if process.poll() is not None: - raise RuntimeError( - f"Server on port {port} terminated unexpectedly. " - f"Return code: {process.returncode}") - try: - if requests.get(url, timeout=5).status_code == 200: - print(f"Server on port {port} is ready.") - return - except requests.exceptions.RequestException: - pass - if time.time() - start_time > 3600: - raise TimeoutError(f"Server on port {port} failed to start.") - time.sleep(5) - - -batch_manager = BatchServerManager() - - -def pytest_sessionstart(session): - session.config.vllm_batches = _schedule_configs() - - -def pytest_generate_tests(metafunc): - """ - Generates tests for serial and parallel execution models. - - `batch_spec` is for parallel tests (one test per batch). - - `benchmark_spec` is for serial tests (one test per config). - """ - if "batch_spec" in metafunc.fixturenames: - batches = metafunc.config.vllm_batches - all_specs, all_ids = [], [] - if metafunc.function.__name__ == "test_batch_accuracy": - for batch_idx, configs in enumerate(batches): - for task_name, task_obj in ACCURACY_TASKS.items(): - all_specs.append({ - "batch_idx": batch_idx, - "configs": configs, - "task_name": task_name, - "task_obj": task_obj, - }) - all_ids.append(f"b{batch_idx}-{task_name}") - metafunc.parametrize("batch_spec", - all_specs, - ids=all_ids, - indirect=True) - - if "benchmark_spec" in metafunc.fixturenames: - batches = metafunc.config.vllm_batches - all_specs, all_ids = [], [] - task_map = { - "test_performance": PERFORMANCE_TASKS, - "test_json_mode": JSON_MODE_TASKS - } - test_func_name = metafunc.function.__name__ - if test_func_name in task_map: - for batch_idx, configs in enumerate(batches): - for config_name in configs: - for task_name, task_obj in task_map[test_func_name].items( - ): - all_specs.append({ - "batch_idx": batch_idx, - "config_name": config_name, - "task_name": task_name, - "task_obj": task_obj, - }) - all_ids.append( - f"b{batch_idx}-{config_name}-{task_name}") - metafunc.parametrize("benchmark_spec", - all_specs, - ids=all_ids, - indirect=True) - - -def pytest_collection_modifyitems(session, config, items): - - def get_batch_idx(item): - params = getattr(item, "callspec", {}).params - if "batch_spec" in params: - return params["batch_spec"]["batch_idx"] - if "benchmark_spec" in params: - return params["benchmark_spec"]["batch_idx"] - return -1 - - items.sort(key=get_batch_idx) - - -@pytest.fixture(scope="module") -def benchmark_spec(request): - spec = request.param - batch_idx = spec["batch_idx"] - batches = request.config.vllm_batches - batch_manager.start_batch(batch_idx, batches[batch_idx]) - spec["port"] = batch_manager.port_map[spec["config_name"]] - yield spec - - -@pytest.fixture(scope="module") -def batch_spec(request): - spec = request.param - batch_idx = spec["batch_idx"] - batches = request.config.vllm_batches - batch_manager.start_batch(batch_idx, batches[batch_idx]) - spec["port_map"] = batch_manager.port_map.copy() - yield spec - - -def pytest_sessionfinish(session, exitstatus): - batch_manager.teardown_current_batch() + json.dump(summary_dict, f, indent=4) \ No newline at end of file diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_benchmarks.py index c4a62dd6d..891e83d49 100644 --- a/tests/benchmarks/test_benchmarks.py +++ b/tests/benchmarks/test_benchmarks.py @@ -1,151 +1,222 @@ import argparse import json import multiprocessing -import pathlib import tempfile -import traceback -from typing import Any, Dict +import time import pytest +import requests +import uvloop +from vllm.entrypoints.openai.api_server import ( + make_arg_parser, run_server, validate_parsed_serve_args) +from vllm.utils.argparse_utils import FlexibleArgumentParser -from .benchmark_utils import VLLM_CONFIGS, update_benchmark_summary +from .benchmark_utils import (ACCURACY_TASKS, PERFORMANCE_TASKS, VLLM_CONFIGS, + JSON_MODE_TASKS, update_benchmark_summary) +CUSTOM_PORT = 8080 -def test_performance(benchmark_spec, request): - """Tests vLLM performance (throughput/latency) in serial.""" - config_name = benchmark_spec["config_name"] - task_name = benchmark_spec["task_name"] - task = benchmark_spec["task_obj"] - port = benchmark_spec["port"] - vllm_config = VLLM_CONFIGS[config_name] - - from vllm.benchmarks.serve import add_cli_args, main as benchmark_serve_main - parser = argparse.ArgumentParser() - add_cli_args(parser) - args = parser.parse_args(["--model", vllm_config["model"], "--port", str(port)]) +@pytest.fixture(scope="module", params=list(VLLM_CONFIGS.keys())) +def vllm_server(request): + """ + Fixture to start the OpenAI API server for testing. + """ + parser = FlexibleArgumentParser() + parser = make_arg_parser(parser) - result_path = (request.config.option.benchmark_result_dir / config_name / - f"performance-{task_name}.json") - result_path.parent.mkdir(parents=True, exist_ok=True) - - for key, value in task.config.items(): + args = parser.parse_args([]) + args.disable_log_requests = True + args.disable_uvicorn_access_log = True + + setattr(args, 'port', CUSTOM_PORT) + + for key, value in VLLM_CONFIGS[request.param].items(): setattr(args, key, value) - args.save_result = True - args.result_dir = str(result_path.parent) - args.result_filename = str(result_path.name) - - benchmark_serve_main(args) - with open(result_path, "r") as f: - result = json.load(f) - metrics = {name: key(result) if callable(key) else result[key] - for name, key in task.metrics.items()} - update_benchmark_summary(config_name, task_name, metrics) + validate_parsed_serve_args(args) + def _run_process(): + uvloop.run(run_server(args)) -def test_json_mode(benchmark_spec, request): - config_name = benchmark_spec["config_name"] - task_name = benchmark_spec["task_name"] - task = benchmark_spec["task_obj"] - port = benchmark_spec["port"] - vllm_config = VLLM_CONFIGS[config_name] + # Start server process + process = multiprocessing.Process(target=_run_process) + process.start() - if vllm_config.get("speculative_config", {}).get("enable_suffix_decoding"): - pytest.skip("Skipping JSON mode test for spec + suffix decoding.") + print("Waiting for server to start...") + timeout = 3600 + interval = 5 + start = time.time() - from .json_mode.evaluate_text_json_mode import main as evaluate_json + health_check_url = f"http://localhost:{CUSTOM_PORT}/v1/models" + + while True: + try: + r = requests.get(health_check_url) + if r.status_code == 200: + break + except requests.exceptions.ConnectionError: + pass + if not process.is_alive(): + raise RuntimeError("Server process terminated unexpectedly") + if time.time() - start > timeout: + raise TimeoutError(f"Server didn't start after {timeout} seconds") + time.sleep(interval) + print("Server process started") + + yield request.param, args + + # Stop server process + print("Terminating server process") + if process.is_alive(): + process.terminate() + process.join() + print("Server process terminated") + + +@pytest.mark.parametrize("task_name", list(PERFORMANCE_TASKS.keys())) +def test_performance(request, vllm_server, task_name): + from vllm.benchmarks.serve import add_cli_args, main + + config_name, vllm_args = vllm_server + task = PERFORMANCE_TASKS[task_name] - result_path = (request.config.option.benchmark_result_dir / config_name / - f"json_mode-{task_name}.json") - result_path.parent.mkdir(parents=True, exist_ok=True) - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default=vllm_config["model"]) - parser.add_argument("--output", type=str, default=str(result_path)) - parser.add_argument("--port", type=int, default=port) - for key, value in task.config.items(): - parser.add_argument(f"--{key}", type=type(value), default=value) + add_cli_args(parser) - evaluate_json(parser.parse_args([])) - - with open(result_path, "r") as f: - result = json.load(f) - result_data = result.get("results", {}) - metrics = {name: key(result_data) if callable(key) else result_data.get(key, {}).get("score") + args = parser.parse_args(["--model", vllm_args.model]) + + setattr(args, 'port', CUSTOM_PORT) + + with tempfile.TemporaryDirectory() as tmpdir: + args.save_result = True + args.result_dir = str(tmpdir) + args.result_filename = "result.json" + + for key, value in task.config.items(): + setattr(args, key, value) + + main(args) + + with open(f"{tmpdir}/result.json", "r") as f: + result = json.load(f) + + benchmark_result_dir = request.config.option.benchmark_result_dir + if benchmark_result_dir is not None: + result_path = (benchmark_result_dir / "performance" / + f"{config_name}-{task_name}.json") + result_path.parent.mkdir(parents=True, exist_ok=True) + with open(result_path, "w") as f: + json.dump(result, f, indent=4) + + metrics = {name: key(result) if callable(key) else result[key] for name, key in task.metrics.items()} update_benchmark_summary(config_name, task_name, metrics) -def _run_lm_eval_harness(queue, lm_eval_config, model_name, port): - try: - from lm_eval import evaluator - result = evaluator.simple_evaluate( - model="local-completions", - model_args={"model": model_name, "base_url": f"http://localhost:{port}/v1/completions"}, - **lm_eval_config) - queue.put(result) - except Exception as exc: - queue.put(exc) - -def _run_accuracy_worker(config_name, port, task_name, task_config, - benchmark_result_dir, results_queue): - try: - from lm_eval.utils import handle_non_serializable, make_table - vllm_config = VLLM_CONFIGS[config_name] - queue = multiprocessing.Queue() - eval_process = multiprocessing.Process( - target=_run_lm_eval_harness, - args=(queue, task_config, vllm_config["model"], port)) - eval_process.start() - result_or_exc = queue.get() - eval_process.join() - if isinstance(result_or_exc, Exception): raise result_or_exc - - result = result_or_exc - print(f"Accuracy results for '{config_name}':\n{make_table(result)}") - - result_path = benchmark_result_dir / config_name / f"accuracy-{task_name}.json" +@pytest.mark.parametrize("task_name", list(ACCURACY_TASKS.keys())) +def test_accuracy(request, vllm_server, task_name): + + config_name, vllm_args = vllm_server + task = ACCURACY_TASKS[task_name] + + assert len(task.config["tasks"]) == 1, \ + "Accuracy benchmarks should only have one task configured" + + q = multiprocessing.Queue() + + def _run_process(): + # Run lm_eval in a separate process because it imports torch and + # initializes CUDA, which breaks process forking in later tests. + try: + from lm_eval import evaluator + from lm_eval.utils import handle_non_serializable, make_table + + base_url = f"http://localhost:{CUSTOM_PORT}/v1/completions" + + result = evaluator.simple_evaluate( + model="local-completions", + model_args={ + "model": vllm_args.model, + "base_url": base_url, + "num_concurrent": 256, + "timeout": 3600, + }, + **task.config, + ) + print(make_table(result)) + + tmpfile = f"{tmpdir}/result.json" + with open(tmpfile, "w") as f: + json.dump(result, f, indent=4, default=handle_non_serializable) + except Exception as exc: + # If an exception occurs, put it in the queue to be raised later + q.put(exc) + else: + # Send back the temporary file path instead of the result object + # since multiprocessing queue can hang on large objects. + q.put(tmpfile) + + with tempfile.TemporaryDirectory() as tmpdir: + process = multiprocessing.Process(target=_run_process) + process.start() + r = q.get() + process.join() + if isinstance(r, Exception): + raise r + tmpfile = r + with open(tmpfile, "r") as f: + result = json.load(f) + + benchmark_result_dir = request.config.option.benchmark_result_dir + if benchmark_result_dir is not None: + result_path = (benchmark_result_dir / "accuracy" / + f"{config_name}-{task_name}.json") result_path.parent.mkdir(parents=True, exist_ok=True) with open(result_path, "w") as f: - json.dump(result, f, indent=4, default=handle_non_serializable) - results_queue.put({"config_name": config_name, "result_path": result_path}) - except Exception as e: - results_queue.put({"config_name": config_name, "error": str(e), "traceback": traceback.format_exc()}) - -def test_batch_accuracy(batch_spec, request): - """Tests model accuracy for a whole batch in parallel.""" - try: - multiprocessing.set_start_method("spawn", force=True) - except RuntimeError: - pass - - task_name = batch_spec["task_name"] - task_obj = batch_spec["task_obj"] - configs_in_batch = batch_spec["configs"] - port_map = batch_spec["port_map"] - benchmark_result_dir = request.config.option.benchmark_result_dir or pathlib.Path(tempfile.mkdtemp()) - - processes, results_queue = [], multiprocessing.Queue() - for config_name in configs_in_batch: - p = multiprocessing.Process( - target=_run_accuracy_worker, - args=(config_name, port_map[config_name], task_name, - task_obj.config, benchmark_result_dir, results_queue)) - processes.append(p) - p.start() - for p in processes: - p.join() - - while not results_queue.empty(): - result = results_queue.get() - config_name = result["config_name"] - if "error" in result: - pytest.fail(f"Worker for '{config_name}' failed:\n{result['error']}\n{result['traceback']}") - - with open(result["result_path"], "r") as f: - raw_result = json.load(f) - lm_eval_task_name = task_obj.config["tasks"][0] - result_data = raw_result["results"][lm_eval_task_name] - metrics = {name: key(result_data) if callable(key) else result_data[key] - for name, key in task_obj.metrics.items()} - update_benchmark_summary(config_name, task_name, metrics) \ No newline at end of file + json.dump(result, f, indent=4) + + result = result["results"][task.config["tasks"][0]] + metrics = {name: key(result) if callable(key) else result[key] + for name, key in task.metrics.items()} + update_benchmark_summary(config_name, task_name, metrics) + + +@pytest.mark.parametrize("task_name", list(JSON_MODE_TASKS.keys())) +def test_json_mode(request, vllm_server, task_name): + """ + Test JSON mode using the evaluate_text_json_mode script. + """ + from .json_mode.evaluate_text_json_mode import main as evaluate_json + + config_name, vllm_args = vllm_server + task = JSON_MODE_TASKS[task_name] + + if (vllm_args.speculative_config and + vllm_args.speculative_config.get('enable_suffix_decoding', False)): + pytest.skip("Skipping JSON mode test for spec + suffix decoding enabled") + + with tempfile.TemporaryDirectory() as tmpdir: + result_path = f"{tmpdir}/result.json" + + args = FlexibleArgumentParser() + args.model = vllm_args.model + args.output = result_path + args.task = task.config["task"] + args.input = task.config["input"] + args.n_samples = task.config["n_samples"] + + args.port = CUSTOM_PORT + + evaluate_json(args) + + with open(result_path, "r") as f: + result = json.load(f) + + result_data = result.get("results", {}) + + metrics = { + name: key(result_data) if callable(key) else result_data.get(key, {}).get('score') + for name, key in task.metrics.items() + } + + update_benchmark_summary(config_name, task_name, metrics) \ No newline at end of file diff --git a/tests/unit_tests/test_arctic_spec_max_len.py b/tests/unit_tests/test_arctic_spec_max_len.py index 43491f86b..591f18aba 100644 --- a/tests/unit_tests/test_arctic_spec_max_len.py +++ b/tests/unit_tests/test_arctic_spec_max_len.py @@ -49,13 +49,13 @@ def sampling_configs(): @pytest.fixture def model_name(): - return "Snowflake/Llama-3.1-SwiftKV-8B-Instruct" + return "meta-llama/Llama-3.3-70B-Instruct" # Define the speculative configurations that will be tested ARCTIC_SPEC_CONFIG = { "method": "arctic", - "model": "Snowflake/Arctic-LSTM-Speculator-Llama-3.1-8B-Instruct", + "model": "Snowflake/Arctic-LSTM-Speculator-Llama-3.3-70B-Instruct", "num_speculative_tokens": 3, "disable_by_batch_size": 64, "enable_suffix_decoding": True, @@ -85,6 +85,7 @@ def test_speculative_decoding( This test is parameterized to cover 'arctic' and 'suffix' methods. ''' with monkeypatch.context() as m: + m.setenv("ARCTIC_INFERENCE_ENABLED", "1") m.setenv("VLLM_PLUGINS", "arctic_inference") m.setenv("VLLM_USE_V1", "1") @@ -92,11 +93,12 @@ def test_speculative_decoding( spec_llm = LLM( model=model_name, - tensor_parallel_size=1, + tensor_parallel_size=2, quantization="fp8", speculative_config=spec_config, max_model_len=MAX_MODEL_LEN, enforce_eager=True, + trust_remote_code=True, ) for sampling_config in sampling_configs: diff --git a/tests/unit_tests/test_reshape_and_cache_flash_nvfp4.py b/tests/unit_tests/test_reshape_and_cache_flash_nvfp4.py new file mode 100644 index 000000000..4602eb4eb --- /dev/null +++ b/tests/unit_tests/test_reshape_and_cache_flash_nvfp4.py @@ -0,0 +1,256 @@ +import math +import torch + + +def unpack_fp4_e2m1_to_float32(packed_u8: torch.Tensor) -> torch.Tensor: + assert packed_u8.dtype == torch.uint8 + low = packed_u8 & 0x0F + high = packed_u8 >> 4 + nibbles = torch.stack((low, high), dim=-1) # [..., D_bytes, 2] + nibbles = nibbles.reshape(*packed_u8.shape[:-1], + packed_u8.shape[-1] * 2) # [..., D] + + mag_code = (nibbles & 0x07).long() + sign_bit = (nibbles & 0x08) != 0 + + mag_lut = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], + dtype=torch.float32, + device=packed_u8.device) + mag = mag_lut[mag_code] + sgn = torch.where(sign_bit, -1.0, 1.0).to(torch.float32) + return sgn * mag + + +def decode_fp8_e4m3(u8: torch.Tensor) -> torch.Tensor: + assert u8.dtype == torch.uint8 + + s = (u8 >> 7).to(torch.int32) + e = ((u8 >> 3) & 0x0F).to(torch.int32) + m = (u8 & 0x07).to(torch.int32) + + sign = torch.where(s != 0, -1.0, 1.0).to(torch.float32) + + is_sub = (e == 0) + exp_norm = (e - 7).to(torch.float32) + exp_sub = torch.full_like(exp_norm, -6.0) + + mant_norm = 1.0 + (m.to(torch.float32) / 8.0) + mant_sub = (m.to(torch.float32) / 8.0) + + base = torch.where(is_sub, exp_sub, exp_norm) + mant = torch.where(is_sub, mant_sub, mant_norm) + + val = sign * torch.pow(torch.tensor(2.0, device=u8.device), base) * mant + return torch.clamp(val, -448.0, 448.0) # finite-range variant + + +def expand_scales_from_bytes(scales_u8_tok: torch.Tensor, H: int, + D: int) -> torch.Tensor: + sf = decode_fp8_e4m3(scales_u8_tok) # [T, H, D/16] as f32 + + return sf.unsqueeze(-1).expand(-1, -1, -1, 16).reshape(-1, H, D) + + +def gather_cache_rows(cache: torch.Tensor, slot_mapping: torch.Tensor, + block_size: int) -> torch.Tensor: + block_idx = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_off = slot_mapping % block_size + + return cache[block_idx, block_off] + + +def dequantize_fp4_cache_to_fp16_using_cache_scales( + packed_bytes: torch.Tensor, # uint8 [B, page, H, D/2] or [B, H, page, D/2] + scale_bytes: torch. + Tensor, # uint8 [B, page, H, D/16] or [B, H, page, D/16] + slot_mapping: torch.Tensor, # int64 [T] + block_size: int, +) -> torch.Tensor: + packed_tok = gather_cache_rows(packed_bytes, slot_mapping, + block_size) # [T, H, D/2] + scales_u8_tok = gather_cache_rows(scale_bytes, slot_mapping, + block_size) # [T, H, D/16] + + T, H, D_half = packed_tok.shape + D = D_half * 2 + + e2m1_vals = unpack_fp4_e2m1_to_float32(packed_tok) + + scales = expand_scales_from_bytes(scales_u8_tok, H, D) + + return (e2m1_vals * scales).to(torch.float16) + + +def parity_check_fp4_vs_fp16( + num_tokens: int = 64, + num_heads: int = 16, + head_size: int = 128, + block_size: int = 16, + device: str = "cuda", + dtype=torch.float16, + seed: int = 0, +): + torch.manual_seed(seed) + T, H, D = num_tokens, num_heads, head_size + assert D % 16 == 0, "head_size must be divisible by 16 for FP4 groups" + + key = torch.randn(T, H, D, device=device, dtype=dtype) * 0.5 + value = torch.randn(T, H, D, device=device, dtype=dtype) * 0.5 + slot_mapping = torch.arange(T, device=device, dtype=torch.long) + num_blocks = math.ceil(T / block_size) + + # [B, page, H, last_dim] + key_cache_fp16 = torch.zeros(num_blocks, + block_size, + H, + D, + device=device, + dtype=torch.float16) + value_cache_fp16 = torch.zeros(num_blocks, + block_size, + H, + D, + device=device, + dtype=torch.float16) + key_cache_fp4 = torch.zeros(num_blocks, + block_size, + H, + D // 2, + device=device, + dtype=torch.uint8) + value_cache_fp4 = torch.zeros(num_blocks, + block_size, + H, + D // 2, + device=device, + dtype=torch.uint8) + key_scale_cache = torch.zeros(num_blocks, + block_size, + H, + D // 16, + device=device, + dtype=torch.uint8) + value_scale_cache = torch.zeros(num_blocks, + block_size, + H, + D // 16, + device=device, + dtype=torch.uint8) + + from vllm import _custom_ops as _ + k_scale = torch.tensor(0.0, device=device, dtype=torch.float32) + v_scale = torch.tensor(0.0, device=device, dtype=torch.float32) + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache_fp16, + value_cache_fp16, + slot_mapping, "auto", + k_scale, v_scale) + + from arctic_inference.py_custom_ops import (try_load_torch_library, + reshape_and_cache_flash_fp4) + if not try_load_torch_library(): + raise RuntimeError( + "Custom FP4 ops not available (try_load_torch_library() returned False)" + ) + + fp4_called = False + try: + reshape_and_cache_flash_fp4(key, value, key_cache_fp4, value_cache_fp4, + key_scale_cache, value_scale_cache, + slot_mapping, "fp4", k_scale, v_scale) + fp4_called = True + except TypeError: + print("Could not call reshape_and_cache_flash_fp4 with full signature") + + if not fp4_called: + raise RuntimeError( + "Could not invoke reshape_and_cache_flash_fp4 with either signature." + ) + + key_from_fp4 = dequantize_fp4_cache_to_fp16_using_cache_scales( + key_cache_fp4, key_scale_cache, slot_mapping, block_size) + value_from_fp4 = dequantize_fp4_cache_to_fp16_using_cache_scales( + value_cache_fp4, value_scale_cache, slot_mapping, block_size) + + key_fp16_tok = gather_cache_rows(key_cache_fp16, slot_mapping, block_size) + value_fp16_tok = gather_cache_rows(value_cache_fp16, slot_mapping, + block_size) + + def stats(a: torch.Tensor, b: torch.Tensor): + diff = (a.to(torch.float32) - b.to(torch.float32)) + mae = diff.abs().mean().item() + mse = (diff * diff).mean().item() + rmse = math.sqrt(mse) + mx = diff.abs().max().item() + return mae, mse, rmse, mx + + k_mae, k_mse, k_rmse, k_max = stats(key_from_fp4, key_fp16_tok) + v_mae, v_mse, v_rmse, v_max = stats(value_from_fp4, value_fp16_tok) + + print( + f"[K] MAE={k_mae:.6f} RMSE={k_rmse:.6f} MaxAbs={k_max:.6f} MSE={k_mse:.6f}" + ) + print( + f"[V] MAE={v_mae:.6f} RMSE={v_rmse:.6f} MaxAbs={v_max:.6f} MSE={v_mse:.6f}" + ) + + return { + "key_from_fp4": key_from_fp4, + "key_fp16": key_fp16_tok, + "value_from_fp4": value_from_fp4, + "value_fp16": value_fp16_tok, + "metrics": { + "K": { + "MAE": k_mae, + "RMSE": k_rmse, + "MaxAbs": k_max, + "MSE": k_mse + }, + "V": { + "MAE": v_mae, + "RMSE": v_rmse, + "MaxAbs": v_max, + "MSE": v_mse + }, + }, + } + + +def show_example_row( + results: dict, + tensor_name: str = "key", + num_vals: int = 64, +): + original_tensor = results[f"{tensor_name}_fp16"] + dequant_tensor = results[f"{tensor_name}_from_fp4"] + + original_slice = original_tensor[:num_vals].to(torch.float32) + dequant_slice = dequant_tensor[:num_vals].to(torch.float32) + diff = (original_slice - dequant_slice).abs() + + print("\n" + "-" * 50) + print( + f"Example Comparison: '{tensor_name.upper()}' (First {num_vals} values)" + ) + print("-" * 50) + + torch.set_printoptions(precision=4, sci_mode=False) + + print(f"Original FP16 : {original_slice.cpu().numpy()}") + print(f"From FP4 : {dequant_slice.cpu().numpy()}") + print(f"Abs Difference: {diff.cpu().numpy()}") + print("-" * 50 + "\n") + + +if __name__ == "__main__": + for block_size in [16, 32]: + for num_tokens in [1, 16, 256, 2048]: + for num_heads in [8, 16]: + for head_size in [64, 128]: + print( + f"Running parity check: num_tokens={num_tokens}, num_heads={num_heads}, head_size={head_size}, block_size={block_size}" + ) + results = parity_check_fp4_vs_fp16(num_tokens=num_tokens, + num_heads=num_heads, + head_size=head_size, + block_size=block_size, + device="cuda") diff --git a/tests/unit_tests/test_speculator_ln.py b/tests/unit_tests/test_speculator_ln.py new file mode 100644 index 000000000..56c61e98d --- /dev/null +++ b/tests/unit_tests/test_speculator_ln.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import random +import math + +from arctic_inference.py_custom_ops import (try_load_torch_library, + speculator_ln) + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + eps : float + Safety term to prevent division by zero. Make sure the chosen value + fits in the range of your encoding scheme + (i.e. fp16 requires eps >= 6e-8). + elementwise_scale_and_shift : bool + Include a learned scaling and shift term after normalization. + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + elementwise_scale_and_shift=True, + ): + super().__init__() + self.elementwise_scale_and_shift = elementwise_scale_and_shift + if self.elementwise_scale_and_shift: + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) + self.eps = eps + + assert try_load_torch_library(), "Custom ops library failed to load." + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + if self.elementwise_scale_and_shift: + x = self.weight * x + x = x + self.bias + return x + + def forward_opt(self, x): + return speculator_ln( + x, + self.weight if self.elementwise_scale_and_shift else None, + self.bias if self.elementwise_scale_and_shift else None, + float(self.eps), + ) + + +def run_case(shape, dtype, affine, eps=1e-6): + device = "cuda" + hidden = shape[-1] + + x = torch.randn(*shape, device=device, dtype=dtype) + + if affine: + ref = MLPSpeculatorLayerNorm(hidden, + eps=eps, + elementwise_scale_and_shift=True).to( + device=device, dtype=dtype) + with torch.no_grad(): + ref.weight.copy_(torch.randn(hidden, device=device, dtype=dtype)) + ref.bias.copy_(torch.randn(hidden, device=device, dtype=dtype)) + y_ref = ref(x) + + y = ref.forward_opt(x) + else: + ref = MLPSpeculatorLayerNorm(hidden, + eps=eps, + elementwise_scale_and_shift=False).to( + device=device, dtype=dtype) + y_ref = ref(x) + + y = ref.forward_opt(x) + + max_abs = (y - y_ref).abs().max().item() + denom = y_ref.abs().max().item() + 1e-8 + max_rel = (y - y_ref).abs().max().item() / denom + + return max_abs, max_rel + + +def main(): + torch.manual_seed(0) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required for this parity check.") + + shapes = [ + (32, 128), # multiple of 8 + (16, 260), # multiple of 4, not 8 + (7, 513), # tail + (2, 3, 1024) # higher rank + ] + dtypes = [torch.float16, torch.bfloat16] + affines = [False, True] + + print("Running parity checks against MLPSpeculatorLayerNorm...") + for dtype in dtypes: + for affine in affines: + for shape in shapes: + max_abs, max_rel = run_case(shape, dtype, affine, eps=1e-6) + print( + f"dtype={str(dtype).split('.')[-1]:>9} affine={affine!s:>5} shape={shape!s:<12} " + f"max_abs={max_abs:.3e} max_rel={max_rel:.3e}") + + +if __name__ == "__main__": + main() diff --git a/tests/unit_tests/test_sum_lstm.py b/tests/unit_tests/test_sum_lstm.py new file mode 100644 index 000000000..37a8ee16f --- /dev/null +++ b/tests/unit_tests/test_sum_lstm.py @@ -0,0 +1,208 @@ +import argparse +import math +import os +import random +from typing import Optional, Tuple + +import torch + +from arctic_inference.py_custom_ops import try_load_torch_library + +if try_load_torch_library(): + from arctic_inference.py_custom_ops import sum_lstm +else: + raise RuntimeError( + "The fused CUDA extension is not available. " + "Compile your extension so that arctic_inference.py_custom_ops exposes `sum_lstm`." + ) + + +def rms_norm(x: torch.Tensor, eps: float, weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor]) -> torch.Tensor: + inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps) + y = x * inv_rms + if weight is not None: + y = y * weight + if bias is not None: + y = y + bias + return y + + +def reference_sum_lstm(states4: torch.Tensor, z4: torch.Tensor, + prev_cell: torch.Tensor, w_cell: Optional[torch.Tensor], + b_cell: Optional[torch.Tensor], + w_state: Optional[torch.Tensor], + b_state: Optional[torch.Tensor], alpha: float, + eps_cell: float, eps_state: float, + fast_gelu: bool) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pure PyTorch reference for: + + added_states = states + alpha * z4 # gatewise + f, i, o = sigmoid(pre_f), sigmoid(pre_i), sigmoid(pre_o) + c_pre = added_states[..., 3D:4D] + c_new = prev_cell * f + GELU(RMSNorm(c_pre)) * i + state = GELU(RMSNorm(c_new)) * o + """ + D = prev_cell.size(-1) + s_f, s_i, s_o, s_c = states4.split(D, dim=-1) + z_f, z_i, z_o, z_c = z4.split(D, dim=-1) + + pre_f = s_f + alpha * z_f + pre_i = s_i + alpha * z_i + pre_o = s_o + alpha * z_o + c_pre = s_c + alpha * z_c + + f = torch.sigmoid(pre_f) + i = torch.sigmoid(pre_i) + + cn = rms_norm(c_pre, eps_cell, w_cell, b_cell) + c_act = torch.nn.functional.gelu( + cn, approximate="tanh" if fast_gelu else "none") + c_new = prev_cell * f + c_act * i + + sn = rms_norm(c_new, eps_state, w_state, b_state) + s_act = torch.nn.functional.gelu( + sn, approximate="tanh" if fast_gelu else "none") + + o = torch.sigmoid(pre_o) + state = s_act * o + return state, c_new + + +def make_pitched_view(rows: int, cols: int, dtype: torch.dtype, + device: torch.device, pad: int) -> torch.Tensor: + """ + Create a 2D tensor view with contiguous last dimension but a *larger row stride* + (so row_stride != cols). This exercises the kernel's stride handling. + """ + base = torch.empty((rows, cols + pad), dtype=dtype, device=device) + view = base[:, :cols] + return view + + +def gen_optional_vec(D: int, dtype: torch.dtype, device: torch.device, + enable: bool) -> Optional[torch.Tensor]: + if not enable: + return None + t = torch.randn(D, dtype=dtype, device=device) + return t.contiguous() + + +def run_one_case(rows: int, D: int, dtype: torch.dtype, device: torch.device, + fast_gelu: bool, use_wb_cell: bool, use_wb_state: bool, + seed: int, pad4: int, padD: int, atol: float, + rtol: float) -> None: + torch.manual_seed(seed) + + states4 = make_pitched_view(rows, 4 * D, dtype, device, pad4) + z = torch.randn(rows, D, dtype=dtype, device=device) + + z4 = make_pitched_view(rows, 4 * D, dtype, device, pad4) + z4.copy_(z.repeat(1, 4)) + + prev_cell = make_pitched_view(rows, D, dtype, device, padD) + + states4.copy_(torch.randn_like(states4)) + prev_cell.copy_(torch.randn_like(prev_cell)) + + w_cell = gen_optional_vec(D, dtype, device, use_wb_cell) + b_cell = gen_optional_vec(D, dtype, device, use_wb_cell) + w_state = gen_optional_vec(D, dtype, device, use_wb_state) + b_state = gen_optional_vec(D, dtype, device, use_wb_state) + + alpha = 0.35 + eps_cell = 1e-6 + eps_state = 1e-6 + + with torch.no_grad(): + ref_state, ref_cell = reference_sum_lstm(states4, z4, prev_cell, + w_cell, b_cell, w_state, + b_state, alpha, eps_cell, + eps_state, fast_gelu) + + with torch.no_grad(): + fused_state, fused_cell = sum_lstm(states4, z4, prev_cell, w_cell, + b_cell, w_state, b_state, + float(alpha), float(eps_cell), + float(eps_state), bool(fast_gelu)) + + diff_state = (ref_state - fused_state).abs() + diff_cell = (ref_cell - fused_cell).abs() + + max_abs_state = diff_state.max().item() + max_abs_cell = diff_cell.max().item() + + max_rel_state = (diff_state / (ref_state.abs() + 1e-6)).max().item() + max_rel_cell = (diff_cell / (ref_cell.abs() + 1e-6)).max().item() + + print( + f"[dtype={str(dtype):>8s}] rows={rows:4d} D={D:5d} fast_gelu={fast_gelu} " + f"wb_cell={use_wb_cell} wb_state={use_wb_state} " + f"pad4={pad4:3d} padD={padD:3d} | " + f"state abs={max_abs_state:.3e} rel={max_rel_state:.3e} ; " + f"cell abs={max_abs_cell:.3e} rel={max_rel_cell:.3e}") + + assert max_abs_state <= atol or max_rel_state <= rtol, "State parity check failed" + assert max_abs_cell <= atol or max_rel_cell <= rtol, "Cell parity check failed" + + +def main(): + parser = argparse.ArgumentParser( + description="Parity check for fused sum_lstm kernel") + parser.add_argument("--device", + default="cuda", + help="cuda device, e.g., cuda:0") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--fast_gelu", + action="store_true", + help="use approximate GELU in both paths") + args = parser.parse_args() + + assert torch.cuda.is_available(), "CUDA is required for this parity test" + device = torch.device(args.device) + + dtypes = [torch.float16] + if torch.cuda.is_bf16_supported(): + dtypes.append(torch.bfloat16) + + tol = { + torch.float16: dict(atol=2e-2, rtol=2e-2), + torch.bfloat16: dict(atol=1e-1, rtol=2e-2), + } + + cases = [ + dict(rows=1, D=64, pad4=0, padD=0), + dict(rows=4, D=128, pad4=7, padD=3), + dict(rows=32, D=256, pad4=0, padD=0), + dict(rows=64, D=512, pad4=16, padD=8), + ] + + random.seed(args.seed) + torch.manual_seed(args.seed) + + for dtype in dtypes: + atol = tol[dtype]["atol"] + rtol = tol[dtype]["rtol"] + + for cfg in cases: + for wb_cell in (True, False): + for wb_state in (True, False): + run_one_case( + rows=cfg["rows"], + D=cfg["D"], + dtype=dtype, + device=device, + fast_gelu=args.fast_gelu, + use_wb_cell=wb_cell, + use_wb_state=wb_state, + seed=random.randint(0, 10_000_000), + pad4=cfg["pad4"], + padD=cfg["padD"], + atol=atol, + rtol=rtol, + ) + + +if __name__ == "__main__": + main()