diff --git a/crates/cubecl-runtime/src/id.rs b/crates/cubecl-runtime/src/id.rs index 03f9c2738..98056db09 100644 --- a/crates/cubecl-runtime/src/id.rs +++ b/crates/cubecl-runtime/src/id.rs @@ -2,6 +2,8 @@ use alloc::format; use alloc::string::String; use alloc::string::ToString; use alloc::sync::Arc; +use core::sync::atomic::AtomicU32; +use core::sync::atomic::Ordering; use core::{ any::{Any, TypeId}, fmt::Display, @@ -43,17 +45,33 @@ macro_rules! storage_id_type { } /// Reference to a buffer handle. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct HandleRef { - id: Arc, - all: Arc<()>, + state: Arc<(Id, AtomicU32)>, +} + +impl Clone for HandleRef { + fn clone(&self) -> Self { + self.state.1.fetch_add(1, Ordering::Relaxed); + Self { + state: self.state.clone(), + } + } +} + +impl Clone for BindingRef { + fn clone(&self) -> Self { + self.state.1.fetch_add(1, Ordering::Relaxed); + Self { + state: self.state.clone(), + } + } } /// Reference to buffer binding. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct BindingRef { - id: Id, - _all: Arc<()>, + state: Arc<(Id, AtomicU32)>, } impl BindingRef @@ -62,7 +80,7 @@ where { /// The id associated to the buffer. pub(crate) fn id(&self) -> &Id { - &self.id + &self.state.0 } } @@ -73,33 +91,31 @@ where /// Create a new handle. pub(crate) fn new(id: Id) -> Self { Self { - id: Arc::new(id), - all: Arc::new(()), + state: Arc::new((id, AtomicU32::new(1))), } } /// The id associated to the handle. pub(crate) fn id(&self) -> &Id { - &self.id + &self.state.0 } /// Get the binding. pub(crate) fn binding(self) -> BindingRef { BindingRef { - id: self.id.as_ref().clone(), - _all: self.all, + state: self.state.clone(), } } /// If the handle can be mut. pub(crate) fn can_mut(&self) -> bool { // 1 memory management reference with 1 tensor reference. - Arc::strong_count(&self.id) <= 2 + self.state.1.load(Ordering::Relaxed) <= 2 } /// If the resource is free. pub(crate) fn is_free(&self) -> bool { - Arc::strong_count(&self.all) <= 1 + self.state.1.load(Ordering::Relaxed) <= 1 } } diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 2260b62c5..8d9b02e1d 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -1,8 +1,9 @@ use crate::{ + id::HandleRef, kernel::KernelMetadata, logging::ServerLogger, memory_management::{ - MemoryHandle, MemoryUsage, + MemoryHandle, MemoryUsage, SliceId, memory_pool::{SliceBinding, SliceHandle}, }, storage::{BindingResource, ComputeStorage}, @@ -12,8 +13,9 @@ use alloc::collections::BTreeMap; use alloc::sync::Arc; use alloc::vec::Vec; use core::fmt::Debug; -use cubecl_common::{ExecutionMode, benchmark::ProfileDuration, future::DynFut}; +use cubecl_common::{ExecutionMode, benchmark::ProfileDuration, future::DynFut, stub::Mutex}; use cubecl_ir::Elem; +use hashbrown::HashMap; /// The compute server is responsible for handling resources and computations over resources. /// @@ -123,6 +125,25 @@ pub struct Handle { pub offset_end: Option, /// Length of the underlying buffer ignoring offsets size: u64, + views: Views, +} + +#[derive(Debug, Default, Clone)] +pub struct Views { + parent: Option>>>, + // Check overlapping slice for the state. + children: Option>>>>, +} + +impl Drop for Views { + fn drop(&mut self) { + // If I'm the parent, I have nothing to do, but to drop the children list. + // + // If I'm the childen, I have to remove myself from the parent's list IF I'm the last clone + // if the same child. + if self.parent.is_some() { + } + } } impl Handle { @@ -338,6 +359,7 @@ impl Clone for Handle { offset_start: self.offset_start, offset_end: self.offset_end, size: self.size, + views: self.views.clone(), } } }