Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
from transformers.models.granite import GraniteConfig
from transformers.models.llama import LlamaConfig
from vllm.inputs import ProcessorInputs, PromptType, TokenInputs
from vllm.logger import init_logger

Expand Down Expand Up @@ -214,6 +215,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.model_config
):
cls.configure_granite_3_8b(vllm_config)

if cls.is_llama_3_1_8b(vllm_config.model_config):
cls.configure_llama_3_1_8b(vllm_config)

# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
Expand Down Expand Up @@ -678,6 +682,73 @@ def _set_env_with_validation(cls, env_var: str, default_value: int) -> None:
user_value,
default_value,
)

@classmethod
def configure_llama_3_1_8b(cls, vllm_config: VllmConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is all copy-pasted, which I would rather not do. I think @tjohnson31415 has been working on cleaning this up a bit to be more reusable, we should sync up on that

"""
Configure hard coded values for the model
https://huggingface.co/meta-llama/Llama-3.1-8B and other dense 8b variants.
"""
parallel_config = vllm_config.parallel_config

if parallel_config.world_size != 4:
# only override configs for TP=4
return

# Log once upfront that we detected the model
logger.info(
"Llama 3.1 8b dense model with tensor parallel size 4 detected. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a dense model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

"Applying model-specific configuration overrides."
)

tkv_128k = 128 * 1024
if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"):
os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(tkv_128k)
logger.info("Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", tkv_128k)
elif os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT") != str(tkv_128k):
logger.warning(
"VLLM_DT_MAX_BATCH_TKV_LIMIT was set to %s, not overriding to default of %d",
os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"),
tkv_128k,
)

# Set HDMA environment variables with validation
cls._set_env_with_validation("FLEX_HDMA_P2PSIZE", 256 * 1024 * 1024) # 256MB
cls._set_env_with_validation("FLEX_HDMA_COLLSIZE", 32 * 1024 * 1024) # 32MB

# Override the total number of KV cache blocks based on what we know
# will fit. (Unless user already set `--num-gpu-blocks-override`)
# TODO: remove this once we have correct free memory info available
if cls.sendnn_configured() and ((0, 0, 0) < cls.sendnn_version() < (1, 0, 3)):
# Older versions of torch_sendnn use the previous override of ~2k
# blocks.
# NB: A version of (0, 0, 0) means that the version of torch_sendnn
# could not be determined, and we assume this means we have a dev
# install of newer code.
blocks_override = 2080
else:
# If torch_sendnn is not configured or we have a newer torch_sendnn
# install, use the newer 8k override.
blocks_override = 8192

if vllm_config.cache_config.num_gpu_blocks_override is None:
vllm_config.cache_config.num_gpu_blocks_override = blocks_override
logger.info("Overriding available KV Cache blocks to %d", blocks_override)
elif vllm_config.cache_config.num_gpu_blocks_override != blocks_override:
logger.warning(
"--num-gpu-blocks-override was set to %d, not using default of %d",
vllm_config.cache_config.num_gpu_blocks_override,
blocks_override,
)

# hard-coded value for max_num_batched_tokens with chunked prefill
if (
envs_spyre.VLLM_SPYRE_USE_CHUNKED_PREFILL
and envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn"
and os.getenv("VLLM_DT_CHUNK_LEN") is None
):
logger.info("Setting --max-num-batched-tokens to 1024 for chunked prefill")
vllm_config.scheduler_config.max_num_batched_tokens = 1024

@classmethod
def configure_granite_3_8b(cls, vllm_config: VllmConfig):
Expand Down Expand Up @@ -763,6 +834,23 @@ def is_granite_3_8b(cls, model_config: ModelConfig):
and model_config.hf_config.num_attention_heads == 32
)

@classmethod
def is_llama_3_1_8b(cls, model_config: ModelConfig):
"""Returns true if we have a model that looks like
meta-llama/Llama-3.1-8B-Instruct"""
if not isinstance(model_config.hf_config, LlamaConfig):
# Not llama 3 at all
return False

return (
model_config.hf_config.num_hidden_layers == 32
and model_config.hf_config.max_position_embeddings == 131072
and model_config.hf_config.hidden_size == 4096
and model_config.hf_config.vocab_size == 128256
and model_config.hf_config.num_key_value_heads == 8
and model_config.hf_config.num_attention_heads == 32
)

@classmethod
def is_granite_4_8b_dense(cls, model_config: ModelConfig):
"""Returns true if we have a dense granite 4 model with the same architecture as granite 3.3
Expand Down