diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 1dc9a26b1..4f5cc8425 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1273,7 +1273,26 @@ def adapter_memory_size(self) -> int: total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory return ADAPTER_MEMORY_FRACTION * total_gpu_memory + + def init_graph_wrapper(self, max_total_tokens: int): + self.model_graph_wrapper = GraphCache( + self.model, + self.device, + self.kv_cache, + self.adapter_layers, + self.traced_adapter_layers, + self._forward_context, + max_total_tokens, + self.num_heads, + self.num_kv_heads, + self.sliding_window_blocks, + self.layer_to_lora_weights, + self.punica_wrapper, + ) + def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model: bool = False): + logger.info(f'Pre warmup cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB') + # The warmup batch is the biggest batch we could ever receive max_total_tokens = batch.max_input_length + max_new_tokens + get_speculative_tokens() @@ -1297,6 +1316,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.kv_dtype, self.device, ) + logger.info(f'Pre warmup kv init cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB') if not embedding_model: with warmup_mode(): @@ -1326,24 +1346,17 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set. - self.model_graph_wrapper = GraphCache( - self.model, - self.device, - self.kv_cache, - self.adapter_layers, - self.traced_adapter_layers, - self._forward_context, - max_total_tokens, - self.num_heads, - self.num_kv_heads, - self.sliding_window_blocks, - self.layer_to_lora_weights, - self.punica_wrapper, - ) + self.init_graph_wrapper(max_total_tokens) graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory() logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024) torch.cuda.synchronize(self.device) + logger.info(f'Post warmup cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB') + del self.model_graph_wrapper + self.kv_cache = [] + torch.cuda.synchronize(self.device) + torch.cuda.empty_cache() + logger.info(f'Post warmup empty_cache cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB') # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.dtype).element_size() @@ -1358,13 +1371,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model free_memory = max(0, free_memory - graph_cache_memory) logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024) - batch_num_blocks = batch.num_blocks if batch is not None else 0 - num_blocks = ( - # Leave 5% for some wiggle room - int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size) - # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. - + batch_num_blocks - ) + num_blocks = int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size) + logger.info(f"num kv blocks: {num_blocks}, num kv tokens: {num_blocks * BLOCK_SIZE}") del batch @@ -1379,7 +1387,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model torch.cuda.synchronize(self.device) - if self.model_graph_wrapper is not None: + if self.compile: + self.init_graph_wrapper(max_total_tokens) # Warmup the graph cache. Needs to be done after setting cache manager as # tracing will use the static kv cache tensors self.model_graph_wrapper.kv_cache = self.kv_cache