Skip to content

Commit 11a77a5

Browse files
committed
Use mutex in set() Change wait(), get_wait()
Use a mutex to ensure only one caller of `set` at a time. Change `wait()` to call `get_wait()` and panic if none Add `get_wait()` which waits until the value is set and checks the state again and returns None if the state is still not Set.
1 parent 76a1dd8 commit 11a77a5

File tree

1 file changed

+74
-25
lines changed

1 file changed

+74
-25
lines changed

tokio/src/sync/set_once.rs

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use super::{Notify, SetError};
2-
use crate::loom::cell::UnsafeCell;
2+
use crate::{loom::cell::UnsafeCell, pin};
33
use std::fmt;
44
use std::mem::MaybeUninit;
55
use std::ops::Drop;
66
use std::ptr;
77
use std::sync::atomic::{AtomicBool, Ordering};
8+
use std::sync::Mutex;
89

910
// This file contains an implementation of an SetOnce. The value of SetOnce
1011
// can only be modified once during initialization.
@@ -73,6 +74,9 @@ pub struct SetOnce<T> {
7374
value_set: AtomicBool,
7475
value: UnsafeCell<MaybeUninit<T>>,
7576
notify: Notify,
77+
// we lock the mutex inside set to ensure
78+
// only one caller of set can run at a time
79+
lock: Mutex<()>,
7680
}
7781

7882
impl<T> Default for SetOnce<T> {
@@ -105,13 +109,11 @@ impl<T: Eq> Eq for SetOnce<T> {}
105109

106110
impl<T> Drop for SetOnce<T> {
107111
fn drop(&mut self) {
108-
if self.initialized() {
112+
if *self.value_set.get_mut() {
109113
// SAFETY: We're inside the drop implementation of SetOnce
110114
// AND we're also initalized. This is the best way to ensure
111115
// out data gets dropped
112-
unsafe {
113-
let _ = self.value.with_mut(|ptr| ptr::read(ptr).assume_init());
114-
}
116+
unsafe { self.value.with_mut(|ptr| ptr::drop_in_place(ptr as *mut T)) }
115117
// no need to set the flag to false as this set once is being
116118
// dropped
117119
}
@@ -124,6 +126,7 @@ impl<T> From<T> for SetOnce<T> {
124126
value_set: AtomicBool::new(true),
125127
value: UnsafeCell::new(MaybeUninit::new(value)),
126128
notify: Notify::new(),
129+
lock: Mutex::new(()),
127130
}
128131
}
129132
}
@@ -135,6 +138,7 @@ impl<T> SetOnce<T> {
135138
value_set: AtomicBool::new(false),
136139
value: UnsafeCell::new(MaybeUninit::uninit()),
137140
notify: Notify::new(),
141+
lock: Mutex::new(()),
138142
}
139143
}
140144

@@ -177,6 +181,7 @@ impl<T> SetOnce<T> {
177181
value_set: AtomicBool::new(false),
178182
value: UnsafeCell::new(MaybeUninit::uninit()),
179183
notify: Notify::const_new(),
184+
lock: Mutex::new(()),
180185
}
181186
}
182187

@@ -227,6 +232,7 @@ impl<T> SetOnce<T> {
227232
value_set: AtomicBool::new(true),
228233
value: UnsafeCell::new(MaybeUninit::new(value)),
229234
notify: Notify::const_new(),
235+
lock: Mutex::new(()),
230236
}
231237
}
232238

@@ -257,6 +263,9 @@ impl<T> SetOnce<T> {
257263
// called only when the value_set AtomicBool is flipped from FALSE to TRUE
258264
// meaning that the value is being set from uinitialized to initialized via
259265
// this function
266+
//
267+
// The caller also has to ensure writes on `value` are syncronized with a
268+
// external lock to prevent mutliple set_value calls at the same time.
260269
unsafe fn set_value(&self, value: T) {
261270
unsafe {
262271
self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value));
@@ -282,19 +291,30 @@ impl<T> SetOnce<T> {
282291
return Err(SetError::AlreadyInitializedError(value));
283292
}
284293

285-
// Using release ordering so any threads that read a true from this
286-
// atomic is able to read the value we just stored.
287-
if !self.value_set.swap(true, Ordering::Release) {
288-
// SAFETY: We are swapping the value_set AtomicBool from FALSE to
289-
// TRUE with it being previously false and followed by that we are
290-
// initializing the unsafe Cell field with the value
291-
unsafe {
292-
self.set_value(value);
293-
}
294+
// SAFETY: lock the mutex to ensure only one caller of set
295+
// can run at a time.
296+
match self.lock.lock() {
297+
Ok(_) => {
298+
// Using release ordering so any threads that read a true from this
299+
// atomic is able to read the value we just stored.
300+
if !self.value_set.swap(true, Ordering::Release) {
301+
// SAFETY: We are swapping the value_set AtomicBool from FALSE to
302+
// TRUE with it being previously false and followed by that we are
303+
// initializing the unsafe Cell field with the value
304+
unsafe {
305+
self.set_value(value);
306+
}
294307

295-
Ok(())
296-
} else {
297-
Err(SetError::InitializingError(value))
308+
Ok(())
309+
} else {
310+
Err(SetError::AlreadyInitializedError(value))
311+
}
312+
}
313+
Err(_) => {
314+
// If we failed to lock the mutex, it means some other task is
315+
// trying to set the value, so we return an error.
316+
Err(SetError::InitializingError(value))
317+
}
298318
}
299319
}
300320

@@ -317,16 +337,45 @@ impl<T> SetOnce<T> {
317337
}
318338
}
319339

320-
/// Waits until the `SetOnce` has been initialized. Once the `SetOnce` is
321-
/// initialized the wakers are notified and the Future returned from this
322-
/// function completes.
340+
/// Waits until set is called. The future returned will keep blocking until
341+
/// the `SetOnce` is initialized.
342+
///
343+
/// If the `SetOnce` is already initialized, it will return the value
344+
// immediately.
345+
///
346+
/// # Panics
347+
///
348+
/// If the `SetOnce` is not initialized after waiting, it will panic. To
349+
/// avoid this, use `get_wait()` which returns an `Option<&T>` instead of
350+
/// `&T`.
351+
pub async fn wait(&self) -> &T {
352+
match self.get_wait().await {
353+
Some(val) => val,
354+
_ => panic!("SetOnce::wait called but the SetOnce is not initialized"),
355+
}
356+
}
357+
358+
/// Waits until set is called.
323359
///
324-
/// If this function is called after the `SetOnce` is initialized then
325-
/// empty future is returned which completes immediately.
326-
pub async fn wait(&self) {
327-
if !self.initialized() {
328-
let _ = self.notify.notified().await;
360+
/// If the state failed to initalize it will return `None`.
361+
pub async fn get_wait(&self) -> Option<&T> {
362+
let notify_fut = self.notify.notified();
363+
pin!(notify_fut);
364+
365+
if self.value_set.load(Ordering::Acquire) {
366+
// SAFETY: the state is initialized
367+
return Some(unsafe { self.get_unchecked() });
368+
}
369+
// wait until the value is set
370+
(&mut notify_fut).await;
371+
372+
// look at the state again
373+
if self.value_set.load(Ordering::Acquire) {
374+
// SAFETY: the state is initialized
375+
return Some(unsafe { self.get_unchecked() });
329376
}
377+
378+
None
330379
}
331380
}
332381

0 commit comments

Comments
 (0)