diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index f5f78bb271..0a59c34ff4 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -8,8 +8,12 @@ use candle_metal_kernels::{ }; use objc2_foundation::NSURL; use objc2_metal::{MTLCaptureDescriptor, MTLCaptureDestination, MTLCaptureManager}; +use std::ffi::CStr; use std::path::Path; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, RwLock, +}; use super::MetalError; @@ -26,6 +30,96 @@ impl DeviceId { } } +#[derive(Clone)] +pub(crate) struct AllocationPolicy { + /// Maximum number of bytes we allow to be newly allocated since the last + /// synchronization point before forcing a sync to reclaim temporaries. + pending_allocation_bytes_limit: usize, + /// Maximum bytes to keep cached for reuse. + cache_limit_bytes: usize, +} + +impl Default for AllocationPolicy { + fn default() -> Self { + const DEFAULT_PENDING: usize = 4 * 1024 * 1024 * 1024; // 4 GiB + const MIN_PENDING: usize = 512 * 1024 * 1024; // 512 MiB + const MAX_PENDING: usize = 12 * 1024 * 1024 * 1024; // 12 GiB + const MIN_CACHE_LIMIT: usize = 64 * 1024 * 1024; // 64 MiB + const HW_MEMSIZE_KEY: &CStr = c"hw.memsize"; + const IOGPU_WIRED_LIMIT_MB_KEY: &CStr = c"iogpu.wired_limit_mb"; + + fn parse_env_mebibytes(var: &str) -> Option { + std::env::var(var) + .ok() + .and_then(|value| value.trim().parse::().ok()) + .and_then(|mb| mb.checked_mul(1024 * 1024)) + } + + fn sysctl_usize(name: &CStr) -> Option { + use libc::c_void; + unsafe { + let mut value: u64 = 0; + let mut len = core::mem::size_of::(); + if libc::sysctlbyname( + name.as_ptr(), + &mut value as *mut u64 as *mut c_void, + &mut len as *mut usize, + std::ptr::null_mut(), + 0, + ) != 0 + { + return None; + } + if len == 0 { + None + } else { + Some(value as usize) + } + } + } + + fn system_memory_bytes() -> Option { + const MEBIBYTE: usize = 1024 * 1024; + const SYSTEM_RESERVE_FRACTION: usize = 4; // Keep at least 25% for the OS. + const SYSTEM_RESERVE_MIN: usize = 2 * 1024 * 1024 * 1024; // 2 GiB floor. + + let hw_total = sysctl_usize(HW_MEMSIZE_KEY)?; + + let reserve = std::cmp::max(hw_total / SYSTEM_RESERVE_FRACTION, SYSTEM_RESERVE_MIN); + let hw_budget = hw_total.saturating_sub(reserve); + if hw_budget == 0 { + return None; + } + + let wired_limit_bytes = sysctl_usize(IOGPU_WIRED_LIMIT_MB_KEY).and_then(|limit_mb| { + if limit_mb == 0 { + None + } else { + limit_mb.checked_mul(MEBIBYTE) + } + }); + + if let Some(wired) = wired_limit_bytes { + Some(std::cmp::min(wired, hw_budget)) + } else { + Some(hw_budget) + } + } + + let pending_limit = parse_env_mebibytes("CANDLE_METAL_PENDING_LIMIT_MB") + .or_else(|| system_memory_bytes().map(|mem| (mem / 3).clamp(MIN_PENDING, MAX_PENDING))) + .unwrap_or(DEFAULT_PENDING); + + let cache_limit = parse_env_mebibytes("CANDLE_METAL_CACHE_LIMIT_MB") + .unwrap_or_else(|| std::cmp::max(pending_limit / 2, MIN_CACHE_LIMIT)); + + crate::metal_backend::device::AllocationPolicy { + pending_allocation_bytes_limit: pending_limit, + cache_limit_bytes: cache_limit, + } + } +} + #[derive(Clone)] pub struct MetalDevice { /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than @@ -57,6 +151,12 @@ pub struct MetalDevice { pub(crate) kernels: Arc, /// Seed for random number generation. pub(crate) seed: Arc>, + /// Bytes newly allocated since the last GPU synchronization point. This is + /// compared against `allocation_policy.pending_allocation_bytes_limit` to + /// decide when to force a sync and reclaim temporaries. + pub(crate) pending_allocation_bytes: Arc, + /// Allocation thresholds and cache budget. + pub(crate) allocation_policy: AllocationPolicy, } // Resource options used for creating buffers. Shared storage mode allows both CPU and GPU to access the buffer. @@ -112,14 +212,39 @@ impl MetalDevice { } fn drop_unused_buffers(&self) -> Result<()> { + self.trim_buffer_cache_to(self.allocation_policy.cache_limit_bytes) + } + + fn trim_buffer_cache_to(&self, limit: usize) -> Result<()> { let mut buffers = self.buffers.write().map_err(MetalError::from)?; - for subbuffers in buffers.values_mut() { - let newbuffers = subbuffers - .iter() - .filter(|s| Arc::strong_count(*s) > 1) - .map(Arc::clone) - .collect(); - *subbuffers = newbuffers; + let mut cached_bytes = 0usize; + for (size, subbuffers) in buffers.iter() { + for buffer in subbuffers.iter() { + if Arc::strong_count(buffer) == 1 { + cached_bytes += *size; + } + } + } + if cached_bytes <= limit { + return Ok(()); + } + + let mut bytes_to_drop = cached_bytes - limit; + for (size, subbuffers) in buffers.iter_mut() { + if bytes_to_drop == 0 { + break; + } + subbuffers.retain(|buffer| { + if bytes_to_drop == 0 { + return true; + } + if Arc::strong_count(buffer) == 1 { + bytes_to_drop = bytes_to_drop.saturating_sub(*size); + false + } else { + true + } + }); } Ok(()) } @@ -211,6 +336,8 @@ impl MetalDevice { .map_err(MetalError::from)?; let new_buffer = Arc::new(new_buffer); subbuffers.push(new_buffer.clone()); + drop(buffers); + self.on_new_allocation(size)?; Ok(new_buffer) } @@ -235,6 +362,22 @@ impl MetalDevice { .map_err(|e| MetalError::from(e.to_string()))?; Ok(()) } + + fn on_new_allocation(&self, size: usize) -> Result<()> { + let pending = self + .pending_allocation_bytes + .fetch_add(size, Ordering::AcqRel) + .saturating_add(size); + if pending >= self.allocation_policy.pending_allocation_bytes_limit { + // Ensure the GPU processed the backlog so buffers can be reused. + self.wait_until_completed()?; + self.pending_allocation_bytes.store(0, Ordering::Release); + // Drop part of the cache to keep the resident set under control. + let target = self.allocation_policy.cache_limit_bytes / 2; + self.trim_buffer_cache_to(target)?; + } + Ok(()) + } } fn buf_size(size: usize) -> usize { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e7a3324a3a..3dec8bf121 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -2,6 +2,7 @@ //! use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; +use crate::metal_backend::device::AllocationPolicy; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; use candle_metal_kernels::{ @@ -11,8 +12,7 @@ use candle_metal_kernels::{ use objc2_foundation::NSRange; use std::collections::HashMap; use std::ffi::c_void; -use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError}; - +use std::sync::{atomic::AtomicUsize, Arc, Mutex, PoisonError, RwLock, TryLockError}; mod device; pub use device::{DeviceId, MetalDevice}; @@ -2099,6 +2099,8 @@ impl BackendDevice for MetalDevice { buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, + pending_allocation_bytes: Arc::new(AtomicUsize::new(0)), + allocation_policy: AllocationPolicy::default(), }) }