diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fd0c397..ef995d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,6 +78,18 @@ jobs: run: rustup update stable - run: cargo clippy --all-features --all-targets + loom: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust + run: rustup update stable + - name: Loom tests + run: cargo test --release --test loom --features loom + env: + RUSTFLAGS: "--cfg=loom" + LOOM_MAX_PREEMPTIONS: 4 + fmt: runs-on: ubuntu-latest steps: diff --git a/Cargo.toml b/Cargo.toml index 88a24be..0a5ad1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,10 +18,25 @@ exclude = ["/.*"] event-listener = { version = "5.0.0", default-features = false } event-listener-strategy = { version = "0.5.0", default-features = false } pin-project-lite = "0.2.11" +portable-atomic-util = { version = "0.1.4", default-features = false, optional = true, features = ["alloc"] } + +[dependencies.portable_atomic_crate] +package = "portable-atomic" +version = "1.2.0" +default-features = false +optional = true + +[target.'cfg(loom)'.dependencies] +loom = { version = "0.7", optional = true } [features] default = ["std"] +portable-atomic = ["portable-atomic-util", "portable_atomic_crate"] std = ["event-listener/std", "event-listener-strategy/std"] +loom = ["event-listener/loom", "dep:loom"] + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(loom)'] } [dev-dependencies] async-channel = "2.2.0" diff --git a/src/barrier.rs b/src/barrier.rs index 8d06fc1..ff6eba8 100644 --- a/src/barrier.rs +++ b/src/barrier.rs @@ -23,28 +23,31 @@ struct State { } impl Barrier { - /// Creates a barrier that can block the given number of tasks. - /// - /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks - /// at once when the `n`th task calls [`wait()`]. - /// - /// [`wait()`]: `Barrier::wait()` - /// - /// # Examples - /// - /// ``` - /// use async_lock::Barrier; - /// - /// let barrier = Barrier::new(5); - /// ``` - pub const fn new(n: usize) -> Barrier { - Barrier { - n, - state: Mutex::new(State { - count: 0, - generation_id: 0, - }), - event: Event::new(), + const_fn! { + const_if: #[cfg(not(loom))]; + /// Creates a barrier that can block the given number of tasks. + /// + /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks + /// at once when the `n`th task calls [`wait()`]. + /// + /// [`wait()`]: `Barrier::wait()` + /// + /// # Examples + /// + /// ``` + /// use async_lock::Barrier; + /// + /// let barrier = Barrier::new(5); + /// ``` + pub const fn new(n: usize) -> Barrier { + Barrier { + n, + state: Mutex::new(State { + count: 0, + generation_id: 0, + }), + event: Event::new(), + } } } diff --git a/src/lib.rs b/src/lib.rs index 6d64aa3..8a62869 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,6 +71,22 @@ macro_rules! pin { } } +/// Make the given function const if the given condition is true. +macro_rules! const_fn { + ( + const_if: #[cfg($($cfg:tt)+)]; + $(#[$($attr:tt)*])* + $vis:vis const fn $($rest:tt)* + ) => { + #[cfg($($cfg)+)] + $(#[$($attr)*])* + $vis const fn $($rest)* + #[cfg(not($($cfg)+))] + $(#[$($attr)*])* + $vis fn $($rest)* + }; +} + mod barrier; mod mutex; mod once_cell; @@ -97,6 +113,38 @@ pub mod futures { pub use crate::semaphore::{Acquire, AcquireArc}; } +#[cfg(not(loom))] +/// Synchronization primitive implementation. +mod sync { + pub(super) use core::sync::atomic; + + pub(super) trait WithMut { + type Output; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Output) -> R; + } + + impl WithMut for atomic::AtomicUsize { + type Output = usize; + + #[inline] + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Output) -> R, + { + f(self.get_mut()) + } + } +} + +#[cfg(loom)] +/// Synchronization primitive implementation. +mod sync { + pub(super) use loom::sync::atomic; +} + #[cold] fn abort() -> ! { // For no_std targets, panicking while panicking is defined as an abort diff --git a/src/mutex.rs b/src/mutex.rs index 4dc6bac..85c43c5 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -4,12 +4,14 @@ use core::fmt; use core::marker::{PhantomData, PhantomPinned}; use core::ops::{Deref, DerefMut}; use core::pin::Pin; -use core::sync::atomic::{AtomicUsize, Ordering}; use core::task::Poll; use core::usize; use alloc::sync::Arc; +// We don't use loom::UnsafeCell as that doesn't work with the Mutex API. +use crate::sync::atomic::{AtomicUsize, Ordering}; + #[cfg(all(feature = "std", not(target_family = "wasm")))] use std::time::{Duration, Instant}; @@ -56,20 +58,23 @@ unsafe impl Send for Mutex {} unsafe impl Sync for Mutex {} impl Mutex { - /// Creates a new async mutex. - /// - /// # Examples - /// - /// ``` - /// use async_lock::Mutex; - /// - /// let mutex = Mutex::new(0); - /// ``` - pub const fn new(data: T) -> Mutex { - Mutex { - state: AtomicUsize::new(0), - lock_ops: Event::new(), - data: UnsafeCell::new(data), + const_fn! { + const_if: #[cfg(not(loom))]; + /// Creates a new async mutex. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Mutex; + /// + /// let mutex = Mutex::new(0); + /// ``` + pub const fn new(data: T) -> Mutex { + Mutex { + state: AtomicUsize::new(0), + lock_ops: Event::new(), + data: UnsafeCell::new(data), + } } } @@ -186,7 +191,7 @@ impl Mutex { /// # }) /// ``` pub fn get_mut(&mut self) -> &mut T { - unsafe { &mut *self.data.get() } + self.data.get_mut() } /// Unlocks the mutex directly. @@ -756,12 +761,3 @@ impl DerefMut for MutexGuardArc { unsafe { &mut *self.0.data.get() } } } - -/// Calls a function when dropped. -struct CallOnDrop(F); - -impl Drop for CallOnDrop { - fn drop(&mut self) { - (self.0)(); - } -} diff --git a/src/once_cell.rs b/src/once_cell.rs index 8d9485d..6554292 100644 --- a/src/once_cell.rs +++ b/src/once_cell.rs @@ -4,7 +4,11 @@ use core::fmt; use core::future::Future; use core::mem::{forget, MaybeUninit}; use core::ptr; -use core::sync::atomic::{AtomicUsize, Ordering}; + +use crate::sync::atomic::{AtomicUsize, Ordering}; + +#[cfg(not(loom))] +use crate::sync::WithMut; #[cfg(all(feature = "std", not(target_family = "wasm")))] use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; @@ -107,22 +111,25 @@ unsafe impl Send for OnceCell {} unsafe impl Sync for OnceCell {} impl OnceCell { - /// Create a new, uninitialized `OnceCell`. - /// - /// # Example - /// - /// ```rust - /// use async_lock::OnceCell; - /// - /// let cell = OnceCell::new(); - /// # cell.set_blocking(1); - /// ``` - pub const fn new() -> Self { - Self { - active_initializers: Event::new(), - passive_waiters: Event::new(), - state: AtomicUsize::new(State::Uninitialized as _), - value: UnsafeCell::new(MaybeUninit::uninit()), + const_fn! { + const_if: #[cfg(not(loom))]; + /// Create a new, uninitialized `OnceCell`. + /// + /// # Example + /// + /// ```rust + /// use async_lock::OnceCell; + /// + /// let cell = OnceCell::new(); + /// # cell.set_blocking(1); + /// ``` + pub const fn new() -> Self { + Self { + active_initializers: Event::new(), + passive_waiters: Event::new(), + state: AtomicUsize::new(State::Uninitialized as _), + value: UnsafeCell::new(MaybeUninit::uninit()), + } } } @@ -194,13 +201,15 @@ impl OnceCell { /// # }); /// ``` pub fn get_mut(&mut self) -> Option<&mut T> { - if State::from(*self.state.get_mut()) == State::Initialized { - // SAFETY: We know that the value is initialized, so it is safe to - // read it. - Some(unsafe { &mut *self.value.get().cast() }) - } else { - None - } + self.state.with_mut(|state| { + if State::from(*state) == State::Initialized { + // SAFETY: We know that the value is initialized, so it is safe to + // read it. + Some(unsafe { &mut *self.value.get().cast() }) + } else { + None + } + }) } /// Take the value out of this `OnceCell`, moving it back to the uninitialized @@ -219,15 +228,17 @@ impl OnceCell { /// # }); /// ``` pub fn take(&mut self) -> Option { - if State::from(*self.state.get_mut()) == State::Initialized { - // SAFETY: We know that the value is initialized, so it is safe to - // read it. - let value = unsafe { ptr::read(self.value.get().cast()) }; - *self.state.get_mut() = State::Uninitialized.into(); - Some(value) - } else { - None - } + self.state.with_mut(|state| { + if State::from(*state) == State::Initialized { + // SAFETY: We know that the value is initialized, so it is safe to + // read it. + let value = unsafe { ptr::read(self.value.get().cast()) }; + *state = State::Uninitialized.into(); + Some(value) + } else { + None + } + }) } /// Convert this `OnceCell` into the inner value, if it is initialized. @@ -754,11 +765,13 @@ impl fmt::Debug for OnceCell { impl Drop for OnceCell { fn drop(&mut self) { - if State::from(*self.state.get_mut()) == State::Initialized { - // SAFETY: We know that the value is initialized, so it is safe to - // drop it. - unsafe { self.value.get().cast::().drop_in_place() } - } + self.state.with_mut(|state| { + if State::from(*state) == State::Initialized { + // SAFETY: We know that the value is initialized, so it is safe to + // drop it. + unsafe { self.value.get().cast::().drop_in_place() } + } + }); } } diff --git a/src/rwlock.rs b/src/rwlock.rs index b8907d2..67b06d5 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -55,21 +55,24 @@ unsafe impl Send for RwLock {} unsafe impl Sync for RwLock {} impl RwLock { - /// Creates a new reader-writer lock. - /// - /// # Examples - /// - /// ``` - /// use async_lock::RwLock; - /// - /// let lock = RwLock::new(0); - /// ``` - #[must_use] - #[inline] - pub const fn new(t: T) -> RwLock { - RwLock { - raw: RawRwLock::new(), - value: UnsafeCell::new(t), + const_fn! { + const_if: #[cfg(not(loom))]; + /// Creates a new reader-writer lock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::RwLock; + /// + /// let lock = RwLock::new(0); + /// ``` + #[must_use] + #[inline] + pub const fn new(t: T) -> RwLock { + RwLock { + raw: RawRwLock::new(), + value: UnsafeCell::new(t), + } } } diff --git a/src/rwlock/raw.rs b/src/rwlock/raw.rs index 6e96a53..c867ed9 100644 --- a/src/rwlock/raw.rs +++ b/src/rwlock/raw.rs @@ -9,9 +9,10 @@ use core::marker::PhantomPinned; use core::mem::forget; use core::pin::Pin; -use core::sync::atomic::{AtomicUsize, Ordering}; use core::task::Poll; +use crate::sync::atomic::{AtomicUsize, Ordering}; + use event_listener::{Event, EventListener}; use event_listener_strategy::{EventListenerFuture, Strategy}; @@ -43,13 +44,16 @@ pub(super) struct RawRwLock { } impl RawRwLock { - #[inline] - pub(super) const fn new() -> Self { - RawRwLock { - mutex: Mutex::new(()), - no_readers: Event::new(), - no_writer: Event::new(), - state: AtomicUsize::new(0), + const_fn! { + const_if: #[cfg(not(loom))]; + #[inline] + pub(super) const fn new() -> Self { + RawRwLock { + mutex: Mutex::new(()), + no_readers: Event::new(), + no_writer: Event::new(), + state: AtomicUsize::new(0), + } } } diff --git a/src/semaphore.rs b/src/semaphore.rs index cd9aa7a..739fc6a 100644 --- a/src/semaphore.rs +++ b/src/semaphore.rs @@ -2,9 +2,10 @@ use core::fmt; use core::marker::PhantomPinned; use core::mem; use core::pin::Pin; -use core::sync::atomic::{AtomicUsize, Ordering}; use core::task::Poll; +use crate::sync::atomic::{AtomicUsize, Ordering}; + use alloc::sync::Arc; use event_listener::{Event, EventListener}; @@ -18,19 +19,22 @@ pub struct Semaphore { } impl Semaphore { - /// Creates a new semaphore with a limit of `n` concurrent operations. - /// - /// # Examples - /// - /// ``` - /// use async_lock::Semaphore; - /// - /// let s = Semaphore::new(5); - /// ``` - pub const fn new(n: usize) -> Semaphore { - Semaphore { - count: AtomicUsize::new(n), - event: Event::new(), + const_fn! { + const_if: #[cfg(not(loom))]; + /// Creates a new semaphore with a limit of `n` concurrent operations. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Semaphore; + /// + /// let s = Semaphore::new(5); + /// ``` + pub const fn new(n: usize) -> Semaphore { + Semaphore { + count: AtomicUsize::new(n), + event: Event::new(), + } } } diff --git a/tests/loom.rs b/tests/loom.rs new file mode 100644 index 0000000..b27c277 --- /dev/null +++ b/tests/loom.rs @@ -0,0 +1,46 @@ +#![cfg(loom)] + +use loom::sync::{mpsc, Arc}; +use loom::thread; + +use async_lock::Barrier; + +#[ignore] +#[test] +fn barrier_smoke() { + loom::model(|| { + const N: usize = 10; + + let barrier = Arc::new(Barrier::new(N)); + + for _ in 0..10 { + let (tx, rx) = mpsc::channel(); + + for _ in 0..loom::MAX_THREADS - 1 { + let c = barrier.clone(); + let tx = tx.clone(); + + thread::spawn(move || { + let res = c.wait_blocking(); + tx.send(res.is_leader()).unwrap(); + }); + } + + // At this point, all spawned threads should be blocked, + // so we shouldn't get anything from the cahnnel. + let res = rx.try_recv(); + assert!(res.is_err()); + + let mut leader_found = barrier.wait_blocking().is_leader(); + + // Now, the barrier is cleared and we should get data. + for _ in 0..N - 1 { + if rx.recv().unwrap() { + assert!(!leader_found); + leader_found = true; + } + } + assert!(leader_found); + } + }); +}