1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 #![allow(clippy::missing_safety_doc)]
7 
8 use core::{
9     cell::{Cell, UnsafeCell},
10     marker::PhantomPinned,
11     ops::{Deref, DerefMut},
12     pin::Pin,
13     sync::atomic::{AtomicBool, Ordering},
14 };
15 #[cfg(feature = "std")]
16 use std::{
17     sync::Arc,
18     thread::{self, sleep, Builder, Thread},
19     time::Duration,
20 };
21 
22 use pin_init::*;
23 #[allow(unused_attributes)]
24 #[path = "./linked_list.rs"]
25 pub mod linked_list;
26 use linked_list::*;
27 
28 pub struct SpinLock {
29     inner: AtomicBool,
30 }
31 
32 impl SpinLock {
33     #[inline]
acquire(&self) -> SpinLockGuard<'_>34     pub fn acquire(&self) -> SpinLockGuard<'_> {
35         while self
36             .inner
37             .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
38             .is_err()
39         {
40             #[cfg(feature = "std")]
41             while self.inner.load(Ordering::Relaxed) {
42                 thread::yield_now();
43             }
44         }
45         SpinLockGuard(self)
46     }
47 
48     #[inline]
49     #[allow(clippy::new_without_default)]
new() -> Self50     pub const fn new() -> Self {
51         Self {
52             inner: AtomicBool::new(false),
53         }
54     }
55 }
56 
57 pub struct SpinLockGuard<'a>(&'a SpinLock);
58 
59 impl Drop for SpinLockGuard<'_> {
60     #[inline]
drop(&mut self)61     fn drop(&mut self) {
62         self.0.inner.store(false, Ordering::Release);
63     }
64 }
65 
66 #[pin_data]
67 pub struct CMutex<T> {
68     #[pin]
69     wait_list: ListHead,
70     spin_lock: SpinLock,
71     locked: Cell<bool>,
72     #[pin]
73     data: UnsafeCell<T>,
74 }
75 
76 impl<T> CMutex<T> {
77     #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>78     pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
79         pin_init!(CMutex {
80             wait_list <- ListHead::new(),
81             spin_lock: SpinLock::new(),
82             locked: Cell::new(false),
83             data <- unsafe {
84                 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
85                     val.__pinned_init(slot.cast::<T>())
86                 })
87             },
88         })
89     }
90 
91     #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>92     pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
93         let mut sguard = self.spin_lock.acquire();
94         if self.locked.get() {
95             stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
96             // println!("wait list length: {}", self.wait_list.size());
97             while self.locked.get() {
98                 drop(sguard);
99                 #[cfg(feature = "std")]
100                 thread::park();
101                 sguard = self.spin_lock.acquire();
102             }
103             // This does have an effect, as the ListHead inside wait_entry implements Drop!
104             #[expect(clippy::drop_non_drop)]
105             drop(wait_entry);
106         }
107         self.locked.set(true);
108         unsafe {
109             Pin::new_unchecked(CMutexGuard {
110                 mtx: self,
111                 _pin: PhantomPinned,
112             })
113         }
114     }
115 
116     #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T117     pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
118         // SAFETY: we have an exclusive reference and thus nobody has access to data.
119         unsafe { &mut *self.data.get() }
120     }
121 }
122 
123 unsafe impl<T: Send> Send for CMutex<T> {}
124 unsafe impl<T: Send> Sync for CMutex<T> {}
125 
126 pub struct CMutexGuard<'a, T> {
127     mtx: &'a CMutex<T>,
128     _pin: PhantomPinned,
129 }
130 
131 impl<T> Drop for CMutexGuard<'_, T> {
132     #[inline]
drop(&mut self)133     fn drop(&mut self) {
134         let sguard = self.mtx.spin_lock.acquire();
135         self.mtx.locked.set(false);
136         if let Some(list_field) = self.mtx.wait_list.next() {
137             let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
138             #[cfg(feature = "std")]
139             unsafe {
140                 (*_wait_entry).thread.unpark()
141             };
142         }
143         drop(sguard);
144     }
145 }
146 
147 impl<T> Deref for CMutexGuard<'_, T> {
148     type Target = T;
149 
150     #[inline]
deref(&self) -> &Self::Target151     fn deref(&self) -> &Self::Target {
152         unsafe { &*self.mtx.data.get() }
153     }
154 }
155 
156 impl<T> DerefMut for CMutexGuard<'_, T> {
157     #[inline]
deref_mut(&mut self) -> &mut Self::Target158     fn deref_mut(&mut self) -> &mut Self::Target {
159         unsafe { &mut *self.mtx.data.get() }
160     }
161 }
162 
163 #[pin_data]
164 #[repr(C)]
165 struct WaitEntry {
166     #[pin]
167     wait_list: ListHead,
168     #[cfg(feature = "std")]
169     thread: Thread,
170 }
171 
172 impl WaitEntry {
173     #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_174     fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
175         #[cfg(feature = "std")]
176         {
177             pin_init!(Self {
178                 thread: thread::current(),
179                 wait_list <- ListHead::insert_prev(list),
180             })
181         }
182         #[cfg(not(feature = "std"))]
183         {
184             pin_init!(Self {
185                 wait_list <- ListHead::insert_prev(list),
186             })
187         }
188     }
189 }
190 
191 #[cfg_attr(test, test)]
192 #[allow(dead_code)]
main()193 fn main() {
194     #[cfg(feature = "std")]
195     {
196         let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
197         let mut handles = vec![];
198         let thread_count = 20;
199         let workload = if cfg!(miri) { 100 } else { 1_000 };
200         for i in 0..thread_count {
201             let mtx = mtx.clone();
202             handles.push(
203                 Builder::new()
204                     .name(format!("worker #{i}"))
205                     .spawn(move || {
206                         for _ in 0..workload {
207                             *mtx.lock() += 1;
208                         }
209                         println!("{i} halfway");
210                         sleep(Duration::from_millis((i as u64) * 10));
211                         for _ in 0..workload {
212                             *mtx.lock() += 1;
213                         }
214                         println!("{i} finished");
215                     })
216                     .expect("should not fail"),
217             );
218         }
219         for h in handles {
220             h.join().expect("thread panicked");
221         }
222         println!("{:?}", &*mtx.lock());
223         assert_eq!(*mtx.lock(), workload * thread_count * 2);
224     }
225 }
226