Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory pressure safety valve #1103

Merged
3 changes: 3 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@
warnings.simplefilter("ignore", ModelDependencyMissing)

DISK_CACHE_CLEANUP = str2bool(os.getenv("DISK_CACHE_CLEANUP", "True"))
MEMORY_FREE_THRESHOLD = float(
os.getenv("MEMORY_FREE_THRESHOLD", "0.0")
) # percentage of free memory, 0 disables memory pressure detection

# Stream manager configuration
try:
Expand Down
26 changes: 24 additions & 2 deletions inference/core/managers/decorators/fixed_size_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from inference.core import logger
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import DISK_CACHE_CLEANUP
from inference.core.env import DISK_CACHE_CLEANUP, MEMORY_FREE_THRESHOLD
from inference.core.managers.base import Model, ModelManager
from inference.core.managers.decorators.base import ModelManagerDecorator
from inference.core.managers.entities import ModelDescription
Expand Down Expand Up @@ -43,7 +43,9 @@ def add_model(
return None

logger.debug(f"Current capacity of ModelManager: {len(self)}/{self.max_size}")
while len(self) >= self.max_size:
while len(self) >= self.max_size or (
MEMORY_FREE_THRESHOLD and self.memory_pressure_detected()
):
to_remove_model_id = self._key_queue.popleft()
super().remove(
to_remove_model_id, delete_from_disk=DISK_CACHE_CLEANUP
Expand Down Expand Up @@ -141,3 +143,23 @@ def _resolve_queue_id(
self, model_id: str, model_id_alias: Optional[str] = None
) -> str:
return model_id if model_id_alias is None else model_id_alias

def memory_pressure_detected(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would feature flag this and import torch locally here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how this would work in multi-gpu env -but probably we can wait until someone actually uses it like that (seems like we are probing default device here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the value of MEMORY_FREE_THRESHOLD = 0 by default, and so the memory pressure is not checked by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

return_boolean = False
try:
import torch

if torch.cuda.is_available():
free_memory, total_memory = torch.cuda.mem_get_info()
return_boolean = (
float(free_memory / total_memory) < MEMORY_FREE_THRESHOLD
)
logger.debug(
f"Free memory: {free_memory}, Total memory: {total_memory}, threshold: {MEMORY_FREE_THRESHOLD}, return_boolean: {return_boolean}"
)
# TODO: Add memory calculation for other non-CUDA devices
except Exception as e:
logger.error(
f"Failed to check CUDA memory pressure: {e}, returning {return_boolean}"
)
return return_boolean