- Arctic Ulysses (blog,
- paper)
+ Arctic Ulysses (blog)
- Shift Parallelism (blog)
+ Shift Parallelism (blog,
+ paper)
|
Arctic Speculator (blog)
@@ -105,7 +105,7 @@ By using the examples below, you can get benefits from Shift Parallelism, Specul
#### Serving
```console
-vllm serve Snowflake/Llama-3.1-SwiftKV-8B-Instruct \
+ARCTIC_INFERENCE_ENABLED=1 vllm serve Snowflake/Llama-3.1-SwiftKV-8B-Instruct \
--quantization "fp8" \
--tensor-parallel-size 1 \
--ulysses-sequence-parallel-size 2 \
@@ -121,6 +121,8 @@ vllm serve Snowflake/Llama-3.1-SwiftKV-8B-Instruct \
#### Offline
+Save the following script to `arctic_example.py`:
+
```python
import vllm
from vllm import LLM, SamplingParams
@@ -156,6 +158,12 @@ outputs = llm.chat(conversation, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
```
+Run the script with Arctic Inference enabled:
+
+```console
+ARCTIC_INFERENCE_ENABLED=1 python arctic_example.py
+```
+
## Citation
```
@misc{arcticinference2025,
diff --git a/arctic_inference/envs.py b/arctic_inference/envs.py
index 02056a158..6b3da367a 100644
--- a/arctic_inference/envs.py
+++ b/arctic_inference/envs.py
@@ -20,10 +20,18 @@
ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK: bool = False
environment_variables: dict[str, Callable[[], Any]] = {
+ "ARCTIC_INFERENCE_ENABLED":
+ lambda: os.getenv("ARCTIC_INFERENCE_ENABLED", "0") == "1",
+ "ARCTIC_INFERENCE_SKIP_PLATFORM_CHECK":
+ lambda: os.getenv("ARCTIC_INFERENCE_SKIP_PLATFORM_CHECK", "0") == "1",
"ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK":
lambda: os.getenv("ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK", "0") == "1",
+ "ARCTIC_INFERENCE_SKIP_VERSION_CHECK":
+ lambda: os.getenv("ARCTIC_INFERENCE_SKIP_VERSION_CHECK", "0") == "1",
}
+# temporary workaround for gpt-oss model
+ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK = 1
def __getattr__(name: str) -> Any:
if name in environment_variables:
diff --git a/arctic_inference/op_builder/__init__.py b/arctic_inference/op_builder/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/arctic_inference/op_builder/builder.py b/arctic_inference/op_builder/builder.py
new file mode 100644
index 000000000..3206c3f1c
--- /dev/null
+++ b/arctic_inference/op_builder/builder.py
@@ -0,0 +1,545 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import os
+import re
+import sys
+import time
+import importlib
+from pathlib import Path
+import subprocess
+import shlex
+import shutil
+import tempfile
+import distutils.ccompiler
+import distutils.log
+import distutils.sysconfig
+from distutils.errors import CompileError, LinkError
+from abc import ABC, abstractmethod
+from typing import List
+
+YELLOW = '\033[93m'
+END = '\033[0m'
+WARNING = f"{YELLOW} [WARNING] {END}"
+
+DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
+DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
+
+try:
+ import torch
+except ImportError:
+ print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.")
+else:
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
+
+
+class MissingCUDAException(Exception):
+ pass
+
+
+class CUDAMismatchException(Exception):
+ pass
+
+
+def installed_cuda_version(name=""):
+ import torch.utils.cpp_extension
+ cuda_home = torch.utils.cpp_extension.CUDA_HOME
+ if cuda_home is None:
+ raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)")
+ # Ensure there is not a cuda version mismatch between torch and nvcc compiler
+ output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
+ output_split = output.split()
+ release_idx = output_split.index("release")
+ release = output_split[release_idx + 1].replace(',', '').split(".")
+ # Ignore patch versions, only look at major + minor
+ cuda_major, cuda_minor = release[:2]
+ return int(cuda_major), int(cuda_minor)
+
+
+def get_default_compute_capabilities():
+ compute_caps = DEFAULT_COMPUTE_CAPABILITIES
+ # Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported
+ import torch.utils.cpp_extension
+ if torch.utils.cpp_extension.CUDA_HOME is not None:
+ if installed_cuda_version()[0] == 11:
+ if installed_cuda_version()[1] >= 0:
+ compute_caps += ";8.0"
+ if installed_cuda_version()[1] >= 1:
+ compute_caps += ";8.6"
+ if installed_cuda_version()[1] >= 8:
+ compute_caps += ";9.0"
+ elif installed_cuda_version()[0] == 12:
+ compute_caps += ";8.0;8.6;9.0"
+ if installed_cuda_version()[1] >= 8:
+ compute_caps += ";10.0;12.0"
+ return compute_caps
+
+
+# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
+# to build deepspeed and system-wide installed cuda 11.2
+cuda_minor_mismatch_ok = {
+ 10: ["10.0", "10.1", "10.2"],
+ 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"],
+ 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6", "12.8", "12.9"], # There is no CUDATk 12.7
+}
+
+
+def assert_no_cuda_mismatch(name=""):
+ cuda_major, cuda_minor = installed_cuda_version(name)
+ sys_cuda_version = f'{cuda_major}.{cuda_minor}'
+ torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
+ # This is a show-stopping error, should probably not proceed past this
+ if sys_cuda_version != torch_cuda_version:
+ if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
+ and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
+ print(f"Installed CUDA version {sys_cuda_version} does not match the "
+ f"version torch was compiled with {torch.version.cuda} "
+ "but since the APIs are compatible, accepting this combination")
+ return True
+ elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1":
+ print(
+ f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
+ f"version torch was compiled with {torch.version.cuda}."
+ "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior."
+ )
+ return True
+ raise CUDAMismatchException(
+ f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
+ f"version torch was compiled with {torch.version.cuda}, unable to compile "
+ "cuda/cpp extensions without a matching cuda version.")
+ return True
+
+
+class OpBuilder(ABC):
+ _loaded_ops = {}
+
+ def __init__(self, name):
+ self.name = name
+ self.jit_mode = False
+ self.enable_bf16 = False
+ self.error_log = None
+
+ @abstractmethod
+ def absolute_name(self):
+ '''
+ Returns absolute build path for cases where the op is pre-installed will be installed.
+ '''
+ pass
+
+ @abstractmethod
+ def sources(self):
+ '''
+ Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
+ '''
+ pass
+
+ @staticmethod
+ def validate_torch_version(torch_info):
+ install_torch_version = torch_info['version']
+ current_torch_version = ".".join(torch.__version__.split('.')[:2])
+ if install_torch_version != current_torch_version:
+ raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed "
+ "with a different version than what is being used at runtime. "
+ f"Please re-install DeepSpeed or switch torch versions. "
+ f"Install torch version={install_torch_version}, "
+ f"Runtime torch version={current_torch_version}")
+
+ @staticmethod
+ def validate_torch_op_version(torch_info):
+
+ current_hip_version = ".".join(torch.version.hip.split('.')[:2])
+ install_hip_version = torch_info['hip_version']
+ if install_hip_version != current_hip_version:
+ raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed "
+ "with a different version than what is being used at runtime. "
+ f"Please re-install DeepSpeed or switch torch versions. "
+ f"Install HIP version={install_hip_version}, "
+ f"Runtime HIP version={current_hip_version}")
+
+ def include_paths(self):
+ '''
+ Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
+ '''
+ return []
+
+ def nvcc_args(self):
+ '''
+ Returns optional list of compiler flags to forward to nvcc when building CUDA sources
+ '''
+ return []
+
+ def cxx_args(self):
+ '''
+ Returns optional list of compiler flags to forward to the build
+ '''
+ return []
+
+ def is_compatible(self, verbose=False):
+ '''
+ Check if all non-python dependencies are satisfied to build this op
+ '''
+ return True
+
+ def extra_ldflags(self):
+ return []
+
+ def has_function(self, funcname, libraries, library_dirs=None, verbose=False):
+ '''
+ Test for existence of a function within a tuple of libraries.
+
+ This is used as a smoke test to check whether a certain library is available.
+ As a test, this creates a simple C program that calls the specified function,
+ and then distutils is used to compile that program and link it with the specified libraries.
+ Returns True if both the compile and link are successful, False otherwise.
+ '''
+ tempdir = None # we create a temporary directory to hold various files
+ filestderr = None # handle to open file to which we redirect stderr
+ oldstderr = None # file descriptor for stderr
+ try:
+ # Echo compile and link commands that are used.
+ if verbose:
+ distutils.log.set_verbosity(1)
+
+ # Create a compiler object.
+ compiler = distutils.ccompiler.new_compiler(verbose=verbose)
+
+ # Configure compiler and linker to build according to Python install.
+ distutils.sysconfig.customize_compiler(compiler)
+
+ # Create a temporary directory to hold test files.
+ tempdir = tempfile.mkdtemp()
+
+ # Define a simple C program that calls the function in question
+ prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname)
+
+ # Write the test program to a file.
+ filename = os.path.join(tempdir, 'test.c')
+ with open(filename, 'w') as f:
+ f.write(prog)
+
+ # Redirect stderr file descriptor to a file to silence compile/link warnings.
+ if not verbose:
+ filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
+ oldstderr = os.dup(sys.stderr.fileno())
+ os.dup2(filestderr.fileno(), sys.stderr.fileno())
+
+ # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames()
+ # Otherwise, a local directory will be used instead of tempdir
+ drive, driveless_filename = os.path.splitdrive(filename)
+ root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else ''
+ output_dir = os.path.join(drive, root_dir)
+
+ # Attempt to compile the C program into an object file.
+ cflags = shlex.split(os.environ.get('CFLAGS', ""))
+ objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags))
+
+ # Attempt to link the object file into an executable.
+ # Be sure to tack on any libraries that have been specified.
+ ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
+ compiler.link_executable(objs,
+ os.path.join(tempdir, 'a.out'),
+ extra_preargs=self.strip_empty_entries(ldflags),
+ libraries=libraries,
+ library_dirs=library_dirs)
+
+ # Compile and link succeeded
+ return True
+
+ except CompileError:
+ return False
+
+ except LinkError:
+ return False
+
+ except:
+ return False
+
+ finally:
+ # Restore stderr file descriptor and close the stderr redirect file.
+ if oldstderr is not None:
+ os.dup2(oldstderr, sys.stderr.fileno())
+ if filestderr is not None:
+ filestderr.close()
+
+ # Delete the temporary directory holding the test program and stderr files.
+ if tempdir is not None:
+ shutil.rmtree(tempdir)
+
+ def strip_empty_entries(self, args):
+ '''
+ Drop any empty strings from the list of compile and link flags
+ '''
+ return [x for x in args if len(x) > 0]
+
+ def get_cuda_compile_flag(self):
+ try:
+ assert_no_cuda_mismatch(self.name)
+ return "-D__ENABLE_CUDA__"
+ except MissingCUDAException:
+ print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
+ "only cpu ops can be compiled!")
+ return '-D__DISABLE_CUDA__'
+
+ def command_exists(self, cmd):
+ if '|' in cmd:
+ cmds = cmd.split("|")
+ else:
+ cmds = [cmd]
+ valid = False
+ for cmd in cmds:
+ safe_cmd = ["bash", "-c", f"type {cmd}"]
+ result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
+ valid = valid or result.wait() == 0
+
+ if not valid and len(cmds) > 1:
+ print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!")
+ elif not valid and len(cmds) == 1:
+ print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!")
+ return valid
+
+ def warning(self, msg):
+ self.error_log = f"{msg}"
+ print(f"{WARNING} {msg}")
+
+ def _src_path(self, code_path):
+ if os.path.isabs(code_path):
+ return code_path
+ else:
+ return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
+
+ def builder(self):
+ from torch.utils.cpp_extension import CppExtension
+ include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())]
+ return CppExtension(name=self.absolute_name(),
+ sources=self.strip_empty_entries(self.sources()),
+ include_dirs=include_dirs,
+ extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
+ extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
+
+ def load(self, verbose=False):
+ if self.name in __class__._loaded_ops:
+ return __class__._loaded_ops[self.name]
+
+ from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name
+ from deepspeed.accelerator import get_accelerator
+ if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name:
+ # Ensure the op we're about to load was compiled with the same
+ # torch/cuda versions we are currently using at runtime.
+ self.validate_torch_version(torch_info)
+ if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder):
+ self.validate_torch_op_version(torch_info)
+
+ op_module = importlib.import_module(self.absolute_name())
+ __class__._loaded_ops[self.name] = op_module
+ return op_module
+ else:
+ return self.jit_load(verbose)
+
+ def jit_load(self, verbose=True):
+ if not self.is_compatible(verbose):
+ raise RuntimeError(
+ f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
+ )
+ try:
+ import ninja # noqa: F401 # type: ignore
+ except ImportError:
+ raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
+
+ self.jit_mode = True
+ from torch.utils.cpp_extension import load
+
+ start_build = time.time()
+ sources = [os.path.abspath(self._src_path(path)) for path in self.sources()]
+ extra_include_paths = [os.path.abspath(self._src_path(path)) for path in self.include_paths()]
+
+ # Torch will try and apply whatever CCs are in the arch list at compile time,
+ # we have already set the intended targets ourselves we know that will be
+ # needed at runtime. This prevents CC collisions such as multiple __half
+ # implementations. Stash arch list to reset after build.
+ torch_arch_list = None
+ if "TORCH_CUDA_ARCH_LIST" in os.environ:
+ torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
+
+ nvcc_args = self.strip_empty_entries(self.nvcc_args())
+ cxx_args = self.strip_empty_entries(self.cxx_args())
+
+ cxx_args.append("-UC10_USE_GLOG")
+ nvcc_args.append("-UC10_USE_GLOG")
+ if isinstance(self, CUDAOpBuilder):
+ if self.enable_bf16:
+ cxx_args.append("-DBF16_AVAILABLE")
+ nvcc_args.append("-DBF16_AVAILABLE")
+ nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
+ nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
+ nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
+
+ op_module = load(name=self.name,
+ sources=self.strip_empty_entries(sources),
+ extra_include_paths=self.strip_empty_entries(extra_include_paths),
+ extra_cflags=cxx_args,
+ extra_cuda_cflags=nvcc_args,
+ extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
+ with_cuda=True if (isinstance(self, CUDAOpBuilder)) else None,
+ verbose=verbose)
+
+ build_duration = time.time() - start_build
+ if verbose:
+ print(f"Time to load {self.name} op: {build_duration} seconds")
+
+ # Reset arch list so we are not silently removing it for other possible use cases
+ if torch_arch_list:
+ os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
+
+ __class__._loaded_ops[self.name] = op_module
+
+ return op_module
+
+
+class CUDAOpBuilder(OpBuilder):
+
+ def compute_capability_args(self, cross_compile_archs=None):
+ """
+ Returns nvcc compute capability compile flags.
+
+ 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
+ 2. If neither is set default compute capabilities will be used
+ 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
+
+ Format:
+
+ - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
+
+ TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
+ TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...
+
+ - `cross_compile_archs` uses ; separator.
+
+ """
+ ccs = []
+ if self.jit_mode:
+ # Compile for underlying architectures since we know those at runtime
+ for i in range(torch.cuda.device_count()):
+ CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
+ cc = f"{CC_MAJOR}.{CC_MINOR}"
+ if cc not in ccs:
+ ccs.append(cc)
+ ccs = sorted(ccs)
+ ccs[-1] += '+PTX'
+ else:
+ # Cross-compile mode, compile for various architectures
+ # env override takes priority
+ cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
+ if cross_compile_archs_env is not None:
+ if cross_compile_archs is not None:
+ print(
+ f"{WARNING} env var TORCH_CUDA_ARCH_LIST={cross_compile_archs_env} overrides cross_compile_archs={cross_compile_archs}"
+ )
+ cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
+ else:
+ if cross_compile_archs is None:
+ cross_compile_archs = get_default_compute_capabilities()
+ ccs = cross_compile_archs.split(';')
+
+ ccs = self.filter_ccs(ccs)
+ if len(ccs) == 0:
+ raise RuntimeError(
+ f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")
+
+ args = []
+ self.enable_bf16 = True
+ for cc in ccs:
+ num = cc[0] + cc[1].split('+')[0]
+ args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
+ if cc[1].endswith('+PTX'):
+ args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
+
+ if int(cc[0]) <= 7:
+ self.enable_bf16 = False
+
+ return args
+
+ def filter_ccs(self, ccs: List[str]):
+ """
+ Prune any compute capabilities that are not compatible with the builder. Should log
+ which CCs have been pruned.
+ """
+ return [cc.split('.') for cc in ccs]
+
+ def version_dependent_macros(self):
+ # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
+ version_ge_1_1 = []
+ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
+ version_ge_1_1 = ['-DVERSION_GE_1_1']
+ version_ge_1_3 = []
+ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
+ version_ge_1_3 = ['-DVERSION_GE_1_3']
+ version_ge_1_5 = []
+ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
+ version_ge_1_5 = ['-DVERSION_GE_1_5']
+ return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
+
+ def is_compatible(self, verbose=False):
+ return super().is_compatible(verbose)
+
+ def builder(self):
+ from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder
+ include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())]
+ compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \
+ {'cxx': self.strip_empty_entries(self.cxx_args()), \
+ 'nvcc': self.strip_empty_entries(self.nvcc_args())}
+
+ if self.enable_bf16:
+ compile_args['cxx'].append("-DBF16_AVAILABLE")
+ compile_args['nvcc'].append("-DBF16_AVAILABLE")
+
+ cuda_ext = ExtensionBuilder(name=self.absolute_name(),
+ sources=self.strip_empty_entries(self.sources()),
+ include_dirs=include_dirs,
+ libraries=self.strip_empty_entries(self.libraries_args()),
+ extra_compile_args=compile_args,
+ extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
+
+ return cuda_ext
+
+ def cxx_args(self):
+ if sys.platform == "win32":
+ return ['-O2']
+ else:
+ return ['-O3', '-std=c++17', '-g', '-Wno-reorder']
+
+ def nvcc_args(self):
+ args = ['-O3']
+ try:
+ nvcc_threads = int(os.getenv("DS_NVCC_THREADS", ""))
+ if nvcc_threads <= 0:
+ raise ValueError("")
+ except ValueError:
+ nvcc_threads = min(os.cpu_count(), 8)
+ cuda_major, cuda_minor = installed_cuda_version()
+ if cuda_major > 10:
+ if cuda_major == 12 and cuda_minor >= 5:
+ std_lib = '-std=c++20'
+ else:
+ std_lib = '-std=c++17'
+ else:
+ std_lib = '-std=c++14'
+ args += [
+ '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', std_lib,
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ f'--threads={nvcc_threads}'
+ ]
+ if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1':
+ args.append('--ptxas-options=-v')
+ args += self.compute_capability_args()
+ return args
+
+ def libraries_args(self):
+ if sys.platform == "win32":
+ return ['cublas', 'curand']
+ else:
+ return []
+
diff --git a/arctic_inference/op_builder/swiftkv_ops_builder.py b/arctic_inference/op_builder/swiftkv_ops_builder.py
new file mode 100644
index 000000000..77ebc8abe
--- /dev/null
+++ b/arctic_inference/op_builder/swiftkv_ops_builder.py
@@ -0,0 +1,29 @@
+import os
+from .builder import CUDAOpBuilder
+
+class SwiftKVOpsBuilder(CUDAOpBuilder):
+ def __init__(self):
+ super().__init__(name="reshape_and_cache_flash_bulk")
+
+ def absolute_name(self):
+ return f'arctic_inference.swiftkv_ops.{self.name}'
+
+ def get_prefix(self):
+ # borrowed from moe_op. refactor later
+ ai_path = self._src_path("arctic_inference")
+ return "arctic_inference" if os.path.isdir(ai_path) else ".."
+
+ def sources(self):
+ sources = [
+ 'csrc/custom_ops/torch_bindings.cpp',
+ 'csrc/custom_ops/reshape_and_cache_flash_bulk.cu',
+ ]
+ prefix = self.get_prefix()
+ sources = [os.path.join(prefix, src) for src in sources]
+ return sources
+
+ def include_paths(self):
+ sources = ['csrc/custom_ops']
+ prefix = self.get_prefix()
+ sources = [os.path.join(prefix, src) for src in sources]
+ return sources
diff --git a/arctic_inference/py_custom_ops.py b/arctic_inference/py_custom_ops.py
index 89a05d92b..2df4ab397 100644
--- a/arctic_inference/py_custom_ops.py
+++ b/arctic_inference/py_custom_ops.py
@@ -22,8 +22,8 @@ def try_load_torch_library() -> bool:
return False
try:
- logger.info(f"Attempting to load custom ops from {library_path}...")
torch.ops.load_library(library_path)
+ logger.info(f"Successfully loaded custom ops from {library_path}.")
return True
except RuntimeError as e:
logger.info(
@@ -37,6 +37,25 @@ def try_load_torch_library() -> bool:
return False
+def try_load_jit_library() -> bool:
+ try:
+ from arctic_inference.op_builder.swiftkv_ops_builder import SwiftKVOpsBuilder
+ swiftkv_ops_module = SwiftKVOpsBuilder().load()
+
+ logger.info("Successfully loaded SwiftKVOpsBuilder JIT library.")
+ return True
+ except ImportError as e:
+ logger.info(
+ f"Unable to import SwiftKVOpsBuilder. ImportError: {e}. Falling back to original implementation."
+ )
+ return False
+ except Exception as e:
+ logger.info(
+ f"Unable to load JIT library. Exception: {e}. Falling back to original implementation."
+ )
+ return False
+
+
def reshape_and_cache_flash_bulk(
keys: list[torch.Tensor],
values: list[torch.Tensor],
@@ -52,3 +71,48 @@ def reshape_and_cache_flash_bulk(
torch.ops.arctic_inference.reshape_and_cache_flash_bulk(
keys, values, key_caches, value_caches, slot_mapping, kv_cache_dtype,
k_scales, v_scales, num_heads, head_size)
+
+
+def reshape_and_cache_flash_fp4(
+ keys: torch.Tensor,
+ values: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ key_cache_scales: torch.Tensor,
+ value_cache_scales: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ kv_cache_dtype: str,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+) -> None:
+ torch.ops.arctic_inference.reshape_and_cache_flash_fp4(
+ keys, values, key_cache, value_cache, slot_mapping, kv_cache_dtype,
+ k_scale, v_scale, key_cache_scales, value_cache_scales)
+
+
+def speculator_ln(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return torch.ops.arctic_inference.speculator_ln_cuda(
+ input, weight, bias, eps)
+
+
+def sum_lstm(
+ states_4d: torch.Tensor,
+ z4_4d: torch.Tensor,
+ prev_cell_d: torch.Tensor,
+ w_cell: torch.Tensor | None,
+ b_cell: torch.Tensor | None,
+ w_state: torch.Tensor | None,
+ b_state: torch.Tensor | None,
+ alpha: float,
+ eps_cell: float,
+ eps_state: float,
+ use_fast_gelu: bool,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ return torch.ops.arctic_inference.sum_lstm_cuda(
+ states_4d, z4_4d, prev_cell_d, w_cell, b_cell, w_state, b_state, alpha,
+ eps_cell, eps_state, use_fast_gelu)
\ No newline at end of file
diff --git a/arctic_inference/common/suffix_cache/__init__.py b/arctic_inference/suffix_decoding/__init__.py
similarity index 83%
rename from arctic_inference/common/suffix_cache/__init__.py
rename to arctic_inference/suffix_decoding/__init__.py
index 59161c330..2e2ad3384 100644
--- a/arctic_inference/common/suffix_cache/__init__.py
+++ b/arctic_inference/suffix_decoding/__init__.py
@@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .suffix_cache import SuffixCache, SuffixSpecResult
+from .cache import SuffixDecodingCache, SuffixDecodingDraft
-__all__ = ["SuffixCache", "SuffixSpecResult"]
+__all__ = ["SuffixDecodingCache", "SuffixDecodingDraft"]
diff --git a/arctic_inference/common/suffix_cache/suffix_cache.py b/arctic_inference/suffix_decoding/cache.py
similarity index 61%
rename from arctic_inference/common/suffix_cache/suffix_cache.py
rename to arctic_inference/suffix_decoding/cache.py
index c06cf19e5..2c0cd32a0 100644
--- a/arctic_inference/common/suffix_cache/suffix_cache.py
+++ b/arctic_inference/suffix_decoding/cache.py
@@ -16,13 +16,15 @@
from __future__ import annotations
from dataclasses import dataclass, field
-from typing import Hashable, KeysView, List, Optional, Sequence, Union
+from typing import Hashable, KeysView, List, Optional, Sequence
-from arctic_inference.common.suffix_cache._C import SuffixTree, Candidate
+import numpy as np
+
+from arctic_inference.suffix_decoding._C import SuffixTree, Draft
@dataclass
-class SuffixSpecResult:
+class SuffixDecodingDraft:
"""
A dataclass representing the result of a speculation using SuffixDecoding.
@@ -34,7 +36,7 @@ class SuffixSpecResult:
probs (List[float]): List of estimated probabilities for each token.
score (float): The overall score of the suffix match computed as the
sum of the estimated probabilities of each speculated token.
- match_len (int): The length of the pattern match that yielded this
+ match_len (int): The length of the context match that yielded this
speculation result.
"""
token_ids: List[int] = field(default_factory=list)
@@ -44,30 +46,33 @@ class SuffixSpecResult:
match_len: int = 0
@staticmethod
- def from_candidate(candidate: Candidate) -> SuffixSpecResult:
- return SuffixSpecResult(
- token_ids=candidate.token_ids,
- parents=candidate.parents,
- probs=candidate.probs,
- score=candidate.score,
- match_len=candidate.match_len,
+ def from_native(draft: Draft) -> SuffixDecodingDraft:
+ return SuffixDecodingDraft(
+ token_ids=draft.token_ids,
+ parents=draft.parents,
+ probs=draft.probs,
+ score=draft.score,
+ match_len=draft.match_len,
)
-class SuffixCache:
+class SuffixDecodingCache:
def __init__(self,
max_tree_depth: int = 64,
- max_cached_requests: Optional[int] = None):
+ max_cached_requests: int = -1):
"""
- Initialize the SuffixCache.
+ Initialize the SuffixDecodingCache.
Args:
max_tree_depth (int): The maximum depth of the suffix trees.
max_cached_requests (int, optional): The maximum number of cached
- requests. Cache eviction is used when the limit is reached. If
- `None`, there is no limit on the number of cached requests.
+ requests. Eviction is triggered when the limit is reached. `-1`
+ means no limit on the number of cached requests.
"""
+ if max_cached_requests > 0x7FFFFFFF:
+ raise ValueError("max_cached_requests must be at most 2^31")
+
self._max_tree_depth = max_tree_depth
self._max_cached_requests = max_cached_requests
@@ -78,7 +83,7 @@ def __init__(self,
self._local_trees = {}
# Maps between Python request ID and int32_t sequence ID. Tracks all
- # request IDs that are in the global tree or one of the local trees.
+ # request IDs that are in the global tree.
self._req_to_seq_id = {}
self._seq_to_req_id = {}
@@ -112,33 +117,54 @@ def cached_requests(self) -> KeysView:
"""
return self._req_to_seq_id.keys()
- def start_request(self, req_id: Hashable, prompt_token_ids: Sequence[int]):
+ def start_request(
+ self,
+ req_id: Hashable,
+ prompt_token_ids: np.ndarray | Sequence[int],
+ ):
"""
This method should be called when starting to process a new request. It
will store the prompt for the request, allowing future speculations for
the same request to use the prompt context. The prompt will be stored
- until `stop_request` is called.
+ until `stop_request` is called. If `max_cached_requests != 0`, then a
+ new slot is allocated in the global cache for the response, triggering
+ cache eviction (FIFO order) if needed.
Args:
req_id (Hashable): The request identifier. Must be a hashable value
that uniquely identifies the request.
- prompt_token_ids (Sequence[int]): A sequence of token IDs
- representing the prompt of the request.
+ prompt_token_ids (np.ndarray | Sequence[int]): A sequence of token
+ IDs representing the prompt of the request.
Raises:
ValueError: If a request with the same `req_id` is already active
or cached.
"""
- if req_id in self._req_to_seq_id:
- raise ValueError(f"Request '{req_id}' is already active or cached")
- seq_id = self._generate_seq_id(req_id)
+ if req_id in self._local_trees:
+ raise ValueError(f"Request '{req_id}' is already active")
+
+ if isinstance(prompt_token_ids, np.ndarray):
+ # If input is a numpy array, use the zero-copy ndarray overload.
+ self._validate_ndarray(prompt_token_ids)
+ extend_func = SuffixTree.extend_ndarray
+ else:
+ extend_func = SuffixTree.extend
+
self._local_trees[req_id] = SuffixTree(self._max_tree_depth)
- self._local_trees[req_id].extend(seq_id, prompt_token_ids)
+ extend_func(self._local_trees[req_id], 0, prompt_token_ids)
+ if self._max_cached_requests != 0:
+ # Global cache is enabled.
+ if req_id in self._req_to_seq_id:
+ # Evict existing cached response for the request if present.
+ self.evict_cached_response(req_id)
+ # Allocate a new seq_id for the request.
+ self._generate_seq_id(req_id)
def stop_request(self, req_id: Hashable):
"""
This method should be called when a request is completed. It will evict
- the prompt for the request, freeing up memory.
+ the prompt for the request, freeing up memory. The request's response
+ may still be cached in the global cache until it is evicted.
Args:
req_id (Hashable): The request identifier. Must be a hashable value
@@ -154,7 +180,7 @@ def stop_request(self, req_id: Hashable):
def add_active_response(
self,
req_id: Hashable,
- token_ids: Union[int, Sequence[int]],
+ token_ids: np.ndarray | Sequence[int],
):
"""
Update the cached response for a given request by appending token(s) to
@@ -163,51 +189,35 @@ def add_active_response(
Args:
req_id (Hashable): The unique identifier for the request.
- token_ids (Union[int, Sequence[int]]): Either a single token ID
- (int) or a sequence of token IDs to be appended to the response
- for the given request.
+ token_ids (np.ndarray | Sequence[int]): A sequence of token IDs to
+ be appended to the response for the given request.
Raises:
ValueError: If the request with the given `req_id` is not active.
"""
if req_id not in self._local_trees:
raise ValueError(f"Request '{req_id}' is not active")
- seq_id = self._req_to_seq_id[req_id]
- if isinstance(token_ids, int):
- self._global_tree.append(seq_id, token_ids)
- self._local_trees[req_id].append(seq_id, token_ids)
- else:
- self._global_tree.extend(seq_id, token_ids)
- self._local_trees[req_id].extend(seq_id, token_ids)
- def insert_new_response(
- self,
- req_id: Hashable,
- token_ids: Union[int, Sequence[int]],
- ):
- """
- Insert a complete response to the global cache for a request that is
- not active and is not already cached.
+ if isinstance(token_ids, np.ndarray):
+ # If input is a numpy array, use the zero-copy ndarray overload.
+ self._validate_ndarray(token_ids)
+ extend_func = SuffixTree.extend_ndarray
+ else:
+ extend_func = SuffixTree.extend
- Args:
- req_id (Hashable): The unique identifier for the request.
- token_ids (Sequence[int]): A sequence of token IDs to be inserted
- as the response for the given request.
+ # Update the local tree for the active request.
+ extend_func(self._local_trees[req_id], 0, token_ids)
- Raises:
- ValueError: If a request with the same `req_id` is already active
- or cached.
- """
+ # Also update the response if the request is in the global cache (it
+ # may be evicted from the global cache before the request is stopped).
if req_id in self._req_to_seq_id:
- raise ValueError(f"Request '{req_id}' is already active or cached")
- seq_id = self._generate_seq_id(req_id)
- self._global_tree.extend(seq_id, token_ids)
+ seq_id = self._req_to_seq_id[req_id]
+ extend_func(self._global_tree, seq_id, token_ids)
- def evict_request(self, req_id: Hashable):
+ def evict_cached_response(self, req_id: Hashable):
"""
- Evicts the given request's prompt and response from the cache. If the
- request is active, it becomes inactive. The `req_id` can then be reused
- after eviction.
+ Evicts the given request's response from the global cache. `req_id` can
+ be safely reused for a new request after eviction.
Args:
req_id (Hashable): The unique identifier for the request that
@@ -217,9 +227,7 @@ def evict_request(self, req_id: Hashable):
ValueError: If no response exists for the given request identifier.
"""
if req_id not in self._req_to_seq_id:
- raise ValueError(f"Request '{req_id}' is not active or cached")
- if req_id in self._local_trees:
- del self._local_trees[req_id]
+ raise ValueError(f"Request '{req_id}' is not cached")
seq_id = self._req_to_seq_id.pop(req_id)
self._seq_to_req_id.pop(seq_id)
self._global_tree.remove(seq_id)
@@ -227,29 +235,33 @@ def evict_request(self, req_id: Hashable):
def speculate(
self,
req_id: Hashable,
- pattern: Sequence[int],
+ context: np.ndarray | Sequence[int],
max_spec_tokens: Optional[int] = None,
max_spec_factor: float = 1.0,
max_spec_offset: float = 0.0,
min_token_prob: float = 0.1,
use_tree_spec: bool = False,
- ) -> SuffixSpecResult:
+ ) -> SuffixDecodingDraft:
"""
Speculates and returns the most likely continuation of a given token
- pattern using the request's prompt and the global cache of previous
+ context using the request's prompt and the global cache of previous
responses. This method can only be called for active requests (i.e.
after calling `start_request` and before calling `stop_request`).
Args:
req_id (Hashable): The unique identifier for the request.
- pattern (Sequence[int]): The sequence of token IDs to match and
- continue from.
+ context (np.ndarray | Sequence[int]): A sequence of token IDs to
+ match and speculate subsequent tokens from.
max_spec_tokens (int): Maximum number of tokens to speculate. If 0,
uses the cache's max_depth.
max_spec_factor (float): Factor that limits speculation based on
- matched pattern length.
+ matched context length. The number of speculated tokens is
+ limited by `max_spec_factor * match_length + max_spec_offset`.
+ max_spec_offset (float): Offset that limits speculation based on
+ matched context length. The number of speculated tokens is
+ limited by `max_spec_factor * match_length + max_spec_offset`.
min_token_prob (float): Minimum estimated probability threshold for
- candidate tokens.
+ draft tokens.
use_tree_spec (bool): If True, uses tree-based speculation.
Returns:
@@ -262,32 +274,40 @@ def speculate(
if req_id not in self._local_trees:
raise ValueError(f"Request '{req_id}' is not active")
+ if isinstance(context, np.ndarray):
+ # If input is a numpy array, use the zero-copy ndarray overload.
+ self._validate_ndarray(context)
+ spec_func = SuffixTree.speculate_ndarray
+ else:
+ spec_func = SuffixTree.speculate
+
if max_spec_tokens is None:
max_spec_tokens = self.max_depth
- if len(pattern) > self._max_tree_depth:
- pattern = pattern[-self._max_tree_depth :]
+ if len(context) > self._max_tree_depth:
+ context = context[-self._max_tree_depth :]
- candidate = self._local_trees[req_id].speculate(
- pattern,
+ draft1 = spec_func(
+ self._local_trees[req_id],
+ context,
max_spec_tokens,
max_spec_factor,
max_spec_offset,
min_token_prob,
use_tree_spec)
- result = SuffixSpecResult.from_candidate(candidate)
- candidate = self._global_tree.speculate(
- pattern,
+ draft2 = spec_func(
+ self._global_tree,
+ context,
max_spec_tokens,
max_spec_factor,
max_spec_offset,
min_token_prob,
use_tree_spec)
- if candidate.score > result.score:
- result = SuffixSpecResult.from_candidate(candidate)
- return result
+ draft = draft1 if draft1.score >= draft2.score else draft2
+
+ return SuffixDecodingDraft.from_native(draft)
def _generate_seq_id(self, req_id: Hashable) -> int:
# Find the next available seq_id not used by an active request.
@@ -313,16 +333,25 @@ def _generate_seq_id(self, req_id: Hashable) -> int:
return seq_id
def _maybe_evict_requests(self, new_seq_id: int):
- if self._max_cached_requests is None:
+ if self._max_cached_requests < 0:
+ # Negative value means no global cache size limit.
return
+ assert self._max_cached_requests != 0 # Global cache must be enabled.
while len(self._req_to_seq_id) > self._max_cached_requests:
- # Evict the first request that is not active. Should be FIFO order
- # in python 3.7+ as dict preserves insertion order. We also want to
- # avoid evicting the request that was just added (new_seq_id).
+ # Evict the first eligible request. Should be FIFO order in Python
+ # 3.7+ since dict preserves insertion order. Avoid evicting the
+ # request that was just added (new_seq_id).
for req_id, seq_id in self._req_to_seq_id.items():
- if seq_id != new_seq_id and req_id not in self._local_trees:
- self.evict_request(req_id)
+ if seq_id != new_seq_id:
+ self.evict_cached_response(req_id)
break
- else:
- # All previously cached requests are active, cannot evict any.
- break
+
+ def _validate_ndarray(self, arr: np.ndarray):
+ if arr.ndim != 1:
+ raise ValueError(f"ndarray input must have ndim=1, "
+ f"got ndim={arr.ndim}")
+ if arr.dtype != np.int32:
+ raise ValueError(f"ndarray input must have dtype=int32, "
+ f"got dtype={arr.dtype.name}")
+ if not arr.flags["CONTIGUOUS"]:
+ raise ValueError(f"ndarray input must be contiguous")
diff --git a/arctic_inference/common/suffix_cache/simulator.py b/arctic_inference/suffix_decoding/simulator.py
similarity index 93%
rename from arctic_inference/common/suffix_cache/simulator.py
rename to arctic_inference/suffix_decoding/simulator.py
index 19a158c9f..0e70ef630 100644
--- a/arctic_inference/common/suffix_cache/simulator.py
+++ b/arctic_inference/suffix_decoding/simulator.py
@@ -22,20 +22,21 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
+import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
-from arctic_inference.common.suffix_cache import SuffixCache
+from arctic_inference.suffix_decoding import SuffixDecodingCache
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def suffix_decode(
- suffix_cache: SuffixCache,
+ suffix_cache: SuffixDecodingCache,
request_id: int,
- prompt: List[int],
- ground_truth_response: List[int],
+ prompt: np.ndarray,
+ ground_truth_response: np.ndarray,
max_spec_tokens: int,
max_spec_factor: float,
min_token_prob: float,
@@ -47,17 +48,21 @@ def suffix_decode(
suffix_cache.start_request(request_id, prompt if use_cached_prompt else [])
- assert isinstance(prompt, list) and isinstance(ground_truth_response, list)
+ assert isinstance(prompt, np.ndarray)
+ assert isinstance(ground_truth_response, np.ndarray)
results = []
response = []
while len(response) < len(ground_truth_response):
- text = prompt + response
+ if response:
+ sequence = np.concatenate([prompt, response], dtype=np.int32)
+ else:
+ sequence = prompt
start_time = time.perf_counter()
result = suffix_cache.speculate(
request_id,
- text,
+ sequence,
max_spec_tokens=max_spec_tokens,
max_spec_factor=max_spec_factor,
min_token_prob=min_token_prob,
@@ -105,7 +110,7 @@ def suffix_decode(
"update_ms": update_time * 1000,
})
- assert response == ground_truth_response
+ assert np.array_equal(response, ground_truth_response)
suffix_cache.stop_request(request_id)
@@ -159,7 +164,7 @@ def process_task(
use_cached_prompt: bool,
evict_fraction: float,
evict_strategy: str,
- max_cached_requests: Optional[int],
+ max_cached_requests: int,
) -> List[Dict]:
eval_subset, train_subset = sample_data(
dataset,
@@ -168,8 +173,8 @@ def process_task(
num_train,
seed,
)
- suffix_cache = SuffixCache(max_tree_depth=max_depth,
- max_cached_requests=max_cached_requests)
+ suffix_cache = SuffixDecodingCache(max_tree_depth=max_depth,
+ max_cached_requests=max_cached_requests)
train_request_ids = []
num_cached_tokens = {} # request_id -> num tokens
for request_id, example in tqdm(train_subset.iterrows(),
@@ -178,7 +183,9 @@ def process_task(
# Use negative request_id to indicate training examples and avoid
# conflicts with eval request_ids numbered 0, .., num_eval - 1.
train_request_id = -1 - request_id
- suffix_cache.insert_new_response(train_request_id, example["response"])
+ suffix_cache.start_request(train_request_id, example["prompt"])
+ suffix_cache.add_active_response(train_request_id, example["response"])
+ suffix_cache.stop_request(train_request_id)
train_request_ids.append(train_request_id)
num_cached_tokens[train_request_id] = len(example["response"])
@@ -194,7 +201,7 @@ def process_task(
rng = random.Random(seed)
evict_ids = rng.sample(cached_request_ids, num_evict)
for request_id in tqdm(evict_ids, desc="Evicting cached responses"):
- suffix_cache.evict_request(request_id)
+ suffix_cache.evict_cached_response(request_id)
print("Checking cache integrity...", end=" ", flush=True)
if ret := suffix_cache._global_tree.check_integrity():
@@ -321,8 +328,10 @@ def tokenize_data(dataset: pd.DataFrame, tokenizer_name: str) -> pd.DataFrame:
responses = []
for _, row in tqdm(dataset.iterrows(), total=len(dataset),
desc="Tokenizing dataset"):
- prompts.append(tokenizer.encode(row["prompt"]))
- responses.append(tokenizer.encode(row["response"]))
+ prompt = tokenizer.encode(row["prompt"], return_tensors="np")
+ prompts.append(prompt.astype(np.int32).flatten())
+ response = tokenizer.encode(row["response"], return_tensors="np")
+ responses.append(response.astype(np.int32).flatten())
return pd.DataFrame({
"prompt": prompts,
"response": responses,
@@ -550,8 +559,8 @@ def get_parser():
"--max-cached-requests",
type=int,
nargs="+",
- default=[None],
- help="Max number of cached requests (if None, unlimited)",
+ default=[-1],
+ help="Max number of cached requests (if -1, unlimited)",
)
parser.add_argument(
"--evict-fraction",
diff --git a/arctic_inference/vllm/args.py b/arctic_inference/vllm/args.py
index 8407e4bec..9cba50abd 100644
--- a/arctic_inference/vllm/args.py
+++ b/arctic_inference/vllm/args.py
@@ -20,7 +20,7 @@
from vllm.config import ParallelConfig
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
-from vllm.utils import FlexibleArgumentParser
+from vllm.utils.argparse_utils import FlexibleArgumentParser
from arctic_inference.patching import ArcticPatch
from arctic_inference.vllm.config import ArcticParallelConfig
@@ -50,7 +50,6 @@ class EngineArgsPatch(ArcticPatch[EngineArgs]):
_orig_add_cli_args = EngineArgs.add_cli_args
_orig_from_cli_args = EngineArgs.__dict__["from_cli_args"].__wrapped__
_orig_create_engine_config = EngineArgs.create_engine_config
- _orig_is_v1_supported_oracle = EngineArgs._is_v1_supported_oracle
def __new__(cls, *args, **kwargs):
# Override __new__ to return an ArcticEngineArgs instead of an
@@ -109,6 +108,11 @@ def create_engine_config(self, *args, **kwargs):
if (self.ulysses_sequence_parallel_size > 1 and
self.distributed_executor_backend is None):
self.distributed_executor_backend = "mp"
+
+ # Store ulysses_sequence_parallel_size for access during config initialization
+ from arctic_inference.vllm import ulysses
+ ulysses._ulysses_sp_size = self.ulysses_sequence_parallel_size
+
vllm_config = self._orig_create_engine_config(*args, **kwargs)
# Recreate the parallel config with Arctic parameters since they might
# not be passed to the parallel config __init__ when first initialized.
@@ -121,21 +125,6 @@ def create_engine_config(self, *args, **kwargs):
vllm_config.parallel_config = ArcticParallelConfig(**kwargs)
return vllm_config
- def _is_v1_supported_oracle(self, *args, **kwargs):
- orig_speculative_config = self.speculative_config
-
- # Since Arctic Inference is only compatible with v1 and we already
- # check it earlier, we can just disable this check altogether.
- if (self.speculative_config is not None and
- self.speculative_config.get("method") in ("arctic", "suffix")):
- self.speculative_config = None
-
- res = self._orig_is_v1_supported_oracle(*args, **kwargs)
-
- self.speculative_config = orig_speculative_config
-
- return res
-
class AsyncEngineArgsPatch(ArcticPatch[AsyncEngineArgs]):
diff --git a/arctic_inference/vllm/config.py b/arctic_inference/vllm/config.py
index 4cda01436..b310e1f4f 100644
--- a/arctic_inference/vllm/config.py
+++ b/arctic_inference/vllm/config.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
+from pydantic.dataclasses import dataclass
import logging
+import vllm
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
@@ -55,8 +56,10 @@ def world_size(self, value: int) -> None:
@dataclass
class ArcticSpeculativeConfig(SpeculativeConfig):
+ method: str | None = None
enable_suffix_decoding: bool = False
suffix_cache_max_depth: int = 64
+ suffix_speculative_tokens: int = 0
suffix_cache_max_requests: int = 100000
suffix_max_spec_factor: float = 1.0
suffix_max_spec_offset: float = 0.0
@@ -76,46 +79,66 @@ def __new__(cls, *args, **kwargs):
class SpeculativeConfigPatch(ArcticPatch[SpeculativeConfig]):
- _orig_from_dict = SpeculativeConfig.__dict__["from_dict"].__wrapped__
_orig_post_init = SpeculativeConfig.__post_init__
def __new__(cls, *args, **kwargs):
- # Override __new__ to return an ArcticSpeculativeConfig instead of a
- # SpeculativeConfig when creating a new instance of the class.
if cls is SpeculativeConfig:
return ArcticSpeculativeConfig.__new__(ArcticSpeculativeConfig,
*args, **kwargs)
return super(SpeculativeConfig, cls).__new__(cls)
def __post_init__(self):
- use_suffix = (self.method
- == "suffix") or (self.method is None
- and self.enable_suffix_decoding)
- if (use_suffix or self.method == "arctic") and \
- self.disable_by_batch_size is None:
+ is_arctic_method = self.method in ("arctic", "mlp_speculator")
+ use_suffix = (self.method == "suffix") or (self.method is None
+ and self.enable_suffix_decoding)
+ use_hybrid = (self.method == "arctic" and self.enable_suffix_decoding)
+
+ if (use_suffix or is_arctic_method) and self.disable_by_batch_size is None:
logger.info("Defaulting disable_by_batch_size to 64")
self.disable_by_batch_size = 64
+
+ if use_hybrid:
+ self.suffix_speculative_tokens = self.suffix_cache_max_depth
if use_suffix:
self.method = "suffix"
self.enable_suffix_decoding = True
- self.num_speculative_tokens = self.suffix_cache_max_depth
+ # Use suffix_speculative_tokens if explicitly set, otherwise
+ # default to 16 (not suffix_cache_max_depth which can be very
+ # large and makes every step process 1+N tokens even when the
+ # suffix cache has no matches).
+ # NOTE: num_speculative_tokens defaults to None (not 0).
+ if self.suffix_speculative_tokens > 0:
+ self.num_speculative_tokens = self.suffix_speculative_tokens
+ elif self.num_speculative_tokens is None:
+ self.num_speculative_tokens = 16
self._verify_args()
+ return
+
+ if is_arctic_method:
+ actual_draft_model = getattr(self, "draft_model", None)
+
+ self.draft_model = None
+
+ try:
+ self._orig_post_init()
+ finally:
+ self.draft_model = actual_draft_model
+
+ if self.num_speculative_tokens == 0:
+ self.num_speculative_tokens = getattr(self, "num_lookahead_slots", 1)
else:
self._orig_post_init()
- @classmethod
- def from_dict(cls, dict_value: dict) -> SpeculativeConfig:
- """Parse the CLI value for the speculative config."""
- if cls is SpeculativeConfig:
- return SpeculativeConfigPatch._orig_from_dict(
- ArcticSpeculativeConfig, dict_value)
- return SpeculativeConfigPatch._orig_from_dict(cls, dict_value)
-
class VllmConfigPatch(ArcticPatch[VllmConfig]):
_orig_str = VllmConfig.__str__
+ _orig_post_init = VllmConfig.__post_init__
+
+ from typing import Literal
+ OldEagleModelTypes = vllm.config.speculative.EagleModelTypes
+ NewEagleModelTypes = Literal["arctic", "suffix", OldEagleModelTypes]
def __str__(self, *args, **kwargs):
string = self._orig_str(*args, **kwargs)
@@ -124,6 +147,24 @@ def __str__(self, *args, **kwargs):
string += f", shift_parallel_threshold={self.parallel_config.shift_parallel_threshold}"
return string
+ def __post_init__(self, *args, **kwargs):
+ # if self.speculative_config is not None:
+ # if self.speculative_config.method not in get_args(EagleModelTypes):
+ # raise ValueError(
+ # "Currently, async scheduling is only supported "
+ # "with EAGLE/MTP kind of speculative decoding"
+ # )
+ import sys
+ from typing import Literal
+ target_module = sys.modules[VllmConfig.__module__]
+ original_types = getattr(target_module, "EagleModelTypes")
+ NewEagleModelTypes = Literal["mlp_speculator", "suffix", original_types]
+ setattr(target_module, "EagleModelTypes", NewEagleModelTypes)
+ try:
+ self._orig_post_init(*args, **kwargs)
+ finally:
+ setattr(target_module, "EagleModelTypes", original_types)
+
class MLPSpeculatorConfigPatch(ArcticPatch[MLPSpeculatorConfig]):
@@ -132,3 +173,16 @@ class MLPSpeculatorConfigPatch(ArcticPatch[MLPSpeculatorConfig]):
def __init__(self, *args, **kwargs):
self.base_model_arch = kwargs.pop("base_model_arch", "")
self._orig_init(*args, **kwargs)
+
+ # Inject dummy attributes required by vLLM's ModelArchConfigConvertor
+ # The convertor tries to calculate head_size = hidden_size // num_attention_heads
+ if not hasattr(self, "num_attention_heads"):
+ self.num_attention_heads = 1
+
+ if not hasattr(self, "hidden_size"):
+ # Fallback to n_embd if present, otherwise default to a safe dummy value
+ self.hidden_size = getattr(self, "n_embd", 1024)
+
+ # Ensure hidden_size is an integer to prevent TypeError during division
+ if hasattr(self, "hidden_size"):
+ self.hidden_size = int(self.hidden_size)
diff --git a/arctic_inference/vllm/model_runner.py b/arctic_inference/vllm/model_runner.py
index b5530ca1d..2cff488ef 100644
--- a/arctic_inference/vllm/model_runner.py
+++ b/arctic_inference/vllm/model_runner.py
@@ -5,7 +5,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,47 +15,58 @@
import contextlib
import copy
+import gc
import time
-from typing import Any, Union, Optional, TYPE_CHECKING
-from itertools import tee
+from typing import Any, Optional, TYPE_CHECKING, Union
import numpy as np
import torch
+from tqdm import tqdm
+
import vllm.distributed.parallel_state as parallel_state
import vllm.envs as envs
-from tqdm import tqdm
-from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
-from vllm.config import CompilationLevel
+from vllm.compilation.monitor import set_cudagraph_capturing_enabled
+from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
is_global_first_rank)
-from vllm.forward_context import set_forward_context
-from vllm.config import VllmConfig
+from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.model_loader import get_model
from vllm.sequence import IntermediateTensors
-from vllm.utils import round_up
-from vllm.v1.kv_cache_interface import KVCacheConfig
-from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+from vllm.utils.math_utils import round_up, cdiv
+from vllm.v1.attention.backend import CommonAttentionMetadata
+from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput,
+ SamplerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import MAX_SPEC_LEN, RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
-from vllm.v1.worker.gpu_model_runner import GPUModelRunner, logger
+from vllm.v1.utils import record_function_or_nullcontext
+from vllm.v1.worker.gpu_model_runner import (
+ GPUModelRunner,
+ logger,
+ AsyncGPUModelRunnerOutput,
+)
+from vllm.v1.structured_output.utils import apply_grammar_bitmask
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
-from arctic_inference.common.suffix_cache import SuffixCache
from arctic_inference.patching import ArcticPatch
-from arctic_inference.vllm.spec_dec.arctic_proposer import ArcticProposer
-from arctic_inference.common.suffix_cache import SuffixSpecResult
+from arctic_inference.suffix_decoding import (SuffixDecodingCache,
+ SuffixDecodingDraft)
+from arctic_inference.vllm.spec_dec.arctic_proposer import (ArcticProposer,
+ SuffixProposer)
SP_TP_MODE = None
@contextlib.contextmanager
def set_shift_parallel_mode(mode: Optional[bool]):
+ """
+ Swap the tensor-parallel group to an SP-compatible variant when 'mode' is True.
+ """
if mode is None:
yield
return
@@ -76,7 +87,6 @@ def set_shift_parallel_mode(mode: Optional[bool]):
try:
yield
finally:
- # restore the original state
SP_TP_MODE = old_mode
parallel_state._TP = old_tp_group
@@ -88,113 +98,299 @@ def is_shift_parallel_mode() -> bool:
class GPUModelRunnerPatch(ArcticPatch[GPUModelRunner]):
+ """
+ Rebased GPUModelRunnerPatch for vLLM v14.
+ """
- _orig_initialize_kv_cache = GPUModelRunner.initialize_kv_cache
- _orig_prepare_inputs = GPUModelRunner._prepare_inputs
+ _orig_capture_cudagraphs = GPUModelRunner._capture_cudagraphs
_orig_profile_run = GPUModelRunner.profile_run
_orig_load_model = GPUModelRunner.load_model
_orig_propose_draft_token_ids = GPUModelRunner.propose_draft_token_ids
+ _orig_dummy_run = GPUModelRunner._dummy_run
_orig_init = GPUModelRunner.__init__
+ _orig_build_attention_metadata = GPUModelRunner._build_attention_metadata
+ _orig_execute_model = GPUModelRunner.execute_model
+ _orig_bookkeeping_sync = GPUModelRunner._bookkeeping_sync
+ _orig_sample_tokens = GPUModelRunner.sample_tokens
+ _orig_initialize_kv_cache = GPUModelRunner.initialize_kv_cache
+ # _orig_pad_for_sequence_parallelism = GPUModelRunner._pad_for_sequence_parallelism
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
- # Ulysses sequence parallelism
if vllm_config.parallel_config.ulysses_sequence_parallel_size > 1:
self.use_ulysses = True
pass_config = vllm_config.compilation_config.pass_config
- if pass_config.enable_sequence_parallelism:
+ if pass_config.enable_sp:
raise ValueError(
"Ulysses sequence parallelism is incompatible with native "
"sequence parallelism. Set enable_sequence_parallelism "
- "to False in the pass config to use Ulysses.")
+ "to False in the pass config to use Ulysses."
+ )
else:
self.use_ulysses = False
- # Speculative decoding
- # TODO: Use "arctic" as an umbrella method that also covers the Arctic
- # Inverence version of "mlp_speculator".
- if (vllm_config.speculative_config is not None and \
- vllm_config.speculative_config.method in (
- "arctic", "suffix", "mlp_speculator")):
- # Delay the creation of the drafter until
- # after the child class has been initialized.
+ arctic_methods = ("arctic", "suffix", "mlp_speculator")
+ is_arctic_spec = (vllm_config.speculative_config is not None and
+ vllm_config.speculative_config.method in arctic_methods)
+
+ arctic_speculative_config = None
+ if is_arctic_spec:
arctic_speculative_config = vllm_config.speculative_config
vllm_config.speculative_config = None
- else:
- arctic_speculative_config = None
self._orig_init(vllm_config, device)
- # Set up speculative decoding.
- self._suffix_cache = None
- if arctic_speculative_config is not None:
- # Restore the speculative config.
+ self._suffix_cache: Optional[SuffixDecodingCache] = None
+
+ if is_arctic_spec:
self.vllm_config.speculative_config = arctic_speculative_config
self.speculative_config = arctic_speculative_config
+ self.num_spec_tokens = getattr(self.speculative_config,
+ "num_speculative_tokens", 0)
+ self.uniform_decode_query_len = 1 + self.num_spec_tokens
+
+ if not hasattr(self, "draft_token_ids_cpu") or self.draft_token_ids_cpu is None:
+ self.draft_token_ids_event = torch.Event()
+ self.draft_token_ids_copy_stream = torch.cuda.Stream()
+ self.draft_token_ids_cpu = torch.empty(
+ (self.max_num_reqs, self.num_spec_tokens),
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+
+ if (self.use_async_scheduling
+ and self.speculative_config.method in ("arctic", "mlp_speculator", "suffix")):
+ if not hasattr(self, "valid_sampled_token_count_cpu") or self.valid_sampled_token_count_cpu is None:
+ self.valid_sampled_token_count_event = torch.Event()
+ self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
+ self.valid_sampled_token_count_cpu = torch.empty(
+ self.max_num_reqs,
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
if get_pp_group().is_last_rank:
- if (self.speculative_config.method == "arctic"
- or self.speculative_config.method == "mlp_speculator"):
+ if self.speculative_config.method in ("arctic", "mlp_speculator"):
self.drafter = ArcticProposer(self.vllm_config)
- elif self.speculative_config.method != "suffix":
- raise ValueError("Unknown speculative decoding method: "
- f"{self.speculative_config.method}")
-
- self.rejection_sampler = RejectionSampler()
-
- if (self.speculative_config is not None
- and self.speculative_config.enable_suffix_decoding):
- if self.speculative_config.method not in ("arctic", "suffix",
- "mlp_speculator"):
+ elif self.speculative_config.method == "suffix":
+ self.drafter = SuffixProposer()
+ else:
+ raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
+
+ self.rejection_sampler = RejectionSampler(self.sampler)
+
+ if (self.speculative_config is not None and
+ getattr(self.speculative_config, "enable_suffix_decoding", False)):
+
+ if self.speculative_config.method not in arctic_methods:
raise ValueError(
"Suffix decoding is only supported with the 'arctic', "
- "'mlp_speculator' or 'suffix' spec decoding methods.")
+ "'mlp_speculator' or 'suffix' spec decoding methods."
+ )
spec_cfg = self.speculative_config
- self._suffix_cache = SuffixCache(
+ self._suffix_cache = SuffixDecodingCache(
max_tree_depth=spec_cfg.suffix_cache_max_depth,
- max_cached_requests=spec_cfg.suffix_cache_max_requests)
+ max_cached_requests=spec_cfg.suffix_cache_max_requests
+ )
+
+ # Async suffix decoding infrastructure: a dedicated CUDA stream and
+ # pinned buffer for copying sampled token IDs to CPU *without*
+ # serialising behind Arctic GPU drafting work on the default stream.
+ if self._suffix_cache is not None and self.use_async_scheduling:
+ self.suffix_copy_stream = torch.cuda.Stream()
+ self.suffix_copy_done_event = torch.Event()
+ max_gen_len = 1 + self.num_spec_tokens
+ self.suffix_sampled_ids_pinned = torch.empty(
+ (self.max_num_reqs, max_gen_len),
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+ # Pinned buffer for suffix merge results. Using pinned memory
+ # for H2C copies avoids the implicit default-stream
+ # synchronisation that cudaMemcpyAsync performs with pageable
+ # (non-pinned) source memory. Without this, the merge step
+ # blocks the CPU until ALL pending GPU work (including Arctic
+ # drafting) completes, destroying the async overlap.
+ self._suffix_merge_pinned = torch.zeros(
+ (self.max_num_reqs, self.num_spec_tokens),
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+
+ # Pre-allocated GPU buffer for the merged draft tensor.
+ # Avoids per-step F.pad / torch.zeros allocations in the async
+ # Arctic drafting path (propose_draft_token_ids + suffix merge).
+ # Shape: [max_num_reqs, num_spec_tokens], int64 (matches
+ # draft_token_ids_cpu for zero-cost _copy_draft_token_ids_to_cpu).
+ if (self.speculative_config is not None
+ and self.use_async_scheduling
+ and self.speculative_config.method
+ in ("arctic", "mlp_speculator")):
+ self._draft_merged_gpu = torch.zeros(
+ (self.max_num_reqs, self.num_spec_tokens),
+ dtype=torch.int64, device=self.device,
+ )
+
+ # Pre-allocated pinned index buffer for suffix merge overlay.
+ # Avoids per-step torch.tensor(...).pin_memory() allocations.
+ if self._suffix_cache is not None and self.use_async_scheduling:
+ self._suffix_index_pinned = torch.empty(
+ self.max_num_reqs, dtype=torch.long,
+ device="cpu", pin_memory=self.pin_memory,
+ )
+
+ # Per-request response tokens for suffix pattern building in async
+ # mode. In async scheduling, _bookkeeping_sync writes -1 placeholders
+ # to token_ids_cpu instead of real values, corrupting the pattern that
+ # propose_suffix_draft_token_ids reads. We keep a clean copy here.
+ self._suffix_response_tokens: dict[str, list[int]] = {}
+
+ # Actual draft lengths per request from the previous step. Used
+ # by execute_model to trim the scheduler's spec token allocation
+ # down to the real draft width, and communicated back to the
+ # scheduler (via scheduler_output._actual_draft_lens) so
+ # _update_after_schedule can set dynamic placeholder counts.
+ self._prev_actual_draft_lens: dict[str, int] = {}
+
+ # Backup-token buffer used by suffix-only async rejection sampling.
+ # The arctic proposer has its own buffer; this one covers the case
+ # where no arctic drafter is present.
+ self._suffix_backup_tokens_gpu: Optional[torch.Tensor] = None
+ if (self._suffix_cache is not None
+ and self.use_async_scheduling
+ and self.speculative_config.method not in ("arctic",
+ "mlp_speculator")):
+ self._suffix_backup_tokens_gpu = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, device=self.device,
+ )
+
+ def _suffix_only_rejection_sample(
+ self,
+ sampled_token_ids: torch.Tensor,
+ common_attn_metadata: "CommonAttentionMetadata",
+ ) -> None:
+ """Rejection-sample accepted tokens for suffix-only async scheduling.
+
+ EAGLE / arctic do this inside propose_draft_token_ids via
+ prepare_next_token_ids_padded. For suffix-only there is no
+ drafter with that method, so we call the same Triton kernel
+ directly and feed the results into _copy_valid_sampled_token_count.
+ """
+ from vllm.triton_utils import triton
+ from vllm.v1.spec_decode.utils import (
+ eagle_prepare_next_token_padded_kernel,
+ )
+
+ num_reqs = self.input_batch.num_reqs
+ batch_size, num_tokens = sampled_token_ids.shape
+ device = sampled_token_ids.device
+
+ # Compute backup tokens (last accepted token per request) on CPU,
+ # then copy to GPU in one shot to avoid per-element synchronisation.
+ backup = self._suffix_backup_tokens_gpu
+ assert backup is not None
+ backup_np = np.empty(num_reqs, dtype=np.int32)
+ for i in range(num_reqs):
+ req_id = self.input_batch.req_ids[i]
+ seq_len = int(common_attn_metadata.seq_lens_cpu[i].item())
+ backup_np[i] = self.requests[req_id].get_token_id(seq_len)
+ # Copy directly from CPU numpy-backed tensor to GPU; avoids
+ # creating an intermediate GPU tensor via .to(device).
+ backup[:num_reqs].copy_(
+ torch.from_numpy(backup_np), non_blocking=True,
+ )
+
+ next_token_ids = torch.empty(batch_size, dtype=torch.int32,
+ device=device)
+ valid_counts = torch.empty(batch_size, dtype=torch.int32,
+ device=device)
+
+ BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
+ eagle_prepare_next_token_padded_kernel[(batch_size,)](
+ sampled_token_ids,
+ self.discard_request_mask.gpu,
+ backup,
+ next_token_ids,
+ valid_counts,
+ self.model_config.get_vocab_size(),
+ num_tokens,
+ batch_size,
+ sampled_token_ids.stride(0),
+ BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
+ )
+
+ self._copy_valid_sampled_token_count(next_token_ids, valid_counts)
+
+ def _build_attention_metadata(self, *args, **kwargs):
+ attn_metadata, spec_decode_common_attn_metadata = \
+ self._orig_build_attention_metadata(*args, **kwargs)
+
+ logits_indices = kwargs.get("logits_indices", None)
+ if logits_indices is not None:
+ if isinstance(attn_metadata, list):
+ for ub in attn_metadata:
+ for meta in ub.values():
+ meta.swiftkv_logits_indices = logits_indices
+ else:
+ for meta in attn_metadata.values():
+ meta.swiftkv_logits_indices = logits_indices
+
+ return attn_metadata, spec_decode_common_attn_metadata
+
+ # set padding for SP here
+ def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
+
+ sp_size = self.parallel_config.ulysses_sequence_parallel_size
+ num_input_tokens = round_up(num_scheduled_tokens, sp_size)
+
+ #if torch.distributed.get_rank() == 0:
+ # print(f"padding num_scheduled_tokens {num_scheduled_tokens} -> num_input_tokens {num_input_tokens}")
+
+ return num_input_tokens
+
def profile_run(self) -> None:
self._orig_profile_run()
- if self.shift_model is not None:
- # Run the shift model to trigger compilation.
+ if getattr(self, "shift_model", None) is not None:
orig_model, self.model = self.model, self.shift_model
+ cc = self.vllm_config.compilation_config
+ base_ctx = cc.static_forward_context
+ shift_ctx = getattr(self, 'shift_forward_context', None)
try:
+ if shift_ctx is not None:
+ cc.static_forward_context = shift_ctx
with set_shift_parallel_mode(True):
self._dummy_run(self.max_num_tokens, is_profile=True)
finally:
self.model = orig_model
+ cc.static_forward_context = base_ctx
- def _prepare_inputs(self, *args, **kwargs):
- attn_metadata, attention_cuda_graphs, logits_indices, *rest = (
- self._orig_prepare_inputs(*args, **kwargs))
- # SwiftKV requires knowing the logits indices from inside the model
- # definition in order to early-stop the prefill tokens.
- for meta in attn_metadata.values():
- meta.swiftkv_logits_indices = logits_indices
- return attn_metadata, attention_cuda_graphs, logits_indices, *rest
def monkeypatch_forward(self: GPUModelRunner):
+ """
+ Slice the batch across Ulysses SP ranks for forward, then all-gather.
+ """
sp_size = parallel_state._SP.world_size
sp_rank = parallel_state._SP.rank_in_group
device_group = parallel_state._SP.device_group
model_forward = self.model.forward
- input_key = 'inputs_embeds' if self.is_multimodal_model else 'input_ids'
+ input_key = 'inputs_embeds' if self.supports_mm_inputs else 'input_ids'
def ulysses_forward(*args, **kwargs):
- # update inputs
input_tensor = kwargs[input_key]
positions = kwargs['positions']
- # Ulysses parameters
- N = input_tensor.shape[0]
+ N = input_tensor.shape[0]
N_ulysses = N // sp_size
N_offset = N_ulysses * sp_rank
- # narrow the input
kwargs[input_key] = input_tensor[N_offset:N_offset + N_ulysses]
kwargs['positions'] = positions[N_offset:N_offset + N_ulysses]
@@ -202,459 +398,629 @@ def ulysses_forward(*args, **kwargs):
output = model_forward(*args, **kwargs)
if output.size(0) == N_ulysses:
- # all-gather model_output
- model_output = torch.empty((N, self.hidden_size),
+ model_output = torch.empty((N, output.shape[1]),
dtype=output.dtype,
device=output.device)
torch.distributed.all_gather_into_tensor(model_output,
output,
group=device_group)
else:
- # SwiftKV models will already have all-gathered the output.
assert output.size(0) == N
model_output = output
return model_output
- self.model.forward = ulysses_forward
+ self.get_model().forward = ulysses_forward
@torch.inference_mode()
- def execute_model(
+ def _dummy_run(
self,
- scheduler_output: "SchedulerOutput",
- intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> Union[ModelRunnerOutput, IntermediateTensors]:
- self._update_states(scheduler_output)
- if not scheduler_output.total_num_scheduled_tokens:
- if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if there's no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
-
- return self.kv_connector_no_forward(scheduler_output)
-
- # Prepare the decoder inputs.
- (attn_metadata, attention_cuda_graphs, logits_indices,
- spec_decode_metadata,
- num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
- num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
- use_shift_model = (self.use_ulysses and self.shift_model is not None
- and num_scheduled_tokens
- <= self.shift_parallel_threshold)
- if self.use_ulysses and not use_shift_model:
- # add padding to the batch size to make it a multiple of SP
- sp_size = self.parallel_config.ulysses_sequence_parallel_size
- num_input_tokens = round_up(num_scheduled_tokens, sp_size)
- if (self.use_cuda_graph and num_input_tokens // sp_size
- <= self.cudagraph_batch_sizes[-1]):
- num_input_tokens = self.vllm_config.pad_for_cudagraph(
- num_input_tokens // sp_size) * sp_size
- elif (self.use_cuda_graph
- and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
- # Use piecewise CUDA graphs.
- # Add padding to the batch size.
- num_input_tokens = self.vllm_config.pad_for_cudagraph(
- num_scheduled_tokens)
+ num_tokens: int,
+ cudagraph_runtime_mode: CUDAGraphMode | None = None,
+ force_attention: bool = False,
+ uniform_decode: bool = False,
+ allow_microbatching: bool = True,
+ skip_eplb: bool = False,
+ is_profile: bool = False,
+ create_mixed_batch: bool = False,
+ remove_lora: bool = True,
+ activate_lora: bool = False,
+ is_graph_capturing: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+
+ from vllm.v1.worker.gpu_model_runner import supports_mm_encoder_only
+ if supports_mm_encoder_only(self.model):
+ return torch.tensor([]), torch.tensor([])
+
+ max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens
+ max_num_reqs = self.scheduler_config.max_num_seqs
+
+ if create_mixed_batch:
+ num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2)
+ num_prefill_tokens = num_tokens - num_decode_tokens
+ num_reqs = num_decode_tokens + 1
+ num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens]
+ max_query_len = num_prefill_tokens
+ elif uniform_decode:
+ num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
+ num_scheduled_tokens_list = [max_query_len] * num_reqs
+ if num_tokens % max_query_len != 0:
+ num_scheduled_tokens_list[-1] = num_tokens % max_query_len
else:
- # Eager mode.
- # Pad tokens to multiple of tensor_parallel_size when
- # enabled collective fusion for SP
- tp_size = self.vllm_config.parallel_config.tensor_parallel_size
- if self.compilation_config.pass_config. \
- enable_sequence_parallelism and tp_size > 1:
- num_input_tokens = round_up(num_scheduled_tokens, tp_size)
- else:
- num_input_tokens = num_scheduled_tokens
-
- # Padding for DP
- num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
- num_input_tokens += num_pad
-
- # _prepare_inputs may reorder the batch, so we must gather multi
- # modal outputs after that to ensure the correct order
- if self.is_multimodal_model:
- # Run the multimodal encoder if any.
- self._execute_mm_encoder(scheduler_output)
- mm_embeds = self._gather_mm_embeddings(scheduler_output)
- else:
- mm_embeds = []
-
- if self.is_multimodal_model and get_pp_group().is_first_rank:
- # NOTE(woosuk): To unify token ids and soft tokens (vision
- # embeddings), we always use embeddings (rather than token ids)
- # as input to the multimodal model, even when the input is text.
- input_ids = self.input_ids[:num_scheduled_tokens]
- if mm_embeds:
- inputs_embeds = self.model.get_input_embeddings(
- input_ids, mm_embeds)
+ num_reqs = min(num_tokens, max_num_reqs)
+ min_tokens_per_req = num_tokens // num_reqs
+ num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
+ num_scheduled_tokens_list[-1] += num_tokens % num_reqs
+
+ num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
+ num_tokens_unpadded = int(num_scheduled_tokens.sum())
+ num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
+
+ if torch.distributed.get_rank() == 0:
+ print(f"num_tokens_unpadded: {num_tokens_unpadded}, num_reqs: {num_reqs}")
+
+ _cg_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
+ self._determine_batch_execution_and_padding(
+ num_tokens=num_tokens_unpadded,
+ num_reqs=num_reqs,
+ num_scheduled_tokens_np=num_scheduled_tokens,
+ max_num_scheduled_tokens=max_query_len,
+ use_cascade_attn=False,
+ allow_microbatching=allow_microbatching,
+ force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
+ force_uniform_decode=uniform_decode,
+ force_has_lora=activate_lora,
+ )
+ )
+
+ if cudagraph_runtime_mode is None:
+ cudagraph_runtime_mode = _cg_mode
+
+ num_tokens_padded = batch_desc.num_tokens
+ num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
+
+ from vllm.v1.worker.gpu_model_runner import maybe_create_ubatch_slices
+ ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
+ should_ubatch,
+ num_scheduled_tokens,
+ num_tokens_padded,
+ num_reqs_padded,
+ self.vllm_config.parallel_config.num_ubatches,
+ )
+
+ logits_indices_cpu = np.cumsum(num_scheduled_tokens) - 1
+ logits_indices = torch.from_numpy(logits_indices_cpu).to(self.device)
+
+ attn_metadata = None
+ if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
+ if create_mixed_batch:
+ seq_lens_list = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
- inputs_embeds = self.model.get_input_embeddings(input_ids)
- # TODO(woosuk): Avoid the copy. Optimize.
- self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
- inputs_embeds = self.inputs_embeds[:num_input_tokens]
- input_ids = None
- else:
- # For text-only models, we use token ids as input.
- # While it is possible to use embeddings as input just like the
- # multimodal models, it is not desirable for performance since
- # then the embedding layer is not included in the CUDA graph.
- input_ids = self.input_ids[:num_input_tokens]
- inputs_embeds = None
- if self.uses_mrope:
- positions = self.mrope_positions[:, :num_input_tokens]
- else:
- positions = self.positions[:num_input_tokens]
+ seq_lens_list = [max_query_len] * num_reqs # simplified
+
+ self.seq_lens.np[:num_reqs] = seq_lens_list
+ self.seq_lens.np[num_reqs:] = 0
+ self.seq_lens.copy_to_gpu()
+
+ cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
+ self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
+ self.query_start_loc.copy_to_gpu()
+
+ pad_attn = (cudagraph_runtime_mode == CUDAGraphMode.FULL)
+ attn_metadata, _ = self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_reqs=num_reqs_padded,
+ max_query_len=max_query_len,
+ ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
+ for_cudagraph_capture=is_graph_capturing,
+ )
- if get_pp_group().is_first_rank:
- intermediate_tensors = None
- else:
- intermediate_tensors = self.sync_and_slice_intermediate_tensors(
- num_input_tokens, intermediate_tensors, True)
-
- # Some attention backends only support CUDA Graphs in pure decode.
- # If attention doesn't support CUDA Graphs for this batch, but we
- # compiled with full CUDA graphs, we have to skip them entirely.
- skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
-
- # Run the model.
- # Use persistent buffers for CUDA graphs.
- with set_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=num_input_tokens,
- num_tokens_across_dp=num_tokens_across_dp,
- skip_cuda_graphs=skip_cuda_graphs,
- ):
- self.maybe_setup_kv_connector(scheduler_output)
-
- model = self.shift_model if use_shift_model else self.model
- with set_shift_parallel_mode(use_shift_model):
- model_output = model(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- )
+ if attn_metadata is not None:
+ if isinstance(attn_metadata, list):
+ for ub_meta in attn_metadata:
+ for meta in ub_meta.values():
+ meta.swiftkv_logits_indices = logits_indices
+ else:
+ for meta in attn_metadata.values():
+ meta.swiftkv_logits_indices = logits_indices
+
+ with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens,
+ num_sampled_tokens, activate_lora, remove_lora):
+
+ model_kwargs = self._init_model_kwargs()
+ if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
+ input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)
+ model_kwargs.update(self._dummy_mm_kwargs(num_reqs))
+ elif self.enable_prompt_embeds:
+ input_ids = None
+ inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
+ else:
+ input_ids = self.input_ids.gpu[:num_tokens_padded]
+ inputs_embeds = None
- self.maybe_wait_for_kv_save()
- finished_sending, finished_recving = (
- self.get_finished_kv_transfers(scheduler_output))
+ positions = self.positions.gpu[:num_tokens_padded]
+ if self.uses_mrope:
+ positions = self.mrope_positions.gpu[:, :num_tokens_padded]
- if self.use_aux_hidden_state_outputs:
- hidden_states, aux_hidden_states = model_output
- else:
- hidden_states = model_output
- aux_hidden_states = None
-
- # Broadcast PP output for external_launcher (torchrun)
- # to make sure we are synced across pp ranks
- # TODO: Support overlapping mirco-batches
- # https://github.com/vllm-project/vllm/issues/18019
- broadcast_pp_output = \
- self.parallel_config.distributed_executor_backend \
- == "external_launcher" and len(get_pp_group().ranks) > 0
- if not get_pp_group().is_last_rank:
- # For mid-pipeline stages, return the hidden states.
- if not broadcast_pp_output:
- return hidden_states
- assert isinstance(hidden_states, IntermediateTensors)
- get_pp_group().send_tensor_dict(hidden_states.tensors,
- all_gather_group=get_tp_group())
- logits = None
- else:
- if self.input_batch.pooling_params:
- return self._pool(hidden_states, num_scheduled_tokens,
- num_scheduled_tokens_np, finished_sending,
- finished_recving)
-
- sample_hidden_states = hidden_states[logits_indices]
- logits = self.model.compute_logits(sample_hidden_states, None)
- if broadcast_pp_output:
- model_output_broadcast_data = {
- "logits": logits.contiguous(),
- } if logits is not None else {}
- model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
- model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
- assert model_output_broadcast_data is not None
- logits = model_output_broadcast_data["logits"]
-
- # Apply structured output bitmasks if present
- if scheduler_output.grammar_bitmask is not None:
- self.apply_grammar_bitmask(scheduler_output, logits)
-
- # Sample the next token and get logprobs if needed.
+ intermediate_tensors = None
+ if not get_pp_group().is_first_rank:
+ if self.intermediate_tensors is None:
+ self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
+ batch_size=self.max_num_tokens, dtype=self.model_config.dtype, device=self.device)
+ intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False)
+
+ target_num_tokens = num_tokens_padded
+ if ubatch_slices_padded is not None:
+ target_num_tokens = ubatch_slices_padded[0].num_tokens
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[:] = target_num_tokens
+
+ with self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context(
+ attn_metadata, self.vllm_config, num_tokens=target_num_tokens,
+ num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode,
+ batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded):
+
+ outputs = self.model(input_ids=input_ids, positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds, **model_kwargs)
+
+ hidden_states = outputs[0] if self.use_aux_hidden_state_outputs else outputs
+
+ if self.speculative_config and self.speculative_config.use_eagle():
+ self.drafter.dummy_run(num_tokens, use_cudagraphs=False, is_graph_capturing=is_graph_capturing)
+
+ if not skip_eplb:
+ self.eplb_step(is_dummy=True, is_profile=is_profile)
+
+ return hidden_states, hidden_states[logits_indices]
+
+ # ------------------------------------------------------------------
+ # _sample: inline the base GPUModelRunner._sample logic here because
+ # the class is monkey-patched at runtime, making both super() and
+ # GPUModelRunner._sample(self, ...) resolve back to this method.
+ # ------------------------------------------------------------------
+ def _sample(
+ self,
+ logits: torch.Tensor | None,
+ spec_decode_metadata: SpecDecodeMetadata | None,
+ ) -> SamplerOutput:
sampling_metadata = self.input_batch.sampling_metadata
+ self.input_batch.update_async_output_token_ids()
if spec_decode_metadata is None:
- sampler_output = self.sampler(
- logits=logits,
- sampling_metadata=sampling_metadata,
- )
- else:
- # When indexing with a tensor (bonus_logits_indices), PyTorch
- # creates a new tensor with separate storage from the original
- # logits tensor. This means any in-place operations on bonus_logits
- # won't affect the original logits tensor.
- assert logits is not None
- bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
- sampler_output = self.sampler(
- logits=bonus_logits,
- sampling_metadata=sampling_metadata,
- )
- bonus_token_ids = sampler_output.sampled_token_ids
-
- # Just like `bonus_logits`, `target_logits` is a new tensor with
- # separate storage from the original `logits` tensor. Therefore,
- # it is safe to update `target_logits` in place.
- target_logits = logits[spec_decode_metadata.target_logits_indices]
- output_token_ids = self.rejection_sampler(
- spec_decode_metadata,
- None, # draft_probs
- target_logits,
- bonus_token_ids,
- sampling_metadata,
- )
- sampler_output.sampled_token_ids = output_token_ids
-
- num_nans_in_logits = {}
- if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
- num_nans_in_logits = self._get_nans_in_logits(logits)
-
- # TODO(woosuk): The following loop can be slow since it iterates over
- # the requests one by one. Optimize.
- discard_sampled_tokens_req_indices = []
- for i, req_id in enumerate(self.input_batch.req_ids):
- req_state = self.requests[req_id]
- seq_len = (req_state.num_computed_tokens +
- scheduler_output.num_scheduled_tokens[req_id])
- if seq_len < req_state.num_tokens:
- # Ignore the sampled token for partial prefills.
- # Rewind the generator state as if the token was not sampled.
- # This relies on cuda-specific torch-internal impl details
- generator = self.input_batch.generators.get(i)
- if generator is not None:
- generator.set_offset(generator.get_offset() - 4)
- # Record the index of the request that should not be sampled,
- # so that we could clear the sampled tokens before returning.
- discard_sampled_tokens_req_indices.append(i)
-
- # NOTE: GPU -> CPU Sync happens here.
- # Move as many CPU operations as possible before this sync point.
- logprobs_tensors = sampler_output.logprobs_tensors
- logprobs_lists = logprobs_tensors.tolists() \
- if logprobs_tensors is not None else None
-
- # Compute prompt logprobs if needed.
- prompt_logprobs_dict = self._get_prompt_logprobs_dict(
- hidden_states[:num_scheduled_tokens],
- scheduler_output,
+ return self.sampler(
+ logits=logits, sampling_metadata=sampling_metadata)
+
+ if (self.use_async_scheduling
+ and self._draft_token_req_ids is not None):
+ draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu()
+ self.input_batch.update_async_spec_token_ids(
+ draft_token_ids_cpu)
+
+ sampler_output = self.rejection_sampler(
+ spec_decode_metadata,
+ None, # draft_probs
+ logits,
+ sampling_metadata,
)
+ self._update_states_after_model_execute(
+ sampler_output.sampled_token_ids)
+ return sampler_output
- # Get the valid generated tokens.
- sampled_token_ids = sampler_output.sampled_token_ids
- max_gen_len = sampled_token_ids.shape[-1]
- if max_gen_len == 1:
- # No spec decode tokens.
- valid_sampled_token_ids = sampled_token_ids.tolist()
- else:
- # Includes spec decode tokens.
- valid_sampled_token_ids = self.rejection_sampler.parse_output(
- sampled_token_ids,
- self.input_batch.vocab_size,
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ scheduler_output: "SchedulerOutput",
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> Union[
+ ModelRunnerOutput, AsyncGPUModelRunnerOutput, IntermediateTensors
+ ]:
+ num_scheduled_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None)
+ if num_scheduled_tokens is None:
+ try:
+ num_scheduled_tokens = int(
+ sum(scheduler_output.num_scheduled_tokens.values())
+ )
+ except Exception:
+ num_scheduled_tokens = 0
+
+ use_shift_model = (
+ getattr(self, "use_ulysses", False)
+ and getattr(self, "shift_model", None) is not None
+ and num_scheduled_tokens <= int(getattr(self, "shift_parallel_threshold", 0))
+ )
+
+ if not use_shift_model:
+ return self._orig_execute_model(scheduler_output, intermediate_tensors)
+
+ orig_model = self.model
+ cc = self.vllm_config.compilation_config
+ base_ctx = cc.static_forward_context
+ shift_ctx = getattr(self, 'shift_forward_context', None)
+ try:
+ self.model = self.shift_model
+ if shift_ctx is not None:
+ cc.static_forward_context = shift_ctx
+ with set_shift_parallel_mode(True), \
+ self._use_shift_cudagraph_tables():
+ result = self._orig_execute_model(scheduler_output, intermediate_tensors)
+ finally:
+ self.model = orig_model
+ cc.static_forward_context = base_ctx
+ return result
+
+ @torch.inference_mode
+ def sample_tokens(self, grammar_output):
+ """Wrapper around base sample_tokens for arctic async spec decode.
+
+ Saves execute_model_state before the base clears it, then handles
+ the 'not-fits-in-drafter' case that the base only handles for Eagle.
+ """
+ _arctic_saved_state = None
+ if (self.execute_model_state is not None
+ and self.speculative_config is not None
+ and self.speculative_config.method
+ in ("arctic", "mlp_speculator", "suffix")
+ and self.use_async_scheduling):
+ _arctic_saved_state = (
+ self.execute_model_state.scheduler_output,
+ self.execute_model_state.spec_decode_common_attn_metadata,
)
- # Mask out the sampled tokens that should not be sampled.
- for i in discard_sampled_tokens_req_indices:
- valid_sampled_token_ids[i].clear()
- # Cache the sampled tokens in the model runner, so that the scheduler
- # doesn't need to send them back.
- # NOTE(woosuk): As an exception, when using PP, the scheduler sends
- # the sampled tokens back, because there's no direct communication
- # between the first-stage worker and the last-stage worker.
- for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
- if not sampled_ids:
- continue
+ result = self._orig_sample_tokens(grammar_output)
- start_idx = self.input_batch.num_tokens_no_spec[req_idx]
- end_idx = start_idx + len(sampled_ids)
- assert end_idx <= self.max_model_len, (
- "Sampled token IDs exceed the max model length. "
- f"Total number of tokens: {end_idx} > max_model_len: "
- f"{self.max_model_len}")
-
- self.input_batch.token_ids_cpu[req_idx,
- start_idx:end_idx] = sampled_ids
- self.input_batch.num_tokens_no_spec[req_idx] = end_idx
- self.input_batch.num_tokens[req_idx] = end_idx
- req_id = self.input_batch.req_ids[req_idx]
- req_state = self.requests[req_id]
- req_state.output_token_ids.extend(sampled_ids)
+ # If _arctic_async_sampled_tensor was stashed by _bookkeeping_sync
+ # but never consumed by propose_draft_token_ids, this is the
+ # not-fits-in-drafter case. Mirror Eagle's handling: call
+ # _copy_valid_sampled_token_count and set draft tokens to zeros.
+ stashed = getattr(self, '_arctic_async_sampled_tensor', None)
+ if stashed is not None:
+ del self._arctic_async_sampled_tensor
+ if _arctic_saved_state is not None:
+ scheduler_output, common_attn_meta = _arctic_saved_state
+ self._arctic_handle_not_fits(
+ stashed, scheduler_output, common_attn_meta)
- if self._suffix_cache is not None:
- self._update_suffix_cache(valid_sampled_token_ids)
+ return result
- if not self.speculative_config:
- # Speculative decoding is not enabled.
- spec_token_ids = None
+ def _arctic_handle_not_fits(
+ self,
+ sampled_token_ids: torch.Tensor,
+ scheduler_output: "SchedulerOutput",
+ common_attn_metadata,
+ ) -> None:
+ """Mirror Eagle's not-fits-in-drafter path for arctic async.
+
+ When the input is too long for the drafter but spec decode is
+ active, Eagle still calls prepare_next_token_ids_padded /
+ _copy_valid_sampled_token_count and sets draft tokens to zeros.
+ Without this, _get_valid_sampled_token_count returns stale counts
+ and _prepare_input_ids scatters -1 placeholders into the
+ embedding layer.
+ """
+ if (hasattr(self, 'drafter')
+ and hasattr(self.drafter, 'prepare_next_token_ids_padded')
+ and common_attn_metadata is not None):
+ next_token_ids, valid_sampled_tokens_count = (
+ self.drafter.prepare_next_token_ids_padded(
+ common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_mask.gpu,
+ )
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count)
else:
- spec_token_ids = self.propose_draft_token_ids(
- scheduler_output,
- valid_sampled_token_ids,
- sampler_output.sampled_token_ids,
- sampling_metadata,
- hidden_states,
- sample_hidden_states,
- aux_hidden_states,
- spec_decode_metadata,
- attn_metadata,
- )
-
- # Clear KVConnector state after all KVs are generated.
- if has_kv_transfer_group():
- get_kv_transfer_group().clear_connector_metadata()
-
- self.eplb_step()
-
- return ModelRunnerOutput(
- req_ids=self.input_batch.req_ids,
- req_id_to_index=self.input_batch.req_id_to_index,
- sampled_token_ids=valid_sampled_token_ids,
- spec_token_ids=spec_token_ids,
- logprobs=logprobs_lists,
- prompt_logprobs_dict=prompt_logprobs_dict,
- pooler_output=[],
- finished_sending=finished_sending,
- finished_recving=finished_recving,
- num_nans_in_logits=num_nans_in_logits,
- )
+ # Fallback for drafters without prepare_next_token_ids_padded
+ # (e.g. suffix-only). Compute valid counts with PyTorch ops.
+ mask = sampled_token_ids != -1
+ valid_counts = mask.sum(dim=1)
+ batch_size = sampled_token_ids.shape[0]
+ col_indices = torch.arange(
+ sampled_token_ids.shape[1],
+ device=sampled_token_ids.device,
+ ).unsqueeze(0).expand_as(sampled_token_ids)
+ last_valid_col = (
+ col_indices.masked_fill(~mask, -1).max(dim=1).values)
+ last_valid_col = last_valid_col.clamp(min=0)
+ next_token_ids = sampled_token_ids[
+ torch.arange(batch_size,
+ device=sampled_token_ids.device),
+ last_valid_col,
+ ]
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_counts)
+
+ # Zero draft tokens -- same as Eagle's not-fits path.
+ self._draft_token_ids = torch.zeros(
+ 1, device=self.device, dtype=torch.int32,
+ ).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
+ self._copy_draft_token_ids_to_cpu(
+ scheduler_output, zeros_only=True)
+
+ def _bookkeeping_sync(
+ self,
+ scheduler_output: "SchedulerOutput",
+ sampler_output: SamplerOutput,
+ logits: torch.Tensor | None,
+ hidden_states: torch.Tensor,
+ num_scheduled_tokens: int,
+ spec_decode_metadata: SpecDecodeMetadata | None,
+ ):
+ """Wrap base _bookkeeping_sync to handle arctic async spec decode.
+
+ In the base vLLM code, only Eagle-style drafters run *before*
+ bookkeeping (setting prev_sampled_token_ids via
+ _copy_valid_sampled_token_count). Arctic/suffix drafting runs
+ *after* bookkeeping, so prev_sampled_token_ids is still None
+ when bookkeeping checks ``assert sampled_token_ids.shape[-1] == 1``.
+
+ We fix this by:
+ 1. Saving the GPU sampled tensor for propose_draft_token_ids.
+ 2. Setting prev_sampled_token_ids to a placeholder so the
+ assertion is skipped. The real value will be written by
+ _copy_valid_sampled_token_count inside propose_draft_token_ids
+ (fits case) or sample_tokens (not-fits case).
+ """
+ sampled_token_ids = sampler_output.sampled_token_ids
+ if (self.use_async_scheduling
+ and self.speculative_config is not None
+ and self.speculative_config.method
+ in ("arctic", "mlp_speculator", "suffix")
+ and spec_decode_metadata is not None
+ and sampled_token_ids.shape[-1] > 1
+ and self.input_batch.prev_sampled_token_ids is None):
+ # Stash the full GPU tensor so propose_draft_token_ids can
+ # pick it up later (it normally only receives an empty list
+ # in the post-bookkeeping path).
+ self._arctic_async_sampled_tensor = sampled_token_ids
+ # Placeholder: first column only (bonus token per request).
+ # Prevents the assertion from firing; the correct value will
+ # be overwritten by _copy_valid_sampled_token_count shortly.
+ self.input_batch.prev_sampled_token_ids = (
+ sampled_token_ids[:, :1])
+
+ return self._orig_bookkeeping_sync(
+ scheduler_output, sampler_output, logits, hidden_states,
+ num_scheduled_tokens, spec_decode_metadata)
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
- sampled_token_ids: list[list[int]],
- original_sampled_token_ids: np.ndarray,
+ sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
- aux_hidden_states: Optional[torch.Tensor],
- spec_decode_metadata: Optional[SpecDecodeMetadata],
- attn_metadata: dict[str, Any],
- ) -> list[list[int]]:
- disable_spec_decode = (self.speculative_config and
- self.speculative_config.disable_by_batch_size
- and len(self.input_batch.req_ids)
- > self.speculative_config.disable_by_batch_size)
- if disable_spec_decode:
- # No speculative decoding is enabled.
- return [[] for _ in sampled_token_ids]
+ aux_hidden_states: list[torch.Tensor] | None,
+ spec_decode_metadata: SpecDecodeMetadata | None,
+ common_attn_metadata: CommonAttentionMetadata,
+ ) -> list[list[int]] | torch.Tensor:
+ # In async mode, the base vLLM dispatches arctic to the
+ # post-bookkeeping path which passes valid_sampled_token_ids
+ # (an empty list for async). Recover the stashed GPU tensor
+ # so the fast async drafting path below can activate.
+ if (isinstance(sampled_token_ids, list)
+ and len(sampled_token_ids) == 0
+ and hasattr(self, '_arctic_async_sampled_tensor')):
+ sampled_token_ids = self._arctic_async_sampled_tensor
+ del self._arctic_async_sampled_tensor
+
+ # Compute the maximum number of requests to draft for.
+ # When disable_by_batch_size is set and the batch exceeds it,
+ # we still draft for the first N requests instead of disabling
+ # entirely. This avoids the stale-data crash that occurs when
+ # drafting is fully disabled one step and re-enabled the next.
+ batch_size = len(self.input_batch.req_ids)
+ draft_limit = batch_size # default: draft for all
+ if (
+ self.speculative_config
+ and self.speculative_config.disable_by_batch_size
+ and batch_size > self.speculative_config.disable_by_batch_size
+ ):
+ draft_limit = self.speculative_config.disable_by_batch_size
+
+ use_async_path = (
+ self.speculative_config.method in ("arctic", "mlp_speculator")
+ and isinstance(sampled_token_ids, torch.Tensor)
+ and self.use_async_scheduling
+ and common_attn_metadata is not None
+ )
+
+ if use_async_path:
+ assert isinstance(sampled_token_ids, torch.Tensor)
+
+ next_token_ids, valid_sampled_tokens_count = (
+ self.drafter.prepare_next_token_ids_padded(
+ common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_mask.gpu,
+ )
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
+
+ target_hidden_states = self.drafter.prepare_hidden_states(
+ sample_hidden_states=sample_hidden_states,
+ sampled_token_ids=sampled_token_ids,
+ spec_decode_metadata=spec_decode_metadata,
+ )
+
+ # Only draft for the first draft_limit requests.
+ raw_draft = self.drafter.propose(
+ context_token_ids=next_token_ids[:draft_limit],
+ previous_hidden_states=target_hidden_states[:draft_limit],
+ num_predict_tokens=self.drafter.model.n_predict,
+ )
+ # Use pre-allocated GPU buffer when available. This avoids
+ # per-step F.pad + torch.zeros allocations. The buffer is
+ # [max_num_reqs, num_spec_tokens] so a single zero_() +
+ # copy_() handles both width and batch padding in one shot.
+ merged_buf = getattr(self, '_draft_merged_gpu', None)
+ if merged_buf is not None:
+ draft = merged_buf[:batch_size]
+ draft.zero_()
+ rd_rows, rd_cols = raw_draft.shape
+ draft[:rd_rows, :rd_cols].copy_(raw_draft)
+ else:
+ draft = raw_draft
+ if draft.shape[1] < self.num_spec_tokens:
+ draft = torch.nn.functional.pad(
+ draft,
+ (0, self.num_spec_tokens - draft.shape[1]),
+ value=0,
+ )
+ if draft_limit < batch_size:
+ full_draft = torch.zeros(
+ batch_size, draft.shape[1],
+ dtype=draft.dtype, device=draft.device,
+ )
+ full_draft[:draft_limit] = draft
+ draft = full_draft
+ return draft
+
+ if isinstance(sampled_token_ids, torch.Tensor):
+ vocab_size = self.model_config.get_vocab_size()
+ sampled_token_ids_list = [
+ [t for t in seq if t != -1 and t < vocab_size]
+ for seq in sampled_token_ids.tolist()
+ ]
+ sampled_token_ids_tensor = sampled_token_ids
+ else:
+ sampled_token_ids_list = sampled_token_ids
+ sampled_token_ids_tensor = None
+
+ arctic_spec_token_ids = None
suffix_spec_token_ids = None
- new_sampled_token_ids = sampled_token_ids.copy()
+
+ if self.speculative_config.method in ("arctic", "mlp_speculator"):
+ if sampled_token_ids_tensor is None:
+ import numpy as np
+ sampled_token_ids_tensor = torch.tensor(sampled_token_ids_list, device=self.device)
+
+ previous_hidden_states = self.drafter.prepare_hidden_states(
+ sample_hidden_states=sample_hidden_states,
+ sampled_token_ids=sampled_token_ids_tensor,
+ spec_decode_metadata=spec_decode_metadata,
+ )
+
+ next_token_ids = self.drafter.prepare_next_token_ids_cpu(
+ sampled_token_ids_list,
+ self.requests,
+ self.input_batch,
+ scheduler_output.num_scheduled_tokens,
+ )
+
+ # Only draft for the first draft_limit requests.
+ arctic_output_tensor = self.drafter.propose(
+ context_token_ids=next_token_ids[:draft_limit],
+ previous_hidden_states=previous_hidden_states[:draft_limit],
+ num_predict_tokens=self.drafter.model.n_predict,
+ )
+
+ arctic_spec_token_ids = arctic_output_tensor.tolist()
+ # Pad with empty lists for requests beyond draft_limit.
+ if draft_limit < batch_size:
+ arctic_spec_token_ids.extend(
+ [] for _ in range(batch_size - draft_limit)
+ )
+
if self._suffix_cache is not None:
- results = self.propose_suffix_draft_token_ids(
- new_sampled_token_ids)
+ self._update_suffix_cache(sampled_token_ids_list)
+ results = self.propose_suffix_draft_token_ids(sampled_token_ids_list)
+
suffix_spec_token_ids = []
- # The score is an estimate of the acceptance length. Thus, the
- # heuristic is to use the suffix decoded tokens if the score is
- # greater than the # of tokens we would speculate otherwise.
- min_score = (self.speculative_config.num_speculative_tokens
- if self.speculative_config.method != "suffix" else 0)
- min_score = (0 if self.speculative_config.method == "suffix" else
- self.speculative_config.num_speculative_tokens)
- for i, result in enumerate(results):
+ min_score = 0 if self.speculative_config.method == "suffix" \
+ else self.drafter.model.n_predict
+
+ for result in results:
if result.score >= min_score:
- # Use suffix decoded tokens, disable other speculation
- # methods for this request.
- new_sampled_token_ids[i] = []
suffix_spec_token_ids.append(result.token_ids)
else:
suffix_spec_token_ids.append([])
spec_token_ids = None
- if self.speculative_config.method == "suffix":
- pass
- elif (self.speculative_config.method == "arctic"
- or self.speculative_config.method == "mlp_speculator"):
- assert isinstance(self.drafter, ArcticProposer)
- previous_hidden_states = self.drafter.prepare_hidden_states(
- sample_hidden_states=sample_hidden_states,
- sampled_token_ids=original_sampled_token_ids,
- spec_decode_metadata=spec_decode_metadata,
- )
- spec_token_ids = self.propose_arctic_draft_token_ids(
- scheduler_output,
- new_sampled_token_ids,
- previous_hidden_states=previous_hidden_states)
+ if suffix_spec_token_ids is not None and arctic_spec_token_ids is not None:
+ spec_token_ids = [
+ s_tokens if s_tokens else a_tokens
+ for s_tokens, a_tokens in zip(suffix_spec_token_ids, arctic_spec_token_ids)
+ ]
+ elif suffix_spec_token_ids is not None:
+ spec_token_ids = suffix_spec_token_ids
+ elif arctic_spec_token_ids is not None:
+ spec_token_ids = arctic_spec_token_ids
else:
spec_token_ids = self._orig_propose_draft_token_ids(
scheduler_output,
- new_sampled_token_ids,
+ sampled_token_ids_list,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
- attn_metadata,
+ common_attn_metadata,
)
if spec_token_ids is None:
- spec_token_ids = suffix_spec_token_ids
- elif suffix_spec_token_ids is not None:
- spec_token_ids = [
- suffix_spec_token_ids[i] or spec_token_ids[i]
- for i in range(len(suffix_spec_token_ids))
- ]
-
- return spec_token_ids
-
- def propose_arctic_draft_token_ids(
- self,
- scheduler_output: "SchedulerOutput",
- sampled_token_ids: list[list[int]],
- previous_hidden_states: Optional[torch.Tensor] = None,
- ) -> list[list[int]]:
- last_tokens: list[int] = []
- max_spec_tokens = self.speculative_config.num_speculative_tokens
- for i, sampled_ids in enumerate(sampled_token_ids):
- num_sampled_ids = len(sampled_ids)
-
- if (num_sampled_ids == 0):
- if self.speculative_config.enable_suffix_decoding:
- return [[]] * len(sampled_token_ids)
- req_id = self.input_batch.req_ids[i]
- req_state = self.requests[req_id]
- seq_len = (req_state.num_computed_tokens +
- scheduler_output.num_scheduled_tokens[req_id])
- sampled_ids = [req_state.get_token_id(seq_len)]
-
- # Add sampled_token_ids to token_ids_cpu.
- start_idx = self.input_batch.num_tokens_no_spec[i]
- end_idx = start_idx + num_sampled_ids
-
- max_spec_tokens = min(
- max_spec_tokens,
- self.max_model_len - end_idx - 1,
+ spec_token_ids = [[] for _ in range(len(self.input_batch.req_ids))]
+
+ # For async scheduling the base _prepare_input_ids asserts that
+ # _draft_token_ids is a torch.Tensor and uses it to scatter draft
+ # tokens into the input. If we reached here (non-async code path)
+ # while async scheduling is active we must:
+ # 1. Convert the list-of-lists draft tokens to a padded tensor.
+ # 2. Call _copy_valid_sampled_token_count so the next step's
+ # _get_valid_sampled_token_count returns correct counts.
+ if (self.use_async_scheduling
+ and isinstance(spec_token_ids, list)
+ and isinstance(sampled_token_ids, torch.Tensor)):
+ # --- _copy_valid_sampled_token_count ---
+ if (hasattr(self, 'drafter')
+ and hasattr(self.drafter, 'prepare_next_token_ids_padded')
+ and common_attn_metadata is not None):
+ next_tok, valid_cnt = (
+ self.drafter.prepare_next_token_ids_padded(
+ common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_mask.gpu,
+ ))
+ self._copy_valid_sampled_token_count(next_tok, valid_cnt)
+ else:
+ # Manual fallback (suffix-only drafter).
+ mask = sampled_token_ids != -1
+ valid_cnt = mask.sum(dim=1)
+ _bs = sampled_token_ids.shape[0]
+ cols = torch.arange(
+ sampled_token_ids.shape[1],
+ device=sampled_token_ids.device,
+ ).unsqueeze(0).expand_as(sampled_token_ids)
+ last_col = cols.masked_fill(~mask, -1).max(dim=1).values
+ last_col = last_col.clamp(min=0)
+ next_tok = sampled_token_ids[
+ torch.arange(_bs, device=sampled_token_ids.device),
+ last_col,
+ ]
+ self._copy_valid_sampled_token_count(next_tok, valid_cnt)
+
+ # --- Convert list[list[int]] -> padded tensor ---
+ padded = torch.zeros(
+ batch_size, self.num_spec_tokens,
+ dtype=torch.int32, device=self.device,
)
- if max_spec_tokens <= 0:
- continue
-
- self.input_batch.token_ids_cpu[i,
- start_idx:end_idx] = sampled_ids[-1]
- last_tokens.append(self.input_batch.token_ids_cpu[i, end_idx - 1])
+ for i, tokens in enumerate(spec_token_ids):
+ length = min(len(tokens), self.num_spec_tokens)
+ if length > 0:
+ padded[i, :length] = torch.tensor(
+ tokens[:length], dtype=torch.int32,
+ device=self.device)
+ spec_token_ids = padded
- if max_spec_tokens <= 0:
- return [[] for _ in sampled_token_ids]
-
- drafter_output = self.drafter.propose(
- last_tokens,
- previous_hidden_states=previous_hidden_states,
- num_predict_tokens=max_spec_tokens,
- )
-
- draft_token_ids = drafter_output.tolist()
-
- for i, sampled_ids in enumerate(sampled_token_ids):
- if not sampled_ids:
- draft_token_ids[i] = []
-
- return draft_token_ids
+ return spec_token_ids
def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None:
seen_req_ids = set()
@@ -666,207 +1032,842 @@ def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None:
continue
index = self.input_batch.req_id_to_index[req_id]
- if req_id not in self._suffix_cache.active_requests:
+ is_new = req_id not in self._suffix_cache.active_requests
+ if is_new:
if req_id in self._suffix_cache.cached_requests:
- # Reset the suffix cache for this request.
- self._suffix_cache.evict_request(req_id)
+ self._suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = self.input_batch.num_prompt_tokens[index]
- prompt_token_ids = (
- self.input_batch.token_ids_cpu[index, :num_prompt_tokens])
- self._suffix_cache.start_request(req_id, prompt_token_ids)
+ prompt_token_ids = self.input_batch.token_ids_cpu[index, :num_prompt_tokens]
+ self._suffix_cache.start_request(req_id, prompt_token_ids.tolist())
+ self._suffix_response_tokens[req_id] = []
self._suffix_cache.add_active_response(req_id, sampled_ids)
+ self._suffix_response_tokens[req_id].extend(sampled_ids)
- # Stop requests that are not seen
+ stopped_ids = []
for req_id in list(self._suffix_cache.active_requests):
if req_id not in seen_req_ids:
self._suffix_cache.stop_request(req_id)
+ self._suffix_response_tokens.pop(req_id, None)
+ stopped_ids.append(req_id)
def propose_suffix_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
- spec_token_ids: Optional[list[list[int]]] = None,
- ) -> list[list[int]]:
+ ) -> list[SuffixDecodingDraft]:
config = self.speculative_config
results = []
for i, sampled_ids in enumerate(sampled_token_ids):
- spec_ids = spec_token_ids[i] if spec_token_ids is not None else []
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
- # Skip speculative decoding.
- results.append(SuffixSpecResult())
+ results.append(SuffixDecodingDraft())
continue
req_id = self.input_batch.req_ids[i]
+ index = self.input_batch.req_id_to_index[req_id]
- # Add sampled_token_ids to token_ids_cpu.
- start_idx = self.input_batch.num_tokens_no_spec[i]
- end_idx = start_idx + len(sampled_ids)
-
- if end_idx >= self.max_model_len:
- results.append(SuffixSpecResult())
- self.input_batch.token_ids_cpu[
- i, start_idx:self.
- max_model_len] = sampled_ids[:self.max_model_len -
- start_idx]
- continue
-
- self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
-
- size = min(end_idx, config.suffix_cache_max_depth)
- pattern = self.input_batch.token_ids_cpu[i, end_idx - size:end_idx]
- pattern = pattern.tolist() + spec_ids
- if len(pattern) > config.suffix_cache_max_depth:
- pattern = pattern[-config.suffix_cache_max_depth:]
- max_spec_tokens = min(MAX_SPEC_LEN - len(spec_ids),
- config.suffix_cache_max_depth,
- self.max_model_len - end_idx - 1)
- # max_spec_offset is modified to mimic the behavior of the original
- # max_spec_factor and max_spec_offset as if the speculative tokens
- # were generated by suffix decoding. For example, if:
- # - max_spec_factor = 2
- # - max_spec_offset = -1
- # - we've already speculated 3 tokens
- # - and the suffix match length is 6
- # Then:
- # - The match length before the already-speculated tokens is 3
- # - The original config allow up to 5 speculated tokens total
- # - Already speculated 3 tokens, so should allow 2 more tokens
- # So the new config should map match length 6 to 2 max spec tokens.
- max_spec_factor = config.suffix_max_spec_factor
- max_spec_offset = (config.suffix_max_spec_offset - len(spec_ids) *
- (max_spec_factor + 1))
+ # In async mode, token_ids_cpu contains -1 placeholders at
+ # decoded positions (written by _bookkeeping_sync). Build the
+ # pattern from the clean _suffix_response_tokens instead.
+ if (self.use_async_scheduling
+ and req_id in self._suffix_response_tokens):
+ response = self._suffix_response_tokens[req_id]
+ num_prompt = int(
+ self.input_batch.num_prompt_tokens[index])
+ num_tokens = num_prompt + len(response)
+ if num_tokens >= self.max_model_len:
+ results.append(SuffixDecodingDraft())
+ continue
+ # Take up to suffix_cache_max_depth tokens from the tail.
+ depth = config.suffix_cache_max_depth
+ if len(response) >= depth:
+ pattern = response[-depth:]
+ else:
+ need = depth - len(response)
+ prompt_start = max(0, num_prompt - need)
+ prompt_part = self.input_batch.token_ids_cpu[
+ index, prompt_start:num_prompt].tolist()
+ pattern = prompt_part + response
+ else:
+ num_tokens = self.input_batch.num_tokens_no_spec[i]
+ if num_tokens >= self.max_model_len:
+ results.append(SuffixDecodingDraft())
+ continue
+ start = max(0, num_tokens - config.suffix_cache_max_depth)
+ pattern = self.input_batch.token_ids_cpu[
+ i, start:num_tokens].tolist()
+
+ max_spec = min(
+ MAX_SPEC_LEN, self.max_model_len - num_tokens - 1
+ )
result = self._suffix_cache.speculate(
req_id,
pattern,
- max_spec_tokens=max_spec_tokens,
- max_spec_factor=max_spec_factor,
- max_spec_offset=max_spec_offset,
- min_token_prob=config.suffix_min_token_prob)
+ max_spec_tokens=max_spec,
+ max_spec_factor=config.suffix_max_spec_factor,
+ max_spec_offset=config.suffix_max_spec_offset,
+ min_token_prob=config.suffix_min_token_prob,
+ )
results.append(result)
return results
- def load_model(self) -> None:
+
+ def _start_suffix_copy(
+ self,
+ sampled_token_ids: torch.Tensor,
+ ) -> None:
+ """Initiate an async D2H copy of sampled token IDs for suffix decoding.
+
+ Copies ``sampled_token_ids`` to a pinned CPU buffer on a dedicated
+ CUDA stream (``suffix_copy_stream``) that only waits for prior work
+ on the default stream (i.e. sampling).
+
+ **This MUST be called BEFORE launching Arctic GPU work on the default
+ stream** so that the copy is not ordered behind Arctic kernels. The
+ resulting timeline is::
+
+ Default stream: [sample] -> [arctic prepare / propose ...]
+ Suffix stream : [sample] -> [D2H copy] -> [event]
+ CPU : wait -> suffix logic
+
+ The companion ``_finish_suffix_copy`` synchronises on the copy event
+ and returns the materialised CPU list.
+ """
+ n_rows = sampled_token_ids.shape[0]
+ n_cols = sampled_token_ids.shape[-1]
+ default_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(self.suffix_copy_stream):
+ self.suffix_copy_stream.wait_stream(default_stream)
+ self.suffix_sampled_ids_pinned[:n_rows, :n_cols].copy_(
+ sampled_token_ids, non_blocking=True,
+ )
+ self.suffix_copy_done_event.record()
+ self._suffix_copy_shape = (n_rows, n_cols)
+
+ def _finish_suffix_copy(self) -> list[list[int]]:
+ """Wait for the suffix copy and return sampled token IDs as CPU lists.
+
+ Synchronises on the copy event recorded by ``_start_suffix_copy``,
+ then parses the pinned buffer into ``list[list[int]]``, applying
+ rejection-sampling for spec-decode batches and masking discarded
+ (still-in-prefill) requests.
+ """
+ self.suffix_copy_done_event.synchronize()
+ n_rows, n_cols = self._suffix_copy_shape
+ pinned = self.suffix_sampled_ids_pinned[:n_rows, :n_cols]
+
+ num_reqs = self.input_batch.num_reqs
+ discard_indices = np.nonzero(
+ self.discard_request_mask.np[:num_reqs]
+ )[0]
+
+ if n_cols == 1:
+ result = pinned.tolist()
+ for i in discard_indices:
+ result[int(i)].clear()
+ else:
+ result, _ = RejectionSampler.parse_output(
+ pinned,
+ self.input_batch.vocab_size,
+ discard_indices,
+ )
+
+ return result
+
+ @torch.inference_mode
+ def sample_tokens(
+ self, grammar_output: "GrammarOutput | None"
+ ) -> Union[ModelRunnerOutput, AsyncGPUModelRunnerOutput, IntermediateTensors]:
+ kv_connector_output = self.kv_connector_output
+ self.kv_connector_output = None
+
+ if self.execute_model_state is None:
+ # Nothing to do (PP non-final rank case), output isn't used.
+ if not kv_connector_output:
+ return None # noqa
+
+ # In case of PP with kv transfer, we need to pass through the
+ # kv_connector_output
+ if kv_connector_output.is_empty():
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
+ output.kv_connector_output = kv_connector_output
+ return output
+
+ # Unpack ephemeral state.
+ (
+ scheduler_output,
+ logits,
+ spec_decode_metadata,
+ spec_decode_common_attn_metadata,
+ hidden_states,
+ sample_hidden_states,
+ aux_hidden_states,
+ ec_connector_output,
+ cudagraph_stats,
+ ) = self.execute_model_state
+ # Clear ephemeral state.
+ self.execute_model_state = None
+
+ # Apply structured output bitmasks if present.
+ if grammar_output is not None:
+ apply_grammar_bitmask(
+ scheduler_output, grammar_output, self.input_batch, logits
+ )
+
+ with record_function_or_nullcontext("gpu_model_runner: sample"):
+ sampler_output = self._sample(logits, spec_decode_metadata)
+
+ self._draft_token_ids = None
+ self._draft_token_req_ids = None
+ self.input_batch.prev_sampled_token_ids = None
+
+ def propose_draft_token_ids(
+ sampled_token_ids: torch.Tensor | list[np.ndarray],
+ ) -> None:
+ assert spec_decode_common_attn_metadata is not None
+ with record_function_or_nullcontext("gpu_model_runner: draft"):
+ self._draft_token_ids = self.propose_draft_token_ids(
+ scheduler_output,
+ sampled_token_ids,
+ self.input_batch.sampling_metadata,
+ hidden_states,
+ sample_hidden_states,
+ aux_hidden_states,
+ spec_decode_metadata,
+ spec_decode_common_attn_metadata,
+ )
+ self._copy_draft_token_ids_to_cpu(scheduler_output)
+
+ # --- Draft proposal orchestration ---
+ #
+ # There are three drafting modes depending on async scheduling and
+ # which drafters are active:
+ #
+ # A) Arctic-only + async:
+ # Pre-bookkeeping. Arctic runs entirely on GPU tensors
+ # (the existing EAGLE-like async path).
+ #
+ # B) Suffix (with or without arctic) + async:
+ # Pre-bookkeeping. Uses a *dedicated* suffix_copy_stream so
+ # that the D2H transfer of sampled token IDs only waits on
+ # "sampling complete" -- NOT on Arctic GPU work. Timeline:
+ #
+ # Default stream: [sample] -> [arctic GPU work ----------->]
+ # Suffix stream: [sample] -> [D2H copy] -> [event]
+ # CPU : wait -> suffix
+ #
+ # At merge time, if any requests need arctic fallback, we
+ # sync on the arctic result (.tolist()); by then the GPU
+ # kernels have had the full suffix-CPU window to finish.
+ #
+ # C) Non-async (any combination):
+ # Post-bookkeeping. propose_draft_token_ids handles everything
+ # using CPU valid_sampled_token_ids from bookkeeping.
+
+ has_suffix = self._suffix_cache is not None
+ is_arctic_method = (
+ self.speculative_config is not None
+ and self.speculative_config.method in ("arctic", "mlp_speculator")
+ )
+ input_fits_in_drafter = spec_decode_common_attn_metadata is not None
+ sampled_token_ids = sampler_output.sampled_token_ids
+
+ # Determine if we should use async pre-bookkeeping drafting.
+ use_async_spec = (
+ self.use_async_scheduling
+ and self.speculative_config is not None
+ and (is_arctic_method or has_suffix)
+ and input_fits_in_drafter
+ )
+
+ if use_async_spec:
+ if is_arctic_method and not has_suffix:
+ # (A) Arctic-only async: use existing closure which calls
+ # propose_draft_token_ids (the method) with a GPU tensor
+ # and triggers the async GPU path internally.
+ propose_draft_token_ids(sampled_token_ids)
+
+ # Track actual draft lengths for next step's allocation.
+ _n_predict = self.drafter.model.n_predict
+ _batch_size = len(self.input_batch.req_ids)
+ _disable_bs = (
+ self.speculative_config.disable_by_batch_size
+ if self.speculative_config else None
+ )
+ _draft_limit = _batch_size
+ if _disable_bs and _batch_size > _disable_bs:
+ _draft_limit = _disable_bs
+ self._prev_actual_draft_lens = {
+ req_id: _n_predict if i < _draft_limit else 0
+ for i, req_id in enumerate(self.input_batch.req_ids)
+ }
+ scheduler_output._actual_draft_lens = (
+ self._prev_actual_draft_lens
+ )
+
+ elif is_arctic_method:
+ # (B) Arctic + suffix async.
+ # D2H copy of sampled tokens runs on a dedicated
+ # suffix_copy_stream that only waits for sampling,
+ # NOT for Arctic GPU kernels enqueued afterwards.
+ # Suffix CPU work then overlaps with Arctic GPU.
+
+ # Step 1: Initiate async D2H copy before Arctic GPU work.
+ self._start_suffix_copy(sampled_token_ids)
+
+ # Step 2: Launch arctic on the default stream.
+ with record_function_or_nullcontext(
+ "gpu_model_runner: draft (arctic)"
+ ):
+ arctic_draft = self.propose_draft_token_ids(
+ scheduler_output,
+ sampled_token_ids,
+ self.input_batch.sampling_metadata,
+ hidden_states,
+ sample_hidden_states,
+ aux_hidden_states,
+ spec_decode_metadata,
+ spec_decode_common_attn_metadata,
+ )
+
+ # Step 3: Wait for D2H copy, then run suffix on CPU.
+ # This overlaps with remaining Arctic GPU work.
+ with record_function_or_nullcontext(
+ "gpu_model_runner: draft (suffix)"
+ ):
+ sampled_cpu = self._finish_suffix_copy()
+ self._update_suffix_cache(sampled_cpu)
+ suffix_results = self.propose_suffix_draft_token_ids(
+ sampled_cpu
+ )
+
+ min_score = self.drafter.model.n_predict
+ suffix_draft = [
+ result.token_ids if result.score >= min_score
+ else []
+ for result in suffix_results
+ ]
+
+ # Step 4: Merge arctic + suffix results.
+ # Suffix takes priority when available.
+ #
+ # The merged result MUST be a torch.Tensor on GPU
+ # because upstream _prepare_input_ids (next iteration)
+ # asserts isinstance(self._draft_token_ids, torch.Tensor)
+ # and uses it for on-device scatter.
+
+ # Collect suffix rows that have results.
+ suffix_indices = [
+ i for i, s in enumerate(suffix_draft) if s
+ ]
+
+ # arctic_draft is already a [batch, num_spec_tokens]
+ # GPU tensor (from the pre-allocated _draft_merged_gpu
+ # buffer when available, or F.pad fallback).
+ if isinstance(arctic_draft, torch.Tensor):
+ merged = arctic_draft
+ else:
+ # Shouldn't happen in async path, but handle
+ # gracefully.
+ k = self.num_spec_tokens
+ n = len(arctic_draft)
+ pin = self._suffix_merge_pinned[:n, :k]
+ pin_np = pin.numpy()
+ pin_np[:] = 0
+ for i, row in enumerate(arctic_draft):
+ t = row[:k]
+ if t:
+ pin_np[i, :len(t)] = t
+ merged = pin.to(
+ device=self.device, non_blocking=True)
+
+ if merged.shape[1] < self.num_spec_tokens:
+ merged = torch.nn.functional.pad(
+ merged,
+ (0, self.num_spec_tokens - merged.shape[1]),
+ value=0,
+ )
+
+ # Batch-overwrite rows that have suffix results.
+ # CRITICAL: use the pre-allocated *pinned* merge
+ # buffer so that cudaMemcpyAsync does NOT synchronise
+ # the default stream. With pageable (non-pinned)
+ # memory CUDA must sync the stream first, blocking the
+ # CPU until Arctic GPU work finishes -- destroying
+ # the async overlap.
+ width = merged.shape[1]
+ if suffix_indices:
+ n_sfx = len(suffix_indices)
+ overlay_pin = \
+ self._suffix_merge_pinned[:n_sfx, :width]
+ overlay_np = overlay_pin.numpy()
+ overlay_np[:] = 0
+ for j, idx in enumerate(suffix_indices):
+ s = suffix_draft[idx]
+ slen = min(len(s), width)
+ overlay_np[j, :slen] = s[:slen]
+ # Pinned H2C -- truly non-blocking.
+ overlay_t = overlay_pin.to(
+ device=self.device, non_blocking=True)
+ # Use pre-allocated pinned index buffer when
+ # available to avoid per-step allocation.
+ idx_pinned = getattr(
+ self, '_suffix_index_pinned', None)
+ if idx_pinned is not None:
+ idx_pin = idx_pinned[:n_sfx]
+ idx_pin[:] = torch.tensor(
+ suffix_indices, dtype=torch.long)
+ else:
+ idx_pin = torch.tensor(
+ suffix_indices, dtype=torch.long
+ ).pin_memory()
+ idx_t = idx_pin.to(
+ device=self.device, non_blocking=True)
+ merged.index_copy_(0, idx_t, overlay_t)
+
+ self._draft_token_ids = merged
+ self._copy_draft_token_ids_to_cpu(scheduler_output)
+
+ # Track actual draft lengths for next step's allocation.
+ _n_predict = self.drafter.model.n_predict
+ _batch_size = len(self.input_batch.req_ids)
+ _disable_bs = (
+ self.speculative_config.disable_by_batch_size
+ if self.speculative_config else None
+ )
+ _draft_limit = _batch_size
+ if _disable_bs and _batch_size > _disable_bs:
+ _draft_limit = _disable_bs
+ _actual_lens: dict[str, int] = {}
+ for _i, _req_id in enumerate(self.input_batch.req_ids):
+ _s = (len(suffix_draft[_i])
+ if _i < len(suffix_draft) else 0)
+ _a = _n_predict if _i < _draft_limit else 0
+ # Suffix takes priority when available.
+ _actual_lens[_req_id] = _s if _s > 0 else _a
+ self._prev_actual_draft_lens = _actual_lens
+ scheduler_output._actual_draft_lens = _actual_lens
+
+ else:
+ # (B2) Suffix-only async.
+ # No Arctic GPU work to overlap with, so skip the
+ # copy stream (its overhead exceeds the marginal
+ # overlap with the lightweight rejection kernel).
+ # Instead: rejection sample → parse tokens directly
+ # → suffix CPU work → build GPU tensor via pinned buf.
+ self._suffix_only_rejection_sample(
+ sampled_token_ids,
+ spec_decode_common_attn_metadata,
+ )
+
+ # Parse sampled tokens to CPU lists. .cpu() syncs the
+ # default stream (waits for the rejection kernel, which
+ # is lightweight), then parse_output performs CPU-side
+ # rejection to extract accepted token IDs.
+ with record_function_or_nullcontext(
+ "gpu_model_runner: draft (suffix)"
+ ):
+ num_reqs = self.input_batch.num_reqs
+ discard_indices = np.nonzero(
+ self.discard_request_mask.np[:num_reqs]
+ )[0]
+ n_cols = sampled_token_ids.shape[-1]
+ if n_cols == 1:
+ sampled_cpu = sampled_token_ids.tolist()
+ for idx in discard_indices:
+ sampled_cpu[int(idx)].clear()
+ else:
+ sampled_cpu, _ = RejectionSampler.parse_output(
+ sampled_token_ids.cpu(),
+ self.input_batch.vocab_size,
+ discard_indices,
+ )
+
+ self._update_suffix_cache(sampled_cpu)
+ suffix_results = self.propose_suffix_draft_token_ids(
+ sampled_cpu
+ )
+ suffix_draft = [
+ result.token_ids if result.score >= 0
+ else []
+ for result in suffix_results
+ ]
+
+ # Build GPU tensor from suffix lists via pinned buffer.
+ k = self.num_spec_tokens
+ n = len(suffix_draft)
+ pin = self._suffix_merge_pinned[:n, :k]
+ pin_np = pin.numpy()
+ pin_np[:] = 0
+ for i, s in enumerate(suffix_draft):
+ if s:
+ slen = min(len(s), k)
+ pin_np[i, :slen] = s[:slen]
+ self._draft_token_ids = pin.to(
+ device=self.device, non_blocking=True)
+ self._copy_draft_token_ids_to_cpu(scheduler_output)
+
+ # Track actual draft lengths for next step's allocation.
+ _actual_lens_b2: dict[str, int] = {}
+ for _i, _req_id in enumerate(self.input_batch.req_ids):
+ _s = (len(suffix_draft[_i])
+ if _i < len(suffix_draft) else 0)
+ _actual_lens_b2[_req_id] = _s
+ self._prev_actual_draft_lens = _actual_lens_b2
+ scheduler_output._actual_draft_lens = _actual_lens_b2
+
+ # --- Bookkeeping ---
+ with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
+ (
+ num_nans_in_logits,
+ logprobs_lists,
+ valid_sampled_token_ids,
+ prompt_logprobs_dict,
+ req_ids_output_copy,
+ req_id_to_index_output_copy,
+ invalid_req_indices,
+ ) = self._bookkeeping_sync(
+ scheduler_output,
+ sampler_output,
+ logits,
+ hidden_states,
+ scheduler_output.total_num_scheduled_tokens,
+ spec_decode_metadata,
+ )
+
+ # (C) Non-async drafting: run after bookkeeping.
+ # Pass the raw GPU sampled_token_ids tensor (not the cleaned
+ # valid_sampled_token_ids list) so that propose_draft_token_ids
+ # gets a properly shaped 2-D tensor with -1 markers for rejected
+ # tokens. The method already handles tensors correctly: it
+ # filters out -1 to build the CPU list for
+ # prepare_next_token_ids_cpu and keeps the raw tensor for
+ # prepare_hidden_states.
+ if (
+ self.speculative_config is not None
+ and not use_async_spec
+ and input_fits_in_drafter
+ ):
+ propose_draft_token_ids(sampled_token_ids)
+
+ with record_function_or_nullcontext("gpu_model_runner: eplb"):
+ self.eplb_step()
+
+ with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
+ if self.model_config.enable_return_routed_experts:
+ capturer = RoutedExpertsCapturer.get_instance()
+ if capturer is not None:
+ capturer.save_captured_experts(indices=self.slot_mapping) # noqa
+ else:
+ logger.error("RoutedExpertsCapturer not initialized.")
+
+ output = ModelRunnerOutput(
+ req_ids=req_ids_output_copy,
+ req_id_to_index=req_id_to_index_output_copy,
+ sampled_token_ids=valid_sampled_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict=prompt_logprobs_dict,
+ kv_connector_output=kv_connector_output,
+ ec_connector_output=ec_connector_output
+ if self.supports_mm_inputs
+ else None,
+ num_nans_in_logits=num_nans_in_logits,
+ cudagraph_stats=cudagraph_stats,
+ )
+ # Attach actual draft lengths to the ModelRunnerOutput so the
+ # scheduler can read them reliably in update_from_output.
+ # This survives the async pipeline (scheduler_output attrs
+ # may not due to object lifecycle in the batch queue).
+ output._actual_draft_lens = getattr(
+ scheduler_output, '_actual_draft_lens', None)
+
+ if not self.use_async_scheduling:
+ return output
+
+ with record_function_or_nullcontext(
+ "gpu_model_runner: AsyncGPUModelRunnerOutput"
+ ):
+ async_output = AsyncGPUModelRunnerOutput(
+ model_runner_output=output,
+ sampled_token_ids=sampler_output.sampled_token_ids,
+ logprobs_tensors=sampler_output.logprobs_tensors,
+ invalid_req_indices=invalid_req_indices,
+ async_output_copy_stream=self.async_output_copy_stream,
+ vocab_size=self.input_batch.vocab_size,
+ )
+ with record_function_or_nullcontext(
+ "gpu_model_runner: set_async_sampled_token_ids"
+ ):
+ # Save ref of sampled_token_ids CPU tensor if the batch contains
+ # any requests with sampling params that require output ids.
+ self.input_batch.set_async_sampled_token_ids(
+ async_output.sampled_token_ids_cpu,
+ async_output.async_copy_ready_event,
+ )
+
+ return async_output
+
+
+ def load_model(self, eep_scale_up: bool = False) -> None:
load_shift_model = (
- self.vllm_config.parallel_config.enable_shift_parallel)
+ self.vllm_config.parallel_config.enable_shift_parallel
+ )
if load_shift_model:
- # Make a deep copy of the config before loading the model.
shift_config = copy.deepcopy(self.vllm_config)
- self._orig_load_model()
+ self._orig_load_model(eep_scale_up)
if self.parallel_config.ulysses_sequence_parallel_size > 1:
self.monkeypatch_forward()
if load_shift_model:
shift_config.parallel_config.tensor_parallel_size *= (
- shift_config.parallel_config.ulysses_sequence_parallel_size)
+ shift_config.parallel_config.ulysses_sequence_parallel_size
+ )
shift_config.parallel_config.ulysses_sequence_parallel_size = 1
with set_shift_parallel_mode(True):
self.shift_model = get_model(vllm_config=shift_config)
self.shift_parallel_threshold = (
- shift_config.parallel_config.shift_parallel_threshold)
+ shift_config.parallel_config.shift_parallel_threshold
+ )
+ self.shift_forward_context = (
+ shift_config.compilation_config.static_forward_context
+ )
+
if "SwiftKV" in self.model.__class__.__name__:
- # HACK: Replace the decode-runner since it always runs in full
- # TP, but the original model is captured using SP * BATCH_SIZE,
- # which does not cover all its cuda graph sizes. The shift-mode
- # model should have all its cuda graphs captured correctly.
- self.model.model.decode_runner = (
- self.shift_model.model.decode_runner)
+ if hasattr(self.model, "model") and hasattr(self.model.model, "decode_runner"):
+ self.model.model.decode_runner = self.shift_model.model.decode_runner
+ else:
+ logger.warning("Could not apply SwiftKV HACK: "
+ "model.model.decode_runner not found.")
+
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+ if (cudagraph_mode is not None
+ and cudagraph_mode.has_full_cudagraphs()
+ and not self.parallel_config.use_ubatching):
+ from vllm.compilation.cuda_graph import CUDAGraphWrapper
+ self.shift_model = CUDAGraphWrapper(
+ self.shift_model, self.vllm_config,
+ runtime_mode=CUDAGraphMode.FULL,
+ )
else:
self.shift_model = None
self.shift_parallel_threshold = 0
+ self.shift_forward_context = None
+
- def capture_model(self) -> None:
- if not self.use_cuda_graph:
- logger.warning(
- "Skipping CUDA graph capture. To turn on CUDA graph capture, "
- "set -O %s and ensure `use_cudagraph` was not manually set to "
- "False", CompilationLevel.PIECEWISE)
+ def initialize_kv_cache(self, kv_cache_config) -> None:
+ self._orig_initialize_kv_cache(kv_cache_config)
+ shift_ctx = getattr(self, 'shift_forward_context', None)
+ if shift_ctx is None:
return
+ base_ctx = self.compilation_config.static_forward_context
+ bound = 0
+ for name, shift_attn in shift_ctx.items():
+ base_attn = base_ctx.get(name)
+ if base_attn is not None and hasattr(base_attn, 'kv_cache'):
+ shift_attn.kv_cache = base_attn.kv_cache
+ bound += 1
+ if is_global_first_rank():
+ logger.info("Bound KV cache to %d shift model attention layers",
+ bound)
+
+ from vllm.forward_context import BatchDescriptor
+ def _case_bs(self, case) -> int:
+ # vLLM can pass ints, tuples, or sometimes BatchDescriptor-like objects
+ if isinstance(case, int):
+ return case
+ if isinstance(case, BatchDescriptor):
+ return int(case.num_tokens)
+ if isinstance(case, tuple):
+ return int(case[0])
+ # last resort
+ return int(getattr(case, "num_tokens"))
+
+ def _with_bs(self, case, new_bs: int):
+ if isinstance(case, tuple):
+ return (new_bs, *case[1:])
+ if isinstance(case, BatchDescriptor):
+ # Best-effort reconstruction; adjust if your BatchDescriptor signature differs.
+ return BatchDescriptor(
+ num_tokens=new_bs,
+ num_reqs=case.num_reqs,
+ uniform=case.uniform,
+ has_lora=case.has_lora,
+ )
+ return new_bs
- compilation_counter.num_gpu_runner_capture_triggers += 1
-
- start_time = time.perf_counter()
- start_free_gpu_memory = torch.cuda.mem_get_info()[0]
-
- # Trigger CUDA graph capture for specific shapes.
- # Capture the large shapes first so that the smaller shapes
- # can reuse the memory pool allocated for the large shapes.
- with parallel_state.graph_capture(device=self.device):
- sp_size = self.parallel_config.ulysses_sequence_parallel_size
- full_cg = self.full_cuda_graph
- # capture original model shapes
- compilation_cases = (
- shape for shape in reversed(self.cudagraph_batch_sizes)
- if shape * sp_size > self.shift_parallel_threshold and shape *
- sp_size <= self.max_num_tokens)
- # Only rank 0 should print progress bar during capture
- if is_global_first_rank():
- print_cases, compilation_cases = tee(compilation_cases)
- logger.info(f"original model shapes {list(print_cases)}")
- compilation_cases = tqdm(
- list(compilation_cases),
- desc="Capturing CUDA graph shapes of original model")
- for num_tokens in compilation_cases:
- # We skip EPLB here since we don't want to record dummy metrics
- for _ in range(self.vllm_config.compilation_config.
- cudagraph_num_of_warmups):
- self._dummy_run(num_tokens * sp_size,
- capture_attn_cudagraph=full_cg,
- skip_eplb=True)
- self._dummy_run(num_tokens * sp_size,
- capture_attn_cudagraph=full_cg,
- skip_eplb=True)
-
- # Capture shift model shapes
- if self.shift_model is not None:
- orig_model, self.model = self.model, self.shift_model
- # Reset compilation cases
- compilation_cases = (
- shape for shape in reversed(self.cudagraph_batch_sizes)
- if shape <= self.shift_parallel_threshold
- or "SwiftKV" in self.model.__class__.__name__)
- # Note: We want to capture all shapes for the SwiftKV shift model.
- # This is necessary since SwiftKV always uses full TP for the decode runner.
- # For all other models, we only capture necessary shapes for the SP_TP mode,
- # yielding less setup time.
- if is_global_first_rank():
- print_cases, compilation_cases = tee(compilation_cases)
- logger.info(f"shift model shapes {list(print_cases)}")
- compilation_cases = tqdm(
- list(compilation_cases),
- desc="Capturing CUDA graph shapes of shift model")
- with set_shift_parallel_mode(True):
- for num_tokens in compilation_cases:
- for _ in range(self.vllm_config.compilation_config.
- cudagraph_num_of_warmups):
- self._dummy_run(num_tokens,
- capture_attn_cudagraph=full_cg,
- skip_eplb=True)
- self._dummy_run(num_tokens,
- capture_attn_cudagraph=full_cg,
- skip_eplb=True)
- self.model = orig_model
- end_time = time.perf_counter()
- end_free_gpu_memory = torch.cuda.mem_get_info()[0]
- elapsed_time = end_time - start_time
- cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
- # This usually takes 5~20 seconds.
- logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
- elapsed_time, cuda_graph_size / (1 << 30))
+ def _register_shift_cudagraph_keys(
+ self,
+ compilation_cases,
+ cudagraph_runtime_mode: CUDAGraphMode,
+ ):
+ """Register shift model batch sizes in the cudagraph dispatcher so
+ that runtime dispatch correctly routes to captured FULL/PIECEWISE
+ graphs."""
+ dispatcher = getattr(self, 'cudagraph_dispatcher', None)
+ if dispatcher is None:
+ return
- def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
- self._orig_initialize_kv_cache(kv_cache_config)
+ uniform = cudagraph_runtime_mode == CUDAGraphMode.FULL
+ added = 0
+ for case in compilation_cases:
+ bs = self._case_bs(case)
+ bd = dispatcher._create_padded_batch_descriptor(
+ bs, uniform, False,
+ )
+ if not uniform:
+ bd = bd.relax_for_mixed_batch_cudagraphs()
+ dispatcher.add_cudagraph_key(cudagraph_runtime_mode, bd)
+ added += 1
+
+ @contextlib.contextmanager
+ def _shift_graph_capture_context(self):
+ """Enable ca_comm for shift model graph capture."""
+ yield
- if self.shift_model is not None:
- # Bind the KV caches to the shift parallel model.
- forward_context = (
- self.vllm_config.compilation_config.static_forward_context)
- for mod in self.shift_model.modules():
- if isinstance(mod, Attention):
- mod.kv_cache = forward_context[mod.layer_name].kv_cache
+ @contextlib.contextmanager
+ def _use_shift_cudagraph_tables(self):
+ """Temporarily swap compilation_config sizes to the shift (unscaled)
+ lookup table so that vLLM internals (dispatcher, pad_for_cudagraph,
+ bounds checks) all see the shift model's sizes."""
+ cc = self.compilation_config
+ saved_sizes = cc.cudagraph_capture_sizes
+ saved_max = cc.max_cudagraph_capture_size
+ saved_table = cc.bs_to_padded_graph_size
+
+ shift_sizes = self.vllm_config._shift_cudagraph_capture_sizes
+ shift_max = self.vllm_config._shift_max_cudagraph_capture_size
+ shift_table = self.vllm_config._shift_bs_to_padded_graph_size
+
+ cc.cudagraph_capture_sizes = shift_sizes
+ cc.max_cudagraph_capture_size = shift_max
+ cc.bs_to_padded_graph_size = shift_table
+ try:
+ yield
+ finally:
+ cc.cudagraph_capture_sizes = saved_sizes
+ cc.max_cudagraph_capture_size = saved_max
+ cc.bs_to_padded_graph_size = saved_table
+
+ def _capture_cudagraphs(
+ self,
+ compilation_cases: list[tuple[int, bool]],
+ cudagraph_runtime_mode: CUDAGraphMode,
+ uniform_decode: bool,
+ ):
+ """
+ Capture CUDA graphs for both base (SP) and shift (TP) variants, splitting
+ shapes by threshold so both models have required graphs captured.
+ """
+ assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
+ cudagraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE]
+
+ sp_size = parallel_state._SP.world_size
+ tp_size = parallel_state._TP.world_size
+ threshold = int(getattr(self, "shift_parallel_threshold", 0))
+ has_shift = getattr(self, "shift_model", None) is not None
+ is_swiftkv = "SwiftKV" in self.model.__class__.__name__
+
+ # --- Base model (Ulysses SP): uses the scaled lookup table (default) ---
+ # Exclude sizes at or below the shift threshold -- those batches
+ # are routed to the shift model at runtime. Capturing them for the
+ # base model would deadlock because the Ulysses all-to-all collectives
+ # diverge across ranks at small batch sizes.
+ if has_shift and not is_swiftkv:
+ compilation_cases_base = [
+ case for case in compilation_cases
+ if self._case_bs(case) > threshold
+ ]
+ else:
+ compilation_cases_base = list(compilation_cases)
+
+ if is_global_first_rank():
+ logger.info(
+ "base model (SP=%s, TP=%s) cudagraph mode %s shapes %s",
+ sp_size, tp_size, cudagraph_runtime_mode,
+ [self._case_bs(c) for c in compilation_cases_base],
+ )
+
+ if compilation_cases_base:
+ self._orig_capture_cudagraphs(
+ compilation_cases_base, cudagraph_runtime_mode, uniform_decode
+ )
+
+ # --- Shift model (SP*TP fused as TP-only): uses the unscaled lookup table ---
+ # The incoming compilation_cases contain *scaled* base sizes (e.g.
+ # [4, 8, ..., 2048] with sp_size=4). The shift model needs the
+ # *unscaled* sizes from its own capture list (e.g. [1, 2, ..., 512]).
+ # We rebuild the cases from _shift_cudagraph_capture_sizes, copying
+ # the non-bs fields (like has_lora) from the first matching base case.
+ if has_shift:
+ shift_sizes = self.vllm_config._shift_cudagraph_capture_sizes
+ # Use the first base case as a template for non-bs fields
+ template = compilation_cases[0] if compilation_cases else None
+ compilation_cases_shift = [
+ self._with_bs(template, bs) if template is not None else bs
+ for bs in reversed(shift_sizes)
+ ]
+
+ if is_global_first_rank():
+ logger.info(
+ "shift model (SPxTP=%s) shapes %s",
+ sp_size * tp_size,
+ [self._case_bs(c) for c in compilation_cases_shift],
+ )
+
+ if compilation_cases_shift:
+ orig_model, self.model = self.model, self.shift_model
+ cc = self.vllm_config.compilation_config
+ base_ctx = cc.static_forward_context
+ shift_ctx = getattr(self, 'shift_forward_context', None)
+ try:
+ if shift_ctx is not None:
+ cc.static_forward_context = shift_ctx
+ _CA_MIN_BS = 8
+ compilation_cases_shift = [
+ c for c in compilation_cases_shift
+ if self._case_bs(c) >= _CA_MIN_BS
+ ]
+ shift_sizes = [
+ s for s in shift_sizes if s >= _CA_MIN_BS
+ ]
+ self.vllm_config._shift_cudagraph_capture_sizes = (
+ shift_sizes)
+ self.vllm_config._shift_max_cudagraph_capture_size = (
+ max(shift_sizes) if shift_sizes else 0)
+ self.vllm_config._shift_bs_to_padded_graph_size = {
+ bs: bs for bs in shift_sizes
+ }
+ for bs in range(1, _CA_MIN_BS):
+ self.vllm_config._shift_bs_to_padded_graph_size[
+ bs] = _CA_MIN_BS
+
+ if is_global_first_rank():
+ logger.info(
+ "shift model: skipping bs < %d for "
+ "ca_comm graph capture (will pad to %d)",
+ _CA_MIN_BS, _CA_MIN_BS,
+ )
+
+ with set_shift_parallel_mode(True), \
+ self._use_shift_cudagraph_tables(), \
+ self._shift_graph_capture_context():
+ self._register_shift_cudagraph_keys(
+ compilation_cases_shift,
+ cudagraph_runtime_mode,
+ )
+ self._orig_capture_cudagraphs(
+ compilation_cases_shift,
+ cudagraph_runtime_mode,
+ uniform_decode,
+ )
+ finally:
+ self.model = orig_model
+ cc.static_forward_context = base_ctx
diff --git a/arctic_inference/vllm/patches.py b/arctic_inference/vllm/patches.py
new file mode 100644
index 000000000..0d13baed7
--- /dev/null
+++ b/arctic_inference/vllm/patches.py
@@ -0,0 +1,302 @@
+# Copyright 2025 Snowflake Inc.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import vllm
+from vllm.logger import init_logger
+from vllm.v1.core.sched.async_scheduler import AsyncScheduler
+from vllm.v1.core.sched.scheduler import Scheduler
+from vllm.v1.engine.core import EngineCoreProc
+from vllm.v1.worker.worker_base import WorkerBase
+
+from arctic_inference.patching import ArcticPatch
+from arctic_inference.utils import get_compatible_vllm_version
+from arctic_inference.vllm.args import EngineArgsPatch, AsyncEngineArgsPatch
+from arctic_inference.vllm.config import (ParallelConfigPatch,
+ SpeculativeConfigPatch,
+ VllmConfigPatch,
+ MLPSpeculatorConfigPatch)
+from arctic_inference.vllm.stats import (SpecDecodingStatsPatch,
+ SpecDecodingLoggingPatch)
+from arctic_inference.vllm.structured_output import XgrammarBackendPatch
+from arctic_inference.vllm.ulysses import apply_shift_parallel_patches
+
+
+logger = init_logger(__name__)
+
+
+class AsyncSchedulerPatch(ArcticPatch[AsyncScheduler]):
+ """Patch AsyncScheduler to:
+ 1. Respect ``disable_by_batch_size`` when allocating spec token
+ placeholders (the worker only drafts for the first N requests).
+ 2. Use the previous step's actual draft length for dynamic placeholder
+ allocation, avoiding wasted verification compute when the real draft
+ width (e.g. Arctic n_predict=3) is much smaller than
+ num_speculative_tokens (e.g. 12).
+ 3. Store ``_scheduled_spec_count`` so that the post-fix in
+ ``update_from_output`` can compensate for worker-side trimming.
+ """
+
+ _orig_update_after_schedule = AsyncScheduler._update_after_schedule
+
+ def _update_after_schedule(self, scheduler_output):
+ # Call the base Scheduler._update_after_schedule (NOT the
+ # AsyncScheduler override which we are replacing).
+ Scheduler._update_after_schedule(self, scheduler_output)
+
+ has_structured_output_requests = False
+ pending_structured_output_tokens = False
+ spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
+
+ # Respect disable_by_batch_size: only add spec token placeholders
+ # for the first N decode requests (matching the worker's draft_limit
+ # in propose_draft_token_ids).
+ spec_config = getattr(self.vllm_config, 'speculative_config', None)
+ disable_bs = (
+ spec_config.disable_by_batch_size if spec_config else None
+ )
+ decode_with_spec_count = 0
+ for req_id in scheduler_output.num_scheduled_tokens:
+ request = self.requests[req_id]
+ has_structured_output_requests |= request.use_structured_output
+ pending_structured_output_tokens |= (
+ request.use_structured_output
+ and request.num_output_placeholders > 0
+ )
+ cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
+ # Store the originally-scheduled spec count so that
+ # update_from_output can compensate for worker-side trimming.
+ request._scheduled_spec_count = cur_num_spec_tokens
+
+ if (
+ request.num_computed_tokens
+ == request.num_tokens
+ + request.num_output_placeholders
+ + cur_num_spec_tokens
+ ):
+ # The request will generate a new token + spec tokens.
+ request.num_output_placeholders += 1 + cur_num_spec_tokens
+
+ # Check if beyond the disable_by_batch_size limit.
+ decode_with_spec_count += 1
+ if disable_bs and decode_with_spec_count > disable_bs:
+ # Beyond limit: no spec token placeholders.
+ request.spec_token_ids = []
+ continue
+
+ # Use previous step's actual draft length to size
+ # placeholders. When suffix had a good match (actual
+ # > n_predict), allocate the full width so the next
+ # step can verify all suffix tokens. When suffix
+ # didn't match (actual = n_predict from arctic),
+ # allocate only that many to avoid wasting attention
+ # compute on zero-padded positions.
+ # Cold start: allocate full width (generous).
+ prev_actual = getattr(
+ request, '_prev_actual_draft_len', None)
+ if prev_actual is not None:
+ num_placeholders = min(
+ max(prev_actual, 1), self.num_spec_tokens)
+ else:
+ num_placeholders = self.num_spec_tokens
+
+ request.spec_token_ids = [-1] * num_placeholders
+
+ scheduler_output.has_structured_output_requests = (
+ has_structured_output_requests)
+ scheduler_output.pending_structured_output_tokens = (
+ pending_structured_output_tokens)
+
+ def update_from_output(self, scheduler_output, model_runner_output):
+ """Wrap Scheduler.update_from_output to store actual draft counts.
+
+ We infer the drafter's real capability from the acceptance
+ results so that the next ``_update_after_schedule`` can size
+ placeholders correctly. This works even when the worker runs
+ in a separate process (where scheduler_output._actual_draft_lens
+ set by the worker doesn't survive serialisation back to the
+ scheduler).
+
+ Strategy:
+ 1. **Primary path**: read ``_actual_draft_lens`` from the
+ ``model_runner_output`` object (attached by the model runner
+ to the ``ModelRunnerOutput`` dataclass, which reliably
+ survives the async pipeline).
+ 2. **Legacy path**: read from ``scheduler_output._actual_draft_lens``
+ (works in same-process non-async mode).
+ 3. **Fallback**: infer from acceptance results with exponential
+ growth — when all drafted tokens are accepted, double the
+ allocation so suffix decoding reaches full capacity in
+ O(log n) steps instead of O(n).
+ """
+ sampled_token_ids = model_runner_output.sampled_token_ids
+ req_id_to_index = model_runner_output.req_id_to_index
+
+ result = Scheduler.update_from_output(
+ self, scheduler_output, model_runner_output)
+
+ # Primary path: read from model_runner_output (most reliable
+ # for async scheduling — the ModelRunnerOutput object is
+ # returned by get_output() and guaranteed to survive).
+ actual_lens = getattr(
+ model_runner_output, '_actual_draft_lens', None)
+
+ # Legacy path: read from scheduler_output (works for non-async
+ # or same-process setups where the attribute is preserved).
+ if not actual_lens:
+ actual_lens = getattr(
+ scheduler_output, '_actual_draft_lens', None)
+
+ if actual_lens:
+ for req_id, actual_len in actual_lens.items():
+ request = self.requests.get(req_id)
+ if request is not None:
+ request._prev_actual_draft_len = actual_len
+ return result
+
+ # Fallback: infer from acceptance results (multi-process case).
+ # Uses exponential growth when all drafted tokens are accepted
+ # (indicating the drafter / suffix cache can handle more), so
+ # the allocation converges to num_spec_tokens in O(log n)
+ # steps:
+ # step 0: 1 position → accept 1/1 → prev = 2
+ # step 1: 2 positions → accept 2/2 → prev = 4
+ # step 2: 4 positions → accept 4/4 → prev = 8
+ # ...
+ # When not all are accepted (normal drafter), linear growth:
+ # step 0: 1 position → accept 1 → prev = 2
+ # step 1: 2 positions → accept 2 → prev = 3
+ # step 2: 3 positions → steady state (n_predict = 3)
+ if not sampled_token_ids:
+ return result
+
+ for req_id in scheduler_output.num_scheduled_tokens:
+ scheduled_spec = (
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id))
+ if not scheduled_spec:
+ continue
+
+ req_index = req_id_to_index.get(req_id)
+ if req_index is None:
+ continue
+
+ request = self.requests.get(req_id)
+ if request is None:
+ continue
+
+ generated = sampled_token_ids[req_index]
+ num_accepted = (len(generated) - 1) if generated else 0
+ num_draft = len(scheduled_spec)
+
+ prev = getattr(request, '_prev_actual_draft_len', None)
+
+ if num_accepted > 0:
+ if num_accepted >= num_draft and num_draft > 0:
+ # All drafted tokens accepted — the drafter (or
+ # suffix cache) could produce more if given room.
+ # Double the allocation for exponential ramp-up.
+ new_val = min(
+ num_draft * 2, self.num_spec_tokens)
+ else:
+ # Partial acceptance: grow linearly.
+ new_val = min(
+ num_accepted + 1, self.num_spec_tokens)
+ request._prev_actual_draft_len = max(
+ prev or 0, new_val)
+ elif prev is None:
+ # First step for this request, zero acceptance.
+ # Seed with 1 so _update_after_schedule doesn't
+ # fall back to the full num_spec_tokens next time.
+ request._prev_actual_draft_len = 1
+
+ return result
+
+
+class EngineCoreProcPatch(ArcticPatch[EngineCoreProc]):
+
+ _orig_run_engine_core = EngineCoreProc.run_engine_core
+
+ @staticmethod
+ def run_engine_core(*args, **kwargs):
+ # When starting the API server, it will spawn a new process to run the
+ # EngineCore. We need to load the plugins in the new process before it
+ # initializes the Executor.
+ vllm.plugins.load_general_plugins()
+ return EngineCoreProcPatch._orig_run_engine_core(*args, **kwargs)
+
+
+class WorkerBasePatch(ArcticPatch[WorkerBase]):
+
+ _orig_init = WorkerBase.__init__
+
+ def __init__(self, *args, **kwargs):
+ # Some patches like the GPUModelRunner will import CUDA libraries when
+ # they are initialized, which will cause process forking to fail. For
+ # these patches, we need to delay the initialization until after the
+ # process has been forked (i.e., in the WorkerBase initializer).
+ from arctic_inference.vllm.model_runner import GPUModelRunnerPatch
+
+ GPUModelRunnerPatch.apply_patch()
+
+ return self._orig_init(*args, **kwargs)
+
+
+def apply_arctic_patches():
+
+ from transformers import AutoConfig
+ from arctic_inference.common.swiftkv import LlamaSwiftKVConfig
+
+ # Register SwiftKV model configurations to transformers.
+ AutoConfig.register("llama_swiftkv", LlamaSwiftKVConfig)
+
+ from vllm import ModelRegistry
+ #from arctic_inference.vllm.swiftkv import LlamaSwiftKVForCausalLM
+
+ # Register SwiftKV model definitions to vLLM.
+ ModelRegistry.register_model(
+ "LlamaSwiftKVForCausalLM",
+ "arctic_inference.vllm.swiftkv:LlamaSwiftKVForCausalLM")
+
+ # Register ArcticSpeculator models to vLLM.
+ from arctic_inference.vllm.spec_dec.arctic_speculator import (
+ ArcticMLPSpeculator, ArcticLSTMSpeculator)
+ ModelRegistry.register_model("ArcticMLPSpeculatorPreTrainedModel",
+ ArcticMLPSpeculator)
+ ModelRegistry.register_model("ArcticLSTMSpeculatorPreTrainedModel",
+ ArcticLSTMSpeculator)
+ # This name is currently used in corvo
+ ModelRegistry.register_model("MLPVariantSpeculatorPreTrainedModel",
+ ArcticLSTMSpeculator)
+
+ # Patches that make later patches work properly.
+ EngineCoreProcPatch.apply_patch()
+ WorkerBasePatch.apply_patch()
+
+ # Async scheduler patches for spec decode (disable_by_batch_size
+ # interaction + dynamic draft width allocation).
+ AsyncSchedulerPatch.apply_patch()
+
+ # Patches to vLLM arguments and configuration objects.
+ EngineArgsPatch.apply_patch()
+ AsyncEngineArgsPatch.apply_patch()
+ ParallelConfigPatch.apply_patch()
+ SpeculativeConfigPatch.apply_patch()
+ SpecDecodingStatsPatch.apply_patch()
+ SpecDecodingLoggingPatch.apply_patch()
+ VllmConfigPatch.apply_patch()
+ XgrammarBackendPatch.apply_patch()
+ MLPSpeculatorConfigPatch.apply_patch()
+
+ # Main optimization patches.
+ apply_shift_parallel_patches()
diff --git a/arctic_inference/vllm/plugin.py b/arctic_inference/vllm/plugin.py
new file mode 100644
index 000000000..afe100ce4
--- /dev/null
+++ b/arctic_inference/vllm/plugin.py
@@ -0,0 +1,45 @@
+# Copyright 2025 Snowflake Inc.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+import vllm
+
+import arctic_inference.envs as envs
+from arctic_inference.utils import get_compatible_vllm_version
+
+
+def arctic_inference_plugin():
+ if not envs.ARCTIC_INFERENCE_ENABLED:
+ return
+
+ if not envs.ARCTIC_INFERENCE_SKIP_VERSION_CHECK:
+ compatible_version = get_compatible_vllm_version()
+ if vllm.__version__ != compatible_version:
+ raise RuntimeError(
+ f"Arctic Inference plugin requires vllm=={compatible_version} "
+ f"but found vllm=={vllm.__version__}!")
+
+ if not envs.ARCTIC_INFERENCE_SKIP_PLATFORM_CHECK:
+ if not vllm.platforms.current_platform.is_cuda():
+ raise RuntimeError(
+ f"Arctic Inference plugin requires the cuda platform!")
+
+ print("\x1b[36;1mArctic Inference plugin is enabled!\x1b[0m",
+ file=sys.stderr)
+
+ # Lazy import to avoid potential errors when the plugin is disabled.
+ from .patches import apply_arctic_patches
+ apply_arctic_patches()
diff --git a/arctic_inference/vllm/plugins.py b/arctic_inference/vllm/plugins.py
deleted file mode 100644
index 0f34a3999..000000000
--- a/arctic_inference/vllm/plugins.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# Copyright 2025 Snowflake Inc.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import vllm
-from vllm.logger import init_logger
-from vllm.v1.engine.core import EngineCoreProc
-from vllm.v1.worker.worker_base import WorkerBase
-
-from arctic_inference.patching import ArcticPatch
-from arctic_inference.utils import get_compatible_vllm_version
-from arctic_inference.vllm.args import EngineArgsPatch, AsyncEngineArgsPatch
-from arctic_inference.vllm.config import (ParallelConfigPatch,
- SpeculativeConfigPatch,
- VllmConfigPatch,
- MLPSpeculatorConfigPatch)
-from arctic_inference.vllm.stats import (SpecDecodingStatsPatch,
- SpecDecodingLoggingPatch)
-from arctic_inference.vllm.ulysses import apply_shift_parallel_patches
-
-
-logger = init_logger(__name__)
-
-
-class EngineCoreProcPatch(ArcticPatch[EngineCoreProc]):
-
- _orig_run_engine_core = EngineCoreProc.run_engine_core
-
- @staticmethod
- def run_engine_core(*args, **kwargs):
- # When starting the API server, it will spawn a new process to run the
- # EngineCore. We need to load the plugins in the new process before it
- # initializes the Executor.
- vllm.plugins.load_general_plugins()
- return EngineCoreProcPatch._orig_run_engine_core(*args, **kwargs)
-
-
-class WorkerBasePatch(ArcticPatch[WorkerBase]):
-
- _orig_init = WorkerBase.__init__
-
- def __init__(self, *args, **kwargs):
- # Some patches like the GPUModelRunner will import CUDA libraries when
- # they are initialized, which will cause process forking to fail. For
- # these patches, we need to delay the initialization until after the
- # process has been forked (i.e., in the WorkerBase initializer).
- from arctic_inference.vllm.model_runner import GPUModelRunnerPatch
-
- GPUModelRunnerPatch.apply_patch()
-
- return self._orig_init(*args, **kwargs)
-
-
-def arctic_inference_plugin():
- if (vllm.__version__ != get_compatible_vllm_version() and not
- vllm.__version__.startswith("0.1.dev")): # Make it work with dev
- logger.warning(
- f"ArcticInference requires vllm=={get_compatible_vllm_version()} "
- f"but found vllm=={vllm.__version__}. Ignoring plugin!")
- return
-
- if not vllm.platforms.current_platform.is_cuda():
- logger.warning(
- f"ArcticInference requires the cuda platform. Ignoring plugin!")
- return
-
- if os.getenv("VLLM_USE_V1") == "0":
- logger.warning("ArcticInference only supports vLLM V1, but detected V0 engine. "
- "Ignoring plugin!\n"
- "Hint: To strictly enforce the V1 vLLM engine, please set "
- "VLLM_USE_V1=1.")
- return
-
- from transformers import AutoConfig
- from arctic_inference.common.swiftkv import LlamaSwiftKVConfig
-
- # Register SwiftKV model configurations to transformers.
- AutoConfig.register("llama_swiftkv", LlamaSwiftKVConfig)
-
- from vllm import ModelRegistry
- #from arctic_inference.vllm.swiftkv import LlamaSwiftKVForCausalLM
-
- # Register SwiftKV model definitions to vLLM.
- ModelRegistry.register_model(
- "LlamaSwiftKVForCausalLM",
- "arctic_inference.vllm.swiftkv:LlamaSwiftKVForCausalLM")
-
- # Register ArcticSpeculator models to vLLM.
- from arctic_inference.vllm.spec_dec.arctic_speculator import (
- ArcticMLPSpeculator, ArcticLSTMSpeculator)
- ModelRegistry.register_model("ArcticMLPSpeculatorPreTrainedModel",
- ArcticMLPSpeculator)
- ModelRegistry.register_model("ArcticLSTMSpeculatorPreTrainedModel",
- ArcticLSTMSpeculator)
- # This name is currently used in corvo
- ModelRegistry.register_model("MLPVariantSpeculatorPreTrainedModel",
- ArcticLSTMSpeculator)
-
- # Patches that make later patches work properly.
- EngineCoreProcPatch.apply_patch()
- WorkerBasePatch.apply_patch()
-
- # Patches to vLLM arguments and configuration objects.
- EngineArgsPatch.apply_patch()
- AsyncEngineArgsPatch.apply_patch()
- ParallelConfigPatch.apply_patch()
- SpeculativeConfigPatch.apply_patch()
- SpecDecodingStatsPatch.apply_patch()
- SpecDecodingLoggingPatch.apply_patch()
- VllmConfigPatch.apply_patch()
- MLPSpeculatorConfigPatch.apply_patch()
-
- # Main optimization patches.
- apply_shift_parallel_patches()
diff --git a/arctic_inference/vllm/spec_dec/arctic_proposer.py b/arctic_inference/vllm/spec_dec/arctic_proposer.py
index 5d7612a33..4f471f66d 100644
--- a/arctic_inference/vllm/spec_dec/arctic_proposer.py
+++ b/arctic_inference/vllm/spec_dec/arctic_proposer.py
@@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Union
+from typing import Optional, Union, List
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model
+from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_model_runner import logger
+from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
+from vllm.v1.utils import CpuGpuBuffer
+from vllm.utils.platform_utils import is_pin_memory_available
import numpy as np
import torch
@@ -39,6 +43,9 @@ def __init__(
self.model = None
self.device = None
+ self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
+ self.backup_next_token_ids = None # type: Optional[CpuGpuBuffer]
+
def load_model(
self,
model: Union[ArcticMLPSpeculator, ArcticLSTMSpeculator],
@@ -100,67 +107,195 @@ def load_model(
model_config=draft_config_model_config,
quant_config=draft_config_quant_config,
parallel_config=draft_config_parallel_config,
+ scheduler_config=self.vllm_config.scheduler_config,
+ speculative_config=self.speculative_config,
load_config=self.vllm_config.load_config,
device_config=self.vllm_config.device_config,
)
self.model = get_model(vllm_config=draft_worker_config)
- self.device = next(model.parameters()).device
+ self.device = next(self.model.parameters()).device
self.input_hidden_dim = self.model.input_hidden_dim if isinstance(
self.model, ArcticLSTMSpeculator) else self.model.emb_dim
+ self.backup_next_token_ids = CpuGpuBuffer(
+ self.max_batch_size,
+ dtype=torch.int32,
+ pin_memory=is_pin_memory_available(),
+ device=self.device,
+ with_numpy=True,
+ )
+
def prepare_hidden_states(
self,
sample_hidden_states: torch.Tensor,
- sampled_token_ids: Union[np.ndarray, list[list[int]]],
+ sampled_token_ids: Union[torch.Tensor, np.ndarray, List[List[int]]],
spec_decode_metadata: SpecDecodeMetadata,
) -> torch.Tensor:
- if sample_hidden_states is not None:
- assert sample_hidden_states.shape[-1] == self.input_hidden_dim, \
- f"hidden_states shape mismatch: {sample_hidden_states.shape[-1]} != {self.input_hidden_dim}. \
- Please make sure spec model is trained using the same base model."
-
- # if isinstance(sampled_token_ids, list):
- # # Pad the list of lists to create a uniform tensor
- # max_len = max(len(x) for x in sampled_token_ids) if sampled_token_ids else 0
- # if max_len == 0:
- # return sample_hidden_states
- # padded_ids = [l + [-1] * (max_len - len(l)) for l in sampled_token_ids]
- # sampled_token_ids = torch.tensor(padded_ids,
- # device=sample_hidden_states.device)
+ assert sample_hidden_states is not None, "sample_hidden_states must be provided"
+
+ if isinstance(sampled_token_ids, np.ndarray):
+ sampled_token_ids = torch.as_tensor(
+ sampled_token_ids, device=sample_hidden_states.device, dtype=torch.long
+ )
+ elif isinstance(sampled_token_ids, list):
+ sampled_token_ids = torch.as_tensor(
+ sampled_token_ids, device=sample_hidden_states.device, dtype=torch.long
+ )
+ elif sampled_token_ids.device != sample_hidden_states.device:
+ sampled_token_ids = sampled_token_ids.to(sample_hidden_states.device, non_blocking=True)
max_gen_len = sampled_token_ids.shape[-1]
- if max_gen_len == 1:
+ num_requests = sampled_token_ids.shape[0]
+ if max_gen_len == 1 and sample_hidden_states.shape[0] == num_requests:
+ # Fast path: one row per request, no index-select needed.
return sample_hidden_states
assert spec_decode_metadata is not None
- valid_mask = sampled_token_ids != -1
- gen_lens = valid_mask.sum(dim=1)
- num_sampled_tokens = np.array(spec_decode_metadata.num_draft_tokens)
- num_sampled_tokens = torch.tensor(num_sampled_tokens,
- device=gen_lens.device) + 1
- hidden_states_idx = (gen_lens - 1) + torch.cumsum(
- num_sampled_tokens, 0) - num_sampled_tokens
- previous_hidden_states = sample_hidden_states[hidden_states_idx]
+ if hasattr(spec_decode_metadata, "cu_num_draft_tokens") and spec_decode_metadata.cu_num_draft_tokens is not None:
+ cu = spec_decode_metadata.cu_num_draft_tokens
+ num_draft_tokens_gpu = torch.cat([cu[0:1], cu[1:] - cu[:-1]])
+ else:
+ num_draft_tokens_gpu = torch.as_tensor(
+ spec_decode_metadata.num_draft_tokens,
+ device=sample_hidden_states.device,
+ dtype=torch.int64
+ )
+
+ num_processed_tokens_per_req = num_draft_tokens_gpu + 1
+
+ offsets = torch.cumsum(num_processed_tokens_per_req, dim=0) - num_processed_tokens_per_req
+
+ vocab_size = self.vllm_config.model_config.get_vocab_size()
+ valid_mask = (sampled_token_ids != -1) & (sampled_token_ids < vocab_size)
+ gen_lens = valid_mask.sum(dim=1).to(dtype=torch.int64)
+
+ last_valid = torch.clamp(gen_lens - 1, min=0)
+ hidden_states_idx = offsets + last_valid
+
+ previous_hidden_states = sample_hidden_states.index_select(
+ dim=0, index=hidden_states_idx
+ )
+
+ assert previous_hidden_states.size(-1) == self.input_hidden_dim, (
+ f"hidden_states dim {previous_hidden_states.size(-1)} != speculator expected {self.input_hidden_dim}. "
+ "Make sure the spec model is trained with the same base model."
+ )
+
return previous_hidden_states
def propose(
self,
- context_token_ids: np.ndarray,
- previous_hidden_states: torch.Tensor,
+ context_token_ids: Union[torch.Tensor, np.ndarray, List[int]],
+ previous_hidden_states: Optional[torch.Tensor],
num_predict_tokens: int,
- ) -> Optional[np.ndarray]:
- assert num_predict_tokens > 0, \
- f"num_predict_tokens must be greater than 0, got {num_predict_tokens}."
-
- input_ids = torch.tensor(context_token_ids, device=self.device)
+ ) -> Optional[torch.Tensor]:
+ assert num_predict_tokens > 0
+ if isinstance(context_token_ids, torch.Tensor):
+ if context_token_ids.device != self.device:
+ input_ids = context_token_ids.to(self.device, non_blocking=True)
+ else:
+ input_ids = context_token_ids
+ else:
+ input_ids = torch.as_tensor(context_token_ids, device=self.device, dtype=torch.long)
next_tokens = self.model.generate_proposals(
input_ids=input_ids,
previous_hidden_states=previous_hidden_states,
num_predict_tokens=num_predict_tokens,
)
+ return next_tokens
- return next_tokens.cpu().numpy()
+ # Borrow from eagle
+ def prepare_next_token_ids_cpu(
+ self,
+ sampled_token_ids: list[list[int]],
+ requests: dict[str, CachedRequestState],
+ gpu_input_batch: InputBatch,
+ num_scheduled_tokens: dict[str, int],
+ ) -> torch.Tensor:
+ req_ids = gpu_input_batch.req_ids
+ next_token_ids: list[int] = []
+ for i, token_ids in enumerate(sampled_token_ids):
+ if token_ids:
+ # Common case.
+ next_token_id = token_ids[-1]
+ else:
+ # Partial prefill (rare case).
+ # Get the next token id from the request state.
+ req_id = req_ids[i]
+ req_state = requests[req_id]
+ seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
+ next_token_id = req_state.get_token_id(seq_len)
+ next_token_ids.append(next_token_id)
+ next_token_ids = torch.tensor(
+ next_token_ids, dtype=torch.int32, device=self.device
+ )
+ return next_token_ids
+
+
+ def prepare_next_token_ids_padded(
+ self,
+ common_attn_metadata: CommonAttentionMetadata,
+ sampled_token_ids: torch.Tensor,
+ requests: dict[str, CachedRequestState],
+ gpu_input_batch: InputBatch,
+ discard_request_mask: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ from vllm.triton_utils import triton
+ from vllm.v1.spec_decode.utils import eagle_prepare_next_token_padded_kernel
+
+ num_reqs = gpu_input_batch.num_reqs
+ self.backup_next_token_ids.np[:num_reqs] = np.array(
+ [
+ requests[gpu_input_batch.req_ids[i]].get_token_id(
+ common_attn_metadata.seq_lens_cpu[i].item()
+ )
+ for i in range(num_reqs)
+ ],
+ dtype=np.int32,
+ )
+ self.backup_next_token_ids.copy_to_gpu(num_reqs)
+ backup_tokens_gpu = self.backup_next_token_ids.gpu
+
+ batch_size, num_tokens = sampled_token_ids.shape
+ device = sampled_token_ids.device
+
+ assert discard_request_mask.dtype == torch.bool
+ assert backup_tokens_gpu.dtype == torch.int32
+
+ next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
+ valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
+
+ # Kernel grid: one program per request (row)
+ grid = (batch_size,)
+
+ # Find the next power of 2 for block sizes
+ BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
+ eagle_prepare_next_token_padded_kernel[grid](
+ sampled_token_ids,
+ discard_request_mask,
+ backup_tokens_gpu,
+ next_token_ids,
+ valid_sampled_tokens_count,
+ gpu_input_batch.vocab_size,
+ num_tokens,
+ batch_size,
+ sampled_token_ids.stride(0),
+ BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
+ )
+
+ return next_token_ids, valid_sampled_tokens_count
+
+
+class SuffixProposer:
+ def __init__(self):
+ pass
+
+ def load_model(
+ self,
+ model: None,
+ ):
+ pass
diff --git a/arctic_inference/vllm/spec_dec/arctic_speculator.py b/arctic_inference/vllm/spec_dec/arctic_speculator.py
index ecd857e53..22c4a09d2 100644
--- a/arctic_inference/vllm/spec_dec/arctic_speculator.py
+++ b/arctic_inference/vllm/spec_dec/arctic_speculator.py
@@ -22,7 +22,7 @@
from vllm.config import VllmConfig
from arctic_inference.vllm.spec_dec.logits_processor_opt import LogitsProcessorOpt
-from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.v1.outputs import SamplerOutput
from arctic_inference.vllm.spec_dec.fp8 import (Fp8ConfigWithEmbedding,
OriginalFp8LinearMethod)
@@ -33,7 +33,11 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from arctic_inference.py_custom_ops import (try_load_torch_library,
+ speculator_ln, sum_lstm)
+
SQRT2 = 2**0.5
+USE_CUSTOM_OP = try_load_torch_library()
def padding_size(size: int) -> int:
@@ -85,7 +89,7 @@ def __init__(
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps
- def forward(self, x):
+ def forward_fallback(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
@@ -94,6 +98,22 @@ def forward(self, x):
x = x + self.bias
return x
+ def forward_opt(self, x):
+ return speculator_ln(
+ x,
+ self.weight if self.elementwise_scale_and_shift else None,
+ self.bias if self.elementwise_scale_and_shift else None,
+ float(self.eps),
+ )
+
+ def forward(self, x):
+ if USE_CUSTOM_OP:
+ return self.forward_opt(x)
+ else:
+ return self.forward_fallback(x)
+
+
+
def _generate_cg_key(padding_size: int, head_index: int):
return (padding_size << 16) + head_index
@@ -222,13 +242,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
scale=1.0,
skip_last_gather=True,
)
- self.sampler = get_sampler()
self.cuda_graph_max_batch_size = 0
self.cuda_graph_mode = False
if not vllm_config.model_config.enforce_eager:
self.cuda_graph_mode = True
self.cuda_graphs = {}
+ spec_config = vllm_config.speculative_config
+ disable_by_batch_size = (
+ spec_config.disable_by_batch_size
+ if spec_config is not None and spec_config.disable_by_batch_size is not None
+ else vllm_config.scheduler_config.max_num_seqs
+ )
+ self.cuda_graph_max_batch_size = padding_size(disable_by_batch_size)
self.cuda_graph_max_batch_size = padding_size(
vllm_config.scheduler_config.max_num_seqs)
self.static_cuda_buffers = {
@@ -315,6 +341,7 @@ def generate_token_ids(
argidx = torch.argmax(vals, -1).reshape(batch_size, -1)
last_tokens = torch.gather(indices, -1, argidx)
+ last_tokens.clamp_(0, self.vocab_size - 1)
if next_tokens_tensors[head_index] == None:
next_tokens_tensors[head_index] = last_tokens
else:
@@ -421,11 +448,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.input_hidden_dim = config.input_hidden_dim
- config.inner_dim = [int(i) for i in config.inner_dim.split(".")]
+
+ def _parse_dim(value):
+ """Helper to normalize dimension config into a list of ints."""
+ if isinstance(value, str):
+ return [int(i) for i in value.split(".")]
+ elif isinstance(value, int):
+ return [value]
+ elif isinstance(value, list):
+ return [int(i) for i in value]
+ return value
+
+ config.inner_dim = _parse_dim(config.inner_dim)
self.inner_dim = config.inner_dim
- config.emb_dim = [int(i) for i in config.emb_dim.split(".")]
- self.emb_dim = config.emb_dim
- config.proj_dim = [int(i) for i in config.proj_dim.split(".")]
+
+ config.emb_dim = _parse_dim(config.emb_dim)
+ self.emb_dim = config.emb_dim
+
+ config.proj_dim = _parse_dim(config.proj_dim)
self.proj_dim = config.proj_dim
self.max_speculative_tokens = config.num_lookahead_tokens
@@ -582,13 +622,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
scale=1.0,
skip_last_gather=True,
)
- self.sampler = get_sampler()
self.cuda_graph_max_batch_size = 0
self.cuda_graph_mode = False
- self.cuda_graph_max_batch_size = padding_size(
- vllm_config.scheduler_config.max_num_seqs)
+ spec_config = vllm_config.speculative_config
+ disable_by_batch_size = (
+ spec_config.disable_by_batch_size
+ if spec_config is not None and spec_config.disable_by_batch_size is not None
+ else vllm_config.scheduler_config.max_num_seqs
+ )
+ self.cuda_graph_max_batch_size = padding_size(disable_by_batch_size)
+
self.static_cuda_buffers = {
"last_tokens":
torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long),
@@ -622,6 +667,7 @@ def _prepare_cuda_graph_ios(
cell_states=None,
use_lstm=False,
):
+ # TODO: optimize this
self.static_cuda_buffers["last_tokens"][:size] = last_tokens
if cell_states is not None:
self.static_cuda_buffers["cell_states"][:size] = cell_states
@@ -664,32 +710,69 @@ def generate_states(
prev_state = previous_hidden_states
- z = self.forget_emb[actual_i](last_tokens).repeat(1, 1, 4) # b n d
+ z4 = self.forget_emb[actual_i](last_tokens).repeat(1, 1, 4) # b n d
states = self.projs[actual_proj_i](prev_state)
- added_states = torch.add(states,
- z,
- alpha=self.emb_weight / self.state_weight)
- forget_input_output, cell_candidate = added_states.split(
- [self.proj_dim[0] * 3, self.proj_dim[0]], dim=-1)
- forget_gate, input_gate, output_gate = torch.sigmoid(
- forget_input_output).split(
- [self.proj_dim[0], self.proj_dim[0], self.proj_dim[0]],
- dim=-1)
+ if USE_CUSTOM_OP:
+ # Shapes:
+ # prev_state: [B, 1, D_eff] (e.g., 2880 in the first round and 4096 later)
+ # states: [B, 1, 4*D_gate] (e.g., 4*4096)
+ # z4: [B, 1, 4*D_gate]
+ states_4d = states.flatten(0, 1).contiguous() # [B, 4*D_gate]
+ z4_4d = z4.flatten(0, 1).contiguous() # [B, 4*D_gate]
+
+ orig_cell_shape = cell_states.shape # [B, 1, D_gate]
+ pc_d = cell_states.flatten(0, 1).contiguous() # [B, D_gate]
+
+ # Optional precondition checks that mirror the kernel's TORCH_CHECKs:
+ assert states_4d.size(-1) % 4 == 0
+ assert z4_4d.size(-1) == states_4d.size(-1)
+ assert pc_d.size(-1) == states_4d.size(-1) // 4
+
+ w_cell = self.cell_ln[actual_i].weight
+ b_cell = self.cell_ln[actual_i].bias
+ w_state = self.state_ln[actual_i].weight
+ b_state = self.state_ln[actual_i].bias
+
+ alpha = float(self.emb_weight / self.state_weight)
+ eps_cell = float(self.cell_ln[actual_i].eps)
+ eps_state = float(self.state_ln[actual_i].eps)
+ use_fast_gelu = False
+
+ state_d, cell_d = sum_lstm(
+ states_4d, z4_4d, pc_d,
+ w_cell, b_cell, w_state, b_state,
+ alpha, eps_cell, eps_state, use_fast_gelu
+ )
+
+ state = state_d.reshape(orig_cell_shape) # [B, 1, D_gate]
+ cell_states = cell_d.reshape(orig_cell_shape) # [B, 1, D_gate]
+
+ return state, cell_states
+ else:
+ added_states = torch.add(states,
+ z4,
+ alpha=self.emb_weight / self.state_weight)
- cell_candidate = self.activation(
- self.cell_ln[actual_i](cell_candidate)) # b n d
- cell_candidate = cell_candidate * input_gate
+ forget_input_output, cell_candidate = added_states.split(
+ [self.proj_dim[0] * 3, self.proj_dim[0]], dim=-1)
+ forget_gate, input_gate, output_gate = torch.sigmoid(
+ forget_input_output).split(
+ [self.proj_dim[0], self.proj_dim[0], self.proj_dim[0]],
+ dim=-1)
- cell_states = cell_states * forget_gate
- cell_states = cell_states + cell_candidate
+ cell_candidate = self.activation(
+ self.cell_ln[actual_i](cell_candidate)) # b n d
+ cell_candidate = cell_candidate * input_gate
- state_candidate = self.activation(
- self.state_ln[actual_i](cell_states))
- state = state_candidate * output_gate
+ cell_states = cell_states * forget_gate
+ cell_states = cell_states + cell_candidate
- return state, cell_states
+ state_candidate = self.activation(
+ self.state_ln[actual_i](cell_states))
+ state = state_candidate * output_gate
+ return state, cell_states
else:
# Project and predict
z = self.emb[actual_i](last_tokens) # b k d
@@ -712,6 +795,7 @@ def generate_token_ids(
next_tokens_tensors: List[torch.Tensor],
cell_states: torch.Tensor = None,
) -> torch.Tensor:
+ last_tokens.clamp_(0, self.vocab_size - 1)
for head_index in range(num_predict_tokens):
if self.method == "sum_lstm":
states, cell_states = self.generate_states(
@@ -731,6 +815,7 @@ def generate_token_ids(
last_tokens = torch.argmax(logits,
dim=-1).reshape(batch_size, -1)
else:
+ # TODO: fuse topk + all_gather
vals, indices = torch.topk(logits, 1, dim=-1)
indices = indices + self.tp_rank * logits.shape[-1]
@@ -743,6 +828,7 @@ def generate_token_ids(
argidx = torch.argmax(vals, -1).reshape(batch_size, -1)
last_tokens = torch.gather(indices, -1, argidx)
+ last_tokens.clamp_(0, self.vocab_size - 1)
if next_tokens_tensors[head_index] == None:
next_tokens_tensors[head_index] = last_tokens
else:
@@ -817,6 +903,9 @@ def generate_proposals(
if g is None:
device = torch.cuda.current_device()
+ for i in range(num_predict_tokens):
+ self.static_cuda_buffers["next_tokens"][i][:padded_size] = torch.zeros(
+ (padded_size, 1), dtype=torch.long, device=device)
with graph_capture(device=device) as capture_context:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=capture_context.stream):
@@ -874,9 +963,15 @@ def maybe_load_weight(self, param, loaded_weight):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = collections.OrderedDict(weights)
if self.method == "sum_lstm" and self.tie_lstm_embs:
- weights.pop("input_emb.0.weight")
- weights.pop("cell_emb.0.weight")
- weights.pop("output_emb.0.weight")
+ try:
+ weights.pop("input_emb.0.weight")
+ weights.pop("cell_emb.0.weight")
+ weights.pop("output_emb.0.weight")
+ except KeyError:
+ # If the weights are not present, it means they are not tied
+ # and we should not try to pop them.
+ print("No tied LSTM embeddings found, skipping.")
+ pass
for name, param in self.named_parameters():
if "projs." in name:
print(f"REPLACING {name}")
diff --git a/arctic_inference/vllm/spec_dec/fp8.py b/arctic_inference/vllm/spec_dec/fp8.py
index 7366d9127..b2a749694 100644
--- a/arctic_inference/vllm/spec_dec/fp8.py
+++ b/arctic_inference/vllm/spec_dec/fp8.py
@@ -66,9 +66,7 @@ def __init__(self, quant_config: Fp8Config):
# Marlin doesn't support block-wise fp8
self.use_marlin = False
- self.fp8_linear = Fp8LinearOp(
- # Default to using per_token quantization if cutlass is supported
- use_per_token_if_dynamic=cutlass_fp8_supported())
+ self.fp8_linear = Fp8LinearOp(act_quant_static=False)
def create_weights(
self,
diff --git a/arctic_inference/vllm/spec_dec/logits_processor_opt.py b/arctic_inference/vllm/spec_dec/logits_processor_opt.py
index 4516a8656..b90c9ba86 100644
--- a/arctic_inference/vllm/spec_dec/logits_processor_opt.py
+++ b/arctic_inference/vllm/spec_dec/logits_processor_opt.py
@@ -44,8 +44,7 @@ def __init__(self,
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
- self.use_gather = not current_platform.is_tpu(
- ) and not envs.VLLM_USE_V1
+ self.use_gather = False
self.skip_last_gather = skip_last_gather
diff --git a/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py b/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py
index 149f40a3d..0b766229d 100644
--- a/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py
+++ b/arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py
@@ -408,19 +408,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
- if current_platform.is_hpu():
- # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
- # so we're using a workaround. Remove this when fixed in
- # HPU PT bridge.
- padded_weight = torch.cat([
- loaded_weight,
- torch.zeros(param.shape[0] - loaded_weight.shape[0],
- *loaded_weight.shape[1:])
- ])
- param.data.copy_(padded_weight)
- else:
- param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
- param[loaded_weight.shape[0]:].data.fill_(0)
+ param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
+ param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:
@@ -432,7 +421,12 @@ def forward(self, input_):
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
- masked_input = input_
+ # For tp_size==1 there is no masking, so clamp to the
+ # valid embedding range. Async scheduling + spec decode
+ # can leave -1 sentinel tokens in input_ids that would
+ # otherwise crash F.embedding.
+ masked_input = input_.clamp(
+ 0, self.num_embeddings_per_partition - 1)
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
diff --git a/arctic_inference/vllm/structured_output.py b/arctic_inference/vllm/structured_output.py
new file mode 100644
index 000000000..9409ea784
--- /dev/null
+++ b/arctic_inference/vllm/structured_output.py
@@ -0,0 +1,38 @@
+# Copyright 2025 Snowflake Inc.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from vllm.logger import init_logger
+from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
+
+from arctic_inference.patching import ArcticPatch
+
+logger = init_logger(__name__)
+
+
+class XgrammarBackendPatch(ArcticPatch[XgrammarBackend]):
+ """Patch for XgrammarBackend to handle additional structured output."""
+
+ _orig_post_init = XgrammarBackend.__post_init__
+
+ def __post_init__(self):
+ self._orig_post_init()
+
+ if self.vllm_config.speculative_config is not None:
+ self.num_speculative_tokens = \
+ max(self.vllm_config.speculative_config.num_speculative_tokens,
+ self.vllm_config.speculative_config.suffix_speculative_tokens)
+
+ logger.info(f"XgrammarBackendPatch: num_speculative_tokens="
+ f"{self.num_speculative_tokens}")
diff --git a/arctic_inference/vllm/swiftkv/llama_swiftkv.py b/arctic_inference/vllm/swiftkv/llama_swiftkv.py
index 93f003579..675dc4aa8 100644
--- a/arctic_inference/vllm/swiftkv/llama_swiftkv.py
+++ b/arctic_inference/vllm/swiftkv/llama_swiftkv.py
@@ -20,10 +20,12 @@
from torch import nn
import vllm.distributed.parallel_state as parallel_state
-from vllm.attention.backends.abstract import AttentionType
+from vllm.v1.attention.backend import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
-from vllm.forward_context import ForwardContext, get_forward_context
+from vllm.config.compilation import CUDAGraphMode
+from vllm.forward_context import (BatchDescriptor, ForwardContext,
+ get_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -40,19 +42,20 @@
LlamaMLP)
from vllm.model_executor.models.utils import (AutoWeightsLoader,
maybe_prefix)
-from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.v1.sample.metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
# Add FlashInfer backend detection
try:
from vllm.v1.attention.backends.flashinfer import FlashInferMetadata
FLASHINFER_AVAILABLE = True
-except ImportError:
+except (ImportError, RuntimeError):
FLASHINFER_AVAILABLE = False
FlashInferMetadata = None
import arctic_inference.vllm.model_runner as model_runner
from arctic_inference.common.swiftkv.configs import LlamaSwiftKVConfig
+import arctic_inference.envs as envs
logger = init_logger(__name__)
@@ -75,8 +78,6 @@ def __init__(
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
@@ -90,8 +91,6 @@ def __init__(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=bias,
@@ -147,16 +146,8 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
- # Support abacusai/Smaug-72B-v0.1 with attention_bias
- # Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = LlamaSwiftKVAttention(
@@ -165,8 +156,6 @@ def __init__(
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
@@ -343,10 +332,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config=self.quant_config,
)
self.layers = torch.nn.ModuleList([
- LlamaDecoderLayer(config=config,
- cache_config=vllm_config.cache_config,
- quant_config=vllm_config.quant_config,
- prefix=f"{prefix}.layers.{idx}")
+ LlamaDecoderLayer(vllm_config=vllm_config,
+ prefix=f"{prefix}.layers.{idx}",
+ config=config,)
for idx in range(config.num_key_value_layers)
])
with model_runner.set_shift_parallel_mode(True):
@@ -367,12 +355,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._init_prefill_runner(vllm_config)
self._init_decode_runner(vllm_config)
- from arctic_inference.py_custom_ops import try_load_torch_library
- self.use_custom_ops = True if try_load_torch_library() else False
+ from arctic_inference.py_custom_ops import (try_load_torch_library,
+ try_load_jit_library)
+
+ self.use_custom_ops = try_load_torch_library() or try_load_jit_library()
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
def _init_prefill_runner(self, vllm_config: VllmConfig):
vllm_config.compilation_config = copy.copy(
vllm_config.compilation_config)
@@ -609,7 +603,7 @@ def swiftkv_select(
kv_cache = attn.kv_cache[forward_context.virtual_engine]
if kv_cache.numel():
# different cache layouts
- if isinstance(attn_metadata, FlashInferMetadata):
+ if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
# FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size]
key_caches.append(kv_cache[:, 0])
value_caches.append(kv_cache[:, 1])
@@ -637,7 +631,7 @@ def swiftkv_select(
attn = layer.self_attn.attn
kv_cache = attn.kv_cache[forward_context.virtual_engine]
if kv_cache.numel():
- if isinstance(attn_metadata, FlashInferMetadata):
+ if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
# FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size]
k_cache, v_cache = kv_cache.unbind(1)
else:
@@ -658,7 +652,7 @@ def swiftkv_select(
logits_indices = attn_metadata.swiftkv_logits_indices
num_surviving_tokens = logits_indices.numel()
- if isinstance(attn_metadata, FlashInferMetadata):
+ if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
# Handle FlashInfer metadata
final_logits_indices = self._fix_flashinfer_metadata(attn_metadata, logits_indices, num_surviving_tokens)
else:
@@ -702,6 +696,28 @@ def forward(
k_states,
v_states))
+ # When swiftkv_select filters tokens (mixed prefill-decode batches),
+ # the decode runner processes fewer tokens than the original batch.
+ # Piecewise CUDA graphs captured for the original batch size cannot
+ # be replayed with modified attention metadata (stale FA3 scheduler
+ # metadata, changed query_start_loc, etc.), so we fall back to eager
+ # compiled execution for the decode runner on these batches.
+ # For decode-only batches all tokens survive, so CUDA graphs are
+ # used normally -- preserving decode throughput.
+ fwd_ctx = get_forward_context()
+ saved_batch_descriptor = fwd_ctx.batch_descriptor
+ saved_cudagraph_mode = fwd_ctx.cudagraph_runtime_mode
+ decode_num_tokens = hidden_states.shape[0]
+ if (saved_batch_descriptor is not None
+ and saved_batch_descriptor.num_tokens != decode_num_tokens):
+ fwd_ctx.batch_descriptor = BatchDescriptor(
+ num_tokens=decode_num_tokens,
+ num_reqs=saved_batch_descriptor.num_reqs,
+ uniform=saved_batch_descriptor.uniform,
+ has_lora=saved_batch_descriptor.has_lora,
+ )
+ fwd_ctx.cudagraph_runtime_mode = CUDAGraphMode.NONE
+
with model_runner.set_shift_parallel_mode(True):
hidden_states = self.decode_runner(
hidden_states,
@@ -711,12 +727,15 @@ def forward(
v_states,
)
+ fwd_ctx.batch_descriptor = saved_batch_descriptor
+ fwd_ctx.cudagraph_runtime_mode = saved_cudagraph_mode
+
attn_metadata = get_attn_metadata_for_swiftkv()
if attn_metadata is not None:
logits_indices = attn_metadata.swiftkv_logits_indices
batch_size = logits_indices.numel()
- if isinstance(attn_metadata, FlashInferMetadata):
+ if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
inverse_sort_indices = attn_metadata.swiftkv_inverse_sort_indices
orig_hidden_states[logits_indices] = hidden_states[inverse_sort_indices][:batch_size]
else:
@@ -839,6 +858,9 @@ def _init_model(self,
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
def forward(
self,
input_ids: torch.Tensor,
@@ -853,10 +875,8 @@ def forward(
def compute_logits(
self,
hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
+ logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
diff --git a/arctic_inference/vllm/ulysses.py b/arctic_inference/vllm/ulysses.py
index 23210a015..7de12536a 100644
--- a/arctic_inference/vllm/ulysses.py
+++ b/arctic_inference/vllm/ulysses.py
@@ -16,44 +16,54 @@
import threading
import weakref
from contextlib import contextmanager
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, Any
+from concurrent.futures import Future
+from collections import deque
+from collections.abc import Callable
+from typing import Optional, cast
+import time
import torch
import vllm.distributed.parallel_state as parallel_state
import vllm.envs as envs
from vllm.attention.layer import Attention
-from vllm.config import ModelConfig, ParallelConfig
+from vllm.config import ModelConfig, ParallelConfig, CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.parallel_state import (init_model_parallel_group,
get_world_group,
destroy_model_parallel,
destroy_distributed_environment)
-from vllm.executor.multiproc_worker_utils import (
+from vllm.v1.executor.multiproc_executor import (
set_multiprocessing_worker_envs)
-from vllm.utils import get_distributed_init_method, get_open_port
+from vllm.utils.network_utils import get_distributed_init_method, get_open_port, get_loopback_ip
+from vllm.utils.system_utils import get_mp_context
from vllm.v1.executor.abstract import FailureCallback
from vllm.v1.executor.multiproc_executor import (MultiprocExecutor, WorkerProc,
- UnreadyWorkerProcHandle)
-from vllm.platforms import current_platform
-from vllm.utils import resolve_obj_by_qualname
-from vllm.compilation.backends import PiecewiseCompileInterpreter
-from vllm.model_executor.layers.fused_moe import FusedMoE
+ UnreadyWorkerProcHandle,
+ FutureWrapper)
+from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
+from vllm.config.compilation import CompilationConfig
+from vllm.v1.engine.core import EngineCore, EngineCoreOutputs
+from vllm.v1.outputs import ModelRunnerOutput
+
from arctic_inference.patching import ArcticPatch
+# global variable to hack compilation config
+_ulysses_sp_size = 1
def apply_shift_parallel_patches():
- UlyssesModelConfigPatch.apply_patch()
- UlyssesParallelStatePatch.apply_patch()
- UlyssesWorkerProcPatch.apply_patch()
- UlyssesMultiprocExecutorPatch.apply_patch()
- UlyssesAttentionPatch.apply_patch()
- PiecewiseCompileInterpreterPatch.apply_patch()
- UlyssesFusedMoEPatch.apply_patch()
+ UlyssesModelConfig.apply_patch()
+ UlyssesParallelState.apply_patch()
+ UlyssesWorkerProc.apply_patch()
+ UlyssesMultiprocExecutor.apply_patch()
+ UlyssesAttention.apply_patch()
+ UlyssesCudagraphDispatcher.apply_patch()
+ UlyssesCompilationConfig.apply_patch()
+ UlyssesVllmConfig.apply_patch()
+ UlyssesEngineCore.apply_patch()
-class UlyssesModelConfigPatch(ArcticPatch[ModelConfig]):
+class UlyssesModelConfig(ArcticPatch[ModelConfig]):
_orig_get_num_kv_heads = ModelConfig.get_num_kv_heads
_orig_get_num_attention_heads = ModelConfig.get_num_attention_heads
@@ -74,7 +84,8 @@ def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
if (self.hf_text_config.model_type == "deepseek_mtp"
- or self.hf_config.model_type == "mimo_mtp"):
+ or self.hf_config.model_type == "mimo_mtp"
+ or self.hf_config.model_type == "glm4_moe_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
@@ -90,54 +101,27 @@ def get_layers_start_end_indices(
return start, end
-class UlyssesParallelStatePatch(ArcticPatch[parallel_state]):
+class UlyssesParallelState(ArcticPatch[parallel_state]):
_SP = None
_SP_TP = None
- _SP_AA = None
- _SP_AG = None
- # Rationale for SP_AA and SP_AG groups:
- # When num_kv_heads > SP, the kv heads are distributed and replicated as in TP.
- # To implement the logic, the distributed kv heads are exchanged with a local
- # all-to-all within SP_AA group followed by an local all-gather within SP_AG
- # group. The SP_AA and SP_AG groups partitions the SP group into two orthogonal
- # sub-groups and will not be initialized if max(1, num_kv_heads / TP) < SP.
- # See the figure in PR #126 https://github.com/snowflakedb/ArcticInference/pull/126
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
+ prefill_context_model_parallel_size: int = 1,
+ decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
- """
- Initialize model parallel groups.
-
- Arguments:
- tensor_model_parallel_size: number of GPUs used for tensor model
- parallelism.
- pipeline_model_parallel_size: number of GPUs used for pipeline model
- parallelism.
-
- Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
- use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
- the model pipeline. The present function will
- create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
- 4 tensor model-parallel groups:
- [g0, g1], [g2, g3], [g4, g5], [g6, g7]
- 2 pipeline model-parallel groups:
- [g0, g2, g4, g6], [g1, g3, g5, g7]
- Note that for efficiency, the caller should make sure adjacent ranks
- are on the same DGX box. For example if we are using 2 DGX-1 boxes
- with a total of 16 GPUs, rank 0 to 7 belong to the first box and
- ranks 8 to 15 belong to the second box.
- """
- from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP
- # Get world size and rank. Ensure some consistencies.
+
+ from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP, _DCP, _PCP
+
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
- get_world_group().device_group)
+ get_world_group().device_group
+ )
data_parallel_size = 1
from vllm.config import get_current_vllm_config
@@ -145,154 +129,180 @@ def initialize_model_parallel(
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
- sequence_parallel_size = \
- config.parallel_config.ulysses_sequence_parallel_size
-
- # the layout order is: ExternalDP x DP x PP x SP x TP
- # ExternalDP is the data parallel group that is not part of the model,
- # every dp rank can generate independently (in verl integration).
- # DP is the data parallel group that is part of the model,
- # all the ranks in the same DP group should generate simultaneously,
- # i.e. the `generate` call in the same DP group should be called together,
- # otherwise it will cause deadlock.
- # to get group_ranks for each dimension, transpose that dimension to the
- # last dimension, then reshape to 2D, then unbind the last dimension
- all_ranks = torch.arange(world_size).reshape(
- -1, data_parallel_size, pipeline_model_parallel_size,
- sequence_parallel_size, tensor_model_parallel_size) # noqa
+ sequence_parallel_size = config.parallel_config.ulysses_sequence_parallel_size
+
+ # vLLM types allow None, but group building needs an int
+ if decode_context_model_parallel_size is None:
+ # treat "no DCP" as DCP==TP (common interpretation)
+ decode_context_model_parallel_size = tensor_model_parallel_size
- # Build the tensor model-parallel groups.
- assert _TP is None, ("tensor model parallel group is already initialized")
+ # Layout order (extended from vLLM's): ExternalDP x DP x PP x PCP x SP x TP
+ all_ranks = torch.arange(world_size).reshape(
+ -1,
+ data_parallel_size,
+ pipeline_model_parallel_size,
+ prefill_context_model_parallel_size,
+ sequence_parallel_size,
+ tensor_model_parallel_size,
+ )
+
+ assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
TP_group_ranks = group_ranks
- # message queue broadcaster is only used in tensor model parallel group
- _TP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- use_message_queue_broadcaster=True,
- group_name="tp")
-
- # Build the pipeline model-parallel groups.
- assert _PP is None, (
- "pipeline model parallel group is already initialized")
- group_ranks = all_ranks.transpose(2, 4).reshape(
- -1, pipeline_model_parallel_size).unbind(0)
+ _TP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ use_message_queue_broadcaster=True,
+ group_name="tp",
+ )
+
+ assert _DCP is None, "decode context model parallel group is already initialized"
+ group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
+ group_ranks = [x.tolist() for x in group_ranks]
+ DCP_group_ranks = group_ranks
+ _DCP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ use_message_queue_broadcaster=True,
+ group_name="dcp",
+ )
+
+ assert _PCP is None, "prefill context parallel group is already initialized"
+ group_ranks = (
+ all_ranks.transpose(3, 5)
+ .reshape(-1, prefill_context_model_parallel_size)
+ .unbind(0)
+ )
+ group_ranks = [x.tolist() for x in group_ranks]
+ PCP_group_ranks = group_ranks
+ _PCP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="pcp",
+ )
+
+ assert _PP is None, "pipeline model parallel group is already initialized"
+ group_ranks = (
+ all_ranks.transpose(2, 5)
+ .reshape(-1, pipeline_model_parallel_size)
+ .unbind(0)
+ )
group_ranks = [x.tolist() for x in group_ranks]
PP_group_ranks = group_ranks
- _PP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="pp")
-
- assert _DP is None, ("data parallel group is already initialized")
- group_ranks = all_ranks.transpose(1,
- 4).reshape(-1,
- data_parallel_size).unbind(0)
+ _PP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="pp",
+ )
+
+ assert _DP is None, "data parallel group is already initialized"
+ group_ranks = (
+ all_ranks.transpose(1, 5)
+ .reshape(-1, data_parallel_size)
+ .unbind(0)
+ )
group_ranks = [x.tolist() for x in group_ranks]
DP_group_ranks = group_ranks
- _DP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="dp")
-
- assert _EP is None, ("expert parallel group is already initialized")
- group_ranks = all_ranks.transpose(1, 3).reshape(
- -1, data_parallel_size * tensor_model_parallel_size).unbind(0)
+ _DP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="dp",
+ )
+
+ assert _EP is None, "expert parallel group is already initialized"
+ group_ranks = (
+ all_ranks.permute(0, 4, 2, 1, 3, 5) # ExternalDP, SP, PP, DP, PCP, TP
+ .reshape(-1, data_parallel_size * prefill_context_model_parallel_size * tensor_model_parallel_size)
+ .unbind(0)
+ )
group_ranks = [x.tolist() for x in group_ranks]
EP_group_ranks = group_ranks
- _EP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="ep")
-
- # Build the sequence parallel groups.
- assert parallel_state._SP is None, (
- "sequence parallel group is already initialized")
- group_ranks = all_ranks.transpose(3, 4).reshape(
- -1, sequence_parallel_size).unbind(0)
+ _EP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="ep",
+ )
+
+ assert parallel_state._SP is None, "sequence parallel group is already initialized"
+ group_ranks = (
+ all_ranks.transpose(4, 5)
+ .reshape(-1, sequence_parallel_size)
+ .unbind(0)
+ )
group_ranks = [x.tolist() for x in group_ranks]
SP_group_ranks = group_ranks
- _SP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="sp")
-
- # Build full-TP groups for ShiftParallel
- shift_parallel_size = (tensor_model_parallel_size *
- sequence_parallel_size)
- assert parallel_state._SP_TP is None, (
- "full-TP group is already initialized")
- # transpose(3, 4) for obtaining the correct attn head order
- group_ranks = all_ranks.transpose(3, 4).reshape(
- -1, shift_parallel_size).unbind(0)
+ _SP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="sp",
+ )
+
+ shift_parallel_size = tensor_model_parallel_size * sequence_parallel_size
+ assert parallel_state._SP_TP is None, "full-TP group is already initialized"
+ group_ranks = (
+ all_ranks.transpose(4, 5) # keep same head-order trick as your old transpose(3,4)
+ .reshape(-1, shift_parallel_size)
+ .unbind(0)
+ )
group_ranks = [x.tolist() for x in group_ranks]
SP_TP_group_ranks = group_ranks
- _SP_TP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="sp_tp")
+ _SP_TP = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="sp_tp",
+ )
parallel_state.logger.info(
"rank %s in world size %s is assigned as DP rank %s, PP rank %s, "
- "TP rank %s, EP rank %s, SP rank %s, SP_TP rank %s", rank,
- world_size, _DP.rank_in_group, _PP.rank_in_group,
- _TP.rank_in_group, _EP.rank_in_group, _SP.rank_in_group,
- _SP_TP.rank_in_group)
+ "PCP rank %s, TP rank %s, DCP rank %s, EP rank %s, SP rank %s, SP_TP rank %s",
+ rank,
+ world_size,
+ _DP.rank_in_group,
+ _PP.rank_in_group,
+ _PCP.rank_in_group,
+ _TP.rank_in_group,
+ _DCP.rank_in_group,
+ _EP.rank_in_group,
+ _SP.rank_in_group,
+ _SP_TP.rank_in_group,
+ )
parallel_state._TP = _TP
+ parallel_state._DCP = _DCP
+ parallel_state._PCP = _PCP
parallel_state._PP = _PP
+ parallel_state._DP = _DP
+ parallel_state._EP = _EP
parallel_state._SP = _SP
parallel_state._SP_TP = _SP_TP
- parallel_state._DP = _DP
-
- # check if SP requires kv replication
- num_kv_heads = config.model_config._orig_get_num_kv_heads(config.parallel_config)
- if num_kv_heads < sequence_parallel_size:
-
- # divide SP group into two orthogonal sub-groups:
- sp_aa_size = num_kv_heads
- sp_ag_size = sequence_parallel_size // num_kv_heads
- all_ranks_ = torch.arange(world_size).reshape(
- -1, data_parallel_size, pipeline_model_parallel_size,
- sp_aa_size, sp_ag_size, tensor_model_parallel_size)
-
- group_ranks = all_ranks_.transpose(3, 5).reshape(
- -1, sp_aa_size).unbind(0)
- group_ranks = [x.tolist() for x in group_ranks]
- SP_AA_group_ranks = group_ranks
- # SP_AA group is used for all-to-all communication of kv heads
- _SP_AA = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="sp_aa")
-
- group_ranks = all_ranks_.transpose(4, 5).reshape(
- -1, sp_ag_size).unbind(0)
- group_ranks = [x.tolist() for x in group_ranks]
- SP_AG_group_ranks = group_ranks
- # SP_AG group is used for all-gather communication of kv heads
- _SP_AG = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="sp_ag")
-
- parallel_state._SP_AA = _SP_AA
- parallel_state._SP_AG = _SP_AG
if get_world_group().local_rank == 0:
parallel_state.logger.info(
- f"UlyssesParallelStatePatch initialized:\n"
- f" PP {_PP.world_size} ranks {PP_group_ranks}\n"
- f" TP {_TP.world_size} ranks {TP_group_ranks}\n"
- f" SP {_SP.world_size} ranks {SP_group_ranks}\n"
- f" DP {_DP.world_size} ranks {DP_group_ranks}\n"
- f" EP {_EP.world_size} ranks {EP_group_ranks}\n"
- f" SP_TP {_SP_TP.world_size} ranks {SP_TP_group_ranks}")
- if num_kv_heads < sequence_parallel_size:
- parallel_state.logger.info(
- f" SP_AA {parallel_state._SP_AA.world_size} ranks {SP_AA_group_ranks}\n"
- f" SP_AG {parallel_state._SP_AG.world_size} ranks {SP_AG_group_ranks}\n")
+ "UlyssesParallelState initialized:\n"
+ f" PP {_PP.world_size} ranks {PP_group_ranks}\n"
+ f" TP {_TP.world_size} ranks {TP_group_ranks}\n"
+ f" DCP {_DCP.world_size} ranks {DCP_group_ranks}\n"
+ f" PCP {_PCP.world_size} ranks {PCP_group_ranks}\n"
+ f" SP {_SP.world_size} ranks {SP_group_ranks}\n"
+ f" DP {_DP.world_size} ranks {DP_group_ranks}\n"
+ f" EP {_EP.world_size} ranks {EP_group_ranks}\n"
+ f" SP_TP {_SP_TP.world_size} ranks {SP_TP_group_ranks}"
+ )
+
+ num_kv_heads = config.model_config._orig_get_num_kv_heads(config.parallel_config)
+ if get_world_group().local_rank == 0 and num_kv_heads < sequence_parallel_size:
+ parallel_state.logger.info(
+ f"KV cache is replicated by factor {sequence_parallel_size // num_kv_heads}"
+ )
@contextmanager
def graph_capture(device: torch.device):
@@ -316,22 +326,16 @@ def graph_capture(device: torch.device):
yield context
-class UlyssesWorkerProcPatch(ArcticPatch[WorkerProc]):
+class UlyssesWorkerProc(ArcticPatch[WorkerProc]):
def destroy_model_parallel(self):
- from vllm.distributed.parallel_state import _SP, _SP_TP, _SP_AA, _SP_AG
+ from vllm.distributed.parallel_state import _SP, _SP_TP
if _SP:
_SP.destroy()
_SP = None
if _SP_TP:
_SP_TP.destroy()
_SP_TP = None
- if _SP_AA:
- _SP_AA.destroy()
- _SP_AA = None
- if _SP_AG:
- _SP_AG.destroy()
- _SP_AG = None
def shutdown(self):
self.rpc_broadcast_mq = None
@@ -342,7 +346,7 @@ def shutdown(self):
destroy_distributed_environment()
-class UlyssesMultiprocExecutorPatch(ArcticPatch[MultiprocExecutor]):
+class UlyssesMultiprocExecutor(ArcticPatch[MultiprocExecutor]):
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
@@ -350,81 +354,125 @@ def _init_executor(self) -> None:
self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
self.shutdown_event = threading.Event()
- self.failure_callback: Optional[FailureCallback] = None
- self.io_thread_pool: Optional[ThreadPoolExecutor] = None
+ self.failure_callback: FailureCallback | None = None
self.world_size = self.parallel_config.world_size
- tensor_parallel_size = self.parallel_config.tensor_parallel_size
- pp_parallel_size = self.parallel_config.pipeline_parallel_size
- sp_parallel_size = self.parallel_config.ulysses_sequence_parallel_size
- assert (self.world_size ==
- tensor_parallel_size * pp_parallel_size * sp_parallel_size), (
+ assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
+ f"global world_size ({self.parallel_config.world_size}) must be "
+ f"divisible by nnodes_within_dp "
+ f"({self.parallel_config.nnodes_within_dp}). "
+ )
+ self.local_world_size = self.parallel_config.local_world_size
+ tp_size = self.parallel_config.tensor_parallel_size
+ pp_size = self.parallel_config.pipeline_parallel_size
+ pcp_size = self.parallel_config.prefill_context_parallel_size
+ sp_size = self.parallel_config.ulysses_sequence_parallel_size
+
+ assert self.world_size == tp_size * pp_size * pcp_size * sp_size, (
f"world_size ({self.world_size}) must be equal to the "
- f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
- f"_parallel_size ({pp_parallel_size}) x ulysses_sequence_parallel"
- f"_size ({sp_parallel_size}).")
+ f"tensor_parallel_size ({tp_size}) x pipeline"
+ f"_parallel_size ({pp_size}) x prefill_context"
+ f"_parallel_size ({pcp_size}) x ulysses_sequence_parallel"
+ f"_size ({sp_size})."
+ )
- # Set multiprocessing envs that are common to V0 and V1
- set_multiprocessing_worker_envs(self.parallel_config)
+ # Set multiprocessing envs
+ set_multiprocessing_worker_envs()
- # Multiprocessing-based executor does not support multi-node setting.
- # Since it only works for single node, we can use the loopback address
- # 127.0.0.1 for communication.
+ # use the loopback address get_loopback_ip() for communication.
distributed_init_method = get_distributed_init_method(
- "127.0.0.1", get_open_port())
-
+ get_loopback_ip(), get_open_port()
+ )
+ self.rpc_broadcast_mq: MessageQueue | None = None
+ scheduler_output_handle: Handle | None = None
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
- max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
- self.rpc_broadcast_mq = MessageQueue(self.world_size,
- self.world_size,
- max_chunk_bytes=max_chunk_bytes)
- scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
-
+ if self.parallel_config.node_rank_within_dp == 0:
+ # For leader node within each dp rank,
+ # each dp will have its own leader multiproc executor.
+ max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
+ self.rpc_broadcast_mq = MessageQueue(
+ self.world_size,
+ self.local_world_size,
+ max_chunk_bytes=max_chunk_bytes,
+ connect_ip=self.parallel_config.master_addr,
+ )
+ scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
+
# Create workers
+ # FIX: Removed duplicate initialization and local import that caused UnboundLocalError
+ context = get_mp_context()
+ shared_worker_lock = context.Lock()
unready_workers: list[UnreadyWorkerProcHandle] = []
+
success = False
try:
- for rank in range(self.world_size):
+ global_start_rank = (
+ self.local_world_size * self.parallel_config.node_rank_within_dp
+ )
+ for local_rank in range(self.local_world_size):
+ global_rank = global_start_rank + local_rank
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
- local_rank=rank,
- rank=rank,
+ local_rank=local_rank,
+ rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
- ))
+ shared_worker_lock=shared_worker_lock,
+ )
+ )
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
+
+ # Wait for all local workers to be ready.
self.workers = WorkerProc.wait_for_ready(unready_workers)
+ # Start background thread to monitor worker health if not in headless mode.
+ if self.monitor_workers:
+ self.start_worker_monitor()
+
+ self.response_mqs = []
+ # Only leader node have remote response mqs
+ if self.parallel_config.node_rank_within_dp == 0:
+ for rank in range(self.world_size):
+ if rank < self.local_world_size:
+ local_message_queue = self.workers[rank].worker_response_mq
+ assert local_message_queue is not None
+ self.response_mqs.append(local_message_queue)
+ else:
+ remote_message_queue = self.workers[0].peer_worker_response_mqs[
+ rank
+ ]
+ assert remote_message_queue is not None
+ self.response_mqs.append(remote_message_queue)
+
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
- self.rpc_broadcast_mq.wait_until_ready()
- for w in self.workers:
- w.worker_response_mq.wait_until_ready()
- self.start_worker_monitor()
+ # Wait for all input mqs to be ready.
+ if self.rpc_broadcast_mq is not None:
+ self.rpc_broadcast_mq.wait_until_ready()
+ # Wait for all remote response mqs to be ready.
+ for response_mq in self.response_mqs:
+ response_mq.wait_until_ready()
success = True
finally:
if not success:
# Clean up the worker procs if there was a failure.
- self._ensure_worker_termination(
- [w.proc for w in unready_workers])
+ # Close death_writers first to signal workers to exit
+ for uw in unready_workers:
+ if uw.death_writer is not None:
+ uw.death_writer.close()
+ self._ensure_worker_termination([uw.proc for uw in unready_workers])
- # For pipeline parallel, we use a thread pool for asynchronous
- # execute_model.
- if self.max_concurrent_batches > 1:
- # Note: must use only 1 IO thread to keep dequeue sequence
- # from the response queue
- self.io_thread_pool = ThreadPoolExecutor(
- max_workers=1, thread_name_prefix="mp_exec_io")
+ self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self.output_rank = self._get_output_rank()
-class UlyssesAttentionPatch(ArcticPatch[Attention]):
+class UlyssesAttention(ArcticPatch[Attention]):
_orig_init = Attention.__init__
_orig_forward = Attention.forward
@@ -438,17 +486,8 @@ def __init__(self, num_heads, *args, **kwargs):
num_kv_heads = kwargs["num_kv_heads"]
self.is_kv_replicated = True if num_kv_heads < self.sp_size else False
if self.is_kv_replicated:
+ self.replication_factor = self.sp_size // num_kv_heads
num_kv_heads = 1
- assert parallel_state._SP_AA is not None and parallel_state._SP_AG is not None, (
- "UlyssesAttentionPatch requires SP_AA and SP_AG groups to be initialized.")
- self.sp_aa_device_group = parallel_state._SP_AA.device_group
- self.sp_ag_device_group = parallel_state._SP_AG.device_group
- self.sp_aa_size = parallel_state._SP_AA.world_size
- self.sp_ag_size = parallel_state._SP_AG.world_size
- # this reorders the all-gathered sequence
- self.order = [j * self.sp_aa_size + i
- for i in range(self.sp_aa_size)
- for j in range(self.sp_ag_size)]
else:
num_kv_heads //= self.sp_size
kwargs["num_kv_heads"] = num_kv_heads
@@ -459,52 +498,29 @@ def forward(self, query, key, value, **kwargs):
if self.sp_size == 1 or is_shift_parallel_mode():
return self._orig_forward(query, key, value, **kwargs)
+ # prepare
+ q = query.view(-1, self.sp_size, self.num_heads * self.head_size)
if self.is_kv_replicated:
- # Ulysses all-to-all 1/2 (query)
- q = query.view(-1,
- self.sp_size, self.num_heads * self.head_size).transpose(
- 0, 1).reshape(-1,
- self.num_heads * self.head_size)
- q_ = torch.empty_like(q)
- torch.distributed.all_to_all_single(q_, q, group=self.sp_device_group)
- # Ulysses pack (key, value)
- kv = torch.cat((key.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size),
- value.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size)),
- dim=-1).transpose(0, 1).reshape(
- -1, 2 * self.num_kv_heads * self.head_size)
- # Ulysses all-to-all (key, value)
- kv_part = torch.empty_like(kv)
- torch.distributed.all_to_all_single(kv_part, kv, group=self.sp_aa_device_group)
- # Ulysses all-gather (key, value)
- kv_ = torch.empty(q_.shape[0],
- 2 * self.num_kv_heads * self.head_size,
- dtype=query.dtype,
- device=query.device)
- torch.distributed.all_gather_into_tensor(kv_,
- kv_part,
- group=self.sp_ag_device_group)
- # reorder
- kv_chunk = kv_.chunk(self.sp_size)
- kv_ordered = torch.cat([kv_chunk[i] for i in self.order])
- # unpack (key, value)
- k_, v_ = kv_ordered.split([self.num_kv_heads * self.head_size] * 2, dim=-1)
+ k = key.view(-1, self.sp_size // self.replication_factor, self.head_size).repeat_interleave(self.replication_factor, dim=1)
+ v = value.view(-1, self.sp_size // self.replication_factor, self.head_size).repeat_interleave(self.replication_factor, dim=1)
else:
- # pack
- qkv = (torch.cat(
- (query.view(-1, self.sp_size, self.num_heads * self.head_size),
- key.view(-1, self.sp_size, self.num_kv_heads * self.head_size),
- value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)),
- dim=-1)
- .transpose(0, 1)
- .reshape(-1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size))
- # Ulysses all-to-all 1/2
- qkv_ = torch.empty_like(qkv)
- torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group)
- # unpack
- q_, k_, v_ = qkv_.split([
- self.num_heads * self.head_size, self.num_kv_heads *
- self.head_size, self.num_kv_heads * self.head_size
- ], dim=-1)
+ k = key.view(-1, self.sp_size, self.num_kv_heads * self.head_size)
+ v = value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)
+
+ # pack
+ qkv = torch.cat((q, k, v), dim=-1).transpose(0, 1).reshape(
+ -1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size)
+
+ # Ulysses all-to-all 1/2
+ qkv_ = torch.empty_like(qkv)
+ torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group)
+
+ # unpack
+ q_, k_, v_ = qkv_.split([
+ self.num_heads * self.head_size,
+ self.num_kv_heads * self.head_size,
+ self.num_kv_heads * self.head_size
+ ], dim=-1)
# original attention
c_ = self._orig_forward(q_, k_, v_, **kwargs)
@@ -519,81 +535,288 @@ def forward(self, query, key, value, **kwargs):
return output
-class PiecewiseCompileInterpreterPatch(ArcticPatch[PiecewiseCompileInterpreter]):
-
- # find the symbolic shape of the subgraph
- def find_symbolic_shape(self, args: tuple[torch.fx.node.Argument,
- ...]) -> torch.SymInt:
- symbols = set()
- for x in args:
- if isinstance(x, torch._subclasses.fake_tensor.FakeTensor):
- for dim in x.shape:
- if isinstance(dim, torch.SymInt):
- symbols.update(dim.node.expr.free_symbols)
- assert len(symbols) == 1, (
- f"Expected exactly one symbolic shape, but found {len(symbols)}: {symbols}")
- return list(symbols)[0]
-
- def call_module(self, target: torch.fx.node.Target,
- args: tuple[torch.fx.node.Argument,
- ...], kwargs: dict[str, Any]) -> Any:
- assert isinstance(target, str)
- # [Arctic Inference]
- # Since monkeypatching inherits the original class
- # through ArcticPatch class, we lose the access to the original class'
- # super() function. Instead of using super(), we directly invoke call_module
- # from the super class torch.fx.Interpreter of PiecewiseCompileInterpreter.
- # see - v0.9.0.1/compilation/backends.py#L241
- output = torch.fx.Interpreter.call_module(self, target, args, kwargs)
-
- if target in self.compile_submod_names:
- index = self.compile_submod_names.index(target)
- submod = self.fetch_attr(target)
- # [Arctic Inference]
- # Compiler may create subgraphs with certain symbolic
- # integer values that violates vllm's assumption here:
- # - v0.9.0.1/compilation/base_piecewise_backend.py#L64
- # The index of the significant symbol determines the runtime shape here:
- # - v0.9.0.1/compilation/cuda_piecewise_backend.py#L112
- # The fix is relaxing vllm's original assumption that there is only a
- # single symbolic that determines the shape.We then find the matching
- # symbol indices.
- sym_shape = self.find_symbolic_shape(args)
- sym_shape_indices = []
- for i, x in enumerate(args):
- if isinstance(x, torch.SymInt):
- if sym_shape == x:
- sym_shape_indices.append(i)
-
- global compilation_start_time
- compiled_graph_for_general_shape = self.vllm_backend.\
- compiler_manager.compile(
- submod,
- args,
- self.compilation_config.inductor_compile_config,
- self.compilation_config,
- graph_index=index,
- num_graphs=len(self.compile_submod_names),
- runtime_shape=None)
-
- piecewise_backend = resolve_obj_by_qualname(
- current_platform.get_piecewise_backend_cls())
- self.module.__dict__[target] = piecewise_backend(
- submod, self.vllm_config, self.graph_pool, index,
- len(self.compile_submod_names), sym_shape_indices,
- compiled_graph_for_general_shape, self.vllm_backend)
-
- from vllm.compilation.counter import compilation_counter
- compilation_counter.num_piecewise_capturable_graphs_seen += 1
+class UlyssesCudagraphDispatcher(ArcticPatch[CudagraphDispatcher]):
+
+ _orig_initialize_cudagraph_keys = CudagraphDispatcher.initialize_cudagraph_keys
+
+ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
+ uniform_decode_query_len: int):
+ self._orig_initialize_cudagraph_keys(cudagraph_mode, uniform_decode_query_len)
+
+ # sp_group = getattr(parallel_state, "_SP", None)
+ # sp_size = sp_group.world_size if sp_group is not None else 1
+ # if sp_size <= 1:
+ # return
- return output
+ # if self.vllm_config.lora_config:
+ # if self.compilation_config.cudagraph_specialize_lora:
+ # lora_cases = [True, False]
+ # else:
+ # lora_cases = [True]
+ # else:
+ # lora_cases = [False]
+
+ # if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
+ # for bs, has_lora in product(
+ # self.compilation_config.cudagraph_capture_sizes, lora_cases
+ # ):
+ # bd = self._create_padded_batch_descriptor(
+ # num_tokens=bs, # * sp_size,
+ # uniform_decode=False,
+ # has_lora=has_lora,
+ # ).relax_for_mixed_batch_cudagraphs()
+
+ # self.add_cudagraph_key(cudagraph_mode.mixed_mode(), bd)
+
+ # if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
+ # and cudagraph_mode.separate_routine()):
+ # max_num_tokens = (
+ # uniform_decode_query_len
+ # * self.vllm_config.scheduler_config.max_num_seqs
+ # )
+ # cudagraph_capture_sizes_for_decode = [
+ # x for x in self.compilation_config.cudagraph_capture_sizes
+ # if uniform_decode_query_len <= x <= max_num_tokens
+ # ]
+ # for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
+ # bd = self._create_padded_batch_descriptor(
+ # num_tokens=bs, # * sp_size,
+ # uniform_decode=True,
+ # has_lora=has_lora,
+ # )
+ # self.add_cudagraph_key(CUDAGraphMode.FULL, bd)
+
+
+class UlyssesCompilationConfig(ArcticPatch[CompilationConfig]):
+
+ _orig_post_init_cudagraph_sizes = CompilationConfig.post_init_cudagraph_sizes
+
+ def post_init_cudagraph_sizes(self) -> None:
+
+# # print(f"Before post_init_cudagraph_sizes: max_cudagraph_capture_size={self.max_cudagraph_capture_size}, cudagraph_capture_sizes={self.cudagraph_capture_sizes}")
+
+# # Access the module-level variable set during engine config creation
+# sp_size = _ulysses_sp_size
+
+# # scale sizes by Ulysses sequence parallel size
+# self.max_cudagraph_capture_size *= sp_size
+# self.cudagraph_capture_sizes = [size * sp_size for size in self.cudagraph_capture_sizes]
+
+# # print(f"After scaling for SP size {sp_size}: max_cudagraph_capture_size={self.max_cudagraph_capture_size}, cudagraph_capture_sizes={self.cudagraph_capture_sizes}")
+
+ self._orig_post_init_cudagraph_sizes()
+
+# # revert back to original shapes
+# self.max_cudagraph_capture_size //= sp_size
+# self.cudagraph_capture_sizes = [size // sp_size for size in self.cudagraph_capture_sizes]
+
+# # print(f"self.bs_to_padded_graph_size {self.bs_to_padded_graph_size}")
+
+# # import traceback
+# # traceback.print_stack()
+
+class UlyssesVllmConfig(ArcticPatch[VllmConfig]):
+
+ _orig_set_cudagraph_sizes = VllmConfig._set_cudagraph_sizes
+
+ @staticmethod
+ def _generate_capture_sizes(max_size: int) -> list[int]:
+ sizes = [i for i in [1, 2, 4] if i <= max_size]
+ if max_size >= 8:
+ sizes += list(range(8, min(max_size + 1, 256), 8))
+ if max_size >= 256:
+ sizes += list(range(256, min(max_size + 1, 512), 16))
+ if max_size >= 512:
+ sizes += list(range(512, max_size + 1, 32))
+ return sizes
+
+ @staticmethod
+ def _build_bs_to_padded(capture_sizes: list[int],
+ max_capture_size: int) -> list[int]:
+ table = [0] * (max_capture_size + 1)
+ for end, start in zip(
+ capture_sizes + [max_capture_size + 1],
+ [0] + capture_sizes,
+ ):
+ for bs in range(start, end):
+ table[bs] = start if bs == start else end
+ return table
+
+ def _set_cudagraph_sizes(self):
+ sp_size = _ulysses_sp_size
+
+ max_cudagraph_capture_size = self.compilation_config.max_cudagraph_capture_size
+ cudagraph_capture_sizes = self.compilation_config.cudagraph_capture_sizes
+
+ if cudagraph_capture_sizes is None:
+ if max_cudagraph_capture_size is None:
+ max_cudagraph_capture_size = 512
+ # Canonical (unscaled) sizes: [1, 2, 4, 8, ..., 512]
+ canonical_sizes = self._generate_capture_sizes(
+ max_cudagraph_capture_size)
+
+ # Base model (Ulysses): scale by sp_size
+ self.compilation_config.cudagraph_capture_sizes = [
+ s * sp_size for s in canonical_sizes
+ ]
+ self.compilation_config.max_cudagraph_capture_size = (
+ max_cudagraph_capture_size * sp_size
+ )
+
+ # Shift model: scale by 1 (use canonical sizes as-is)
+ shift_sizes = list(canonical_sizes)
+ shift_max = max_cudagraph_capture_size
+ else:
+ shift_sizes = list(cudagraph_capture_sizes)
+ shift_max = max(shift_sizes) if shift_sizes else 0
+ self._shift_cudagraph_capture_sizes = shift_sizes
+ self._shift_max_cudagraph_capture_size = shift_max
+ self._shift_bs_to_padded_graph_size = self._build_bs_to_padded(
+ shift_sizes, shift_max) if shift_sizes else []
-class UlyssesFusedMoEPatch(ArcticPatch[FusedMoE]):
+ print(
+ f"UlyssesVllmConfig: base max_cudagraph_capture_size="
+ f"{self.compilation_config.max_cudagraph_capture_size}, "
+ f"base sizes={self.compilation_config.cudagraph_capture_sizes}, "
+ f"shift max={shift_max}, shift sizes={shift_sizes}"
+ )
- def forward(self, hidden_states: torch.Tensor,
- router_logits: torch.Tensor):
- # directly call forward_impl to bypass custom opt
- # custom opt prevents using the shift model
- # we will expand this function to fuse SP with EP
- return self.forward_impl(hidden_states, router_logits)
+ self._orig_set_cudagraph_sizes()
+
+ def pad_for_cudagraph(self, batch_size: int) -> int:
+ from .model_runner import is_shift_parallel_mode
+ if is_shift_parallel_mode() and self._shift_bs_to_padded_graph_size:
+ return self._shift_bs_to_padded_graph_size[batch_size]
+ return self.compilation_config.bs_to_padded_graph_size[batch_size]
+
+
+class UlyssesEngineCore(ArcticPatch[EngineCore]):
+
+ iteration = 0
+
+ def step_with_batch_queue(
+ self,
+ ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
+ """Schedule and execute batches with the batch queue.
+ Note that if nothing to output in this step, None is returned.
+
+ The execution flow is as follows:
+ 1. Try to schedule a new batch if the batch queue is not full.
+ If a new batch is scheduled, directly return an empty engine core
+ output. In other words, fulfilling the batch queue has a higher priority
+ than getting model outputs.
+ 2. If there is no new scheduled batch, meaning that the batch queue
+ is full or no other requests can be scheduled, we block until the first
+ batch in the job queue is finished.
+ 3. Update the scheduler from the output.
+ """
+ batch_queue = self.batch_queue
+ assert batch_queue is not None
+
+ # Try to schedule a new batch if the batch queue is not full, but
+ # the scheduler may return an empty batch if all requests are scheduled.
+ # Note that this is not blocking.
+ assert len(batch_queue) < self.batch_queue_size
+
+ step_start_time = time.monotonic()
+
+ model_executed = False
+ deferred_scheduler_output = None
+ if self.scheduler.has_requests():
+ scheduler_output = self.scheduler.schedule()
+ exec_future = self.model_executor.execute_model(
+ scheduler_output, non_block=True
+ )
+ if not self.is_ec_producer:
+ model_executed = scheduler_output.total_num_scheduled_tokens > 0
+
+ if self.is_pooling_model or not model_executed:
+ # No sampling required (no requests scheduled).
+ future = cast(Future[ModelRunnerOutput], exec_future)
+ else:
+ if not scheduler_output.pending_structured_output_tokens:
+ # We aren't waiting for any tokens, get any grammar output
+ # and sample immediately.
+ grammar_output = self.scheduler.get_grammar_bitmask(
+ scheduler_output
+ )
+ future = self.model_executor.sample_tokens(
+ grammar_output, non_block=True
+ )
+ else:
+ # We need to defer sampling until we have processed the model output
+ # from the prior step.
+ deferred_scheduler_output = scheduler_output
+
+ if not deferred_scheduler_output:
+ # Add this step's future to the queue.
+ batch_queue.appendleft((future, scheduler_output, exec_future))
+ if (
+ model_executed
+ and len(batch_queue) < self.batch_queue_size
+ and not batch_queue[-1][0].done()
+ ):
+ # Don't block on next worker response unless the queue is full
+ # or there are no more requests to schedule.
+ return None, True
+
+ elif not batch_queue:
+ # Queue is empty. We should not reach here since this method should
+ # only be called when the scheduler contains requests or the queue
+ # is non-empty.
+ return None, False
+
+ # Block until the next result is available.
+ future, scheduler_output, exec_model_fut = batch_queue.pop()
+ with (
+ self.log_error_detail(scheduler_output),
+ self.log_iteration_details(scheduler_output),
+ ):
+ model_output = future.result()
+ if model_output is None:
+ # None from sample_tokens() implies that the original execute_model()
+ # call failed - raise that exception.
+ exec_model_fut.result()
+ raise RuntimeError("unexpected error")
+
+ # Before processing the model output, process any aborts that happened
+ # during the model execution.
+ self._process_aborts_queue()
+ engine_core_outputs = self.scheduler.update_from_output(
+ scheduler_output, model_output
+ )
+
+ # NOTE(nick): We can either handle the deferred tasks here or save
+ # in a field and do it immediately once step_with_batch_queue is
+ # re-called. The latter slightly favors TTFT over TPOT/throughput.
+ if deferred_scheduler_output:
+ # If we are doing speculative decoding with structured output,
+ # we need to get the draft token ids from the prior step before
+ # we can compute the grammar bitmask for the deferred request.
+ if self.use_spec_decode:
+ draft_token_ids = self.model_executor.take_draft_token_ids()
+ assert draft_token_ids is not None
+ # Update the draft token ids in the scheduler output to
+ # filter out the invalid spec tokens, which will be padded
+ # with -1 and skipped by the grammar bitmask computation.
+ self.scheduler.update_draft_token_ids_in_output(
+ draft_token_ids, deferred_scheduler_output
+ )
+ # We now have the tokens needed to compute the bitmask for the
+ # deferred request. Get the bitmask and call sample tokens.
+ grammar_output = self.scheduler.get_grammar_bitmask(
+ deferred_scheduler_output
+ )
+ future = self.model_executor.sample_tokens(grammar_output, non_block=True)
+ batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
+
+ total_time_ms = (time.monotonic() - step_start_time) * 1000
+
+ running, waiting = self.scheduler.get_request_counts()
+ scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ concurrency = len(scheduler_output.num_scheduled_tokens.keys())
+ # print(f"iteration {self.iteration}, running: {running}, waiting: {waiting}, scheduled tokens: {scheduled_tokens}, concurrency: {concurrency}, total_time_ms: {total_time_ms:.2f}")
+ self.iteration += 1
+
+ return engine_core_outputs, model_executed
\ No newline at end of file
diff --git a/benchmark/reproducibility/README.md b/benchmark/reproducibility/README.md
new file mode 100644
index 000000000..eaa2bcc80
--- /dev/null
+++ b/benchmark/reproducibility/README.md
@@ -0,0 +1,147 @@
+
+## Reproducibility
+
+This page is on reproducibility of the [Shift Parallelism](https://arxiv.org/pdf/2509.16495) paper. Please see the Artifact Appendix.
+
+## Instructions
+
+### Step 1: Making vLLM work
+
+1. Create a clean environment
+
+- If python 3.10 is already installed, create a clean virtual environment
+```console
+python3.10 -m venv myvenv
+source myvenv/bin/activate
+```
+
+- If not, use conda to create an environment with python 3.10
+```console
+conda create -n myenv python=3.10
+conda activate myenv
+```
+
+2. Install vLLM
+```console
+pip install vllm==v0.10.1
+```
+
+3. Download the models
+```console
+huggingface-cli download RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic --local-dir RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic
+huggingface-cli download Qwen/Qwen3-32B-FP8 --local-dir Qwen/Qwen3-32B-FP8
+```
+
+### Step 2: Install ArcticInference
+Checkout to the right commit for v0.10.1.
+```console
+git clone https://github.com/snowflakedb/ArcticInference.git
+cd ArcticInference
+git checkout d096fdf
+pip install .
+```
+
+Once installed, Arctic Inference automatically patches vLLM to use Arctic Inference with Shift Parallelism, and users can continue to use their familiar vLLM APIs and CLI.
+
+### Step 3: Vibe test
+`vibe_test.py`
+```python
+import vllm
+from vllm import LLM, SamplingParams
+
+vllm.plugins.load_general_plugins()
+
+llm = LLM(
+ model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic",
+ ulysses_sequence_parallel_size=8,
+ enable_shift_parallel=True,
+ shift_parallel_threshold=8,
+)
+
+conversation = [
+ {
+ "role": "user",
+ "content": "Write an essay about the importance of higher education.",
+ },
+]
+
+sampling_params = SamplingParams(temperature=0.0, max_tokens=800)
+
+outputs = llm.chat(conversation, sampling_params=sampling_params)
+
+print(outputs[0].outputs[0].text)
+```
+
+```console
+ARCTIC_INFERENCE_ENABLED=1 VLLM_DISABLE_COMPILE_CACHE=1 python vibe_test.py
+```
+We turn off compile cache for now to prevent complications.
+
+### Step 4: Patch vLLM bench
+
+This patch is necessary to run traces for running traces.
+
+Please replace your `vllm/bencmarks/serve.py` with `patches/serve.py` and `vllm/benchmarks/datasets.py` with `patches/datasets.py`. The vLLM path can be found by `pip show vllm`.
+
+### Step 5: Run Traces
+
+The traces are available: https://doi.org/10.5281/zenodo.18240909
+
+```console
+wget https://zenodo.org/records/18240909/files/AzureLLMInferenceTrace_code_15mins.jsonl
+wget https://zenodo.org/records/18240909/files/conversation_trace_15mins.jsonl
+```
+
+There are two traces and three parallelisms in full reproduction. For each case, the server and the client is run in separate terminals. For example,
+
+1. Azure LLM code (Figure 9)
+
+server:
+```console
+ARCTIC_INFERENCE_ENABLED=1 VLLM_DISABLE_COMPILE_CACHE=1 vllm serve RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic \
+ --disable-log-requests \
+ --no-enable-prefix-caching \
+ --ulysses-sequence-parallel-size 8 \
+ --enable-shift-parallel \
+ --max-num-batched-tokens 131072
+```
+
+client:
+```console
+vllm bench serve --model RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic --trace-dataset-path AzureLLMInferenceTrace_code_15mins.jsonl --ignore-eos --trace-output-path code_output.csv
+```
+
+2. Mooncake conversation (Figure 10)
+
+For the Mooncake dataset, we first increase the context length with YaRN: https://huggingface.co/Qwen/Qwen3-32B-FP8#processing-long-texts
+
+server:
+```console
+ARCTIC_INFERENCE_ENABLED=1 VLLM_DISABLE_COMPILE_CACHE=1 vllm serve Qwen/Qwen3-32B-FP8 \
+ --disable-log-requests \
+ --no-enable-prefix-caching \
+ --ulysses-sequence-parallel-size 8 \
+ --enable-shift-parallel \
+ --max-num-batched-tokens 131072
+```
+
+client:
+```console
+vllm bench serve --model Qwen/Qwen3-32B-FP8 --trace-dataset-path conversation_trace_15mins.jsonl --ignore-eos --trace-output-path conversation_output.csv
+```
+
+Ready-to-use `server.sh` and `client.sh` are included.
+
+The key results involves two figures (Figure 9—10). The breakdown of reproduction times are given below:
+
+Figure 9: DP (15 mins), TP (15 mins), SP (15 mins), Shift Parallel (15 mins)
+Figure 10: DP (~2.5 hrs), TP (~1 hr), SP (15 mins), Shift Parallel (15 mins)
+
+### Step 6: Plotting
+
+Install plotting library.
+```console
+pip install matplotlib
+```
+
+run `plot.py` with appropriate trace output path.
diff --git a/benchmark/reproducibility/client.sh b/benchmark/reproducibility/client.sh
new file mode 100644
index 000000000..7e54661b4
--- /dev/null
+++ b/benchmark/reproducibility/client.sh
@@ -0,0 +1,6 @@
+
+model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic"
+vllm bench serve --model $model --trace-dataset-path AzureLLMInferenceTrace_code_15mins.jsonl --ignore-eos --trace-output-path code_output.csv
+
+# model="Qwen/Qwen3-32B-FP8"
+# vllm bench serve --model $model --trace-dataset-path conversation_trace_15mins.jsonl --ignore-eos --trace-output-path conversation_output.csv
diff --git a/benchmark/reproducibility/patches_v0.10.1/datasets.py b/benchmark/reproducibility/patches_v0.10.1/datasets.py
new file mode 100644
index 000000000..f5812f064
--- /dev/null
+++ b/benchmark/reproducibility/patches_v0.10.1/datasets.py
@@ -0,0 +1,1675 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+This module defines a framework for sampling benchmark requests from various
+datasets. Each dataset subclass of BenchmarkDataset must implement sample
+generation. Supported dataset types include:
+ - ShareGPT
+ - Random (synthetic)
+ - Sonnet
+ - BurstGPT
+ - HuggingFace
+ - VisionArena
+"""
+import base64
+import io
+import json
+import logging
+import random
+from abc import ABC, abstractmethod
+from collections.abc import Mapping
+from dataclasses import dataclass
+from functools import cache
+from io import BytesIO
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+from transformers import PreTrainedTokenizerBase
+from typing_extensions import deprecated
+
+from vllm.lora.request import LoRARequest
+from vllm.lora.utils import get_adapter_absolute_path
+from vllm.multimodal import MultiModalDataDict
+from vllm.multimodal.image import convert_image_mode
+from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
+from vllm.utils import PlaceholderModule
+
+try:
+ from datasets import load_dataset
+except ImportError:
+ datasets = PlaceholderModule("datasets")
+ load_dataset = datasets.placeholder_attr("load_dataset")
+
+try:
+ import pandas as pd
+except ImportError:
+ pd = PlaceholderModule("pandas")
+
+try:
+ import librosa
+except ImportError:
+ librosa = PlaceholderModule("librosa")
+
+try:
+ from vllm.utils import FlexibleArgumentParser
+except ImportError:
+ from argparse import ArgumentParser as FlexibleArgumentParser
+
+logger = logging.getLogger(__name__)
+
+# -----------------------------------------------------------------------------
+# Data Classes
+# -----------------------------------------------------------------------------
+
+
+@dataclass
+class SampleRequest:
+ """
+ Represents a single inference request for benchmarking.
+ """
+
+ prompt: Union[str, Any]
+ prompt_len: int
+ expected_output_len: int
+ multi_modal_data: Optional[
+ Union[MultiModalDataDict, dict, list[dict]]
+ ] = None
+ lora_request: Optional[LoRARequest] = None
+ trace_timestamp: int = 0
+
+
+# -----------------------------------------------------------------------------
+# Benchmark Dataset Base Class
+# -----------------------------------------------------------------------------
+
+
+class BenchmarkDataset(ABC):
+ DEFAULT_SEED = 0
+ IS_MULTIMODAL = False
+
+ def __init__(
+ self,
+ dataset_path: Optional[str] = None,
+ random_seed: int = DEFAULT_SEED,
+ ) -> None:
+ """
+ Initialize the BenchmarkDataset with an optional dataset path and random
+ seed.
+
+ Args:
+ dataset_path (Optional[str]): Path to the dataset. If None, it
+ indicates that a default or random dataset might be used.
+ random_seed (int): Seed value for reproducible shuffling or
+ sampling. Defaults to DEFAULT_SEED.
+ """
+ self.dataset_path = dataset_path
+ # Set the random seed, ensuring that a None value is replaced with the
+ # default seed.
+ self.random_seed = (random_seed
+ if random_seed is not None else self.DEFAULT_SEED)
+ self.data = None
+
+ def apply_multimodal_chat_transformation(
+ self,
+ prompt: str,
+ mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
+ """
+ Transform a prompt and optional multimodal content into a chat format.
+ This method is used for chat models that expect a specific conversation
+ format.
+ """
+ content = [{"text": prompt, "type": "text"}]
+ if mm_content is not None:
+ content.append(mm_content)
+ return [{"role": "user", "content": content}]
+
+ def load_data(self) -> None:
+ """
+ Load data from the dataset path into self.data.
+
+ This method must be overridden by subclasses since the method to load
+ data will vary depending on the dataset format and source.
+
+ Raises:
+ NotImplementedError: If a subclass does not implement this method.
+ """
+ # TODO (jenniferzhao): add support for downloading data
+ raise NotImplementedError(
+ "load_data must be implemented in subclasses.")
+
+ def get_random_lora_request(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ max_loras: Optional[int] = None,
+ lora_path: Optional[str] = None,
+ ) -> tuple[Optional[LoRARequest], AnyTokenizer]:
+ """
+ Optionally select a random LoRA request and return its associated
+ tokenizer.
+
+ This method is used when LoRA parameters are provided. It randomly
+ selects a LoRA based on max_loras and retrieves a cached tokenizer for
+ that LoRA if available. Otherwise, it returns the base tokenizer.
+
+ Args:
+ tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
+ LoRA is selected.
+ max_loras (Optional[int]): The maximum number of LoRAs available.
+ If `None`, LoRA is not used.
+ lora_path (Optional[str]): Path to the LoRA parameters on disk.
+ If `None`, LoRA is not used.
+
+ Returns:
+ A tuple with the following elements:
+ - A new [LoRARequest][] (or `None` if not applicable).
+ - The tokenizer associated with the LoRA request
+ (or the base tokenizer).
+ """
+ if max_loras is None or lora_path is None:
+ return None, tokenizer
+
+ # Generate a random LoRA ID in the range [1, max_loras].
+ lora_id = random.randint(1, max_loras)
+ lora_request = LoRARequest(
+ lora_name=str(lora_id),
+ lora_int_id=lora_id,
+ lora_path=lora_path_on_disk(lora_path),
+ )
+ if lora_id not in lora_tokenizer_cache:
+ lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
+ # Return lora_request and the cached tokenizer if available; otherwise,
+ # return the base tokenizer
+ return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
+
+ @abstractmethod
+ def sample(self, tokenizer: PreTrainedTokenizerBase,
+ num_requests: int) -> list[SampleRequest]:
+ """
+ Abstract method to generate sample requests from the dataset.
+
+ Subclasses must override this method to implement dataset-specific logic
+ for generating a list of SampleRequest objects.
+
+ Args:
+ tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
+ for processing the dataset's text.
+ num_requests (int): The number of sample requests to generate.
+
+ Returns:
+ list[SampleRequest]: A list of sample requests generated from the
+ dataset.
+ """
+ raise NotImplementedError("sample must be implemented in subclasses.")
+
+ def maybe_oversample_requests(self, requests: list[SampleRequest],
+ num_requests: int) -> None:
+ """
+ Oversamples the list of requests if its size is less than the desired
+ number.
+
+ Args:
+ requests (List[SampleRequest]): The current list of sampled
+ requests.
+ num_requests (int): The target number of requests.
+ """
+ if len(requests) < num_requests:
+ random.seed(self.random_seed)
+ additional = random.choices(requests,
+ k=num_requests - len(requests))
+ requests.extend(additional)
+ logger.info("Oversampled requests to reach %d total samples.",
+ num_requests)
+
+
+# -----------------------------------------------------------------------------
+# Utility Functions and Global Caches
+# -----------------------------------------------------------------------------
+
+
+def is_valid_sequence(
+ prompt_len: int,
+ output_len: int,
+ min_len: int = 4,
+ max_prompt_len: int = 1024,
+ max_total_len: int = 2048,
+ skip_min_output_len_check: bool = False,
+) -> bool:
+ """
+ Validate a sequence based on prompt and output lengths.
+
+ Default pruning criteria are copied from the original `sample_hf_requests`
+ and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
+ from `sample_requests` in benchmark_throughput.py.
+ """
+ # Check for invalid conditions
+ prompt_too_short = prompt_len < min_len
+ output_too_short = (not skip_min_output_len_check) and (output_len
+ < min_len)
+ prompt_too_long = prompt_len > max_prompt_len
+ combined_too_long = (prompt_len + output_len) > max_total_len
+
+ # Return True if none of the invalid conditions are met
+ return not (prompt_too_short or output_too_short or prompt_too_long
+ or combined_too_long)
+
+
+@cache
+def lora_path_on_disk(lora_path: str) -> str:
+ return get_adapter_absolute_path(lora_path)
+
+
+# Global cache for LoRA tokenizers.
+lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
+
+
+def process_image(image: Any) -> Mapping[str, Any]:
+ """
+ Process a single image input and return a multimedia content dictionary.
+
+ Supports three input types:
+
+ 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
+ containing raw image data. - Loads the bytes as a PIL.Image.Image.
+
+ 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
+ a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
+ a dictionary with the image as a base64 data URL.
+
+ 3. String input: - Treats the string as a URL or local file path. -
+ Prepends "file://" if the string doesn't start with "http://" or
+ "file://". - Returns a dictionary with the image URL.
+
+ Raises:
+ ValueError: If the input is not a supported type.
+ """
+ if isinstance(image, dict) and 'bytes' in image:
+ image = Image.open(BytesIO(image['bytes']))
+ if isinstance(image, Image.Image):
+ image = convert_image_mode(image, "RGB")
+ with io.BytesIO() as image_data:
+ image.save(image_data, format="JPEG")
+ image_base64 = base64.b64encode(
+ image_data.getvalue()).decode("utf-8")
+ return {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{image_base64}"
+ },
+ }
+
+ if isinstance(image, str):
+ image_url = (image if image.startswith(
+ ("http://", "file://")) else f"file://{image}")
+ return {"type": "image_url", "image_url": {"url": image_url}}
+
+ raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
+ " or str or dictionary with raw image bytes.")
+
+
+# -----------------------------------------------------------------------------
+# Random Dataset Implementation (Synthetic Data)
+# -----------------------------------------------------------------------------
+
+
+class RandomDataset(BenchmarkDataset):
+ # Default values copied from benchmark_serving.py for the random dataset.
+ DEFAULT_PREFIX_LEN = 0
+ DEFAULT_RANGE_RATIO = 0.0
+ DEFAULT_INPUT_LEN = 1024
+ DEFAULT_OUTPUT_LEN = 128
+
+ def __init__(
+ self,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ random.seed(self.random_seed)
+ np.random.seed(self.random_seed)
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ prefix_len: int = DEFAULT_PREFIX_LEN,
+ range_ratio: float = DEFAULT_RANGE_RATIO,
+ input_len: int = DEFAULT_INPUT_LEN,
+ output_len: int = DEFAULT_OUTPUT_LEN,
+ trace_dataset_path = None,
+ **kwargs,
+ ) -> list[SampleRequest]:
+ # Enforce range_ratio < 1
+ assert range_ratio < 1.0, (
+ "random_range_ratio must be < 1.0 to ensure a valid sampling range"
+ )
+
+ vocab_size = tokenizer.vocab_size
+ num_special_tokens = tokenizer.num_special_tokens_to_add()
+ real_input_len = input_len - num_special_tokens
+
+ prefix_token_ids = (np.random.randint(
+ 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
+
+ print(f"RandomDataset sample {trace_dataset_path}")
+
+ events = []
+ with open(trace_dataset_path, "r") as f:
+ for line in f:
+ if not line.strip():
+ continue
+ obj = json.loads(line)
+ timestamp = obj["timestamp"]
+ input_length = obj["input_length"]
+ output_length = obj["output_length"]
+ print(f"read trace timestamp {timestamp} input_length {input_length} output_length {output_length}")
+ events.append((timestamp, input_length, output_length))
+ # Ensure chronological order
+ events.sort(key=lambda x: x[0])
+
+ num_requests = len(events)
+ timestamps = [i[0] for i in events]
+ input_lens = [i[1] for i in events]
+ output_lens = [i[2] for i in events]
+ offsets = np.random.randint(0, vocab_size, size=num_requests)
+
+ requests = []
+ for i in range(num_requests):
+ inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
+ vocab_size).tolist()
+ token_sequence = prefix_token_ids + inner_seq
+ prompt = tokenizer.decode(token_sequence)
+ # After decoding the prompt we have to encode and decode it again.
+ # This is done because in some cases N consecutive tokens
+ # give a string tokenized into != N number of tokens.
+ # For example for GPT2Tokenizer:
+ # [6880, 6881] -> ['Ä calls', 'here'] ->
+ # [1650, 939, 486] -> ['Ä call', 'sh', 'ere']
+ # To avoid uncontrolled change of the prompt length,
+ # the encoded sequence is truncated before being decode again.
+ total_input_len = prefix_len + int(input_lens[i])
+ re_encoded_sequence = tokenizer.encode(
+ prompt, add_special_tokens=False)[:total_input_len]
+ prompt = tokenizer.decode(re_encoded_sequence)
+ total_input_len = len(re_encoded_sequence)
+ requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=total_input_len,
+ expected_output_len=int(output_lens[i]),
+ trace_timestamp=timestamps[i]
+ ))
+ return requests
+
+
+# -----------------------------------------------------------------------------
+# ShareGPT Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class ShareGPTDataset(BenchmarkDataset):
+ """
+ Implements the ShareGPT dataset. Loads data from a JSON file and generates
+ sample requests based on conversation turns.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ with open(self.dataset_path, encoding="utf-8") as f:
+ self.data = json.load(f)
+ # Filter entries with at least two conversation turns.
+ self.data = [
+ entry for entry in self.data
+ if "conversations" in entry and len(entry["conversations"]) >= 2
+ ]
+ random.seed(self.random_seed)
+ random.shuffle(self.data)
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ lora_path: Optional[str] = None,
+ max_loras: Optional[int] = None,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ samples: list = []
+ for entry in self.data:
+ if len(samples) >= num_requests:
+ break
+ prompt, completion = (
+ entry["conversations"][0]["value"],
+ entry["conversations"][1]["value"],
+ )
+
+ lora_request, tokenizer = self.get_random_lora_request(
+ tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
+ prompt_ids = tokenizer(prompt).input_ids
+ completion_ids = tokenizer(completion).input_ids
+ prompt_len = len(prompt_ids)
+ new_output_len = (len(completion_ids)
+ if output_len is None else output_len)
+ if not is_valid_sequence(prompt_len,
+ new_output_len,
+ skip_min_output_len_check=output_len
+ is not None):
+ continue
+ # TODO: Also support ShareGPT4Video.
+ if image_path := entry.get("image"):
+ mm_content = process_image(image_path)
+ else:
+ mm_content = None
+ if enable_multimodal_chat:
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, mm_content)
+ samples.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=new_output_len,
+ lora_request=lora_request,
+ multi_modal_data=mm_content,
+ ))
+ self.maybe_oversample_requests(samples, num_requests)
+ return samples
+
+
+def add_dataset_parser(parser: FlexibleArgumentParser):
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.",
+ )
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ default="random",
+ choices=[
+ "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
+ "prefix_repetition"
+ ],
+ help="Name of the dataset to benchmark on.",
+ )
+ parser.add_argument(
+ "--no-stream",
+ action="store_true",
+ help="Do not load the dataset in streaming mode.",
+ )
+ parser.add_argument(
+ "--dataset-path",
+ type=str,
+ default=None,
+ help="Path to the sharegpt/sonnet dataset. "
+ "Or the huggingface dataset ID if using HF dataset.",
+ )
+
+ # group for dataset specific arguments
+ custom_group = parser.add_argument_group("custom dataset options")
+ custom_group.add_argument(
+ "--custom-output-len",
+ type=int,
+ default=256,
+ help=
+ "Number of output tokens per request, used only for custom dataset.",
+ )
+ custom_group.add_argument(
+ "--custom-skip-chat-template",
+ action="store_true",
+ help=
+ "Skip applying chat template to prompt, used only for custom dataset.",
+ )
+
+ sonnet_group = parser.add_argument_group("sonnet dataset options")
+ sonnet_group.add_argument(
+ "--sonnet-input-len",
+ type=int,
+ default=550,
+ help=
+ "Number of input tokens per request, used only for sonnet dataset.",
+ )
+ sonnet_group.add_argument(
+ "--sonnet-output-len",
+ type=int,
+ default=150,
+ help=
+ "Number of output tokens per request, used only for sonnet dataset.",
+ )
+ sonnet_group.add_argument(
+ "--sonnet-prefix-len",
+ type=int,
+ default=200,
+ help=
+ "Number of prefix tokens per request, used only for sonnet dataset.",
+ )
+
+ sharegpt_group = parser.add_argument_group("sharegpt dataset options")
+ sharegpt_group.add_argument(
+ "--sharegpt-output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the output length "
+ "from the ShareGPT dataset.",
+ )
+
+ random_group = parser.add_argument_group("random dataset options")
+ random_group.add_argument(
+ "--random-input-len",
+ type=int,
+ default=1024,
+ help=
+ "Number of input tokens per request, used only for random sampling.",
+ )
+ random_group.add_argument(
+ "--random-output-len",
+ type=int,
+ default=128,
+ help=
+ "Number of output tokens per request, used only for random sampling.",
+ )
+ random_group.add_argument(
+ "--random-range-ratio",
+ type=float,
+ default=0.0,
+ help="Range ratio for sampling input/output length, "
+ "used only for random sampling. Must be in the range [0, 1) to define "
+ "a symmetric sampling range"
+ "[length * (1 - range_ratio), length * (1 + range_ratio)].",
+ )
+ random_group.add_argument(
+ "--random-prefix-len",
+ type=int,
+ default=0,
+ help=("Number of fixed prefix tokens before the random context "
+ "in a request. "
+ "The total input length is the sum of `random-prefix-len` and "
+ "a random "
+ "context length sampled from [input_len * (1 - range_ratio), "
+ "input_len * (1 + range_ratio)]."),
+ )
+
+ hf_group = parser.add_argument_group("hf dataset options")
+ hf_group.add_argument("--hf-subset",
+ type=str,
+ default=None,
+ help="Subset of the HF dataset.")
+ hf_group.add_argument("--hf-split",
+ type=str,
+ default=None,
+ help="Split of the HF dataset.")
+ hf_group.add_argument(
+ "--hf-output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the output lengths "
+ "from the sampled HF dataset.",
+ )
+
+ prefix_repetition_group = parser.add_argument_group(
+ "prefix repetition dataset options")
+ prefix_repetition_group.add_argument(
+ "--prefix-repetition-prefix-len",
+ type=int,
+ default=256,
+ help="Number of prefix tokens per request, used only for prefix "
+ "repetition dataset.",
+ )
+ prefix_repetition_group.add_argument(
+ "--prefix-repetition-suffix-len",
+ type=int,
+ default=256,
+ help="Number of suffix tokens per request, used only for prefix "
+ "repetition dataset. Total input length is prefix_len + suffix_len.",
+ )
+ prefix_repetition_group.add_argument(
+ "--prefix-repetition-num-prefixes",
+ type=int,
+ default=10,
+ help="Number of prefixes to generate, used only for prefix repetition "
+ "dataset. Prompts per prefix is num_requests // num_prefixes.",
+ )
+ prefix_repetition_group.add_argument(
+ "--prefix-repetition-output-len",
+ type=int,
+ default=128,
+ help="Number of output tokens per request, used only for prefix "
+ "repetition dataset.",
+ )
+
+
+def get_samples(args, tokenizer) -> list[SampleRequest]:
+ if args.dataset_name == "custom":
+ dataset = CustomDataset(dataset_path=args.dataset_path)
+ input_requests = dataset.sample(
+ num_requests=args.num_prompts,
+ tokenizer=tokenizer,
+ output_len=args.custom_output_len,
+ skip_chat_template=args.custom_skip_chat_template,
+ )
+
+ elif args.dataset_name == "sonnet":
+ dataset = SonnetDataset(dataset_path=args.dataset_path)
+ # For the "sonnet" dataset, formatting depends on the backend.
+ if args.endpoint_type == "openai-chat":
+ input_requests = dataset.sample(
+ num_requests=args.num_prompts,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
+ tokenizer=tokenizer,
+ return_prompt_formatted=False,
+ )
+ else:
+ assert tokenizer.chat_template or tokenizer.default_chat_template, (
+ "Tokenizer/model must have chat template for sonnet dataset.")
+ input_requests = dataset.sample(
+ num_requests=args.num_prompts,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
+ tokenizer=tokenizer,
+ return_prompt_formatted=True,
+ )
+
+ elif args.dataset_name == "hf":
+ # all following datasets are implemented from the
+ # HuggingFaceDataset base class
+ if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = VisionArenaDataset
+ args.hf_split = "train"
+ args.hf_subset = None
+ elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = InstructCoderDataset
+ args.hf_split = "train"
+ elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = MTBenchDataset
+ args.hf_split = "train"
+ elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = ConversationDataset
+ elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = AIMODataset
+ args.hf_split = "train"
+ elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
+ dataset_class = NextEditPredictionDataset
+ args.hf_split = "train"
+ elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = ASRDataset
+ args.hf_split = "train"
+ elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = MLPerfDataset
+ args.hf_split = "train"
+ else:
+ supported_datasets = set([
+ dataset_name for cls in HuggingFaceDataset.__subclasses__()
+ for dataset_name in cls.SUPPORTED_DATASET_PATHS
+ ])
+ raise ValueError(
+ f"Unsupported dataset path: {args.dataset_path}. "
+ "Huggingface dataset only supports dataset_path"
+ f" from one of following: {supported_datasets}. "
+ "Please consider contributing if you would "
+ "like to add support for additional dataset formats.")
+
+ if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [
+ "openai-chat",
+ "openai-audio",
+ ]:
+ # multi-modal benchmark is only available on OpenAI Chat backend.
+ raise ValueError(
+ "Multi-modal content is only supported on 'openai-chat' and "
+ "'openai-audio' backend.")
+ input_requests = dataset_class(
+ dataset_path=args.dataset_path,
+ dataset_subset=args.hf_subset,
+ dataset_split=args.hf_split,
+ random_seed=args.seed,
+ no_stream=args.no_stream,
+ ).sample(
+ num_requests=args.num_prompts,
+ tokenizer=tokenizer,
+ output_len=args.hf_output_len,
+ )
+
+ else:
+ # For datasets that follow a similar structure, use a mapping.
+ dataset_mapping = {
+ "sharegpt":
+ lambda: ShareGPTDataset(random_seed=args.seed,
+ dataset_path=args.dataset_path).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ output_len=args.sharegpt_output_len,
+ ),
+ "burstgpt":
+ lambda: BurstGPTDataset(random_seed=args.seed,
+ dataset_path=args.dataset_path).
+ sample(tokenizer=tokenizer, num_requests=args.num_prompts),
+ "random":
+ lambda: RandomDataset(random_seed=args.seed,
+ dataset_path=args.dataset_path).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ prefix_len=args.random_prefix_len,
+ input_len=args.random_input_len,
+ output_len=args.random_output_len,
+ range_ratio=args.random_range_ratio,
+ trace_dataset_path=args.trace_dataset_path,
+ ),
+ "prefix_repetition":
+ lambda: PrefixRepetitionRandomDataset(
+ random_seed=args.seed, dataset_path=args.dataset_path
+ ).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ prefix_len=args.prefix_repetition_prefix_len,
+ suffix_len=args.prefix_repetition_suffix_len,
+ num_prefixes=args.prefix_repetition_num_prefixes,
+ output_len=args.prefix_repetition_output_len,
+ ),
+ }
+
+ try:
+ input_requests = dataset_mapping[args.dataset_name]()
+ except KeyError as err:
+ raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
+
+ return input_requests
+
+
+# -----------------------------------------------------------------------------
+# Custom Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class CustomDataset(BenchmarkDataset):
+ """
+ Implements the Custom dataset. Loads data from a JSONL file and generates
+ sample requests based on conversation turns. E.g.,
+ ```
+ {"prompt": "What is the capital of India?"}
+ {"prompt": "What is the capital of Iran?"}
+ {"prompt": "What is the capital of China?"}
+ ```
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ # self.data will be a list of dictionaries
+ # e.g., [{"prompt": "What is the capital of India?"}, ...]
+ # This will be the standardized format which load_data()
+ # has to convert into depending on the filetype of dataset_path.
+ # sample() will assume this standardized format of self.data
+ self.data = []
+
+ # Load the JSONL file
+ if self.dataset_path.endswith(".jsonl"):
+ jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
+ lines=True)
+
+ # check if the JSONL file has a 'prompt' column
+ if "prompt" not in jsonl_data.columns:
+ raise ValueError("JSONL file must contain a 'prompt' column.")
+
+ # Convert each row to a dictionary and append to self.data
+ # This will convert the DataFrame to a list of dictionaries
+ # where each dictionary corresponds to a row in the DataFrame.
+ # This is the standardized format we want for self.data
+ for _, row in jsonl_data.iterrows():
+ self.data.append(row.to_dict())
+ else:
+ raise NotImplementedError(
+ "Only JSONL format is supported for CustomDataset.")
+
+ random.seed(self.random_seed)
+ random.shuffle(self.data)
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ lora_path: Optional[str] = None,
+ max_loras: Optional[int] = None,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ skip_chat_template: bool = False,
+ **kwargs,
+ ) -> list:
+ sampled_requests = []
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ prompt = item["prompt"]
+
+ # apply template
+ if not skip_chat_template:
+ prompt = tokenizer.apply_chat_template(
+ [{
+ "role": "user",
+ "content": prompt
+ }],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ prompt_len = len(tokenizer(prompt).input_ids)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# Sonnet Dataset Implementation
+# -----------------------------------------------------------------------------
+
+@deprecated(
+ "SonnetDataset is deprecated and will be removed in a future version.",
+)
+class SonnetDataset(BenchmarkDataset):
+ """
+ Simplified implementation of the Sonnet dataset. Loads poem lines from a
+ text file and generates sample requests. Default values here copied from
+ `benchmark_serving.py` for the sonnet dataset.
+ """
+
+ DEFAULT_PREFIX_LEN = 200
+ DEFAULT_INPUT_LEN = 550
+ DEFAULT_OUTPUT_LEN = 150
+
+ def __init__(
+ self,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if not self.dataset_path:
+ raise ValueError("dataset_path must be provided.")
+ with open(self.dataset_path, encoding="utf-8") as f:
+ self.data = f.readlines()
+
+ def sample(
+ self,
+ tokenizer,
+ num_requests: int,
+ prefix_len: int = DEFAULT_PREFIX_LEN,
+ input_len: int = DEFAULT_INPUT_LEN,
+ output_len: int = DEFAULT_OUTPUT_LEN,
+ return_prompt_formatted: bool = False,
+ **kwargs,
+ ) -> list:
+ # Calculate average token length for a poem line.
+ tokenized_lines = [tokenizer(line).input_ids for line in self.data]
+ avg_len = sum(len(tokens)
+ for tokens in tokenized_lines) / len(tokenized_lines)
+
+ # Build the base prompt.
+ base_prompt = "Pick as many lines as you can from these poem lines:\n"
+ base_msg = [{"role": "user", "content": base_prompt}]
+ base_fmt = tokenizer.apply_chat_template(base_msg,
+ add_generation_prompt=True,
+ tokenize=False)
+ base_offset = len(tokenizer(base_fmt).input_ids)
+ if input_len <= base_offset:
+ raise ValueError(
+ f"'input_len' must be higher than the base prompt length "
+ f"({base_offset}).")
+
+ # Determine how many poem lines to use.
+ num_input_lines = round((input_len - base_offset) / avg_len)
+ num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
+ prefix_lines = self.data[:num_prefix_lines]
+
+ samples = []
+ while len(samples) < num_requests:
+ extra_lines = random.choices(self.data,
+ k=num_input_lines - num_prefix_lines)
+ prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
+ msg = [{"role": "user", "content": prompt}]
+ prompt_formatted = tokenizer.apply_chat_template(
+ msg, add_generation_prompt=True, tokenize=False)
+ prompt_len = len(tokenizer(prompt_formatted).input_ids)
+ if prompt_len <= input_len:
+ samples.append(
+ SampleRequest(
+ prompt=prompt_formatted
+ if return_prompt_formatted else prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ ))
+ return samples
+
+
+# -----------------------------------------------------------------------------
+# BurstGPT Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class BurstGPTDataset(BenchmarkDataset):
+ """
+ Implements the BurstGPT dataset. Loads data from a CSV file and generates
+ sample requests based on synthetic prompt generation. Only rows with Model
+ "GPT-4" and positive response tokens are used.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self, ):
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ df = pd.read_csv(self.dataset_path)
+ # Filter to keep only GPT-4 rows.
+ gpt4_df = df[df["Model"] == "GPT-4"]
+ # Remove failed requests (where Response tokens is 0 or less).
+ gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
+ # Sample the desired number of rows.
+ self.data = gpt4_df
+
+ def _sample_loaded_data(self, num_requests: int) -> list:
+ if num_requests <= len(self.data):
+ data = self.data.sample(n=num_requests,
+ random_state=self.random_seed)
+ else:
+ data = self.data.sample(
+ n=num_requests,
+ random_state=self.random_seed,
+ replace=True,
+ )
+ # Convert the dataframe to a list of lists.
+ return data.values.tolist()
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ max_loras: Optional[int] = None,
+ lora_path: Optional[str] = None,
+ **kwargs,
+ ) -> list[SampleRequest]:
+ samples = []
+ data = self._sample_loaded_data(num_requests=num_requests)
+ for i in range(num_requests):
+ input_len = int(data[i][2])
+ output_len = int(data[i][3])
+ lora_req, tokenizer = self.get_random_lora_request(
+ tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
+ vocab_size = tokenizer.vocab_size
+ # Generate a synthetic prompt: a list of token IDs computed as (i +
+ # j) modulo vocab_size.
+ token_ids = [(i + j) % vocab_size for j in range(input_len)]
+ prompt = tokenizer.decode(token_ids)
+ samples.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=output_len,
+ lora_request=lora_req,
+ ))
+ return samples
+
+
+# -----------------------------------------------------------------------------
+# HuggingFace Dataset Base Implementation
+# -----------------------------------------------------------------------------
+class HuggingFaceDataset(BenchmarkDataset):
+ """Base class for datasets hosted on HuggingFace."""
+
+ SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
+
+ def __init__(
+ self,
+ dataset_path: str,
+ dataset_split: str,
+ no_stream: bool = False,
+ dataset_subset: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(dataset_path=dataset_path, **kwargs)
+
+ self.dataset_split = dataset_split
+ self.dataset_subset = dataset_subset
+ self.load_stream = not no_stream
+ self.load_data()
+
+ def load_data(self) -> None:
+ """Load data from HuggingFace datasets."""
+ self.data = load_dataset(
+ self.dataset_path,
+ name=self.dataset_subset,
+ split=self.dataset_split,
+ streaming=self.load_stream,
+ )
+ self.data = self.data.shuffle(seed=self.random_seed)
+
+
+# -----------------------------------------------------------------------------
+# Conversation Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class ConversationDataset(HuggingFaceDataset):
+ """Dataset for conversation data with multimodal support."""
+ SUPPORTED_DATASET_PATHS = {
+ 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
+ }
+ IS_MULTIMODAL = True
+
+ def sample(self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs) -> list:
+ # Filter examples with at least 2 conversations
+ filtered_data = self.data.filter(
+ lambda x: len(x["conversations"]) >= 2)
+ sampled_requests = []
+ dynamic_output = output_len is None
+
+ for item in filtered_data:
+ if len(sampled_requests) >= num_requests:
+ break
+ conv = item["conversations"]
+ prompt, completion = conv[0]["value"], conv[1]["value"]
+
+ prompt_ids = tokenizer(prompt).input_ids
+ completion_ids = tokenizer(completion).input_ids
+ prompt_len = len(prompt_ids)
+ completion_len = len(completion_ids)
+ output_len = completion_len if dynamic_output else output_len
+ assert isinstance(output_len, int) and output_len > 0
+ if dynamic_output and not is_valid_sequence(
+ prompt_len, completion_len):
+ continue
+ mm_content = process_image(
+ item["image"]) if "image" in item else None
+ if enable_multimodal_chat:
+ # Note: when chat is enabled the request prompt_len is no longer
+ # accurate and we will be using request output to count the
+ # actual prompt len and output len
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, mm_content)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=mm_content,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# Vision Arena Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class VisionArenaDataset(HuggingFaceDataset):
+ """
+ Vision Arena Dataset.
+ """
+
+ DEFAULT_OUTPUT_LEN = 128
+ SUPPORTED_DATASET_PATHS = {
+ "lmarena-ai/VisionArena-Chat":
+ lambda x: x["conversation"][0][0]["content"],
+ "lmarena-ai/vision-arena-bench-v0.1":
+ lambda x: x["turns"][0][0]["content"]
+ }
+ IS_MULTIMODAL = True
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ output_len = (output_len
+ if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ sampled_requests = []
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
+ if parser_fn is None:
+ raise ValueError(
+ f"Unsupported dataset path: {self.dataset_path}")
+ prompt = parser_fn(item)
+ mm_content = process_image(item["images"][0])
+ prompt_len = len(tokenizer(prompt).input_ids)
+ if enable_multimodal_chat:
+ # Note: when chat is enabled the request prompt_len is no longer
+ # accurate and we will be using request output to count the
+ # actual prompt len
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, mm_content)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=mm_content,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# Instruct Coder Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class InstructCoderDataset(HuggingFaceDataset):
+ """
+ InstructCoder Dataset.
+ https://huggingface.co/datasets/likaixin/InstructCoder
+
+ InstructCoder is the dataset designed for general code editing. It consists
+ of 114,239 instruction-input-output triplets, and covers multiple distinct
+ code editing scenario.
+ """
+
+ DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
+ SUPPORTED_DATASET_PATHS = {
+ "likaixin/InstructCoder",
+ }
+
+ def sample(self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs) -> list:
+ output_len = (output_len
+ if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ sampled_requests = []
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ prompt = f"{item['input']}\n\n{item['instruction']} Just output \
+ the code, do not include any explanation."
+
+ # apply template
+ prompt = tokenizer.apply_chat_template(
+ [{
+ "role": "user",
+ "content": prompt
+ }],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ prompt_len = len(tokenizer(prompt).input_ids)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# MT-Bench Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class MTBenchDataset(HuggingFaceDataset):
+ """
+ MT-Bench Dataset.
+ https://huggingface.co/datasets/philschmid/mt-bench
+
+ We create a single turn dataset for MT-Bench.
+ This is similar to Spec decoding benchmark setup in vLLM
+ https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
+ """ # noqa: E501
+
+ DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
+ SUPPORTED_DATASET_PATHS = {
+ "philschmid/mt-bench",
+ }
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ output_len = (output_len
+ if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ sampled_requests = []
+
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ prompt = item["turns"][0]
+
+ # apply template
+ prompt = tokenizer.apply_chat_template(
+ [{
+ "role": "user",
+ "content": prompt
+ }],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ prompt_len = len(tokenizer(prompt).input_ids)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# AIMO Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class AIMODataset(HuggingFaceDataset):
+ """
+ Dataset class for processing a AIMO dataset with reasoning questions.
+ """
+ SUPPORTED_DATASET_PATHS = {
+ "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
+ "AI-MO/NuminaMath-CoT"
+ }
+
+ def sample(self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ **kwargs) -> list:
+ sampled_requests = []
+ dynamic_output = output_len is None
+
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ prompt, completion = item['problem'], item["solution"]
+
+ prompt_ids = tokenizer(prompt).input_ids
+ completion_ids = tokenizer(completion).input_ids
+ prompt_len = len(prompt_ids)
+ completion_len = len(completion_ids)
+ output_len = completion_len if dynamic_output else output_len
+ assert isinstance(output_len, int) and output_len > 0
+ if dynamic_output and not is_valid_sequence(prompt_len,
+ completion_len,
+ max_prompt_len=2048,
+ max_total_len=32000):
+ continue
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=None,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# Next Edit Prediction Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+zeta_prompt = """### Instruction:
+You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
+
+### User Edits:
+
+{}
+
+### User Excerpt:
+
+{}
+
+### Response:
+
+""" # noqa: E501
+
+
+def _format_zeta_prompt(
+ sample: dict,
+ original_start_marker: str = "<|editable_region_start|>") -> dict:
+ """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
+
+ This function formats examples from the NEP dataset
+ into prompts and expected outputs. It could be
+ further extended to support more NEP datasets.
+
+ Args:
+ sample: The dataset sample containing events,
+ inputs, and outputs.
+ original_start_marker: The marker indicating the
+ start of the editable region. Defaults to
+ "<|editable_region_start|>".
+
+ Returns:
+ A dictionary with the formatted prompts and expected outputs.
+ """
+ events = sample["events"]
+ input = sample["input"]
+ output = sample["output"]
+ prompt = zeta_prompt.format(events, input)
+
+ # following the original implementation, extract the focused region
+ # from the raw output
+ output_start_index = output.find(original_start_marker)
+ output_focused_region = output[output_start_index:]
+ expected_output = output_focused_region
+
+ return {"prompt": prompt, "expected_output": expected_output}
+
+
+class NextEditPredictionDataset(HuggingFaceDataset):
+ """
+ Dataset class for processing a Next Edit Prediction dataset.
+ """
+
+ SUPPORTED_DATASET_PATHS = {
+ "zed-industries/zeta",
+ }
+ MAPPING_PROMPT_FUNCS = {
+ "zed-industries/zeta": _format_zeta_prompt,
+ }
+
+ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
+ **kwargs):
+ formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
+ self.dataset_path)
+ if formatting_prompt_func is None:
+ raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
+ samples = []
+ for sample in self.data:
+ sample = formatting_prompt_func(sample)
+ samples.append(
+ SampleRequest(
+ prompt=sample["prompt"],
+ prompt_len=len(tokenizer(sample["prompt"]).input_ids),
+ expected_output_len=len(
+ tokenizer(sample["expected_output"]).input_ids),
+ ))
+ if len(samples) >= num_requests:
+ break
+ self.maybe_oversample_requests(samples, num_requests)
+ return samples
+
+
+# -----------------------------------------------------------------------------
+# ASR Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class ASRDataset(HuggingFaceDataset):
+ """
+ Dataset class for processing a ASR dataset for transcription.
+ Tested on the following set:
+
+ +----------------+----------------------------------------+--------------------------+-----------------------------+
+ | Dataset | Domain | Speaking Style | hf-subset |
+ +----------------+----------------------------------------+--------------------------+-----------------------------+
+ | TED-LIUM | TED talks | Oratory | release1, release2, release3|
+ | | | | release3-speaker-adaptation |
+ | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
+ | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
+ | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
+ | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
+ | AMI | Meetings | Spontaneous | ihm, sdm |
+ +----------------+----------------------------------------+--------------------------+-----------------------------+
+
+ """ # noqa: E501
+
+ SUPPORTED_DATASET_PATHS = {
+ "openslr/librispeech_asr",
+ "facebook/voxpopuli",
+ "LIUM/tedlium",
+ "edinburghcstr/ami",
+ "speechcolab/gigaspeech",
+ "kensho/spgispeech",
+ }
+
+ DEFAULT_OUTPUT_LEN = 128
+ IS_MULTIMODAL = True
+
+ # TODO Whisper-specific. Abstract interface when more models are supported.
+ TRANSCRIPTION_PREAMBLE = (
+ "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>")
+ skip_long_audios: bool = True
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ **kwargs,
+ ) -> list:
+ output_len = (output_len
+ if output_len is not None else self.DEFAULT_OUTPUT_LEN)
+ prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
+ prompt_len = len(tokenizer(prompt).input_ids)
+ sampled_requests = []
+ skipped = 0
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ audio = item["audio"]
+ y, sr = audio["array"], audio["sampling_rate"]
+ duration_s = librosa.get_duration(y=y, sr=sr)
+ # Whisper max supported duration
+ if self.skip_long_audios and duration_s > 30:
+ skipped += 1
+ continue
+
+ mm_content = {"audio": (y, sr)}
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=mm_content,
+ ))
+ if skipped:
+ logger.warning(
+ "%d samples discarded from dataset due to"
+ " their length being greater than"
+ " what Whisper supports.",
+ skipped,
+ )
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# MLPerf Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class MLPerfDataset(HuggingFaceDataset):
+ """
+ MLPerf Inference Dataset.
+
+ Dataset on HF:
+ https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
+ https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data
+
+ Each record contains:
+ - "system_prompt": system role instruction.
+ - "question": user question.
+ - "output": reference answer.
+
+ We combine the system prompt and question into a chat-formatted prompt
+ (using the tokenizer's chat template) and set the expected output length to
+ the tokenized length of the provided reference answer.
+ """
+
+ SUPPORTED_DATASET_PATHS = {
+ "mgoin/mlperf-inference-llama2-data",
+ "mgoin/mlperf-inference-llama3.1-data",
+ }
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ output_len: Optional[int] = None,
+ **kwargs,
+ ) -> list[SampleRequest]:
+ # Force dynamic output length based on reference completion.
+ dynamic_output = output_len is None
+ sampled_requests: list[SampleRequest] = []
+
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+
+ system_prompt = item["system_prompt"]
+ question = item["question"]
+ reference_answer = item["output"]
+
+ # Build chat-style prompt using tokenizer template, if available.
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": question},
+ ]
+ prompt_formatted = tokenizer.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False
+ )
+ prompt_len = len(tokenizer(prompt_formatted).input_ids)
+
+ # Determine output length from reference answer tokens.
+ ref_out_len = len(
+ tokenizer(reference_answer, add_special_tokens=False).input_ids
+ )
+ expected_output_len = ref_out_len if dynamic_output else output_len
+
+ # Validate sequence lengths.
+ if not is_valid_sequence(prompt_len, expected_output_len):
+ continue
+
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt_formatted,
+ prompt_len=prompt_len,
+ expected_output_len=expected_output_len,
+ )
+ )
+
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+ return sampled_requests
+
+
+# -----------------------------------------------------------------------------
+# Prefix Repetition Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class PrefixRepetitionRandomDataset(BenchmarkDataset):
+ # Default values copied from benchmark_serving.py for the repeated prefix
+ # dataset.
+ DEFAULT_PREFIX_LEN = 256
+ DEFAULT_SUFFIX_LEN = 256
+ DEFAULT_NUM_PREFIXES = 10
+ DEFAULT_OUTPUT_LEN = 128
+
+ def __init__(
+ self,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ random.seed(self.random_seed)
+ np.random.seed(self.random_seed)
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ prefix_len: int = DEFAULT_PREFIX_LEN,
+ suffix_len: int = DEFAULT_SUFFIX_LEN,
+ num_prefixes: int = DEFAULT_NUM_PREFIXES,
+ output_len: int = DEFAULT_OUTPUT_LEN,
+ **kwargs,
+ ) -> list[SampleRequest]:
+ vocab_size = tokenizer.vocab_size
+ prompts_per_prefix = num_requests // num_prefixes
+ if prompts_per_prefix == 0:
+ raise ValueError(
+ f"num_requests ({num_requests}) must be greater than or equal "
+ f"to num_prefixes ({num_prefixes})"
+ )
+
+ def _generate_exact_length_tokens(target_length: int) -> list[int]:
+ """Generate tokens that decode and re-encode to exactly
+ target_length."""
+ # Generate random tokens
+ tokens = np.random.randint(
+ 0, vocab_size, size=target_length).tolist()
+ text = tokenizer.decode(tokens)
+ re_encoded = tokenizer.encode(text, add_special_tokens=False)
+
+ if len(re_encoded) == target_length:
+ return re_encoded
+ elif len(re_encoded) < target_length:
+ # Recursively generate additional consistent tokens
+ needed = target_length - len(re_encoded)
+ extra_tokens = _generate_exact_length_tokens(needed)
+ return re_encoded + extra_tokens
+ else:
+ # Truncate to target length
+ return re_encoded[:target_length]
+
+ requests = []
+ for _ in range(num_prefixes):
+ prefix_tokens = _generate_exact_length_tokens(prefix_len)
+
+ for _ in range(prompts_per_prefix):
+ suffix_tokens = _generate_exact_length_tokens(suffix_len)
+
+ combined_tokens = prefix_tokens + suffix_tokens
+ prompt = tokenizer.decode(combined_tokens)
+ prompt_len = len(combined_tokens)
+ requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ )
+ )
+
+ random.shuffle(requests)
+ return requests
diff --git a/benchmark/reproducibility/patches_v0.10.1/serve.py b/benchmark/reproducibility/patches_v0.10.1/serve.py
new file mode 100644
index 000000000..928dd27b0
--- /dev/null
+++ b/benchmark/reproducibility/patches_v0.10.1/serve.py
@@ -0,0 +1,1199 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+r"""Benchmark online serving throughput.
+
+On the server side, run one of the following commands
+to launch the vLLM OpenAI API server:
+ vllm serve
+
+On the client side, run:
+ vllm bench serve \
+ --endpoint-type \
+ --label \
+ --model \
+ --dataset-name \
+ --request-rate \
+ --num-prompts
+"""
+import argparse
+import asyncio
+import gc
+import json
+import os
+import random
+import time
+import warnings
+from collections.abc import AsyncGenerator, Iterable
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, Literal, Optional
+
+import aiohttp
+import numpy as np
+from tqdm.asyncio import tqdm
+from transformers import PreTrainedTokenizerBase
+
+from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser,
+ get_samples)
+from vllm.benchmarks.lib.endpoint_request_func import (
+ ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
+ RequestFuncOutput)
+from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
+from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format,
+ write_to_json)
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+MILLISECONDS_TO_SECONDS_CONVERSION = 1000
+
+
+@dataclass
+class BenchmarkMetrics:
+ completed: int
+ total_input: int
+ total_output: int
+ request_throughput: float
+ request_goodput: float
+ output_throughput: float
+ total_token_throughput: float
+ mean_ttft_ms: float
+ median_ttft_ms: float
+ std_ttft_ms: float
+ percentiles_ttft_ms: list[tuple[float, float]]
+ mean_tpot_ms: float
+ median_tpot_ms: float
+ std_tpot_ms: float
+ percentiles_tpot_ms: list[tuple[float, float]]
+ mean_itl_ms: float
+ median_itl_ms: float
+ std_itl_ms: float
+ percentiles_itl_ms: list[tuple[float, float]]
+ # E2EL stands for end-to-end latency per request.
+ # It is the time taken on the client side from sending
+ # a request to receiving a complete response.
+ mean_e2el_ms: float
+ median_e2el_ms: float
+ std_e2el_ms: float
+ percentiles_e2el_ms: list[tuple[float, float]]
+
+
+def _get_current_request_rate(
+ ramp_up_strategy: Optional[Literal["linear", "exponential"]],
+ ramp_up_start_rps: Optional[int],
+ ramp_up_end_rps: Optional[int],
+ request_index: int,
+ total_requests: int,
+ request_rate: float,
+) -> float:
+ if (ramp_up_strategy and ramp_up_start_rps is not None
+ and ramp_up_end_rps is not None):
+ progress = request_index / max(total_requests - 1, 1)
+ if ramp_up_strategy == "linear":
+ increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
+ return ramp_up_start_rps + increase
+ elif ramp_up_strategy == "exponential":
+ ratio = ramp_up_end_rps / ramp_up_start_rps
+ return ramp_up_start_rps * (ratio**progress)
+ else:
+ raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
+ return request_rate
+
+
+async def get_request_trace(input_requests: list[SampleRequest]) -> AsyncGenerator[tuple[SampleRequest, float], None]:
+
+ print(f"get_request_trace")
+
+ prev_ts = 0
+ for i in range(len(input_requests)):
+ request = input_requests[i]
+ curr_ts = request.trace_timestamp
+ delay = curr_ts - prev_ts
+ print(f"request {i} prompt_len {request.prompt_len} expected_output_len {request.expected_output_len} timestamp {request.trace_timestamp} delay {delay} ms")
+ if delay > 0:
+ await asyncio.sleep(delay/1000)
+ prev_ts = curr_ts
+ yield request, 1
+
+
+async def get_request(
+ input_requests: list[SampleRequest],
+ request_rate: float,
+ burstiness: float = 1.0,
+ ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
+ ramp_up_start_rps: Optional[int] = None,
+ ramp_up_end_rps: Optional[int] = None,
+) -> AsyncGenerator[tuple[SampleRequest, float], None]:
+ """
+ Asynchronously generates requests at a specified rate
+ with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
+
+ Args:
+ input_requests:
+ A list of input requests, each represented as a SampleRequest.
+ request_rate:
+ The rate at which requests are generated (requests/s).
+ burstiness (optional):
+ The burstiness factor of the request generation.
+ Only takes effect when request_rate is not inf.
+ Default value is 1, which follows a Poisson process.
+ Otherwise, the request intervals follow a gamma distribution.
+ A lower burstiness value (0 < burstiness < 1) results
+ in more bursty requests, while a higher burstiness value
+ (burstiness > 1) results in a more uniform arrival of requests.
+ ramp_up_strategy (optional):
+ The ramp-up strategy. Can be "linear" or "exponential".
+ If None, uses constant request rate (specified by request_rate).
+ ramp_up_start_rps (optional):
+ The starting request rate for ramp-up.
+ ramp_up_end_rps (optional):
+ The ending request rate for ramp-up.
+ """
+ assert burstiness > 0, (
+ f"A positive burstiness factor is expected, but given {burstiness}.")
+ # Convert to list to get length for ramp-up calculations
+ if isinstance(input_requests, Iterable) and not isinstance(
+ input_requests, list):
+ input_requests = list(input_requests)
+
+ total_requests = len(input_requests)
+ assert total_requests > 0, "No requests provided."
+
+ # Precompute delays among requests to minimize request send laggings
+ request_rates = []
+ delay_ts = []
+ for request_index, request in enumerate(input_requests):
+ current_request_rate = _get_current_request_rate(ramp_up_strategy,
+ ramp_up_start_rps,
+ ramp_up_end_rps,
+ request_index,
+ total_requests,
+ request_rate)
+ request_rates.append(current_request_rate)
+ if current_request_rate == float("inf"):
+ delay_ts.append(0)
+ else:
+ theta = 1.0 / (current_request_rate * burstiness)
+
+ # Sample the request interval from the gamma distribution.
+ # If burstiness is 1, it follows exponential distribution.
+ delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))
+
+ # Calculate the cumulative delay time from the first sent out requests.
+ for i in range(1, len(delay_ts)):
+ delay_ts[i] += delay_ts[i - 1]
+ if ramp_up_strategy is None and delay_ts[-1] != 0:
+ # When ramp_up_strategy is not set, we assume the request rate is fixed
+ # and all requests should be sent in target_total_delay_s, the following
+ # logic would re-scale delay time to ensure the final delay_ts
+ # align with target_total_delay_s.
+ #
+ # NOTE: If we simply accumulate the random delta values
+ # from the gamma distribution, their sum would have 1-2% gap
+ # from target_total_delay_s. The purpose of the following logic is to
+ # close the gap for stablizing the throughput data
+ # from different random seeds.
+ target_total_delay_s = total_requests / request_rate
+ normalize_factor = target_total_delay_s / delay_ts[-1]
+ delay_ts = [delay * normalize_factor for delay in delay_ts]
+
+ start_ts = time.time()
+ for request_index, request in enumerate(input_requests):
+ if delay_ts[request_index] > 0:
+ current_ts = time.time()
+ sleep_interval_s = start_ts + delay_ts[request_index] - current_ts
+ if sleep_interval_s > 0:
+ await asyncio.sleep(sleep_interval_s)
+ yield request, request_rates[request_index]
+
+
+def calculate_metrics(
+ input_requests: list[SampleRequest],
+ outputs: list[RequestFuncOutput],
+ dur_s: float,
+ tokenizer: PreTrainedTokenizerBase,
+ selected_percentiles: list[float],
+ goodput_config_dict: dict[str, float],
+ trace_output_path: str | None = None,
+) -> tuple[BenchmarkMetrics, list[int]]:
+ """Calculate the metrics for the benchmark.
+
+ Args:
+ input_requests: The input requests.
+ outputs: The outputs of the requests.
+ dur_s: The duration of the benchmark.
+ tokenizer: The tokenizer to use.
+ selected_percentiles: The percentiles to select.
+ goodput_config_dict: The goodput configuration.
+
+ Returns:
+ A tuple of the benchmark metrics and the actual output lengths.
+ """
+ actual_output_lens: list[int] = []
+ total_input = 0
+ completed = 0
+ good_completed = 0
+ itls: list[float] = []
+ tpots: list[float] = []
+ all_tpots: list[float] = []
+ ttfts: list[float] = []
+ e2els: list[float] = []
+ for i in range(len(outputs)):
+ if outputs[i].success:
+ output_len = outputs[i].output_tokens
+
+ if not output_len:
+ # We use the tokenizer to count the number of output tokens
+ # for some serving backends instead of looking at
+ # len(outputs[i].itl) since multiple output tokens may be
+ # bundled together
+ # Note : this may inflate the output token count slightly
+ output_len = len(
+ tokenizer(outputs[i].generated_text,
+ add_special_tokens=False).input_ids)
+ actual_output_lens.append(output_len)
+ total_input += input_requests[i].prompt_len
+ tpot = 0
+ if output_len > 1:
+ latency_minus_ttft = outputs[i].latency - outputs[i].ttft
+ tpot = latency_minus_ttft / (output_len - 1)
+ tpots.append(tpot)
+ # Note: if output_len <= 1, we regard tpot as 0 for goodput
+ all_tpots.append(tpot)
+ itls += outputs[i].itl
+ ttfts.append(outputs[i].ttft)
+ e2els.append(outputs[i].latency)
+ completed += 1
+ else:
+ actual_output_lens.append(0)
+
+ if trace_output_path is not None:
+ import csv
+ header = ["request", "timestamp", "input_len", "output_len", "ttft_ms", "tpot_ms", "e2el_ms"]
+ with open(trace_output_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(header)
+ for i in range(len(outputs)):
+ row = [
+ i,
+ input_requests[i].trace_timestamp,
+ input_requests[i].prompt_len,
+ actual_output_lens[i],
+ round(ttfts[i] * 1000, 2),
+ round(all_tpots[i] * 1000, 2),
+ round(e2els[i] * 1000, 2)
+ ]
+ writer.writerow(row)
+ print(f"Metrics saved to {trace_output_path}")
+ else:
+ print(f"requsest, timestamp, input_len, output_len, ttft_ms, tpot_ms, e2el_ms")
+ for i in range(len(outputs)):
+ print(f"{i} {input_requests[i].trace_timestamp} {input_requests[i].prompt_len} {actual_output_lens[i]} {ttfts[i]*1000:.2f} {all_tpots[i]*1000:.2f} {e2els[i]*1000:.2f}")
+
+ if goodput_config_dict:
+ valid_metrics = []
+ slo_values = []
+
+ if "ttft" in goodput_config_dict:
+ valid_metrics.append(ttfts)
+ slo_values.append(goodput_config_dict["ttft"] /
+ MILLISECONDS_TO_SECONDS_CONVERSION)
+ if "tpot" in goodput_config_dict:
+ valid_metrics.append(all_tpots)
+ slo_values.append(goodput_config_dict["tpot"] /
+ MILLISECONDS_TO_SECONDS_CONVERSION)
+ if "e2el" in goodput_config_dict:
+ valid_metrics.append(e2els)
+ slo_values.append(goodput_config_dict["e2el"] /
+ MILLISECONDS_TO_SECONDS_CONVERSION)
+
+ for req_metric in zip(*valid_metrics):
+ is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
+ if is_good_req:
+ good_completed += 1
+
+ if completed == 0:
+ warnings.warn(
+ "All requests failed. This is likely due to a misconfiguration "
+ "on the benchmark arguments.",
+ stacklevel=2)
+ metrics = BenchmarkMetrics(
+ completed=completed,
+ total_input=total_input,
+ total_output=sum(actual_output_lens),
+ request_throughput=completed / dur_s,
+ request_goodput=good_completed / dur_s,
+ output_throughput=sum(actual_output_lens) / dur_s,
+ total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
+ mean_ttft_ms=np.mean(ttfts or 0) *
+ 1000, # ttfts is empty if streaming is not supported by the endpoint
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
+ percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_tpot_ms=np.mean(tpots or 0) * 1000,
+ std_tpot_ms=np.std(tpots or 0) * 1000,
+ median_tpot_ms=np.median(tpots or 0) * 1000,
+ percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_itl_ms=np.mean(itls or 0) * 1000,
+ std_itl_ms=np.std(itls or 0) * 1000,
+ median_itl_ms=np.median(itls or 0) * 1000,
+ percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_e2el_ms=np.mean(e2els or 0) * 1000,
+ std_e2el_ms=np.std(e2els or 0) * 1000,
+ median_e2el_ms=np.median(e2els or 0) * 1000,
+ percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
+ for p in selected_percentiles],
+ )
+
+ return metrics, actual_output_lens
+
+
+async def benchmark(
+ endpoint_type: str,
+ api_url: str,
+ base_url: str,
+ model_id: str,
+ model_name: str,
+ tokenizer: PreTrainedTokenizerBase,
+ input_requests: list[SampleRequest],
+ logprobs: Optional[int],
+ request_rate: float,
+ burstiness: float,
+ disable_tqdm: bool,
+ profile: bool,
+ selected_percentile_metrics: list[str],
+ selected_percentiles: list[float],
+ ignore_eos: bool,
+ goodput_config_dict: dict[str, float],
+ max_concurrency: Optional[int],
+ lora_modules: Optional[Iterable[str]],
+ extra_body: Optional[dict],
+ ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
+ ramp_up_start_rps: Optional[int] = None,
+ ramp_up_end_rps: Optional[int] = None,
+ ready_check_timeout_sec: int = 600,
+ trace_dataset_path: str | None = None,
+ trace_output_path: str | None = None,
+):
+ if endpoint_type in ASYNC_REQUEST_FUNCS:
+ request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
+ else:
+ raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
+
+ # Reuses connections across requests to reduce TLS handshake overhead.
+ connector = aiohttp.TCPConnector(
+ limit=max_concurrency or 0,
+ limit_per_host=max_concurrency or 0,
+ ttl_dns_cache=300,
+ use_dns_cache=True,
+ keepalive_timeout=60,
+ enable_cleanup_closed=True,
+ force_close=False,
+ ssl=("https://" in api_url),
+ )
+
+ session = aiohttp.ClientSession(
+ connector=connector,
+ trust_env=True,
+ timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
+ )
+
+ print("Starting initial single prompt test run...")
+ test_prompt, test_prompt_len, test_output_len, test_mm_content = (
+ input_requests[0].prompt,
+ input_requests[0].prompt_len,
+ input_requests[0].expected_output_len,
+ input_requests[0].multi_modal_data,
+ )
+
+ assert (
+ test_mm_content is None
+ or isinstance(test_mm_content, dict)
+ or (
+ isinstance(test_mm_content, list)
+ and all(isinstance(item, dict) for item in test_mm_content)
+ )
+ ), "multi_modal_data must be a dict or list[dict]"
+ test_input = RequestFuncInput(
+ model=model_id,
+ model_name=model_name,
+ prompt=test_prompt,
+ api_url=api_url,
+ prompt_len=test_prompt_len,
+ output_len=test_output_len,
+ logprobs=logprobs,
+ multi_modal_content=test_mm_content,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body,
+ )
+
+ test_output = await wait_for_endpoint(
+ request_func,
+ test_input,
+ session,
+ timeout_seconds=ready_check_timeout_sec,
+ )
+ if not test_output.success:
+ raise ValueError(
+ "Initial test run failed - Please make sure benchmark arguments "
+ f"are correctly specified. Error: {test_output.error}")
+ else:
+ print("Initial test run completed. Starting main benchmark run...")
+
+ if lora_modules:
+ # For each input request, choose a LoRA module at random.
+ lora_modules = iter(
+ [random.choice(lora_modules) for _ in range(len(input_requests))])
+
+ if profile:
+ print("Starting profiler...")
+ profile_input = RequestFuncInput(model=model_id,
+ model_name=model_name,
+ prompt=test_prompt,
+ api_url=base_url + "/start_profile",
+ prompt_len=test_prompt_len,
+ output_len=test_output_len,
+ logprobs=logprobs,
+ multi_modal_content=test_mm_content,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body)
+ profile_output = await request_func(
+ request_func_input=profile_input, session=session)
+ if profile_output.success:
+ print("Profiler started")
+
+ distribution = ("Poisson process" if burstiness == 1.0
+ else "Gamma distribution")
+
+ if ramp_up_strategy is not None:
+ print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
+ print(f"Will increase RPS from {ramp_up_start_rps} to "
+ f"{ramp_up_end_rps} RPS over the duration of the benchmark.")
+ else:
+ print(f"Traffic request rate: {request_rate}")
+
+ print(f"Burstiness factor: {burstiness} ({distribution})")
+ print(f"Maximum request concurrency: {max_concurrency}")
+
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
+
+ # This can be used once the minimum Python version is 3.10 or higher,
+ # and it will simplify the code in limited_request_func.
+ # semaphore = (asyncio.Semaphore(max_concurrency)
+ # if max_concurrency else contextlib.nullcontext())
+ semaphore = (asyncio.Semaphore(max_concurrency)
+ if max_concurrency else None)
+
+ async def limited_request_func(request_func_input, session, pbar):
+ if semaphore is None:
+ return await request_func(request_func_input=request_func_input,
+ session=session,
+ pbar=pbar)
+ async with semaphore:
+ return await request_func(request_func_input=request_func_input,
+ session=session,
+ pbar=pbar)
+
+ if trace_dataset_path is not None:
+ iterator = get_request_trace(input_requests)
+ else:
+ iterator = get_request(
+ input_requests,
+ request_rate,
+ burstiness,
+ ramp_up_strategy,
+ ramp_up_start_rps,
+ ramp_up_end_rps,
+ )
+
+ benchmark_start_time = time.perf_counter()
+ tasks: list[asyncio.Task] = []
+
+ rps_change_events = []
+ last_int_rps = -1
+ if ramp_up_strategy is not None and ramp_up_start_rps is not None:
+ last_int_rps = ramp_up_start_rps
+ rps_change_events.append({
+ "rps": last_int_rps,
+ "timestamp": datetime.now().isoformat(),
+ })
+
+ async for request, current_request_rate in iterator:
+ if ramp_up_strategy is not None:
+ current_int_rps = int(current_request_rate)
+ if current_int_rps > last_int_rps:
+ timestamp = datetime.now().isoformat()
+ for rps_val in range(last_int_rps + 1, current_int_rps + 1):
+ rps_change_events.append({
+ "rps": rps_val,
+ "timestamp": timestamp
+ })
+ last_int_rps = current_int_rps
+ prompt, prompt_len, output_len, mm_content = (
+ request.prompt,
+ request.prompt_len,
+ request.expected_output_len,
+ request.multi_modal_data,
+ )
+ req_model_id, req_model_name = model_id, model_name
+ if lora_modules:
+ req_lora_module = next(lora_modules)
+ req_model_id, req_model_name = req_lora_module, req_lora_module
+
+ request_func_input = RequestFuncInput(model=req_model_id,
+ model_name=req_model_name,
+ prompt=prompt,
+ api_url=api_url,
+ prompt_len=prompt_len,
+ output_len=output_len,
+ logprobs=logprobs,
+ multi_modal_content=mm_content,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body)
+ tasks.append(
+ asyncio.create_task(
+ limited_request_func(request_func_input=request_func_input,
+ session=session,
+ pbar=pbar)))
+ outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
+
+ if pbar is not None:
+ pbar.close()
+
+ benchmark_duration = time.perf_counter() - benchmark_start_time
+
+ metrics, actual_output_lens = calculate_metrics(
+ input_requests=input_requests,
+ outputs=outputs,
+ dur_s=benchmark_duration,
+ tokenizer=tokenizer,
+ selected_percentiles=selected_percentiles,
+ goodput_config_dict=goodput_config_dict,
+ trace_output_path=trace_output_path,
+ )
+
+ print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
+ if max_concurrency is not None:
+ print("{:<40} {:<10}".format("Maximum request concurrency:",
+ max_concurrency))
+ if request_rate != float('inf'):
+ print("{:<40} {:<10.2f}".format("Request rate configured (RPS):",
+ request_rate ))
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
+ benchmark_duration))
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
+ print("{:<40} {:<10}".format("Total generated tokens:",
+ metrics.total_output))
+ print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
+ metrics.request_throughput))
+ if goodput_config_dict:
+ print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
+ metrics.request_goodput))
+ print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
+ metrics.output_throughput))
+ print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
+ metrics.total_token_throughput))
+
+ result = {
+ "duration": benchmark_duration,
+ "completed": metrics.completed,
+ "total_input_tokens": metrics.total_input,
+ "total_output_tokens": metrics.total_output,
+ "request_throughput": metrics.request_throughput,
+ "request_goodput":
+ metrics.request_goodput if goodput_config_dict else None,
+ "output_throughput": metrics.output_throughput,
+ "total_token_throughput": metrics.total_token_throughput,
+ "input_lens": [output.prompt_len for output in outputs],
+ "output_lens": actual_output_lens,
+ "ttfts": [output.ttft for output in outputs],
+ "itls": [output.itl for output in outputs],
+ "generated_texts": [output.generated_text for output in outputs],
+ "errors": [output.error for output in outputs],
+ }
+
+ if rps_change_events:
+ result["rps_change_events"] = rps_change_events
+
+ def process_one_metric(
+ # E.g., "ttft"
+ metric_attribute_name: str,
+ # E.g., "TTFT"
+ metric_name: str,
+ # E.g., "Time to First Token"
+ metric_header: str,
+ ):
+ # This function prints and adds statistics of the specified
+ # metric.
+ if metric_attribute_name not in selected_percentile_metrics:
+ return
+ print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
+ print("{:<40} {:<10.2f}".format(
+ f"Mean {metric_name} (ms):",
+ getattr(metrics, f"mean_{metric_attribute_name}_ms")))
+ print("{:<40} {:<10.2f}".format(
+ f"Median {metric_name} (ms):",
+ getattr(metrics, f"median_{metric_attribute_name}_ms")))
+ result[f"mean_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"mean_{metric_attribute_name}_ms")
+ result[f"median_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"median_{metric_attribute_name}_ms")
+ result[f"std_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"std_{metric_attribute_name}_ms")
+ for p, value in getattr(metrics,
+ f"percentiles_{metric_attribute_name}_ms"):
+ p_word = str(int(p)) if int(p) == p else str(p)
+ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
+ value))
+ result[f"p{p_word}_{metric_attribute_name}_ms"] = value
+
+ process_one_metric("ttft", "TTFT", "Time to First Token")
+ process_one_metric("tpot", "TPOT",
+ "Time per Output Token (excl. 1st token)")
+ process_one_metric("itl", "ITL", "Inter-token Latency")
+ process_one_metric("e2el", "E2EL", "End-to-end Latency")
+
+ print("=" * 50)
+
+ if profile:
+ print("Stopping profiler...")
+ profile_input = RequestFuncInput(
+ model=model_id,
+ prompt=test_prompt,
+ api_url=base_url + "/stop_profile",
+ prompt_len=test_prompt_len,
+ output_len=test_output_len,
+ logprobs=logprobs,
+ )
+ profile_output = await request_func(
+ request_func_input=profile_input, session=session)
+ if profile_output.success:
+ print("Profiler stopped")
+
+ await session.close()
+ return result
+
+
+def check_goodput_args(args):
+ # Check and parse goodput arguments
+ goodput_config_dict = {}
+ VALID_NAMES = ["ttft", "tpot", "e2el"]
+ if args.goodput:
+ goodput_config_dict = parse_goodput(args.goodput)
+ for slo_name, slo_val in goodput_config_dict.items():
+ if slo_name not in VALID_NAMES:
+ raise ValueError(
+ f"Invalid metric name found, {slo_name}: {slo_val}. "
+ "The service level objective name should be one of "
+ f"{str(VALID_NAMES)}. ")
+ if slo_val < 0:
+ raise ValueError(
+ f"Invalid value found, {slo_name}: {slo_val}. "
+ "The service level objective value should be "
+ "non-negative.")
+ return goodput_config_dict
+
+
+def parse_goodput(slo_pairs):
+ goodput_config_dict = {}
+ try:
+ for slo_pair in slo_pairs:
+ slo_name, slo_val = slo_pair.split(":")
+ goodput_config_dict[slo_name] = float(slo_val)
+ except ValueError as err:
+ raise argparse.ArgumentTypeError(
+ "Invalid format found for service level objectives. "
+ "Specify service level objectives for goodput as \"KEY:VALUE\" "
+ "pairs, where the key is a metric name, and the value is a "
+ "number in milliseconds.") from err
+ return goodput_config_dict
+
+
+def save_to_pytorch_benchmark_format(args: argparse.Namespace,
+ results: dict[str, Any],
+ file_name: str) -> None:
+ metrics = [
+ "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
+ "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
+ "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
+ ]
+ # These raw data might be useful, but they are rather big. They can be added
+ # later if needed
+ ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
+ pt_records = convert_to_pytorch_benchmark_format(
+ args=args,
+ metrics={k: [results[k]]
+ for k in metrics if k in results},
+ extra_info={
+ k: results[k]
+ for k in results if k not in metrics and k not in ignored_metrics
+ })
+ if pt_records:
+ # Don't use json suffix here as we don't want CI to pick it up
+ pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
+ write_to_json(pt_file, pt_records)
+
+
+def add_cli_args(parser: argparse.ArgumentParser):
+ add_dataset_parser(parser)
+ parser.add_argument(
+ "--endpoint-type",
+ type=str,
+ default="openai",
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ )
+ parser.add_argument(
+ "--label",
+ type=str,
+ default=None,
+ help="The label (prefix) of the benchmark results. If not specified, "
+ "the endpoint type will be used as the label.",
+ )
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="vllm",
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ )
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default=None,
+ help="Server or API base url if not using http host and port.",
+ )
+ # Use 127.0.0.1 here instead of localhost to force the use of ipv4
+ parser.add_argument("--host", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default="/v1/completions",
+ help="API endpoint.",
+ )
+ parser.add_argument(
+ "--max-concurrency",
+ type=int,
+ default=None,
+ help="Maximum number of concurrent requests. This can be used "
+ "to help simulate an environment where a higher level component "
+ "is enforcing a maximum number of concurrent requests. While the "
+ "--request-rate argument controls the rate at which requests are "
+ "initiated, this argument will control how many are actually allowed "
+ "to execute at a time. This means that when used in combination, the "
+ "actual request rate may be lower than specified with --request-rate, "
+ "if the server is not processing requests fast enough to keep up.")
+
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Name of the model.",
+ )
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ help=
+ "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
+ )
+ parser.add_argument("--use-beam-search", action="store_true")
+ parser.add_argument(
+ "--logprobs",
+ type=int,
+ default=None,
+ help=("Number of logprobs-per-token to compute & return as part of "
+ "the request. If unspecified, then either (1) if beam search "
+ "is disabled, no logprobs are computed & a single dummy "
+ "logprob is returned for each token; or (2) if beam search "
+ "is enabled 1 logprob per token is computed"),
+ )
+ parser.add_argument(
+ "--request-rate",
+ type=float,
+ default=float("inf"),
+ help="Number of requests per second. If this is inf, "
+ "then all the requests are sent at time 0. "
+ "Otherwise, we use Poisson process or gamma distribution "
+ "to synthesize the request arrival times.",
+ )
+ parser.add_argument(
+ "--burstiness",
+ type=float,
+ default=1.0,
+ help="Burstiness factor of the request generation. "
+ "Only take effect when request_rate is not inf. "
+ "Default value is 1, which follows Poisson process. "
+ "Otherwise, the request intervals follow a gamma distribution. "
+ "A lower burstiness value (0 < burstiness < 1) results in more "
+ "bursty requests. A higher burstiness value (burstiness > 1) "
+ "results in a more uniform arrival of requests.",
+ )
+ parser.add_argument(
+ "--trust-remote-code",
+ action="store_true",
+ help="Trust remote code from huggingface",
+ )
+ parser.add_argument(
+ "--disable-tqdm",
+ action="store_true",
+ help="Specify to disable tqdm progress bar.",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ help="Use Torch Profiler. The endpoint must be launched with "
+ "VLLM_TORCH_PROFILER_DIR to enable profiler.",
+ )
+ parser.add_argument(
+ "--save-result",
+ action="store_true",
+ help="Specify to save benchmark results to a json file",
+ )
+ parser.add_argument(
+ "--save-detailed",
+ action="store_true",
+ help="When saving the results, whether to include per request "
+ "information such as response, error, ttfs, tpots, etc.",
+ )
+ parser.add_argument(
+ "--append-result",
+ action="store_true",
+ help="Append the benchmark result to the existing json file.",
+ )
+ parser.add_argument(
+ "--metadata",
+ metavar="KEY=VALUE",
+ nargs="*",
+ help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
+ "for metadata of this run to be saved in the result JSON file "
+ "for record keeping purposes.",
+ )
+ parser.add_argument(
+ "--result-dir",
+ type=str,
+ default=None,
+ help="Specify directory to save benchmark json results."
+ "If not specified, results are saved in the current directory.",
+ )
+ parser.add_argument(
+ "--result-filename",
+ type=str,
+ default=None,
+ help="Specify the filename to save benchmark json results."
+ "If not specified, results will be saved in "
+ "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" # noqa
+ " format.",
+ )
+ parser.add_argument(
+ "--ignore-eos",
+ action="store_true",
+ help="Set ignore_eos flag when sending the benchmark request."
+ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
+ parser.add_argument(
+ "--percentile-metrics",
+ type=str,
+ default="ttft,tpot,itl",
+ help="Comma-separated list of selected metrics to report percentils. "
+ "This argument specifies the metrics to report percentiles. "
+ "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ")
+ parser.add_argument(
+ "--metric-percentiles",
+ type=str,
+ default="99",
+ help="Comma-separated list of percentiles for selected metrics. "
+ "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
+ "Default value is \"99\"."
+ "Use \"--percentile-metrics\" to select metrics.",
+ )
+ parser.add_argument(
+ "--goodput",
+ nargs="+",
+ required=False,
+ help="Specify service level objectives for goodput as \"KEY:VALUE\" "
+ "pairs, where the key is a metric name, and the value is in "
+ "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
+ "separated by spaces. Allowed request level metric names are "
+ "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
+ "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
+ "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
+ )
+
+ sampling_group = parser.add_argument_group("sampling parameters")
+ sampling_group.add_argument(
+ "--top-p",
+ type=float,
+ default=None,
+ help="Top-p sampling parameter. Only has effect on "
+ "openai-compatible backends.",
+ )
+ sampling_group.add_argument(
+ "--top-k",
+ type=int,
+ default=None,
+ help="Top-k sampling parameter. Only has effect on "
+ "openai-compatible backends.",
+ )
+ sampling_group.add_argument(
+ "--min-p",
+ type=float,
+ default=None,
+ help="Min-p sampling parameter. Only has effect on "
+ "openai-compatible backends.",
+ )
+ sampling_group.add_argument(
+ "--temperature",
+ type=float,
+ default=None,
+ help="Temperature sampling parameter. Only has effect on "
+ "openai-compatible backends. If not specified, default to greedy "
+ "decoding (i.e. temperature==0.0).",
+ )
+
+ parser.add_argument(
+ '--tokenizer-mode',
+ type=str,
+ default="auto",
+ choices=['auto', 'slow', 'mistral', 'custom'],
+ help='The tokenizer mode.\n\n* "auto" will use the '
+ 'fast tokenizer if available.\n* "slow" will '
+ 'always use the slow tokenizer. \n* '
+ '"mistral" will always use the `mistral_common` tokenizer. \n*'
+ '"custom" will use --tokenizer to select the preregistered tokenizer.')
+
+ parser.add_argument("--served-model-name",
+ type=str,
+ default=None,
+ help="The model name used in the API. "
+ "If not specified, the model name will be the "
+ "same as the ``--model`` argument. ")
+
+ parser.add_argument("--lora-modules",
+ nargs='+',
+ default=None,
+ help="A subset of LoRA module names passed in when "
+ "launching the server. For each request, the "
+ "script chooses a LoRA module at random.")
+
+ parser.add_argument(
+ "--ramp-up-strategy",
+ type=str,
+ default=None,
+ choices=["linear", "exponential"],
+ help="The ramp-up strategy. This would be used to "
+ "ramp up the request rate from initial RPS to final "
+ "RPS rate (specified by --ramp-up-start-rps and "
+ "--ramp-up-end-rps.) over the duration of the benchmark."
+ )
+ parser.add_argument(
+ "--ramp-up-start-rps",
+ type=int,
+ default=None,
+ help="The starting request rate for ramp-up (RPS). "
+ "Needs to be specified when --ramp-up-strategy is used.",
+ )
+ parser.add_argument(
+ "--ramp-up-end-rps",
+ type=int,
+ default=None,
+ help="The ending request rate for ramp-up (RPS). "
+ "Needs to be specified when --ramp-up-strategy is used.",
+ )
+ parser.add_argument(
+ "--ready-check-timeout-sec",
+ type=int,
+ default=600,
+ help="Maximum time to wait for the endpoint to become ready "
+ "in seconds (default: 600 seconds / 10 minutes).",
+ )
+ parser.add_argument("--trace-dataset-path", type=str, default=None)
+ parser.add_argument("--trace-output-path", type=str, default=None)
+
+
+def main(args: argparse.Namespace) -> dict[str, Any]:
+ return asyncio.run(main_async(args))
+
+async def main_async(args: argparse.Namespace) -> dict[str, Any]:
+ print(args)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ # Validate ramp-up arguments
+ if args.ramp_up_strategy is not None:
+ if args.request_rate != float("inf"):
+ raise ValueError(
+ "When using ramp-up, do not specify --request-rate. "
+ "The request rate will be controlled by ramp-up parameters. "
+ "Please remove the --request-rate argument."
+ )
+ if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
+ raise ValueError(
+ "When using --ramp-up-strategy, both --ramp-up-start-rps and "
+ "--ramp-up-end-rps must be specified"
+ )
+ if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
+ raise ValueError("Ramp-up start and end RPS must be non-negative")
+ if args.ramp_up_start_rps > args.ramp_up_end_rps:
+ raise ValueError("Ramp-up start RPS must be less than end RPS")
+ if (args.ramp_up_strategy == "exponential"
+ and args.ramp_up_start_rps == 0):
+ raise ValueError(
+ "For exponential ramp-up, the start RPS cannot be 0.")
+
+ endpoint_type = args.endpoint_type
+ label = args.label
+ model_id = args.model
+ model_name = args.served_model_name
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
+ tokenizer_mode = args.tokenizer_mode
+
+ if args.base_url is not None:
+ api_url = f"{args.base_url}{args.endpoint}"
+ base_url = f"{args.base_url}"
+ else:
+ api_url = f"http://{args.host}:{args.port}{args.endpoint}"
+ base_url = f"http://{args.host}:{args.port}"
+
+ tokenizer = get_tokenizer(tokenizer_id,
+ tokenizer_mode=tokenizer_mode,
+ trust_remote_code=args.trust_remote_code)
+
+ if args.dataset_name is None:
+ raise ValueError(
+ "Please specify '--dataset-name' and the corresponding "
+ "'--dataset-path' if required.")
+
+ # Load the dataset.
+ input_requests = get_samples(args, tokenizer)
+ goodput_config_dict = check_goodput_args(args)
+
+ # Collect the sampling parameters.
+ sampling_params = {
+ k: v
+ for k, v in {
+ "top_p": args.top_p,
+ "top_k": args.top_k,
+ "min_p": args.min_p,
+ "temperature": args.temperature,
+ }.items() if v is not None
+ }
+
+ # Sampling parameters are only supported by openai-compatible backend.
+ if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
+ raise ValueError("Sampling parameters are only supported by "
+ "openai-compatible backends.")
+
+ if "temperature" not in sampling_params:
+ sampling_params["temperature"] = 0.0 # Default to greedy decoding.
+
+ # Avoid GC processing "static" data - reduce pause times.
+ gc.collect()
+ gc.freeze()
+
+ benchmark_result = await benchmark(
+ endpoint_type=args.endpoint_type,
+ api_url=api_url,
+ base_url=base_url,
+ model_id=model_id,
+ model_name=model_name,
+ tokenizer=tokenizer,
+ input_requests=input_requests,
+ logprobs=args.logprobs,
+ request_rate=args.request_rate,
+ burstiness=args.burstiness,
+ disable_tqdm=args.disable_tqdm,
+ profile=args.profile,
+ selected_percentile_metrics=args.percentile_metrics.split(","),
+ selected_percentiles=[
+ float(p) for p in args.metric_percentiles.split(",")
+ ],
+ ignore_eos=args.ignore_eos,
+ goodput_config_dict=goodput_config_dict,
+ max_concurrency=args.max_concurrency,
+ lora_modules=args.lora_modules,
+ extra_body=sampling_params,
+ ramp_up_strategy=args.ramp_up_strategy,
+ ramp_up_start_rps=args.ramp_up_start_rps,
+ ramp_up_end_rps=args.ramp_up_end_rps,
+ ready_check_timeout_sec=args.ready_check_timeout_sec,
+ trace_dataset_path=args.trace_dataset_path,
+ trace_output_path=args.trace_output_path,
+ )
+
+ # Save config and results to json
+ result_json: dict[str, Any] = {}
+
+ # Setup
+ current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
+ result_json["date"] = current_dt
+ result_json["endpoint_type"] = args.endpoint_type
+ result_json["label"] = label
+ result_json["model_id"] = model_id
+ result_json["tokenizer_id"] = tokenizer_id
+ result_json["num_prompts"] = args.num_prompts
+
+ # Metadata
+ if args.metadata:
+ for item in args.metadata:
+ if "=" in item:
+ kvstring = item.split("=")
+ result_json[kvstring[0].strip()] = kvstring[1].strip()
+ else:
+ raise ValueError(
+ "Invalid metadata format. Please use KEY=VALUE format."
+ )
+
+ # Traffic
+ result_json["request_rate"] = (args.request_rate if args.request_rate
+ < float("inf") else "inf")
+ result_json["burstiness"] = args.burstiness
+ result_json["max_concurrency"] = args.max_concurrency
+
+ if args.ramp_up_strategy is not None:
+ result_json["ramp_up_strategy"] = args.ramp_up_strategy
+ result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
+ result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
+
+ # Merge with benchmark result
+ result_json = {**result_json, **benchmark_result}
+
+ if not args.save_detailed:
+ # Remove fields with too many data points
+ for field in [
+ "input_lens",
+ "output_lens",
+ "ttfts",
+ "itls",
+ "generated_texts",
+ "errors",
+ ]:
+ if field in result_json:
+ del result_json[field]
+ if field in benchmark_result:
+ del benchmark_result[field]
+
+ # Save to file
+ if args.save_result or args.append_result:
+ base_model_id = model_id.split("/")[-1]
+ max_concurrency_str = (f"-concurrency{args.max_concurrency}"
+ if args.max_concurrency is not None else "")
+ label = label or endpoint_type
+ if args.ramp_up_strategy is not None:
+ file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
+ else:
+ file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
+ if args.result_filename:
+ file_name = args.result_filename
+ if args.result_dir:
+ os.makedirs(args.result_dir, exist_ok=True)
+ file_name = os.path.join(args.result_dir, file_name)
+ with open(file_name,
+ mode="a+" if args.append_result else "w",
+ encoding="utf-8") as outfile:
+ # Append a newline.
+ if args.append_result and outfile.tell() != 0:
+ outfile.write("\n")
+ json.dump(result_json, outfile)
+ save_to_pytorch_benchmark_format(args, result_json, file_name)
+
+ return result_json
diff --git a/benchmark/reproducibility/plot.py b/benchmark/reproducibility/plot.py
new file mode 100644
index 000000000..8e286b71b
--- /dev/null
+++ b/benchmark/reproducibility/plot.py
@@ -0,0 +1,39 @@
+import pandas as pd
+import matplotlib.pyplot as plt
+
+# modify the path to your CSV file as needed
+df = pd.read_csv("code_output.csv")
+
+request = df["request"]
+timestamp = df["timestamp"]
+input_length = df["input_len"]
+output_length = df["output_len"]
+ttft = df["ttft_ms"]
+tpot = df["tpot_ms"]
+latency = df["e2el_ms"]
+
+# TTFT
+plt.figure()
+plt.plot(ttft)
+plt.xlabel("Request Id")
+plt.ylabel("Milliseconds")
+plt.title("TTFT")
+plt.savefig("ttft.png")
+
+# TPOT
+plt.figure()
+plt.plot(tpot, 'o')
+plt.yscale('log')
+plt.xlabel("Request Id")
+plt.ylabel("Milliseconds")
+plt.title("TPOT")
+plt.savefig("tpot.png")
+
+# Latency
+plt.figure()
+plt.plot(latency, 'o')
+plt.xlabel("Request Id")
+plt.ylabel("Milliseconds")
+plt.title("Completion Time")
+plt.savefig("latency.png")
+
diff --git a/benchmark/reproducibility/read_trace.py b/benchmark/reproducibility/read_trace.py
new file mode 100644
index 000000000..134809b59
--- /dev/null
+++ b/benchmark/reproducibility/read_trace.py
@@ -0,0 +1,185 @@
+import json
+import csv
+from datetime import datetime
+from collections import defaultdict
+
+def parse_timestamp(timestamp_str):
+ """Parse timestamp string and return datetime object."""
+ try:
+ # Truncate microseconds to 6 digits if longer
+ if '.' in timestamp_str:
+ parts = timestamp_str.split('.')
+ if len(parts[1]) > 6:
+ timestamp_str = parts[0] + '.' + parts[1][:6]
+ return datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S.%f")
+ except ValueError:
+ # Try without microseconds if the format doesn't match
+ return datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
+
+def convert_to_relative_ms(timestamps):
+ """Convert list of timestamp strings to milliseconds relative to first entry."""
+ if not timestamps:
+ return []
+
+ first_time = parse_timestamp(timestamps[0])
+ relative_ms = []
+
+ for ts in timestamps:
+ current_time = parse_timestamp(ts)
+ delta = (current_time - first_time).total_seconds() * 1000
+ relative_ms.append(delta)
+
+ return relative_ms
+
+def read_csv(file_path):
+ data = []
+ timestamps = []
+ with open(file_path, 'r') as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ timestamps.append(row.get("TIMESTAMP"))
+ entry = {
+ "timestamp": row.get("TIMESTAMP"),
+ "input_length": int(row.get("ContextTokens")),
+ "output_length": int(row.get("GeneratedTokens"))
+ }
+ data.append(entry)
+
+ # Convert timestamps to relative milliseconds
+ relative_ms = convert_to_relative_ms(timestamps)
+ for i, entry in enumerate(data):
+ entry["timestamp"] = relative_ms[i]
+
+ return data
+
+def read_jsonl(file_path):
+ data = []
+ with open(file_path, 'r') as f:
+ for line in f:
+ obj = json.loads(line)
+ # Extract required columns
+ entry = {
+ "timestamp": obj.get("timestamp"),
+ "input_length": obj.get("input_length"),
+ "output_length": obj.get("output_length")
+ }
+ data.append(entry)
+ return data
+
+# Example usage:
+# file_path = "synthetic_trace.jsonl"
+# file_path = "toolagent_trace.jsonl"
+file_path = "conversation_trace.jsonl"
+records = read_jsonl(file_path)
+
+# file_path = "AzureLLMInferenceTrace_conv.csv"
+# file_path = "AzureLLMInferenceTrace_code.csv"
+# records = read_csv(file_path)
+
+# file_path = "AzureLLMInferenceTrace_code_1min_section.jsonl"
+# records = read_jsonl(file_path)
+
+
+# threshold_low = 2 * 60 * 1000
+# threshold_high = 17 * 60 * 1000
+threshold_low = 0 * 60 * 1000
+threshold_high = 15 * 60 * 1000
+records = [record for record in records if threshold_low <= record["timestamp"] <= threshold_high]
+
+# print(f"Filtered records between {threshold_low} ms and {threshold_high} ms: {len(records)} entries.")
+
+num_prompts = len(records)
+timestamp = [int(record["timestamp"]) for record in records]
+timestamp = [ts - timestamp[0] for ts in timestamp] # Normalize to start from 0 ms
+input_length = [record["input_length"] for record in records]
+output_length = [record["output_length"] for record in records]
+
+
+# Calculate sum of input_length per batch (prompts with same timestamp)
+
+batch_sums = defaultdict(int)
+batch_counts = defaultdict(int)
+
+for req in range(num_prompts):
+ ts = timestamp[req]
+ batch_sums[ts] += input_length[req]
+ batch_counts[ts] += 1
+
+# Print batch statistics
+print(f"\nBatch statistics (prompts with same timestamp):")
+print(f"Number of unique timestamps (batches): {len(batch_sums)}")
+for ts in sorted(batch_sums.keys()):
+ print(f"Timestamp {ts} ms: {batch_counts[ts]} prompts, total input_length = {batch_sums[ts]}")
+print(f"max input_length in a batch: {max(batch_sums.values())}\n")
+
+# # Write to jsonl file
+# output_file = file_path.replace('.jsonl', '_processed_15mins.jsonl').replace('.csv', '_processed_15mins.jsonl')
+# with open(output_file, 'w') as f:
+# for req in range(num_prompts):
+# entry = {
+# "timestamp": timestamp[req],
+# "input_length": input_length[req],
+# "output_length": output_length[req]
+# }
+# f.write(json.dumps(entry) + '\n')
+print(f"timestamp input_length output_length")
+for req in range(num_prompts):
+ # print(f"Record {req}: Timestamp={timestamp[req]}, Input Length={input_length[req]}, Output Length={output_length[req]}")
+ print(f"{timestamp[req]} {input_length[req]} {output_length[req]}")
+
+import matplotlib.pyplot as plt
+
+window_size = 80000 # in ms
+request_rate = []
+time_axis = [start_time/1000 for start_time in range(0, max(timestamp)+window_size, window_size)]
+for start_time in range(0, max(timestamp)+window_size, window_size):
+ new_requests = [sample for sample in timestamp if start_time <= sample < start_time+window_size]
+ request_rate.append(len(new_requests)/(window_size/1000))
+
+plt.figure(figsize=(10, 6))
+plt.plot(time_axis, request_rate, linestyle='-')
+plt.xlabel('Time (s)')
+plt.ylabel('Request Rate (requests/second)')
+plt.title(f'Request Rate over Time {file_path}')
+plt.grid(True)
+plt.tight_layout()
+plt.show()
+
+# plt.figure(figsize=(10, 6))
+# plt.plot(range(num_prompts), input_length, linestyle='-')
+# plt.xlabel('Request Index')
+# plt.ylabel('Input Length')
+# plt.title(f'Input Length over Requests {file_path}')
+# plt.grid(True)
+# plt.tight_layout()
+# plt.show()
+
+plt.figure(figsize=(10, 6))
+plt.plot(range(num_prompts), output_length, linestyle='-')
+plt.xlabel('Request Index')
+plt.ylabel('Output Length')
+plt.title(f'Output Length over Requests {file_path}')
+plt.grid(True)
+plt.tight_layout()
+plt.show()
+
+plt.figure(figsize=(10, 6))
+plt.plot(range(num_prompts), timestamp, linestyle='-')
+plt.xlabel('Request Index')
+plt.ylabel('Timestamp')
+plt.title(f'Request Timestamps {file_path}')
+plt.grid(True)
+plt.tight_layout()
+plt.show()
+
+plt.figure(figsize=(10, 6))
+plt.scatter(output_length, input_length, alpha=0.5)
+plt.xlabel('Output Length')
+plt.ylabel('Input Length')
+plt.yscale('log')
+plt.title(f'Input Length vs Output Length {file_path}')
+plt.grid(True, alpha=0.3)
+plt.tight_layout()
+plt.show()
+
+print(f"Total number of prompts: {num_prompts} in {file_path}")
\ No newline at end of file
diff --git a/benchmark/reproducibility/server.sh b/benchmark/reproducibility/server.sh
new file mode 100644
index 000000000..ab5be3048
--- /dev/null
+++ b/benchmark/reproducibility/server.sh
@@ -0,0 +1,14 @@
+
+export VLLM_DISABLE_COMPILE_CACHE=1
+
+# two models, three parallelisms, six experiments
+
+model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic"
+ARCTIC_INFERENCE_ENABLED=1 vllm serve $model --disable-log-requests --no-enable-prefix-caching --ulysses-sequence-parallel-size 8 --enable-shift-parallel --max-num-batched-tokens 131072
+# vllm serve $model --disable-log-requests --no-enable-prefix-caching --tensor-parallel-size 8 --max-num-batched-tokens 131072
+# vllm serve $model --disable-log-requests --no-enable-prefix-caching --data-parallel-size 8 --max-num-batched-tokens 131072
+
+# model="Qwen/Qwen3-32B-FP8"
+# ARCTIC_INFERENCE_ENABLED=1 vllm serve $model --disable-log-requests --no-enable-prefix-caching --ulysses-sequence-parallel-size 8 --enable-shift-parallel --max-num-batched-tokens 131072 --kv-cache-dtype fp8
+# vllm serve $model --disable-log-requests --no-enable-prefix-caching --tensor-parallel-size 8 --max-num-batched-tokens 131072 --kv-cache-dtype fp8
+# vllm serve $model --disable-log-requests --no-enable-prefix-caching --data-parallel-size 8 --max-num-batched-tokens 131072 --kv-cache-dtype fp8
diff --git a/benchmark/reproducibility/test.jsonl b/benchmark/reproducibility/test.jsonl
new file mode 100644
index 000000000..e5566daf6
--- /dev/null
+++ b/benchmark/reproducibility/test.jsonl
@@ -0,0 +1,26 @@
+{"timestamp": 0, "input_length": 6758, "output_length": 500}
+{"timestamp": 0, "input_length": 7322, "output_length": 490}
+{"timestamp": 0, "input_length": 7236, "output_length": 794}
+{"timestamp": 0, "input_length": 2290, "output_length": 316}
+{"timestamp": 0, "input_length": 6760, "output_length": 3}
+{"timestamp": 0, "input_length": 4834, "output_length": 173}
+{"timestamp": 0, "input_length": 23141, "output_length": 453}
+{"timestamp": 0, "input_length": 26888, "output_length": 458}
+{"timestamp": 0, "input_length": 10498, "output_length": 402}
+{"timestamp": 0, "input_length": 17450, "output_length": 610}
+{"timestamp": 3000, "input_length": 13544, "output_length": 71}
+{"timestamp": 3000, "input_length": 87169, "output_length": 402}
+{"timestamp": 3000, "input_length": 6324, "output_length": 548}
+{"timestamp": 3000, "input_length": 2012, "output_length": 354}
+{"timestamp": 3000, "input_length": 7324, "output_length": 14}
+{"timestamp": 3000, "input_length": 9418, "output_length": 145}
+{"timestamp": 3000, "input_length": 915, "output_length": 355}
+{"timestamp": 3000, "input_length": 12846, "output_length": 466}
+{"timestamp": 3000, "input_length": 20506, "output_length": 929}
+{"timestamp": 3000, "input_length": 16609, "output_length": 349}
+{"timestamp": 3000, "input_length": 26353, "output_length": 370}
+{"timestamp": 3000, "input_length": 6059, "output_length": 475}
+{"timestamp": 3000, "input_length": 5954, "output_length": 420}
+{"timestamp": 3000, "input_length": 11339, "output_length": 848}
+{"timestamp": 3000, "input_length": 15172, "output_length": 80}
+{"timestamp": 3000, "input_length": 45922, "output_length": 265}
diff --git a/benchmark/reproducibility/vibe_test.py b/benchmark/reproducibility/vibe_test.py
new file mode 100644
index 000000000..46a15d459
--- /dev/null
+++ b/benchmark/reproducibility/vibe_test.py
@@ -0,0 +1,25 @@
+import vllm
+from vllm import LLM, SamplingParams
+
+vllm.plugins.load_general_plugins()
+
+llm = LLM(
+ model="RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic",
+ ulysses_sequence_parallel_size=8,
+ enable_shift_parallel=True,
+ shift_parallel_threshold=8,
+)
+
+conversation = [
+ {
+ "role": "user",
+ "content": "Write an essay about the importance of higher education.",
+ },
+]
+
+sampling_params = SamplingParams(temperature=0.0, max_tokens=800)
+
+outputs = llm.chat(conversation, sampling_params=sampling_params)
+
+print(outputs[0].outputs[0].text)
+
diff --git a/benchmark/rollout/README.md b/benchmark/rollout/README.md
new file mode 100644
index 000000000..96fc91d47
--- /dev/null
+++ b/benchmark/rollout/README.md
@@ -0,0 +1,51 @@
+## Rollout Replay Patch for v0.14.1
+
+This patch extencds SamplingParams to specify the length of each sequence when n > 1.
+
+The patch is applied as `source patch_sampling.sh`.
+
+As a result, you can specify `max_tokens_n` as a list in sampling params and set `ignore_eos` so that each sequence generates exactly specified number of tokens.
+
+```
+# Sample prompts.
+prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ # "The future of AI is",
+]
+
+sampling_params = [SamplingParams(n=2,
+ temperature=0.8,
+ top_p=1.0,
+ max_tokens_n=[25, 50],
+ ignore_eos=True,
+ ),
+ SamplingParams(n=3,
+ temperature=0.8,
+ top_p=1.0,
+ max_tokens_n=[5, 10, 15],
+ ignore_eos=True,
+ ),
+ SamplingParams(n=1,
+ temperature=0.8,
+ top_p=1.0,
+ max_tokens=100,
+ ignore_eos=True,
+ # max_tokens_n=[100], this will be ineffective since n = 1
+ ),
+ ]
+
+outputs = llm.generate(prompts, sampling_params=sampling_params)
+```
+
+The number of resulting input and output tokens per sequence:
+```
+prompt 0 seq 0: input 5 output 25
+prompt 0 seq 1: input 5 output 50
+prompt 1 seq 0: input 7 output 5
+prompt 1 seq 1: input 7 output 10
+prompt 1 seq 2: input 7 output 15
+prompt 2 seq 0: input 5 output 100
+```
+
diff --git a/benchmark/rollout/parallel_sampling.patch b/benchmark/rollout/parallel_sampling.patch
new file mode 100644
index 000000000..8c7bbb7a5
--- /dev/null
+++ b/benchmark/rollout/parallel_sampling.patch
@@ -0,0 +1,18 @@
+--- /home/yak/myenv_v14/lib/python3.12/site-packages/vllm/v1/engine/parallel_sampling.py 2026-01-31 22:11:45.340329657 +0000
++++ parallel_sampling.py 2026-01-31 22:08:27.000000000 +0000
+@@ -66,11 +66,13 @@
+ Child `sampling_params` instance.
+ """
+ seed = self.sampling_params.seed
+- if self.cached_child_sampling_params:
++ # if self.cached_child_sampling_params:
+ # Reuse child sampling_params data structure
+- return self.cached_child_sampling_params
++ # return self.cached_child_sampling_params
+ # Build child sampling_params
+ child_sampling_params = copy(self.sampling_params)
++ print(f"child index: {index}, max_tokens_n: {child_sampling_params.max_tokens_n}")
++ child_sampling_params.max_tokens = child_sampling_params.max_tokens_n[index]
+ child_sampling_params.n = 1
+ if seed is None:
+ # Cache child sampling_params for later reuse
diff --git a/benchmark/rollout/patch_sampling.sh b/benchmark/rollout/patch_sampling.sh
new file mode 100755
index 000000000..531bed562
--- /dev/null
+++ b/benchmark/rollout/patch_sampling.sh
@@ -0,0 +1,12 @@
+
+VLLM_PATH="$(pip show vllm | awk '/^Location: /{print $2}')"
+
+if [ -z "$VLLM_PATH" ]; then
+ echo "Error: could not find VLLM in current env"
+ exit 1
+else
+ echo "VLLM path is: $VLLM_PATH"
+fi
+
+patch $VLLM_PATH/vllm/sampling_params.py < sampling_params.patch
+patch $VLLM_PATH/vllm/v1/engine/parallel_sampling.py < parallel_sampling.patch
diff --git a/benchmark/rollout/sampling_params.patch b/benchmark/rollout/sampling_params.patch
new file mode 100644
index 000000000..7647b8b6d
--- /dev/null
+++ b/benchmark/rollout/sampling_params.patch
@@ -0,0 +1,26 @@
+--- /home/yak/myenv_v14/lib/python3.12/site-packages/vllm/sampling_params.py 2026-01-31 22:11:45.180329435 +0000
++++ sampling_params.py 2026-01-31 22:16:23.000000000 +0000
+@@ -124,6 +124,7 @@
+ """
+
+ n: int = 1
++ max_tokens_n: list[int] | None = None
+ """Number of outputs to return for the given prompt request.
+
+ NOTE:
+@@ -250,6 +251,7 @@
+ @staticmethod
+ def from_optional(
+ n: int | None = 1,
++ max_tokens_n: list[int] | None = None,
+ presence_penalty: float | None = 0.0,
+ frequency_penalty: float | None = 0.0,
+ repetition_penalty: float | None = 1.0,
+@@ -289,6 +291,7 @@
+
+ return SamplingParams(
+ n=1 if n is None else n,
++ max_tokens_n=max_tokens_n,
+ presence_penalty=0.0 if presence_penalty is None else presence_penalty,
+ frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
+ repetition_penalty=1.0
diff --git a/benchmark/trace/README.md b/benchmark/trace/README.md
new file mode 100644
index 000000000..29b0b9c73
--- /dev/null
+++ b/benchmark/trace/README.md
@@ -0,0 +1,18 @@
+The patch allows running datasets with the following json format:
+```jsonl
+{"timestamp": 15, "input_length": 1000, "output_length": 128}
+...
+```
+where each line corresponds to a request with random prompt to be sent at the timestamp (in ms) from t = 0.
+
+The patch is applied as
+```
+bash apply_vllm_bench_patch_v10p1.sh
+```
+This will modify a couple of files of `vllm bench serve`. Please make sure vllm versions match.
+
+Then use serve as
+```
+vllm bench serve --model $model --trace-dataset-path example_trace.jsonl --ignore-eos
+```
+
diff --git a/benchmark/trace/apply_vllm_bench_patch_v10p1.sh b/benchmark/trace/apply_vllm_bench_patch_v10p1.sh
new file mode 100644
index 000000000..397135015
--- /dev/null
+++ b/benchmark/trace/apply_vllm_bench_patch_v10p1.sh
@@ -0,0 +1,115 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+SERVE_PY="/home/yak/myvenv/lib/python3.10/site-packages/vllm/benchmarks/serve.py"
+DATASETS_PY="/home/yak/myvenv/lib/python3.10/site-packages/vllm/benchmarks/datasets.py"
+
+patch "$DATASETS_PY" << 'EOF'
+79d78
+< trace_timestamp: int = 0
+338d336
+< trace_dataset_path = None,
+353,373c351,367
+< print(f"RandomDataset sample {trace_dataset_path}")
+<
+< events = []
+< with open(trace_dataset_path, "r") as f:
+< for line in f:
+< if not line.strip():
+< continue
+< obj = json.loads(line)
+< timestamp = obj["timestamp"]
+< input_length = obj["input_length"]
+< output_length = obj["output_length"]
+< print(f"read trace timestamp {timestamp} input_length {input_length} output_length {output_length}")
+< events.append((timestamp, input_length, output_length))
+< # Ensure chronological order
+< events.sort(key=lambda x: x[0])
+< print(f"events {events}")
+<
+< num_requests = len(events)
+< timestamps = [i[0] for i in events]
+< input_lens = [i[1] for i in events]
+< output_lens = [i[2] for i in events]
+---
+> # New sampling logic: [X * (1 - b), X * (1 + b)]
+> input_low = int(real_input_len * (1 - range_ratio))
+> input_high = int(real_input_len * (1 + range_ratio))
+> output_low = int(output_len * (1 - range_ratio))
+> output_high = int(output_len * (1 + range_ratio))
+>
+> # Add logging for debugging
+> logger.info(
+> "Sampling input_len from [%s, %s] and output_len from [%s, %s]",
+> input_low, input_high, output_low, output_high)
+>
+> input_lens = np.random.randint(input_low,
+> input_high + 1,
+> size=num_requests)
+> output_lens = np.random.randint(output_low,
+> output_high + 1,
+> size=num_requests)
+400d393
+< trace_timestamp=timestamps[i]
+765d757
+< trace_dataset_path=args.trace_dataset_path,
+EOF
+
+patch "$SERVE_PY" << 'EOF'
+101,116d100
+< async def get_request_trace(input_requests: list[SampleRequest]) -> AsyncGenerator[tuple[SampleRequest, float], None]:
+<
+< print(f"get_request_trace")
+<
+< prev_ts = 0
+< for i in range(len(input_requests)):
+< request = input_requests[i]
+< curr_ts = request.trace_timestamp
+< delay = curr_ts - prev_ts
+< print(f"request {i} prompt_len {request.prompt_len} expected_output_len {request.expected_output_len} timestamp {request.trace_timestamp} delay {delay} ms")
+< if delay > 0:
+< await asyncio.sleep(delay/1000)
+< prev_ts = curr_ts
+< yield request, 1
+<
+<
+267,270d250
+< print(f"requsest, timestamp, input_len, output_len, ttft_ms, tpot_ms, e2el_ms")
+< for i in range(len(outputs)):
+< print(f"{i} {input_requests[i].trace_timestamp} {input_requests[i].prompt_len} {actual_output_lens[i]} {ttfts[i]*1000:.2f} {all_tpots[i]*1000:.2f} {e2els[i]*1000:.2f}")
+<
+356d335
+< trace_dataset_path: str | None = None,
+477,488d455
+< if trace_dataset_path is not None:
+< iterator = get_request_trace(input_requests)
+< else:
+< iterator = get_request(
+< input_requests,
+< request_rate,
+< burstiness,
+< ramp_up_strategy,
+< ramp_up_start_rps,
+< ramp_up_end_rps,
+< )
+<
+501c468,470
+< async for request, current_request_rate in iterator:
+---
+> async for request, current_request_rate in get_request(
+> input_requests, request_rate, burstiness, ramp_up_strategy,
+> ramp_up_start_rps, ramp_up_end_rps):
+987d955
+< parser.add_argument("--trace-dataset-path", type=str, default=None)
+1096d1063
+< trace_dataset_path=args.trace_dataset_path,
+EOF
+
+echo "Patch applied successfully!"
+
+echo "You can now use:"
+echo ""
+echo " vllm bench serve \\"
+echo " --model RedHatAI/Meta-Llama-3.1-70B-Instruct-FP8 \\"
+echo " --trace-dataset-path example_trace.jsonl"
+echo ""
diff --git a/benchmark/trace/example_trace.jsonl b/benchmark/trace/example_trace.jsonl
new file mode 100644
index 000000000..7ffb0b673
--- /dev/null
+++ b/benchmark/trace/example_trace.jsonl
@@ -0,0 +1,6 @@
+{"timestamp": 0, "input_length": 6758, "output_length": 500}
+{"timestamp": 0, "input_length": 7322, "output_length": 490}
+{"timestamp": 0, "input_length": 7236, "output_length": 794}
+{"timestamp": 0, "input_length": 2290, "output_length": 316}
+{"timestamp": 0, "input_length": 6760, "output_length": 3}
+{"timestamp": 15000, "input_length": 15, "output_length": 3}
diff --git a/benchmark/trace/run_client.sh b/benchmark/trace/run_client.sh
new file mode 100644
index 000000000..6512b735a
--- /dev/null
+++ b/benchmark/trace/run_client.sh
@@ -0,0 +1,6 @@
+
+model="/data-fast/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic"
+dataset=example_trace.jsonl
+
+vllm bench serve --model $model --trace-dataset-path $dataset --ignore-eos
+
diff --git a/csrc/custom_ops/CMakeLists.txt b/csrc/custom_ops/CMakeLists.txt
index 353c895bb..fa101e1e5 100644
--- a/csrc/custom_ops/CMakeLists.txt
+++ b/csrc/custom_ops/CMakeLists.txt
@@ -16,7 +16,10 @@ find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
pybind11_add_module(custom_ops
- kernels.cu
+ reshape_and_cache_flash_fp4.cu
+ reshape_and_cache_flash_bulk.cu
+ speculator_ln.cu
+ sum_lstm.cu
torch_bindings.cpp
)
diff --git a/csrc/custom_ops/attention_generic.cuh b/csrc/custom_ops/attention_generic.cuh
index 607a062fb..6605da3b6 100644
--- a/csrc/custom_ops/attention_generic.cuh
+++ b/csrc/custom_ops/attention_generic.cuh
@@ -23,32 +23,26 @@
namespace vllm {
// A vector type to store Q, K, V elements.
-template
-struct Vec {};
+template struct Vec {};
// A vector type to store FP32 accumulators.
-template
-struct FloatVec {};
+template struct FloatVec {};
// Template vector operations.
template
inline __device__ Acc mul(A a, B b);
-template
-inline __device__ float sum(T v);
+template inline __device__ float sum(T v);
-template
-inline __device__ float dot(T a, T b) {
+template inline __device__ float dot(T a, T b) {
return sum(mul(a, b));
}
-template
-inline __device__ float dot(T a, T b) {
+template inline __device__ float dot(T a, T b) {
return sum(mul(a, b));
}
-template
-inline __device__ void zero(T& dst) {
+template inline __device__ void zero(T &dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
@@ -62,5 +56,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw;
}
-} // namespace vllm
-
+} // namespace vllm
diff --git a/csrc/custom_ops/custom_ops.h b/csrc/custom_ops/custom_ops.h
index 3edf44ece..336f59370 100644
--- a/csrc/custom_ops/custom_ops.h
+++ b/csrc/custom_ops/custom_ops.h
@@ -7,14 +7,35 @@
#include
void reshape_and_cache_flash_bulk(
- torch::Tensor& keys,
- torch::Tensor& values,
- std::vector const& key_caches,
- std::vector const& value_caches,
- torch::Tensor& slot_mapping,
- const std::string& kv_cache_dtype,
- std::vector const& k_scales,
- std::vector const& v_scales,
- int64_t num_heads,
- int64_t head_size
-);
\ No newline at end of file
+ torch::Tensor &keys, torch::Tensor &values,
+ std::vector const &key_caches,
+ std::vector const &value_caches, torch::Tensor &slot_mapping,
+ const std::string &kv_cache_dtype,
+ std::vector const &k_scales,
+ std::vector const &v_scales, int64_t num_heads,
+ int64_t head_size);
+
+void reshape_and_cache_flash_fp4(torch::Tensor &key, torch::Tensor &value,
+ torch::Tensor &key_cache,
+ torch::Tensor &value_cache,
+ torch::Tensor &slot_mapping,
+ const std::string &kv_cache_dtype,
+ torch::Tensor &k_scale, torch::Tensor &v_scale,
+ torch::Tensor &key_scale_cache,
+ torch::Tensor &value_scale_cache);
+
+torch::Tensor speculator_ln_cuda(const torch::Tensor &input,
+ const c10::optional &weight,
+ const c10::optional &bias,
+ double eps);
+
+std::tuple sum_lstm_cuda(
+ const torch::Tensor& states_4d, // [..., 4D]
+ const torch::Tensor& z4_4d, // [..., 4D] (repeat along last dim)
+ const torch::Tensor& prev_cell_d, // [..., D]
+ const c10::optional& w_cell,
+ const c10::optional& b_cell,
+ const c10::optional& w_state,
+ const c10::optional& b_state,
+ double alpha, double eps_cell, double eps_state,
+ bool use_fast_gelu);
\ No newline at end of file
diff --git a/csrc/custom_ops/dispatch_utils.h b/csrc/custom_ops/dispatch_utils.h
index 3361922f0..1a778131b 100644
--- a/csrc/custom_ops/dispatch_utils.h
+++ b/csrc/custom_ops/dispatch_utils.h
@@ -8,63 +8,62 @@
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
-#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
+#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
-#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
+#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
-#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
+#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
// A host-based check at runtime will create a preferred FP8 type for ROCm
// such that the correct kernel is dispatched.
#ifdef USE_ROCM
- #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
- AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
- AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
+#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
- #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
+#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
- #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
- AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
+#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
- #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
+#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
-#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
+#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
-#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
+#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
-#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
+#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
-#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH(TYPE, NAME, \
+#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
-#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
+#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
-#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
+#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
-
diff --git a/csrc/custom_ops/dtype_common.cuh b/csrc/custom_ops/dtype_common.cuh
new file mode 100644
index 000000000..0ea648f56
--- /dev/null
+++ b/csrc/custom_ops/dtype_common.cuh
@@ -0,0 +1,72 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+template struct IsSupported : std::false_type {};
+template <> struct IsSupported : std::true_type {};
+template <> struct IsSupported : std::true_type {};
+
+__device__ __forceinline__ float to_float_device(__half h) {
+ return __half2float(h);
+}
+
+__device__ __forceinline__ __half from_float_device(float f) {
+ return __float2half(f);
+}
+
+__device__ __forceinline__ float to_float_device(__nv_bfloat16 h) {
+#if !defined(__CUDA_NO_BF16__)
+ return __bfloat162float(h);
+#else
+ uint16_t raw = *reinterpret_cast(&h);
+ uint32_t u32 = (uint32_t)raw << 16;
+ float out = __uint_as_float(u32);
+ return out;
+#endif
+}
+
+__device__ __forceinline__ __nv_bfloat16 from_float_device_bf16(float f) {
+#if !defined(__CUDA_NO_BF16__)
+ return __float2bfloat16(f);
+#else
+ uint32_t u = __float_as_uint(f);
+ uint16_t hi = (uint16_t)(u >> 16);
+ __nv_bfloat16 h;
+ *reinterpret_cast(&h) = hi;
+ return h;
+#endif
+}
+
+template struct DevHalf;
+
+template <> struct DevHalf {
+ using type = __half;
+ static __device__ __forceinline__ float to_float(__half h) {
+ return to_float_device(h);
+ }
+ static __device__ __forceinline__ __half from_float(float f) {
+ return from_float_device(f);
+ }
+};
+
+template <> struct DevHalf {
+ using type = __nv_bfloat16;
+ static __device__ __forceinline__ float to_float(__nv_bfloat16 h) {
+ return to_float_device(h);
+ }
+ static __device__ __forceinline__ __nv_bfloat16 from_float(float f) {
+ return from_float_device_bf16(f);
+ }
+};
+
+template struct alignas(sizeof(T) * N) Pack {
+ T v[N];
+};
+
+static inline bool is_aligned(const void *p, size_t bytes) {
+ return (reinterpret_cast(p) % bytes) == 0;
+}
diff --git a/csrc/custom_ops/dtype_fp8.cuh b/csrc/custom_ops/dtype_fp8.cuh
index 3bb195524..86f15adb5 100644
--- a/csrc/custom_ops/dtype_fp8.cuh
+++ b/csrc/custom_ops/dtype_fp8.cuh
@@ -4,10 +4,10 @@
#include
#ifdef ENABLE_FP8
- #ifndef USE_ROCM
- #include
- #endif // USE_ROCM
-#endif // ENABLE_FP8
+#ifndef USE_ROCM
+#include
+#endif // USE_ROCM
+#endif // ENABLE_FP8
namespace vllm {
@@ -18,25 +18,20 @@ enum class Fp8KVCacheDataType {
};
// fp8 vector types for quantization of kv cache
-template <>
-struct Vec {
+template <> struct Vec {
using Type = uint8_t;
};
-template <>
-struct Vec {
+template <> struct Vec {
using Type = uint16_t;
};
-template <>
-struct Vec {
+template <> struct Vec {
using Type = uint32_t;
};
-template <>
-struct Vec {
+template <> struct Vec {
using Type = uint2;
};
-} // namespace vllm
-
+} // namespace vllm
diff --git a/csrc/custom_ops/quant_utils.cuh b/csrc/custom_ops/quant_utils.cuh
index 1f2b13690..e23bbed64 100644
--- a/csrc/custom_ops/quant_utils.cuh
+++ b/csrc/custom_ops/quant_utils.cuh
@@ -10,9 +10,9 @@ namespace vllm {
#ifndef USE_ROCM
namespace fp8 {
- #ifdef ENABLE_FP8
+#ifdef ENABLE_FP8
- #if 0 // Disable the following code to reduce the binary size.
+#if 0 // Disable the following code to reduce the binary size.
template
__inline__ __device__ Tout
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
@@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion(
template <>
__inline__ __device__ uint8_t vec_conversion(
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
- #else
+#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
return (uint8_t)res;
- #endif
+#endif
}
// float -> fp8
@@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion(
from_float(b, a);
return b;
}
- #endif
+#endif
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains Convention of the scale in API, e.g: FP8_data =
@@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion(
template
__inline__ __device__ Tout scaled_vec_conversion(
- const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
+ const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t scaled_vec_conversion(
- const uint8_t& a, const float scale,
+ const uint8_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
return float_to_half(half_to_float(tmp.x) * scale);
@@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion(
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion(
- const uint16_t& a, const float scale,
+ const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint16_t u16[2];
@@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion(
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion(
- const uint32_t& a, const float scale,
+ const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint2 u32x2;
@@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion(
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4
-scaled_vec_conversion(const uint2& a, const float scale,
+scaled_vec_conversion(const uint2 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint4 u64x2;
@@ -348,7 +348,7 @@ scaled_vec_conversion(const uint2& a, const float scale,
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
- const uint8_t& a, const float scale,
+ const uint8_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
@@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
- const uint16_t& a, const float scale,
+ const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
@@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion(
- const uint32_t& a, const float scale,
+ const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
@@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion(
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion(
- const uint2& a, const float scale,
+ const uint2 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion(a.x, scale, fp8_type);
@@ -404,7 +404,7 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion(
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion(
- const uint8_t& a, const float scale,
+ const uint8_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
@@ -417,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion(
// fp8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion(
- const uint16_t& a, const float scale,
+ const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// fp8x2 -> half2
uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type);
@@ -428,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion(
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion(
- const uint32_t& a, const float scale,
+ const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ res;
res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type);
@@ -440,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion(
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ scaled_vec_conversion(
- const uint2& a, const float scale,
+ const uint2 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion(a.x, scale, fp8_type);
@@ -456,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion(
// half -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion(
- const uint16_t& a, const float scale,
+ const uint16_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
@@ -466,22 +466,22 @@ __inline__ __device__ uint8_t scaled_vec_conversion(
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion(
- const __nv_bfloat16& a, const float scale,
+ const __nv_bfloat16 &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
- #else
+#else
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
__NV_SATFINITE, fp8_type);
return (uint8_t)res;
- #endif
- __builtin_unreachable(); // Suppress missing return statement warning
+#endif
+ __builtin_unreachable(); // Suppress missing return statement warning
}
// float -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion(
- const float& a, const float scale,
+ const float &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
@@ -491,84 +491,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion(
// fp8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion(
- const uint32_t& a, const float scale,
+ const uint32_t &a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
- #endif // ENABLE_FP8
+#endif // ENABLE_FP8
template
-__inline__ __device__ Tout convert(const Tin& x) {
- #if 0 // Disable the following code to reduce the binary size.
+__inline__ __device__ Tout convert(const Tin &x) {
+#if 0 // Disable the following code to reduce the binary size.
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion(x, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return vec_conversion(x, __NV_E5M2);
}
- #endif
+#endif
assert(false);
- __builtin_unreachable(); // Suppress missing return statement warning
+ __builtin_unreachable(); // Suppress missing return statement warning
}
template
-__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
- #ifdef ENABLE_FP8
+__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
+#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion(x, scale, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion(x, scale, __NV_E5M2);
}
- #endif
+#endif
assert(false);
- __builtin_unreachable(); // Suppress missing return statement warning
-}
-
- // The following macro is used to dispatch the conversion function based on
- // the data type of the key and value cache. The FN is a macro that calls a
- // function with template.
- #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
- if (KV_DTYPE == "auto") { \
+ __builtin_unreachable(); // Suppress missing return statement warning
+}
+
+// The following macro is used to dispatch the conversion function based on
+// the data type of the key and value cache. The FN is a macro that calls a
+// function with template.
+#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
+ if (KV_DTYPE == "auto") { \
+ if (SRC_DTYPE == at::ScalarType::Float) { \
+ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
+ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
+ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
+ } \
+ } else { \
+ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
- FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
- FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
- FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
- } else { \
- if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
- if (SRC_DTYPE == at::ScalarType::Float) { \
- FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
- } else if (SRC_DTYPE == at::ScalarType::Half) { \
- FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
- } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
- FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
- } else { \
- TORCH_CHECK(false, \
- "Unsupported input type of kv cache: ", SRC_DTYPE); \
- } \
- } else if (KV_DTYPE == "fp8_e5m2") { \
- if (SRC_DTYPE == at::ScalarType::Float) { \
- FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
- } else if (SRC_DTYPE == at::ScalarType::Half) { \
- FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
- } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
- FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
- } else { \
- TORCH_CHECK(false, \
- "Unsupported input type of kv cache: ", SRC_DTYPE); \
- } \
+ } else if (KV_DTYPE == "fp8_e5m2") { \
+ if (SRC_DTYPE == at::ScalarType::Float) { \
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
- TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
+ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
- }
-
-} // namespace fp8
-#endif // not USE_ROCM
-} // namespace vllm
+ } else { \
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
+ } \
+ }
+} // namespace fp8
+#endif // not USE_ROCM
+} // namespace vllm
diff --git a/csrc/custom_ops/kernels.cu b/csrc/custom_ops/reshape_and_cache_flash_bulk.cu
similarity index 73%
rename from csrc/custom_ops/kernels.cu
rename to csrc/custom_ops/reshape_and_cache_flash_bulk.cu
index fe6ee9ba1..cf9dce1ed 100644
--- a/csrc/custom_ops/kernels.cu
+++ b/csrc/custom_ops/reshape_and_cache_flash_bulk.cu
@@ -2,30 +2,21 @@
#include "dispatch_utils.h"
#include "quant_utils.cuh"
-#include
#include
+#include
#include
namespace vllm {
-template
+template
__global__ void reshape_and_cache_flash_bulk_kernel(
- const scalar_t* __restrict__ keys,
- const scalar_t* __restrict__ values,
- int64_t* key_cache_ptrs,
- int64_t* value_cache_ptrs,
- const int64_t* __restrict__ slot_mapping,
- const int block_stride,
- const int key_stride,
- const int value_stride,
- const int num_heads,
- const int head_size,
- const int block_size,
- int64_t* k_scale_ptrs,
- int64_t* v_scale_ptrs) {
+ const scalar_t *__restrict__ keys, const scalar_t *__restrict__ values,
+ int64_t *key_cache_ptrs, int64_t *value_cache_ptrs,
+ const int64_t *__restrict__ slot_mapping, const int block_stride,
+ const int key_stride, const int value_stride, const int num_heads,
+ const int head_size, const int block_size, int64_t *k_scale_ptrs,
+ int64_t *v_scale_ptrs) {
const int64_t layer_idx = blockIdx.x;
const int64_t token_idx = blockIdx.y;
const int64_t slot_idx = slot_mapping[token_idx];
@@ -37,14 +28,14 @@ __global__ void reshape_and_cache_flash_bulk_kernel(
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
- cache_t* __restrict__ key_cache =
- reinterpret_cast(key_cache_ptrs[layer_idx]);
- cache_t* __restrict__ value_cache =
- reinterpret_cast(value_cache_ptrs[layer_idx]);
- const float* __restrict__ k_scale =
- reinterpret_cast(k_scale_ptrs[layer_idx]);
- const float* __restrict__ v_scale =
- reinterpret_cast(v_scale_ptrs[layer_idx]);
+ cache_t *__restrict__ key_cache =
+ reinterpret_cast(key_cache_ptrs[layer_idx]);
+ cache_t *__restrict__ value_cache =
+ reinterpret_cast(value_cache_ptrs[layer_idx]);
+ const float *__restrict__ k_scale =
+ reinterpret_cast(k_scale_ptrs[layer_idx]);
+ const float *__restrict__ v_scale =
+ reinterpret_cast(v_scale_ptrs[layer_idx]);
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + layer_idx * n + i;
@@ -70,29 +61,26 @@ __global__ void reshape_and_cache_flash_bulk_kernel(
} // namespace vllm
-#define CALL_RESHAPE_AND_CACHE_FLASH_BULK(KV_T, CACHE_T, KV_DTYPE) \
- vllm::reshape_and_cache_flash_bulk_kernel \
- <<>>( \
- reinterpret_cast(keys.data_ptr()), \
- reinterpret_cast(values.data_ptr()), \
- key_cache_ptrs_tensor.data_ptr(), \
- value_cache_ptrs_tensor.data_ptr(), \
- slot_mapping.data_ptr(), block_stride, key_stride, \
- value_stride, static_cast(num_heads), \
- static_cast(head_size), block_size, \
- k_scale_ptrs_tensor.data_ptr(), \
+#define CALL_RESHAPE_AND_CACHE_FLASH_BULK(KV_T, CACHE_T, KV_DTYPE) \
+ vllm::reshape_and_cache_flash_bulk_kernel \
+ <<>>( \
+ reinterpret_cast(keys.data_ptr()), \
+ reinterpret_cast(values.data_ptr()), \
+ key_cache_ptrs_tensor.data_ptr(), \
+ value_cache_ptrs_tensor.data_ptr(), \
+ slot_mapping.data_ptr(), block_stride, key_stride, \
+ value_stride, static_cast(num_heads), \
+ static_cast(head_size), block_size, \
+ k_scale_ptrs_tensor.data_ptr(), \
v_scale_ptrs_tensor.data_ptr());
void reshape_and_cache_flash_bulk(
- torch::Tensor& keys,
- torch::Tensor& values,
- std::vector const& key_caches,
- std::vector const& value_caches,
- torch::Tensor& slot_mapping,
- const std::string& kv_cache_dtype,
- std::vector const& k_scales,
- std::vector const& v_scales,
- int64_t num_heads,
+ torch::Tensor &keys, torch::Tensor &values,
+ std::vector const &key_caches,
+ std::vector const &value_caches, torch::Tensor &slot_mapping,
+ const std::string &kv_cache_dtype,
+ std::vector const &k_scales,
+ std::vector const &v_scales, int64_t num_heads,
int64_t head_size) {
int num_layers = key_caches.size();
@@ -146,7 +134,8 @@ void reshape_and_cache_flash_bulk(
.to(device_of_key);
dim3 grid(num_layers, num_tokens);
- dim3 block(std::min(static_cast(num_heads) * static_cast(head_size), 512));
+ dim3 block(
+ std::min(static_cast(num_heads) * static_cast(head_size), 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
diff --git a/csrc/custom_ops/reshape_and_cache_flash_fp4.cu b/csrc/custom_ops/reshape_and_cache_flash_fp4.cu
new file mode 100644
index 000000000..ce0ffffc5
--- /dev/null
+++ b/csrc/custom_ops/reshape_and_cache_flash_fp4.cu
@@ -0,0 +1,338 @@
+#include "custom_ops.h"
+#include "dispatch_utils.h"
+#include "quant_utils.cuh"
+#include "vectorization_utils.cuh"
+
+#include
+#include
+#include
+
+#include
+
+namespace vllm {
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
+template struct TypeConverter {
+ using Type = half2;
+};
+template <> struct TypeConverter {
+ using Type = half;
+};
+template <> struct TypeConverter {
+ using Type = half2;
+};
+template <> struct TypeConverter<__nv_bfloat162> {
+ using Type = __nv_bfloat16;
+};
+template <> struct TypeConverter<__nv_bfloat16> {
+ using Type = __nv_bfloat162;
+};
+
+template struct PackedVec8 {
+ typename TypeConverter::Type elts[4];
+};
+
+#if __CUDA_ARCH__ < 1000
+__device__ uint8_t float_to_e2m1_rn(float val) {
+ if (isnan(val)) {
+ return 0x0;
+ }
+ if (isinf(val)) {
+ val = val < 0.f ? -6.f : 6.f;
+ }
+ uint32_t sign_bit = (reinterpret_cast(val) & 0x80000000) >> 28;
+ float x = fabsf(val);
+ uint8_t magnitude_bits;
+ if (x > 5.0f)
+ magnitude_bits = 0x7; // 6.0
+ else if (x > 3.5f)
+ magnitude_bits = 0x6; // 4.0
+ else if (x > 2.5f)
+ magnitude_bits = 0x5; // 3.0
+ else if (x > 1.75f)
+ magnitude_bits = 0x4; // 2.0
+ else if (x > 1.25f)
+ magnitude_bits = 0x3; // 1.5
+ else if (x > 0.75f)
+ magnitude_bits = 0x2; // 1.0
+ else if (x > 0.25f)
+ magnitude_bits = 0x1; // 0.5
+ else
+ magnitude_bits = 0x0; // 0.0
+ return sign_bit | magnitude_bits;
+}
+#endif
+
+// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
+inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
+ uint32_t val;
+ asm volatile("{\n"
+ ".reg .b8 byte0, byte1, byte2, byte3;\n"
+ "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
+ "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
+ "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
+ "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
+ "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
+ "}"
+ : "=r"(val)
+ : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x),
+ "f"(array[1].y), "f"(array[2].x), "f"(array[2].y),
+ "f"(array[3].x), "f"(array[3].y));
+ return val;
+#else
+ uint32_t result = 0;
+ uint8_t *result_bytes = reinterpret_cast(&result);
+#pragma unroll
+ for (int i = 0; i < 4; ++i) {
+ uint8_t val1 = float_to_e2m1_rn(array[i].x);
+ uint8_t val2 = float_to_e2m1_rn(array[i].y);
+ result_bytes[i] = (val2 << 4) | (val1 & 0x0F);
+ }
+ return result;
+#endif
+}
+
+inline __device__ float reciprocal_approximate_ftz(float a) {
+ float b;
+ asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
+ return b;
+}
+
+template
+__device__ __forceinline__ void
+quantize_16_to_fp4(const scalar_t *__restrict__ src_base,
+ uint32_t *__restrict__ dst_base,
+ uint8_t *__restrict__ scale_dst, const int tid_in_block) {
+
+ const int lane_in_pair = tid_in_block % 2;
+
+ using PackedVec = PackedVec8;
+ PackedVec vec =
+ *reinterpret_cast(src_base + lane_in_pair * 8);
+
+ auto localMax = __habs2(vec.elts[0]);
+#pragma unroll
+ for (int i = 1; i < 4; i++) {
+ localMax = __hmax2(localMax, __habs2(vec.elts[i]));
+ }
+
+ localMax = __hmax2(__shfl_xor_sync(0xffffffff, localMax, 1), localMax);
+ float vecMax = float(__hmax(localMax.x, localMax.y));
+ float SFValue = vecMax * reciprocal_approximate_ftz(6.0f);
+
+ if (lane_in_pair == 0) {
+ __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
+ *scale_dst = reinterpret_cast(tmp);
+ SFValue = float(tmp);
+ }
+
+ SFValue = __shfl_sync(0xffffffff, SFValue, tid_in_block & ~1);
+ float outputScale =
+ (SFValue != 0.0f) ? reciprocal_approximate_ftz(SFValue) : 0.0f;
+
+ float2 fp2Vals[4];
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ if constexpr (std::is_same_v) {
+ fp2Vals[i] = __half22float2(vec.elts[i]);
+ } else {
+ fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
+ }
+ fp2Vals[i].x *= outputScale;
+ fp2Vals[i].y *= outputScale;
+ }
+
+ dst_base[lane_in_pair] = fp32_vec_to_e2m1(fp2Vals);
+}
+#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
+
+template
+__global__ void reshape_and_cache_flash_kernel_fp4(
+ const scalar_t *__restrict__ key, // [num_tokens, num_heads, head_size]
+ const scalar_t *__restrict__ value, // [num_tokens, num_heads, head_size]
+ cache_t *__restrict__ key_cache, // NHD or HND
+ cache_t *__restrict__ value_cache,
+ const int64_t *__restrict__ slot_mapping, // [num_tokens]
+ const int64_t block_stride, const int64_t page_stride,
+ const int64_t head_stride, const int64_t key_stride,
+ const int64_t value_stride, const int num_heads, const int head_size,
+ const int block_size, const float *k_scale, const float *v_scale,
+ uint8_t *__restrict__ key_scale_cache,
+ uint8_t *__restrict__ value_scale_cache, const bool is_nhd) {
+ const int64_t token_idx = blockIdx.x;
+ const int64_t slot_idx = slot_mapping[token_idx];
+ if (slot_idx < 0) {
+ return;
+ }
+ const int64_t block_idx = slot_idx / block_size;
+ const int64_t block_offset = slot_idx % block_size;
+ const int n_elems = num_heads * head_size;
+
+ const scalar_t *__restrict__ key_src = key + token_idx * key_stride;
+ const scalar_t *__restrict__ value_src = value + token_idx * value_stride;
+
+ cache_t *__restrict__ key_dst =
+ key_cache + block_idx * block_stride + block_offset * page_stride;
+ cache_t *__restrict__ value_dst =
+ value_cache + block_idx * block_stride + block_offset * page_stride;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
+ uint32_t *__restrict__ key_dst_fp4 = reinterpret_cast(key_dst);
+ uint32_t *__restrict__ value_dst_fp4 =
+ reinterpret_cast(value_dst);
+
+ const int64_t scale_block_stride = block_stride / 8;
+ const int64_t scale_page_stride = page_stride / 8;
+ const int64_t scale_head_stride = head_stride / 8;
+
+ uint8_t *__restrict__ key_scale_dst = key_scale_cache +
+ block_idx * scale_block_stride +
+ block_offset * scale_page_stride;
+ uint8_t *__restrict__ value_scale_dst = value_scale_cache +
+ block_idx * scale_block_stride +
+ block_offset * scale_page_stride;
+
+ constexpr int CHUNK_SIZE = 16;
+ constexpr int THREADS_PER_CHUNK = 2;
+
+ const int lane = threadIdx.x % 32;
+ const int warp_id = threadIdx.x / 32;
+ const int warps_in_block = blockDim.x / 32;
+ const int chunks_per_warp = 32 / THREADS_PER_CHUNK;
+
+ if (is_nhd) {
+ // NHD layout
+ const int num_chunks = n_elems / CHUNK_SIZE;
+ for (int chunk_base_idx = warp_id * chunks_per_warp;
+ chunk_base_idx < num_chunks;
+ chunk_base_idx += warps_in_block * chunks_per_warp) {
+ const int chunk_idx = chunk_base_idx + (lane / THREADS_PER_CHUNK);
+ if (chunk_idx < num_chunks) {
+ quantize_16_to_fp4(key_src + chunk_idx * CHUNK_SIZE,
+ key_dst_fp4 + chunk_idx * (CHUNK_SIZE / 8),
+ key_scale_dst + chunk_idx, threadIdx.x);
+ quantize_16_to_fp4(value_src + chunk_idx * CHUNK_SIZE,
+ value_dst_fp4 + chunk_idx * (CHUNK_SIZE / 8),
+ value_scale_dst + chunk_idx, threadIdx.x);
+ }
+ }
+ } else {
+ // HND layout
+ const int num_chunks_per_head = head_size / CHUNK_SIZE;
+ for (int head = warp_id; head < num_heads; head += warps_in_block) {
+ const scalar_t *__restrict__ k_src_h = key_src + head * head_size;
+ const scalar_t *__restrict__ v_src_h = value_src + head * head_size;
+
+ cache_t *__restrict__ k_dst_head_u8 =
+ key_dst + static_cast(head) * head_stride;
+ cache_t *__restrict__ v_dst_head_u8 =
+ value_dst + static_cast(head) * head_stride;
+
+ uint32_t *__restrict__ k_dst_h =
+ reinterpret_cast(k_dst_head_u8);
+ uint32_t *__restrict__ v_dst_h =
+ reinterpret_cast(v_dst_head_u8);
+
+ uint8_t *__restrict__ k_scale_dst_h =
+ key_scale_dst + static_cast(head) * scale_head_stride;
+ uint8_t *__restrict__ v_scale_dst_h =
+ value_scale_dst + static_cast(head) * scale_head_stride;
+
+ for (int chunk_idx = lane / THREADS_PER_CHUNK;
+ chunk_idx < num_chunks_per_head; chunk_idx += chunks_per_warp) {
+ quantize_16_to_fp4(k_src_h + chunk_idx * CHUNK_SIZE,
+ k_dst_h + chunk_idx * (CHUNK_SIZE / 8),
+ k_scale_dst_h + chunk_idx, threadIdx.x);
+ quantize_16_to_fp4(v_src_h + chunk_idx * CHUNK_SIZE,
+ v_dst_h + chunk_idx * (CHUNK_SIZE / 8),
+ v_scale_dst_h + chunk_idx, threadIdx.x);
+ }
+ }
+ }
+#endif // __CUDA_ARCH__ >= 900
+}
+
+} // namespace vllm
+
+void reshape_and_cache_flash_fp4(
+ torch::Tensor &key, // [num_tokens, num_heads, head_size]
+ torch::Tensor &value, // [num_tokens, num_heads, head_size]
+ torch::Tensor &key_cache, torch::Tensor &value_cache,
+ torch::Tensor &slot_mapping, const std::string &kv_cache_dtype,
+ torch::Tensor &k_scale, torch::Tensor &v_scale,
+ torch::Tensor &key_scale_cache, torch::Tensor &value_scale_cache) {
+ const int64_t num_tokens = slot_mapping.size(0);
+ const int64_t num_heads = key.size(1);
+ const int64_t head_size = key.size(2);
+
+ TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4,
+ "KV cache must be rank-4");
+
+ const bool is_nhd = (key_cache.size(2) == num_heads);
+
+ const int64_t block_stride = key_cache.stride(0);
+ const int64_t page_stride =
+ is_nhd ? key_cache.stride(1) : key_cache.stride(2);
+ const int64_t head_stride =
+ is_nhd ? key_cache.stride(2) : key_cache.stride(1);
+ const int block_size = is_nhd ? key_cache.size(1) : key_cache.size(2);
+
+ TORCH_CHECK(value_cache.stride(0) == block_stride, "block_stride mismatch");
+ TORCH_CHECK((is_nhd ? value_cache.stride(1) : value_cache.stride(2)) ==
+ page_stride,
+ "page_stride mismatch");
+ TORCH_CHECK((is_nhd ? value_cache.stride(2) : value_cache.stride(1)) ==
+ head_stride,
+ "head_stride mismatch");
+ TORCH_CHECK((is_nhd ? value_cache.size(1) : value_cache.size(2)) ==
+ block_size,
+ "block_size mismatch between key_cache and value_cache");
+
+ int64_t key_stride = key.stride(0);
+ int64_t value_stride = value.stride(0);
+ TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
+
+ dim3 grid(num_tokens);
+ int threads = std::min(num_heads * head_size, 512);
+ threads = std::max(32, ((threads + 31) / 32) * 32);
+ dim3 block(threads);
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ switch (key.scalar_type()) {
+ case torch::kHalf: {
+ vllm::reshape_and_cache_flash_kernel_fp4
+ <<>>(
+ reinterpret_cast(key.data_ptr()),
+ reinterpret_cast(value.data_ptr()),
+ reinterpret_cast(key_cache.data_ptr()),
+ reinterpret_cast(value_cache.data_ptr()),
+ slot_mapping.data_ptr(), block_stride, page_stride,
+ head_stride, key_stride, value_stride, num_heads, head_size,
+ block_size, k_scale.data_ptr(), v_scale.data_ptr(),
+ key_scale_cache.data_ptr(),
+ value_scale_cache.data_ptr(), is_nhd);
+ break;
+ }
+ case torch::kBFloat16: {
+ vllm::reshape_and_cache_flash_kernel_fp4<__nv_bfloat16, uint8_t>
+ <<>>(
+ reinterpret_cast<__nv_bfloat16 *>(key.data_ptr()),
+ reinterpret_cast<__nv_bfloat16 *>(value.data_ptr()),
+ reinterpret_cast(key_cache.data_ptr()),
+ reinterpret_cast(value_cache.data_ptr()),
+ slot_mapping.data_ptr(), block_stride, page_stride,
+ head_stride, key_stride, value_stride, num_heads, head_size,
+ block_size, k_scale.data_ptr(), v_scale.data_ptr(),
+ key_scale_cache.data_ptr(),
+ value_scale_cache.data_ptr(), is_nhd);
+ break;
+ }
+ default: {
+ TORCH_CHECK(false, "Unsupported input dtype for reshape_and_cache_fp4. "
+ "Must be half or bfloat16.");
+ }
+ }
+}
\ No newline at end of file
diff --git a/csrc/custom_ops/speculator_ln.cu b/csrc/custom_ops/speculator_ln.cu
new file mode 100644
index 000000000..e3c376747
--- /dev/null
+++ b/csrc/custom_ops/speculator_ln.cu
@@ -0,0 +1,252 @@
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include "dtype_common.cuh"
+
+struct SumOp {
+ __device__ __forceinline__ float operator()(const float &a,
+ const float &b) const {
+ return a + b;
+ }
+};
+
+template
+__global__ void spec_ln_vec_kernel(
+ typename DevHalf::type *__restrict__ out,
+ const typename DevHalf::type *__restrict__ in,
+ const typename DevHalf::type *__restrict__ weight,
+ const typename DevHalf::type *__restrict__ bias,
+ int64_t row_stride, int hidden, float eps) {
+
+ using DH = DevHalf;
+ using H = typename DH::type;
+ using VPack = Pack;
+
+ extern __shared__ char smem[];
+ __shared__ float s_inv_rms;
+
+ const int row = blockIdx.x;
+ const H *row_in = in + row * row_stride;
+ H *row_out = out + row * hidden;
+
+ const int vec_len = hidden / VEC;
+ const int tail = hidden - vec_len * VEC;
+
+ float local_ss = 0.f;
+
+ const VPack *in_vec = reinterpret_cast(row_in);
+
+ for (int i = threadIdx.x; i < vec_len; i += blockDim.x) {
+ VPack p = in_vec[i];
+#pragma unroll
+ for (int k = 0; k < VEC; ++k) {
+ float x = DH::to_float(p.v[k]);
+ local_ss += x * x;
+ }
+ }
+
+ for (int j = threadIdx.x + vec_len * VEC; j < hidden; j += blockDim.x) {
+ float x = DH::to_float(row_in[j]);
+ local_ss += x * x;
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage temp_storage;
+ float sumsq = BlockReduce(temp_storage).Reduce(local_ss, SumOp{}, blockDim.x);
+
+ if (threadIdx.x == 0) {
+ float mean = sumsq / static_cast(hidden);
+ s_inv_rms = rsqrtf(mean + eps);
+ }
+ __syncthreads();
+
+ const VPack *w_vec = reinterpret_cast(weight);
+ const VPack *b_vec = reinterpret_cast(bias);
+ VPack *out_vec = reinterpret_cast(row_out);
+
+ for (int i = threadIdx.x; i < vec_len; i += blockDim.x) {
+ VPack px = reinterpret_cast(row_in)[i];
+ VPack py;
+ VPack pw;
+ VPack pb;
+ bool use_w = weight != nullptr;
+ bool use_b = bias != nullptr;
+ if (use_w)
+ pw = w_vec[i];
+ if (use_b)
+ pb = b_vec[i];
+#pragma unroll
+ for (int k = 0; k < VEC; ++k) {
+ float y = DH::to_float(px.v[k]);
+ y = y * s_inv_rms;
+ if (use_w)
+ y *= DH::to_float(pw.v[k]);
+ if (use_b)
+ y += DH::to_float(pb.v[k]);
+ py.v[k] = DH::from_float(y);
+ }
+ out_vec[i] = py;
+ }
+
+ for (int j = threadIdx.x + vec_len * VEC; j < hidden; j += blockDim.x) {
+ float y = DH::to_float(row_in[j]) * s_inv_rms;
+ if (weight)
+ y *= DH::to_float(weight[j]);
+ if (bias)
+ y += DH::to_float(bias[j]);
+ row_out[j] = DH::from_float(y);
+ }
+}
+
+template
+__global__ void spec_ln_scalar_kernel(
+ typename DevHalf::type *__restrict__ out,
+ const typename DevHalf::type *__restrict__ in,
+ const typename DevHalf::type *__restrict__ weight,
+ const typename DevHalf::type *__restrict__ bias,
+ int64_t row_stride, int hidden, float eps) {
+
+ using DH = DevHalf;
+ using H = typename DH::type;
+
+ __shared__ float s_inv_rms;
+
+ const int row = blockIdx.x;
+ const H *row_in = in + row * row_stride;
+ H *row_out = out + row * hidden;
+
+ float local_ss = 0.f;
+ for (int j = threadIdx.x; j < hidden; j += blockDim.x) {
+ float x = DH::to_float(row_in[j]);
+ local_ss += x * x;
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage temp_storage;
+ float sumsq = BlockReduce(temp_storage).Reduce(local_ss, SumOp{}, blockDim.x);
+
+ if (threadIdx.x == 0) {
+ float mean = sumsq / static_cast(hidden);
+ s_inv_rms = rsqrtf(mean + eps);
+ }
+ __syncthreads();
+
+ for (int j = threadIdx.x; j < hidden; j += blockDim.x) {
+ float y = DH::to_float(row_in[j]) * s_inv_rms;
+ if (weight)
+ y *= DH::to_float(weight[j]);
+ if (bias)
+ y += DH::to_float(bias[j]);
+ row_out[j] = DH::from_float(y);
+ }
+}
+
+template
+torch::Tensor speculator_ln_cuda_impl(
+ const torch::Tensor &input, // [..., hidden]
+ const c10::optional &w, // [hidden] or None
+ const c10::optional &b, // [hidden] or None
+ double eps_d) {
+
+ TORCH_CHECK(input.is_cuda(), "speculator_ln: input must be CUDA");
+ TORCH_CHECK(IsSupported::value,
+ "speculator_ln: dtype must be fp16 or bf16");
+
+ const auto dtype = input.scalar_type();
+ const auto device = input.device();
+ const float eps = static_cast(eps_d);
+
+ const int64_t hidden = input.size(-1);
+ TORCH_CHECK(hidden > 0, "speculator_ln: hidden size must be > 0");
+ TORCH_CHECK(
+ input.stride(-1) == 1,
+ "speculator_ln: last dimension must be contiguous (stride -1 == 1)");
+
+ const typename DevHalf::type *w_ptr = nullptr;
+ const typename DevHalf::type *b_ptr = nullptr;
+
+ if (w.has_value() && w->defined()) {
+ TORCH_CHECK(w->is_cuda(), "weight must be CUDA");
+ TORCH_CHECK(w->scalar_type() == dtype,
+ "weight dtype must match input dtype");
+ TORCH_CHECK(w->dim() == 1 && w->numel() == hidden,
+ "weight must be 1D of size hidden");
+ TORCH_CHECK(w->is_contiguous(), "weight must be contiguous");
+ w_ptr = reinterpret_cast::type *>(
+ w->data_ptr());
+ }
+ if (b.has_value() && b->defined()) {
+ TORCH_CHECK(b->is_cuda(), "bias must be CUDA");
+ TORCH_CHECK(b->scalar_type() == dtype, "bias dtype must match input dtype");
+ TORCH_CHECK(b->dim() == 1 && b->numel() == hidden,
+ "bias must be 1D of size hidden");
+ TORCH_CHECK(b->is_contiguous(), "bias must be contiguous");
+ b_ptr = reinterpret_cast::type *>(
+ b->data_ptr());
+ }
+
+ auto in_2d = input.view({-1, hidden});
+ const int64_t num_rows = in_2d.size(0);
+ const int64_t row_stride = in_2d.stride(0);
+
+ auto out = at::empty_like(input);
+ auto out_2d = out.view({-1, hidden});
+
+ const auto stream = at::cuda::getCurrentCUDAStream();
+ const int BLOCK = (num_rows < 256) ? 1024 : 256;
+ dim3 grid(num_rows);
+ dim3 block(std::min(hidden, BLOCK));
+
+ using H = typename DevHalf::type;
+ const H *in_ptr = reinterpret_cast(in_2d.data_ptr());
+ H *out_ptr = reinterpret_cast(out_2d.data_ptr());
+
+ const bool can_vec8 = (hidden % 8 == 0) && (row_stride % 8 == 0) &&
+ is_aligned(in_ptr, 16) && is_aligned(out_ptr, 16) &&
+ (!w_ptr || is_aligned(w_ptr, 16)) &&
+ (!b_ptr || is_aligned(b_ptr, 16));
+
+ const bool can_vec4 = (hidden % 4 == 0) && (row_stride % 4 == 0) &&
+ is_aligned(in_ptr, 8) && is_aligned(out_ptr, 8) &&
+ (!w_ptr || is_aligned(w_ptr, 8)) &&
+ (!b_ptr || is_aligned(b_ptr, 8));
+
+ if (can_vec8) {
+ spec_ln_vec_kernel<<>>(
+ out_ptr, in_ptr, w_ptr, b_ptr, row_stride, (int)hidden, eps);
+ } else if (can_vec4) {
+ spec_ln_vec_kernel<<>>(
+ out_ptr, in_ptr, w_ptr, b_ptr, row_stride, (int)hidden, eps);
+ } else {
+ spec_ln_scalar_kernel<<>>(
+ out_ptr, in_ptr, w_ptr, b_ptr, row_stride, (int)hidden, eps);
+ }
+
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ return out;
+}
+
+torch::Tensor speculator_ln_cuda(const torch::Tensor &input,
+ const c10::optional |