Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import gc
import multiprocessing as pymp
from typing import ClassVar
import hashlib
import json
from collections import OrderedDict

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -57,6 +60,8 @@ class ModelManager(BaseModel):
active_models: dict[ModelConfig, ReproducibleVLLM] = {}
used_ram: float = 0.0
lock: ClassVar[AsyncRLock] = AsyncRLock()
logits_cache: OrderedDict = Field(default_factory=OrderedDict)
max_cache_size: int = 150 #Shouldn't need 150 generations per step, and we only need to cache per step

async def load_always_active_models(self):
for model_config in self.always_active_models:
Expand Down Expand Up @@ -221,13 +226,49 @@ async def generate_logits(
seed: int = None,
continue_last_message: bool = False,
):
# Create a hashable key for the cache
if isinstance(model, ModelConfig):
model_key = model.llm_model_id
else: # If model is a string, it's a model ID
model_key = model

# Convert messages to a hashable format (tuple of strings)
messages_key = tuple(messages)

# Convert sampling_params to a hashable format
sampling_params_key = json.dumps(sampling_params, sort_keys=True) if sampling_params else None

# Create a cache key from all parameters
cache_key = (messages_key, model_key, sampling_params_key, seed, continue_last_message)

# Check if result is in cache
if cache_key in self.logits_cache:
logger.debug(f"Cache hit for logits generation with key {hash(cache_key)}")
# Move this entry to the end to mark it as most recently used
result = self.logits_cache.pop(cache_key)
self.logits_cache[cache_key] = result
return result

# Not in cache, generate logits
model_instance: ReproducibleVLLM = await self.get_model(model)
return await model_instance.generate_logits(
result = await model_instance.generate_logits(
messages=messages,
sampling_params=sampling_params,
seed=seed,
continue_last_message=continue_last_message,
)

# Check if cache is at max capacity
if len(self.logits_cache) >= self.max_cache_size:
# Remove the oldest item (first item in OrderedDict)
self.logits_cache.popitem(last=False)
logger.debug(f"Cache limit reached, removed oldest entry. Cache size: {len(self.logits_cache)}")

# Store in cache
self.logits_cache[cache_key] = result
logger.debug(f"Cached logits generation with key {hash(cache_key)}. Cache size: {len(self.logits_cache)}")

return result

async def _vram_cleanup(self):
"""Perform VRAM clean-up."""
Expand Down
Loading