-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Metal: bound temporary buffer cache and prevent runaway memory usage on large softmax/broadcast/matmul workloads #3197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
5d344a8
5dc33cc
cfe73fd
e63ea2e
b413f06
5d5f41d
e949fd4
42dccad
0435470
6d71931
5449778
57204f5
5c41d75
f3b39b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,10 @@ use candle_metal_kernels::{ | |
| use objc2_foundation::NSURL; | ||
| use objc2_metal::{MTLCaptureDescriptor, MTLCaptureDestination, MTLCaptureManager}; | ||
| use std::path::Path; | ||
| use std::sync::{Arc, Mutex, RwLock}; | ||
| use std::sync::{ | ||
| atomic::{AtomicUsize, Ordering}, | ||
| Arc, Mutex, RwLock, | ||
| }; | ||
|
|
||
| use super::MetalError; | ||
|
|
||
|
|
@@ -26,6 +29,61 @@ impl DeviceId { | |
| } | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| pub(crate) struct AllocationPolicy { | ||
| /// Total bytes we can allocate before forcing a sync to reclaim temporaries. | ||
| pending_limit_bytes: usize, | ||
| /// Maximum bytes to keep cached for reuse. | ||
| cache_limit_bytes: usize, | ||
| } | ||
|
|
||
| impl Default for AllocationPolicy { | ||
| fn default() -> Self { | ||
| const DEFAULT_PENDING: usize = 4 * 1024 * 1024 * 1024; // 4 GiB | ||
| const MIN_PENDING: usize = 512 * 1024 * 1024; // 512 MiB | ||
| const MAX_PENDING: usize = 12 * 1024 * 1024 * 1024; // 12 GiB | ||
|
|
||
TimmyOVO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| fn parse_env_mebibytes(var: &str) -> Option<usize> { | ||
| std::env::var(var) | ||
| .ok() | ||
| .and_then(|value| value.trim().parse::<usize>().ok()) | ||
| .map(|mb| mb * 1024 * 1024) | ||
TimmyOVO marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| fn system_memory_bytes() -> Option<usize> { | ||
| use libc::c_void; | ||
| let mut value: u64 = 0; | ||
| let mut len = core::mem::size_of::<u64>(); | ||
| let ret = unsafe { | ||
| libc::sysctlbyname( | ||
| b"hw.memsize\0".as_ptr() as *const libc::c_char, | ||
TimmyOVO marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| &mut value as *mut u64 as *mut c_void, | ||
| &mut len as *mut usize, | ||
| std::ptr::null_mut(), | ||
| 0, | ||
| ) | ||
| }; | ||
| if ret == 0 { | ||
| Some(value as usize) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
|
|
||
| let pending_limit = parse_env_mebibytes("CANDLE_METAL_PENDING_LIMIT_MB") | ||
| .or_else(|| system_memory_bytes().map(|mem| (mem / 3).clamp(MIN_PENDING, MAX_PENDING))) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just verifying that I understand the intention here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we pick pending_limit from |
||
| .unwrap_or(DEFAULT_PENDING); | ||
|
|
||
| let cache_limit = parse_env_mebibytes("CANDLE_METAL_CACHE_LIMIT_MB") | ||
| .unwrap_or_else(|| std::cmp::max(pending_limit / 2, 64 * 1024 * 1024)); | ||
TimmyOVO marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| crate::metal_backend::device::AllocationPolicy { | ||
| pending_limit_bytes: pending_limit, | ||
| cache_limit_bytes: cache_limit, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| pub struct MetalDevice { | ||
| /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than | ||
|
|
@@ -57,6 +115,10 @@ pub struct MetalDevice { | |
| pub(crate) kernels: Arc<Kernels>, | ||
| /// Seed for random number generation. | ||
| pub(crate) seed: Arc<Mutex<Buffer>>, | ||
| /// Bytes allocated since the last synchronization point. | ||
| pub(crate) pending_allocation_bytes: Arc<AtomicUsize>, | ||
ivarflakstad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// Allocation thresholds and cache budget. | ||
| pub(crate) allocation_policy: AllocationPolicy, | ||
| } | ||
|
|
||
| // Resource options used for creating buffers. Shared storage mode allows both CPU and GPU to access the buffer. | ||
|
|
@@ -112,14 +174,46 @@ impl MetalDevice { | |
| } | ||
|
|
||
| fn drop_unused_buffers(&self) -> Result<()> { | ||
| self.trim_buffer_cache_to(self.allocation_policy.cache_limit_bytes) | ||
| } | ||
|
|
||
| fn trim_buffer_cache_to(&self, limit: usize) -> Result<()> { | ||
| let mut buffers = self.buffers.write().map_err(MetalError::from)?; | ||
| for subbuffers in buffers.values_mut() { | ||
| let newbuffers = subbuffers | ||
| .iter() | ||
| .filter(|s| Arc::strong_count(*s) > 1) | ||
| .map(Arc::clone) | ||
| .collect(); | ||
| *subbuffers = newbuffers; | ||
| let mut cached_bytes = 0usize; | ||
| for (size, subbuffers) in buffers.iter() { | ||
| for buffer in subbuffers.iter() { | ||
| if Arc::strong_count(buffer) == 1 { | ||
| cached_bytes += *size; | ||
| } | ||
| } | ||
| } | ||
| if cached_bytes <= limit { | ||
| return Ok(()); | ||
| } | ||
|
|
||
| let mut bytes_to_drop = cached_bytes - limit; | ||
| let mut empty_keys = Vec::new(); | ||
| for (size, subbuffers) in buffers.iter_mut() { | ||
| if bytes_to_drop == 0 { | ||
| break; | ||
| } | ||
| subbuffers.retain(|buffer| { | ||
| if bytes_to_drop == 0 { | ||
| return true; | ||
| } | ||
| if Arc::strong_count(buffer) == 1 { | ||
ivarflakstad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bytes_to_drop = bytes_to_drop.saturating_sub(*size); | ||
| false | ||
| } else { | ||
| true | ||
| } | ||
| }); | ||
| if subbuffers.is_empty() { | ||
| empty_keys.push(*size); | ||
| } | ||
| } | ||
| for key in empty_keys { | ||
| buffers.remove(&key); | ||
TimmyOVO marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| Ok(()) | ||
| } | ||
|
|
@@ -211,6 +305,8 @@ impl MetalDevice { | |
| .map_err(MetalError::from)?; | ||
| let new_buffer = Arc::new(new_buffer); | ||
| subbuffers.push(new_buffer.clone()); | ||
| drop(buffers); | ||
| self.on_new_allocation(size)?; | ||
| Ok(new_buffer) | ||
| } | ||
|
|
||
|
|
@@ -235,6 +331,22 @@ impl MetalDevice { | |
| .map_err(|e| MetalError::from(e.to_string()))?; | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn on_new_allocation(&self, size: usize) -> Result<()> { | ||
| let pending = self | ||
| .pending_allocation_bytes | ||
| .fetch_add(size, Ordering::AcqRel) | ||
| .saturating_add(size); | ||
| if pending >= self.allocation_policy.pending_limit_bytes { | ||
TimmyOVO marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Ensure the GPU processed the backlog so buffers can be reused. | ||
| self.wait_until_completed()?; | ||
| self.pending_allocation_bytes.store(0, Ordering::Release); | ||
| // Drop part of the cache to keep the resident set under control. | ||
| let target = self.allocation_policy.cache_limit_bytes / 2; | ||
| self.trim_buffer_cache_to(target)?; | ||
| } | ||
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| fn buf_size(size: usize) -> usize { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.