Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,26 @@ def adapter_memory_size(self) -> int:
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
return ADAPTER_MEMORY_FRACTION * total_gpu_memory


def init_graph_wrapper(self, max_total_tokens: int):
self.model_graph_wrapper = GraphCache(
self.model,
self.device,
self.kv_cache,
self.adapter_layers,
self.traced_adapter_layers,
self._forward_context,
max_total_tokens,
self.num_heads,
self.num_kv_heads,
self.sliding_window_blocks,
self.layer_to_lora_weights,
self.punica_wrapper,
)

def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model: bool = False):
logger.info(f'Pre warmup cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')

# The warmup batch is the biggest batch we could ever receive
max_total_tokens = batch.max_input_length + max_new_tokens + get_speculative_tokens()

Expand All @@ -1297,6 +1316,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
self.kv_dtype,
self.device,
)
logger.info(f'Pre warmup kv init cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')

if not embedding_model:
with warmup_mode():
Expand Down Expand Up @@ -1326,24 +1346,17 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
# Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache.
# Needs to be estimated here rather than fully initialized as the graph cache relies on the
# cache manager being set.
self.model_graph_wrapper = GraphCache(
self.model,
self.device,
self.kv_cache,
self.adapter_layers,
self.traced_adapter_layers,
self._forward_context,
max_total_tokens,
self.num_heads,
self.num_kv_heads,
self.sliding_window_blocks,
self.layer_to_lora_weights,
self.punica_wrapper,
)
self.init_graph_wrapper(max_total_tokens)
graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory()
logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024)
torch.cuda.synchronize(self.device)

logger.info(f'Post warmup cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')
del self.model_graph_wrapper
self.kv_cache = []
torch.cuda.synchronize(self.device)
torch.cuda.empty_cache()
logger.info(f'Post warmup empty_cache cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
Expand All @@ -1358,13 +1371,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
free_memory = max(0, free_memory - graph_cache_memory)
logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024)

batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
num_blocks = int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size)
logger.info(f"num kv blocks: {num_blocks}, num kv tokens: {num_blocks * BLOCK_SIZE}")

del batch

Expand All @@ -1379,7 +1387,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model

torch.cuda.synchronize(self.device)

if self.model_graph_wrapper is not None:
if self.compile:
self.init_graph_wrapper(max_total_tokens)
# Warmup the graph cache. Needs to be done after setting cache manager as
# tracing will use the static kv cache tensors
self.model_graph_wrapper.kv_cache = self.kv_cache
Expand Down