From 3bc9d3a095e52c6004cbb64e92c189f9339590a7 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 6 Mar 2025 15:09:34 +0100 Subject: [PATCH] Do not alias fields of tracked_struct Values when updating --- src/input.rs | 8 +- src/interned.rs | 9 +- src/table.rs | 16 +- src/tracked_struct.rs | 276 ++++++++++++++-------------- src/tracked_struct/tracked_field.rs | 5 +- 5 files changed, 158 insertions(+), 156 deletions(-) diff --git a/src/input.rs b/src/input.rs index d555c2dd5..6bd00bf9a 100644 --- a/src/input.rs +++ b/src/input.rs @@ -281,8 +281,12 @@ where C: Configuration, { #[inline(always)] - unsafe fn memos(&self, _current_revision: Revision) -> &crate::table::memo::MemoTable { - &self.memos + unsafe fn memos( + this: *const Self, + _current_revision: Revision, + ) -> *const crate::table::memo::MemoTable { + // SAFETY: Caller obligation demands this pointer to be valid. + unsafe { &raw const (*this).memos } } #[inline(always)] diff --git a/src/interned.rs b/src/interned.rs index 98525f59c..1f9a63326 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -451,12 +451,15 @@ where C: Configuration, { #[inline(always)] - unsafe fn memos(&self, _current_revision: Revision) -> &MemoTable { - &self.memos + unsafe fn memos( + this: *const Self, + _current_revision: Revision, + ) -> *const crate::table::memo::MemoTable { + unsafe { &raw const (*this).memos } } #[inline(always)] - fn memos_mut(&mut self) -> &mut MemoTable { + fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable { &mut self.memos } } diff --git a/src/table.rs b/src/table.rs index d8e72ea22..e66078a60 100644 --- a/src/table.rs +++ b/src/table.rs @@ -36,26 +36,26 @@ pub(crate) trait Slot: Any + Send + Sync { /// # Safety condition /// /// The current revision MUST be the current revision of the database containing this slot. - unsafe fn memos(&self, current_revision: Revision) -> &MemoTable; + unsafe fn memos(slot: *const Self, current_revision: Revision) -> *const MemoTable; /// Mutably access the [`MemoTable`] for this slot. fn memos_mut(&mut self) -> &mut MemoTable; } /// [Slot::memos] -type SlotMemosFnRaw = unsafe fn(*const (), current_revision: Revision) -> *const MemoTable; +type SlotMemosFnErased = unsafe fn(*const (), current_revision: Revision) -> *const MemoTable; /// [Slot::memos] -type SlotMemosFn = unsafe fn(&T, current_revision: Revision) -> &MemoTable; +type SlotMemosFn = unsafe fn(*const T, current_revision: Revision) -> *const MemoTable; /// [Slot::memos_mut] -type SlotMemosMutFnRaw = unsafe fn(*mut ()) -> *mut MemoTable; +type SlotMemosMutFnErased = unsafe fn(*mut ()) -> *mut MemoTable; /// [Slot::memos_mut] type SlotMemosMutFn = fn(&mut T) -> &mut MemoTable; struct SlotVTable { layout: Layout, /// [`Slot`] methods - memos: SlotMemosFnRaw, - memos_mut: SlotMemosMutFnRaw, + memos: SlotMemosFnErased, + memos_mut: SlotMemosMutFnErased, /// A drop impl to call when the own page drops /// SAFETY: The caller is required to supply a correct data pointer to a `Box>` and initialized length, /// and correct memo types. @@ -78,10 +78,10 @@ impl SlotVTable { }, layout: Layout::new::(), // SAFETY: The signatures are compatible - memos: unsafe { mem::transmute::, SlotMemosFnRaw>(T::memos) }, + memos: unsafe { mem::transmute::, SlotMemosFnErased>(T::memos) }, // SAFETY: The signatures are compatible memos_mut: unsafe { - mem::transmute::, SlotMemosMutFnRaw>(T::memos_mut) + mem::transmute::, SlotMemosMutFnErased>(T::memos_mut) }, } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 44d0a6062..5e5e64f80 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -157,7 +157,7 @@ where ingredient_index: IngredientIndex, /// Phantom data: we fetch `Value` out from `Table` - phantom: PhantomData Value>, + phantom: PhantomData ValueWithMetadata>, /// Store freed ids free_list: SegQueue, @@ -262,19 +262,12 @@ impl IdentityMap { } // ANCHOR: ValueStruct -#[derive(Debug)] -pub struct Value +pub struct ValueWithMetadata where C: Configuration, { - /// The minimum durability of all inputs consumed by the creator - /// query prior to creating this tracked struct. If any of those - /// inputs changes, then the creator query may create this struct - /// with different values. - durability: Durability, - /// The revision when this tracked struct was last updated. - /// This field also acts as a kind of "lock". Once it is equal + /// This field also acts as a kind of "lock" over the `value` field. Once it is equal /// to `Some(current_revision)`, the fields are locked and /// cannot change further. This makes it safe to give out `&`-references /// so long as they do not live longer than the current revision @@ -294,14 +287,26 @@ where /// leaked a reference across threads somehow. updated_at: OptionalAtomicRevision, - /// Fields of this tracked struct. They can change across revisions, - /// but they do not change within a particular revision. - fields: C::Fields<'static>, + /// The durability minimum durability of all inputs consumed + /// by the creator query prior to creating this tracked struct. + /// If any of those inputs changes, then the creator query may + /// create this struct with different values. + durability: Durability, /// The revision information for each field: when did this field last change. /// When tracked structs are re-created, this revision may be updated to the /// current revision if the value is different. revisions: C::Revisions, + value: Value, +} + +pub struct Value +where + C: Configuration, +{ + /// Fields of this tracked struct. They can change across revisions, + /// but they do not change within a particular revision. + fields: C::Fields<'static>, /// Memo table storing the results of query functions etc. /*unsafe */ @@ -386,7 +391,6 @@ where disambiguator, }; - let current_revision = zalsa.current_revision(); if let Some(id) = zalsa_local.tracked_struct_id(&identity) { // The struct already exists in the intern map. let index = self.database_key_index(id); @@ -394,8 +398,7 @@ where zalsa_local.add_output(index); // SAFETY: The `id` was present in the interned map, so the value must be initialized. - let update_result = - unsafe { self.update(zalsa, current_revision, id, ¤t_deps, fields) }; + let update_result = unsafe { self.update(zalsa, id, ¤t_deps, fields) }; fields = match update_result { // Overwrite the previous ID if we are reusing the old slot with new fields. @@ -414,7 +417,7 @@ where // We failed to perform the update, or this is a new tracked struct, so allocate a new entry // in the struct map. - let id = self.allocate(zalsa, zalsa_local, current_revision, ¤t_deps, fields); + let id = self.allocate(zalsa, zalsa_local, ¤t_deps, fields); let key = self.database_key_index(id); tracing::trace!("Allocated new tracked struct {key:?}"); zalsa_local.add_output(key); @@ -426,17 +429,19 @@ where &'db self, zalsa: &'db Zalsa, zalsa_local: &'db ZalsaLocal, - current_revision: Revision, current_deps: &StampedValue<()>, fields: C::Fields<'db>, ) -> Id { - let value = |_| Value { + let current_revision = zalsa.current_revision(); + let value = |_| ValueWithMetadata { updated_at: OptionalAtomicRevision::new(Some(current_revision)), durability: current_deps.durability, - // lifetime erase for storage - fields: unsafe { mem::transmute::, C::Fields<'static>>(fields) }, revisions: C::new_revisions(current_deps.changed_at), - memos: Default::default(), + value: Value { + // SAFETY: We just erase the lifetime + fields: unsafe { mem::transmute::, C::Fields<'static>>(fields) }, + memos: Default::default(), + }, }; while let Some(id) = self.free_list.pop() { @@ -454,22 +459,19 @@ where continue; }; - // SAFETY: We just removed `id` from the free-list, so we have exclusive access. - let data = unsafe { &mut *Self::data_raw(zalsa.table(), id) }; - - assert!( - data.updated_at.load().is_none(), + // SAFETY: `data_raw` is a live-unaliased pointer + let data_raw = unsafe { &mut *Self::data_raw(zalsa.table(), id) }; + debug_assert!( + data_raw.updated_at.load().is_none(), "free list entry for `{id:?}` does not have `None` for `updated_at`" ); // Overwrite the free-list entry. Use `*foo = ` because the entry // has been previously initialized and we want to free the old contents. - *data = value(id); - + *data_raw = value(id); return id; } - - zalsa_local.allocate::>(zalsa, self.ingredient_index, value) + zalsa_local.allocate::>(zalsa, self.ingredient_index, value) } /// Get mutable access to the data for `id` -- this holds a write lock for the duration @@ -486,7 +488,6 @@ where unsafe fn update<'db>( &'db self, zalsa: &'db Zalsa, - current_revision: Revision, mut id: Id, current_deps: &StampedValue<()>, fields: C::Fields<'db>, @@ -507,8 +508,9 @@ where // In that case we should not modify or touch it because there may be // `&`-references to its contents floating around. // - // Observing `Some(current_revision)` can happen in two scenarios: leaks (tsk tsk) - // but also the scenario embodied by the test test `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`: + // Observing `Some(current_revision)` can happen in two scenarios: + // - leaks, see tests\preverify-struct-with-leaked-data.rs and tests\preverify-struct-with-leaked-data-2.rs + // - or the following scenario (FIXME verify this? There is no test that covers this behavior): // // * Revision 1: // * Tracked function F creates tracked struct S @@ -528,17 +530,15 @@ where // that is still live. { - // SAFETY: Guaranteed by caller. - let data = unsafe { &*data_raw }; - - let last_updated_at = data.updated_at.load(); + // SAFETY: `updated_at` is never exclusively borrowed, so borrowing it is sound + let last_updated_at = unsafe { (*data_raw).updated_at.load() }; assert!( last_updated_at.is_some(), "two concurrent writers to {id:?}, should not be possible" ); // The value is already read-locked, but we can reuse it safely as per above. - if last_updated_at == Some(current_revision) { + if last_updated_at == Some(zalsa.current_revision()) { return Ok(id); } @@ -557,34 +557,30 @@ where // Acquire the write-lock. This can only fail if there is a parallel thread // reading from this same `id`, which can only happen if the user has leaked it. // Tsk tsk. - let swapped_out = data.updated_at.swap(None); - if swapped_out != last_updated_at { + // SAFETY: `updated_at` is never exclusively borrowed, so borrowing it is sound + let swapped = unsafe { (*data_raw).updated_at.swap(None) }; + if last_updated_at != swapped { panic!( - "failed to acquire write lock, id `{id:?}` must have been leaked across threads" - ); + "failed to acquire write lock, id `{id:?}` \ + must have been leaked across threads" + ); } } - // UNSAFE: Marking as mut requires exclusive access for the duration of - // the `mut`. We have now *claimed* this data by swapping in `None`, - // any attempt to read concurrently will panic. - let data = unsafe { &mut *data_raw }; + // SAFETY: We have now *claimed* mutable access to the `value` field by swapping in `None`, + // any attempt to read concurrently will panic so it is safe to take exclusive references. + let old_fields = unsafe { &raw mut (*data_raw).value.fields }.cast::>(); + + // SAFETY: FIXME: why can this not alias with tracked_field::maybe_changed_after? + let revisions = unsafe { &mut (*data_raw).revisions }; // SAFETY: We assert that the pointer to `data.revisions` // is a pointer into the database referencing a value // from a previous revision. As such, it continues to meet // its validity invariant and any owned content also continues // to meet its safety invariant. - let untracked_update = unsafe { - C::update_fields( - current_deps.changed_at, - &mut data.revisions, - mem::transmute::<*mut C::Fields<'static>, *mut C::Fields<'db>>( - std::ptr::addr_of_mut!(data.fields), - ), - fields, - ) - }; + let untracked_update = + unsafe { C::update_fields(current_deps.changed_at, revisions, old_fields, fields) }; if untracked_update { // Consider this a new tracked-struct when any non-tracked field got updated. @@ -593,7 +589,7 @@ where // Note that we hold the lock and have exclusive access to the tracked struct data, // so there should be no live instances of IDs from the previous generation. We clear // the memos and return a new ID here as if we have allocated a new slot. - let mut table = data.take_memo_table(); + let mut table = unsafe { mem::take(&mut (*data_raw).value.memos) }; // SAFETY: The memo table belongs to a value that we allocated, so it has the // correct type. @@ -604,24 +600,40 @@ where .expect("already verified that generation is not maximum"); } - if current_deps.durability < data.durability { - data.revisions = C::new_revisions(current_deps.changed_at); + let durability = unsafe { &mut (*data_raw).durability }; + if current_deps.durability < *durability { + *revisions = C::new_revisions(current_deps.changed_at); } - data.durability = current_deps.durability; - let swapped_out = data.updated_at.swap(Some(current_revision)); - assert!(swapped_out.is_none()); + *durability = current_deps.durability; + // SAFETY: `updated_at` is never exclusively borrowed, so borrowing it is sound + // release the lock + let swapped_out = unsafe { (*data_raw).updated_at.swap(Some(zalsa.current_revision())) }; + assert!( + swapped_out.is_none(), + "two concurrent writers to {id:?}, should not be possible" + ); Ok(id) } /// Fetch the data for a given id created by this ingredient from the table, /// -giving it the appropriate type. - fn data(table: &Table, id: Id) -> &Value { - table.get(id) + fn data_raw(table: &Table, id: Id) -> *mut ValueWithMetadata { + table.get_raw(id) } - fn data_raw(table: &Table, id: Id) -> *mut Value { - table.get_raw(id) + /// # Safety + /// + /// `data` must be a valid pointer to a `ValueWithMetadata` + unsafe fn fields( + &self, + data: *const ValueWithMetadata, + current_revision: Revision, + ) -> &C::Fields<'_> { + // SAFETY: `data` is a valid pointer + acquire_read_lock(unsafe { &(*data).updated_at }, current_revision); + // SAFETY: We have acquired a read lock, so `values` is not aliased + unsafe { mem::transmute::<&C::Fields<'static>, &C::Fields<'_>>(&(*data).value.fields) } } /// Deletes the given entities. This is used after a query `Q` executes and we can compare @@ -641,32 +653,23 @@ where }) }); - let current_revision = zalsa.current_revision(); - let data_raw = Self::data_raw(zalsa.table(), id); - - { - let data = unsafe { &*data_raw }; + let data = Self::data_raw(zalsa.table(), id); - // We want to set `updated_at` to `None`, signalling that other field values - // cannot be read. The current value should be `Some(R0)` for some older revision. - match data.updated_at.load() { - None => { - panic!("cannot delete write-locked id `{id:?}`; value leaked across threads"); - } - Some(r) if r == current_revision => panic!( - "cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic" - ), - Some(r) => { - if data.updated_at.compare_exchange(Some(r), None).is_err() { - panic!("race occurred when deleting value `{id:?}`") - } - } + // We want to set `updated_at` to `None`, signalling that other field values + // cannot be read. The current value should be `Some(R0)` for some older revision. + match unsafe { (*data).updated_at.swap(None) }{ + None => { + panic!("cannot delete write-locked id `{id:?}`; value leaked across threads"); } + Some(r) if r == zalsa.current_revision() => panic!( + "cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic" + ), + Some(_) => () } - // SAFETY: We have acquired the write lock - let data = unsafe { &mut *data_raw }; - let mut memo_table = data.take_memo_table(); + // Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None` + // signalling that we have acquired the write lock + let mut memo_table = unsafe { mem::take(&mut (*data).value.memos) }; // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. @@ -726,8 +729,10 @@ where s: C::Struct<'db>, ) -> &'db C::Fields<'db> { let id = AsId::as_id(&s); - let data = Self::data(db.zalsa().table(), id); - data.fields() + let zalsa = db.zalsa(); + let data = Self::data_raw(zalsa.table(), id); + // SAFETY: `data` is a valid pointer acquired from the table. + unsafe { self.fields(data, zalsa.current_revision()) } } /// Access to this tracked field. @@ -743,19 +748,18 @@ where let (zalsa, zalsa_local) = db.zalsas(); let id = AsId::as_id(&s); let field_ingredient_index = self.ingredient_index.successor(relative_tracked_index); - let data = Self::data(zalsa.table(), id); + let data = Self::data_raw(zalsa.table(), id); - data.read_lock(zalsa.current_revision()); - - let field_changed_at = data.revisions[relative_tracked_index]; + let field_changed_at = unsafe { (&(*data).revisions)[relative_tracked_index] }; zalsa_local.report_tracked_read_simple( DatabaseKeyIndex::new(field_ingredient_index, id), - data.durability, + unsafe { (*data).durability }, field_changed_at, ); - data.fields() + // SAFETY: `data` is a valid pointer acquired from the table. + unsafe { self.fields(data, zalsa.current_revision()) } } /// Access to this untracked field. @@ -769,15 +773,13 @@ where ) -> &'db C::Fields<'db> { let zalsa = db.zalsa(); let id = AsId::as_id(&s); - let data = Self::data(zalsa.table(), id); - - data.read_lock(zalsa.current_revision()); + let data = Self::data_raw(zalsa.table(), id); // Note that we do not need to add a dependency on the tracked struct // as IDs that are reused increment their generation, invalidating any // dependent queries directly. - - data.fields() + // SAFETY: `data` is a valid pointer acquired from the table. + unsafe { self.fields(data, zalsa.current_revision()) } } #[cfg(feature = "salsa_unstable")] @@ -785,8 +787,8 @@ where pub fn entries<'db>( &'db self, db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + ) -> impl Iterator> { + db.zalsa().table().slots_of::>() } } @@ -857,7 +859,7 @@ where } } -impl Value +impl ValueWithMetadata where C: Configuration, { @@ -866,60 +868,54 @@ where /// They can change across revisions, but they do not change within /// a particular revision. #[cfg_attr(not(feature = "salsa_unstable"), doc(hidden))] - pub fn fields(&self) -> &C::Fields<'_> { - // SAFETY: We are shrinking the lifetime from storage back to the db lifetime. - unsafe { mem::transmute::<&C::Fields<'static>, &C::Fields<'_>>(&self.fields) } - } - - fn take_memo_table(&mut self) -> MemoTable { - // This fn is only called after `updated_at` has been set to `None`; - // this ensures that there is no concurrent access - // (and that the `&mut self` is accurate...). - assert!(self.updated_at.load().is_none()); - - mem::take(&mut self.memos) + pub fn fields(&self) -> &C::Fields<'static> { + &self.value.fields } +} - fn read_lock(&self, current_revision: Revision) { - loop { - match self.updated_at.load() { - None => { - panic!("access to field whilst the value is being initialized"); - } - Some(r) => { - if r == current_revision { - return; - } - - if self - .updated_at - .compare_exchange(Some(r), Some(current_revision)) - .is_ok() - { - break; - } +#[inline] +fn acquire_read_lock(updated_at: &OptionalAtomicRevision, current_revision: Revision) { + loop { + match updated_at.load() { + None => panic!( + "write lock taken; value leaked across threads or user functions not deterministic" + ), + // the read lock was taken by someone else, so we also succeed + Some(r) if r == current_revision => return, + Some(r) => { + if updated_at + .compare_exchange(Some(r), Some(current_revision)) + .is_ok() + { + break; } } } } } -impl Slot for Value +impl Slot for ValueWithMetadata where C: Configuration, { #[inline(always)] - unsafe fn memos(&self, current_revision: Revision) -> &crate::table::memo::MemoTable { + unsafe fn memos( + this: *const Self, + current_revision: Revision, + ) -> *const crate::table::memo::MemoTable { // Acquiring the read lock here with the current revision // ensures that there is no danger of a race // when deleting a tracked struct. - self.read_lock(current_revision); - &self.memos + // SAFETY: `this` is a valid pointer given the caller obligation + unsafe { acquire_read_lock(&(*this).updated_at, current_revision) }; + // SAFETY: `this` is a valid pointer given the caller obligation and we have acquired a read + // lock, so `values` is not aliased + unsafe { &raw const (*this).value.memos } } #[inline(always)] fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable { - &mut self.memos + &mut self.value.memos } } diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 5ec38c680..c1bab0c07 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -62,9 +62,8 @@ where revision: crate::Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); - let data = >::data(zalsa.table(), input); - let field_changed_at = data.revisions[self.field_index]; + let data = >::data_raw(db.zalsa().table(), input); + let field_changed_at = unsafe { (&(*data).revisions)[self.field_index] }; VerifyResult::changed_if(field_changed_at > revision) }