diff --git a/vllm_rbln/v1/worker/rbln_worker.py b/vllm_rbln/v1/worker/rbln_worker.py index 42cad1b07..7d154c0aa 100644 --- a/vllm_rbln/v1/worker/rbln_worker.py +++ b/vllm_rbln/v1/worker/rbln_worker.py @@ -14,6 +14,7 @@ """A RBLN worker class.""" import copy +import math import os from types import NoneType from typing import TYPE_CHECKING @@ -279,6 +280,7 @@ def determine_available_memory(self) -> int: # NOTE - model parallel(tp, dp, ep, pp) # already applied into model params n_model_params = n_model_attentions + n_model_experts + head_size = self.model_config.get_head_size() available_memory_estimate = estimate_available_memory( model_config=self.model_config, @@ -293,7 +295,13 @@ def determine_available_memory(self) -> int: logger.info( "available_memory_estimate = %.2f GB", available_memory_estimate / 10**9 ) - + head_align_ratio = math.ceil(head_size / 64) * 64 / head_size + logger.info("head size align ratio = %s", head_align_ratio) + available_memory_estimate /= head_align_ratio + logger.info( + "available_memory_estimate considering 64B align = %.2f GB", + available_memory_estimate / 10**9 + ) return available_memory_estimate def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: