-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Metal: bound temporary buffer cache and prevent runaway memory usage on large softmax/broadcast/matmul workloads #3197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5d344a8
5dc33cc
cfe73fd
e63ea2e
b413f06
5d5f41d
e949fd4
42dccad
0435470
6d71931
5449778
57204f5
5c41d75
f3b39b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,12 @@ use candle_metal_kernels::{ | |
| }; | ||
| use objc2_foundation::NSURL; | ||
| use objc2_metal::{MTLCaptureDescriptor, MTLCaptureDestination, MTLCaptureManager}; | ||
| use std::ffi::CStr; | ||
| use std::path::Path; | ||
| use std::sync::{Arc, Mutex, RwLock}; | ||
| use std::sync::{ | ||
| atomic::{AtomicUsize, Ordering}, | ||
| Arc, Mutex, RwLock, | ||
| }; | ||
|
|
||
| use super::MetalError; | ||
|
|
||
|
|
@@ -26,6 +30,100 @@ impl DeviceId { | |
| } | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| pub(crate) struct AllocationPolicy { | ||
| /// Maximum number of bytes we allow to be newly allocated since the last | ||
| /// synchronization point before forcing a sync to reclaim temporaries. | ||
| pending_allocation_bytes_limit: usize, | ||
| /// Maximum bytes to keep cached for reuse. | ||
| cache_limit_bytes: usize, | ||
| } | ||
|
|
||
| impl Default for AllocationPolicy { | ||
| fn default() -> Self { | ||
| const DEFAULT_PENDING: usize = 4 * 1024 * 1024 * 1024; // 4 GiB | ||
| const MIN_PENDING: usize = 512 * 1024 * 1024; // 512 MiB | ||
| const MAX_PENDING: usize = 12 * 1024 * 1024 * 1024; // 12 GiB | ||
| const MIN_CACHE_LIMIT: usize = 64 * 1024 * 1024; // 64 MiB | ||
| const HW_MEMSIZE_KEY: &CStr = c"hw.memsize"; | ||
| const IOGPU_WIRED_LIMIT_MB_KEY: &CStr = c"iogpu.wired_limit_mb"; | ||
|
|
||
| fn parse_env_mebibytes(var: &str) -> Option<usize> { | ||
| std::env::var(var) | ||
| .ok() | ||
| .and_then(|value| value.trim().parse::<usize>().ok()) | ||
| .and_then(|mb| mb.checked_mul(1024 * 1024)) | ||
| } | ||
| fn sysctl_u64(name: &CStr) -> Option<u64> { | ||
| use libc::c_void; | ||
| unsafe { | ||
| let mut value: u64 = 0; | ||
| let mut len = core::mem::size_of::<u64>(); | ||
| if libc::sysctlbyname( | ||
| name.as_ptr(), | ||
| &mut value as *mut u64 as *mut c_void, | ||
| &mut len as *mut usize, | ||
| std::ptr::null_mut(), | ||
| 0, | ||
| ) != 0 | ||
| { | ||
| return None; | ||
| } | ||
| if len == 0 { | ||
| None | ||
| } else { | ||
| Some(value) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn system_memory_bytes() -> Option<usize> { | ||
| const MEBIBYTE: usize = 1024 * 1024; | ||
| const SYSTEM_RESERVE_FRACTION: usize = 4; // Keep at least 25% for the OS. | ||
| const SYSTEM_RESERVE_MIN: usize = 2 * 1024 * 1024 * 1024; // 2 GiB floor. | ||
|
|
||
| let hw_total = sysctl_u64(HW_MEMSIZE_KEY).and_then(|bytes| { | ||
| if bytes == 0 { | ||
| None | ||
| } else { | ||
| Some(bytes as usize) | ||
| } | ||
| })?; | ||
|
|
||
| let reserve = std::cmp::max(hw_total / SYSTEM_RESERVE_FRACTION, SYSTEM_RESERVE_MIN); | ||
| let hw_budget = hw_total.saturating_sub(reserve); | ||
| if hw_budget == 0 { | ||
| return None; | ||
| } | ||
|
|
||
| let wired_limit_bytes = sysctl_u64(IOGPU_WIRED_LIMIT_MB_KEY).and_then(|limit_mb| { | ||
| if limit_mb == 0 { | ||
| return None; | ||
| } | ||
| (limit_mb as usize).checked_mul(MEBIBYTE) | ||
| }); | ||
|
|
||
| if let Some(wired) = wired_limit_bytes { | ||
| Some(std::cmp::min(wired, hw_budget)) | ||
| } else { | ||
| Some(hw_budget) | ||
| } | ||
| } | ||
|
|
||
| let pending_limit = parse_env_mebibytes("CANDLE_METAL_PENDING_LIMIT_MB") | ||
| .or_else(|| system_memory_bytes().map(|mem| (mem / 3).clamp(MIN_PENDING, MAX_PENDING))) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just verifying that I understand the intention here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we pick pending_limit from |
||
| .unwrap_or(DEFAULT_PENDING); | ||
|
|
||
| let cache_limit = parse_env_mebibytes("CANDLE_METAL_CACHE_LIMIT_MB") | ||
| .unwrap_or_else(|| std::cmp::max(pending_limit / 2, MIN_CACHE_LIMIT)); | ||
|
|
||
| crate::metal_backend::device::AllocationPolicy { | ||
| pending_allocation_bytes_limit: pending_limit, | ||
| cache_limit_bytes: cache_limit, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| pub struct MetalDevice { | ||
| /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than | ||
|
|
@@ -57,6 +155,12 @@ pub struct MetalDevice { | |
| pub(crate) kernels: Arc<Kernels>, | ||
| /// Seed for random number generation. | ||
| pub(crate) seed: Arc<Mutex<Buffer>>, | ||
| /// Bytes newly allocated since the last GPU synchronization point. This is | ||
| /// compared against `allocation_policy.pending_allocation_bytes_limit` to | ||
| /// decide when to force a sync and reclaim temporaries. | ||
| pub(crate) pending_allocation_bytes: Arc<AtomicUsize>, | ||
ivarflakstad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// Allocation thresholds and cache budget. | ||
| pub(crate) allocation_policy: AllocationPolicy, | ||
| } | ||
|
|
||
| // Resource options used for creating buffers. Shared storage mode allows both CPU and GPU to access the buffer. | ||
|
|
@@ -112,14 +216,39 @@ impl MetalDevice { | |
| } | ||
|
|
||
| fn drop_unused_buffers(&self) -> Result<()> { | ||
| self.trim_buffer_cache_to(self.allocation_policy.cache_limit_bytes) | ||
| } | ||
|
|
||
| fn trim_buffer_cache_to(&self, limit: usize) -> Result<()> { | ||
| let mut buffers = self.buffers.write().map_err(MetalError::from)?; | ||
| for subbuffers in buffers.values_mut() { | ||
| let newbuffers = subbuffers | ||
| .iter() | ||
| .filter(|s| Arc::strong_count(*s) > 1) | ||
| .map(Arc::clone) | ||
| .collect(); | ||
| *subbuffers = newbuffers; | ||
| let mut cached_bytes = 0usize; | ||
| for (size, subbuffers) in buffers.iter() { | ||
| for buffer in subbuffers.iter() { | ||
| if Arc::strong_count(buffer) == 1 { | ||
| cached_bytes += *size; | ||
| } | ||
| } | ||
| } | ||
| if cached_bytes <= limit { | ||
| return Ok(()); | ||
| } | ||
|
|
||
| let mut bytes_to_drop = cached_bytes - limit; | ||
| for (size, subbuffers) in buffers.iter_mut() { | ||
| if bytes_to_drop == 0 { | ||
| break; | ||
| } | ||
| subbuffers.retain(|buffer| { | ||
| if bytes_to_drop == 0 { | ||
| return true; | ||
| } | ||
| if Arc::strong_count(buffer) == 1 { | ||
ivarflakstad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bytes_to_drop = bytes_to_drop.saturating_sub(*size); | ||
| false | ||
| } else { | ||
| true | ||
| } | ||
| }); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
@@ -211,6 +340,8 @@ impl MetalDevice { | |
| .map_err(MetalError::from)?; | ||
| let new_buffer = Arc::new(new_buffer); | ||
| subbuffers.push(new_buffer.clone()); | ||
| drop(buffers); | ||
| self.on_new_allocation(size)?; | ||
| Ok(new_buffer) | ||
| } | ||
|
|
||
|
|
@@ -235,6 +366,22 @@ impl MetalDevice { | |
| .map_err(|e| MetalError::from(e.to_string()))?; | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn on_new_allocation(&self, size: usize) -> Result<()> { | ||
| let pending = self | ||
| .pending_allocation_bytes | ||
| .fetch_add(size, Ordering::AcqRel) | ||
| .saturating_add(size); | ||
| if pending >= self.allocation_policy.pending_allocation_bytes_limit { | ||
| // Ensure the GPU processed the backlog so buffers can be reused. | ||
| self.wait_until_completed()?; | ||
| self.pending_allocation_bytes.store(0, Ordering::Release); | ||
| // Drop part of the cache to keep the resident set under control. | ||
| let target = self.allocation_policy.cache_limit_bytes / 2; | ||
| self.trim_buffer_cache_to(target)?; | ||
| } | ||
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| fn buf_size(size: usize) -> usize { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.