Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

import vllm.envs as envs
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
Expand All @@ -25,6 +25,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
TORCH_DTYPE_TO_KV_CACHE_STR,
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
Expand Down Expand Up @@ -193,6 +194,7 @@ def __init__(
alibi_slopes: list[float] | None = None,
use_alibi_sqrt: bool | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
Expand All @@ -217,12 +219,14 @@ def __init__(
else:
sliding_window = None

vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
assert model_config is not None, (
"model_config is required when cache_config is not provided"
)
kv_cache_dtype = TORCH_DTYPE_TO_KV_CACHE_STR[model_config.dtype]
calculate_kv_scales = False

# llm-compressor mdls need to set cache_dtype to "fp8" manually.
Expand Down Expand Up @@ -256,7 +260,10 @@ def __init__(
if str(layer_idx) in cache_config.kv_cache_dtype_skip_layers:
skip = True
if skip:
kv_cache_dtype = "auto"
assert model_config is not None, (
"model_config is required for kv_cache_dtype_skip_layers"
)
kv_cache_dtype = TORCH_DTYPE_TO_KV_CACHE_STR[model_config.dtype]
calculate_kv_scales = False
logger.info(
"Layer %s: kv_cache_dtype=%s, sliding_window=%s",
Expand All @@ -266,7 +273,7 @@ def __init__(
)

self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
kv_cache_dtype, model_config
)
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
Expand All @@ -285,8 +292,6 @@ def __init__(
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None

# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

# During model initialization, the default dtype is set as the model
Expand Down Expand Up @@ -357,7 +362,7 @@ def __init__(
self.use_direct_call = not current_platform.opaque_attention_op()

self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import torch

from vllm.config import CacheConfig
from vllm.config import CacheConfig, ModelConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.torch_utils import TORCH_DTYPE_TO_KV_CACHE_STR
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
quant_config: QuantizationConfig | None = None,
kv_sharing_target_layer_name: str | None = None,
prefix: str = "",
Expand All @@ -96,7 +98,10 @@ def __init__(
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
else:
kv_cache_dtype = "auto"
assert model_config is not None, (
"model_config is required when cache_config is not provided"
)
kv_cache_dtype = TORCH_DTYPE_TO_KV_CACHE_STR[model_config.dtype]

underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
attn_backend = create_chunked_local_attention_backend(
Expand Down
9 changes: 7 additions & 2 deletions vllm/model_executor/layers/attention/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import numpy as np
import torch

from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import TORCH_DTYPE_TO_KV_CACHE_STR
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__(
head_size: int,
scale: float,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
attn_type: str | None = None,
**kwargs,
):
Expand All @@ -189,7 +191,10 @@ def __init__(
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
else:
kv_cache_dtype = "auto"
assert model_config is not None, (
"model_config is required when cache_config is not provided"
)
kv_cache_dtype = TORCH_DTYPE_TO_KV_CACHE_STR[model_config.dtype]

if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import torch

from vllm.config import CacheConfig
from vllm.config import CacheConfig, ModelConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.attention import Attention
from vllm.utils.torch_utils import TORCH_DTYPE_TO_KV_CACHE_STR
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
head_size: int,
scale: float,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
attn_type: str | None = None,
**kwargs,
):
Expand All @@ -67,7 +69,10 @@ def __init__(
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
else:
kv_cache_dtype = "auto"
assert model_config is not None, (
"model_config is required when cache_config is not provided"
)
kv_cache_dtype = TORCH_DTYPE_TO_KV_CACHE_STR[model_config.dtype]

underlying_attn_backend = get_attn_backend(
head_size,
Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/attention/static_sink_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import torch

from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
TORCH_DTYPE_TO_KV_CACHE_STR,
direct_register_custom_op,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
Expand Down Expand Up @@ -120,14 +123,18 @@ def __init__(
sink_len: int,
attn_backend: type[AttentionBackend] | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
else:
kv_cache_dtype = "auto"
assert model_config is not None, (
"model_config is required when cache_config is not provided"
)
kv_cache_dtype = TORCH_DTYPE_TO_KV_CACHE_STR[model_config.dtype]

if attn_backend is not None:
underlying_attn_backend = attn_backend
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/AXK1.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
model_config=vllm_config.model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
get_pp_group,
Expand Down Expand Up @@ -180,6 +180,7 @@ def __init__(
max_position_embeddings: int = 131072,
head_dim: int | None = None,
rms_norm_eps: float = 1e-05,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
Expand Down Expand Up @@ -259,6 +260,7 @@ def __init__(
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=self.sliding_window,
Expand Down Expand Up @@ -297,6 +299,7 @@ class AfmoeDecoderLayer(nn.Module):
def __init__(
self,
config, # AfmoeConfig
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
Expand All @@ -319,6 +322,7 @@ def __init__(
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
Expand Down Expand Up @@ -405,10 +409,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.embed_tokens = PPMissingLayer()

model_config = vllm_config.model_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: AfmoeDecoderLayer(
config=config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
Expand Down
15 changes: 12 additions & 3 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -230,6 +230,7 @@ class ArcticAttention(nn.Module):
def __init__(
self,
config: ArcticConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
Expand Down Expand Up @@ -285,6 +286,7 @@ def __init__(
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
Expand All @@ -307,6 +309,7 @@ class ArcticDecoderLayer(nn.Module):
def __init__(
self,
config: ArcticConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
Expand All @@ -318,7 +321,8 @@ def __init__(
self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(
config,
cache_config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
Expand Down Expand Up @@ -388,10 +392,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
)
model_config = vllm_config.model_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(
config, cache_config, quant_config, prefix=prefix
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
Expand Down
Loading
Loading