Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions crates/cubecl-runtime/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use alloc::format;
use alloc::string::String;
use alloc::string::ToString;
use alloc::sync::Arc;
use core::sync::atomic::AtomicU32;
use core::sync::atomic::Ordering;
use core::{
any::{Any, TypeId},
fmt::Display,
Expand Down Expand Up @@ -43,17 +45,33 @@ macro_rules! storage_id_type {
}

/// Reference to a buffer handle.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct HandleRef<Id> {
id: Arc<Id>,
all: Arc<()>,
state: Arc<(Id, AtomicU32)>,
}

impl<Id> Clone for HandleRef<Id> {
fn clone(&self) -> Self {
self.state.1.fetch_add(1, Ordering::Relaxed);
Self {
state: self.state.clone(),
}
}
}

impl<Id> Clone for BindingRef<Id> {
fn clone(&self) -> Self {
self.state.1.fetch_add(1, Ordering::Relaxed);
Self {
state: self.state.clone(),
}
}
}

/// Reference to buffer binding.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct BindingRef<Id> {
id: Id,
_all: Arc<()>,
state: Arc<(Id, AtomicU32)>,
}

impl<Id> BindingRef<Id>
Expand All @@ -62,7 +80,7 @@ where
{
/// The id associated to the buffer.
pub(crate) fn id(&self) -> &Id {
&self.id
&self.state.0
}
}

Expand All @@ -73,33 +91,31 @@ where
/// Create a new handle.
pub(crate) fn new(id: Id) -> Self {
Self {
id: Arc::new(id),
all: Arc::new(()),
state: Arc::new((id, AtomicU32::new(1))),
}
}

/// The id associated to the handle.
pub(crate) fn id(&self) -> &Id {
&self.id
&self.state.0
}

/// Get the binding.
pub(crate) fn binding(self) -> BindingRef<Id> {
BindingRef {
id: self.id.as_ref().clone(),
_all: self.all,
state: self.state.clone(),
}
}

/// If the handle can be mut.
pub(crate) fn can_mut(&self) -> bool {
// 1 memory management reference with 1 tensor reference.
Arc::strong_count(&self.id) <= 2
self.state.1.load(Ordering::Relaxed) <= 2
}

/// If the resource is free.
pub(crate) fn is_free(&self) -> bool {
Arc::strong_count(&self.all) <= 1
self.state.1.load(Ordering::Relaxed) <= 1
}
}

Expand Down
26 changes: 24 additions & 2 deletions crates/cubecl-runtime/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::{
id::HandleRef,
kernel::KernelMetadata,
logging::ServerLogger,
memory_management::{
MemoryHandle, MemoryUsage,
MemoryHandle, MemoryUsage, SliceId,
memory_pool::{SliceBinding, SliceHandle},
},
storage::{BindingResource, ComputeStorage},
Expand All @@ -12,8 +13,9 @@ use alloc::collections::BTreeMap;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt::Debug;
use cubecl_common::{ExecutionMode, benchmark::ProfileDuration, future::DynFut};
use cubecl_common::{ExecutionMode, benchmark::ProfileDuration, future::DynFut, stub::Mutex};
use cubecl_ir::Elem;
use hashbrown::HashMap;

/// The compute server is responsible for handling resources and computations over resources.
///
Expand Down Expand Up @@ -123,6 +125,25 @@ pub struct Handle {
pub offset_end: Option<u64>,
/// Length of the underlying buffer ignoring offsets
size: u64,
views: Views,
}

#[derive(Debug, Default, Clone)]
pub struct Views {
parent: Option<Arc<Mutex<HandleRef<SliceId>>>>,
// Check overlapping slice for the state.
children: Option<Arc<Mutex<HashMap<SliceId, HandleRef<SliceId>>>>>,
}

impl Drop for Views {
fn drop(&mut self) {
// If I'm the parent, I have nothing to do, but to drop the children list.
//
// If I'm the childen, I have to remove myself from the parent's list IF I'm the last clone
// if the same child.
if self.parent.is_some() {
}
}
}

impl Handle {
Expand Down Expand Up @@ -338,6 +359,7 @@ impl Clone for Handle {
offset_start: self.offset_start,
offset_end: self.offset_end,
size: self.size,
views: self.views.clone(),
}
}
}
Expand Down
Loading