1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 #![allow(clippy::missing_safety_doc)]
7
8 use core::{
9 cell::{Cell, UnsafeCell},
10 marker::PhantomPinned,
11 ops::{Deref, DerefMut},
12 pin::Pin,
13 sync::atomic::{AtomicBool, Ordering},
14 };
15 #[cfg(feature = "std")]
16 use std::{
17 sync::Arc,
18 thread::{self, sleep, Builder, Thread},
19 time::Duration,
20 };
21
22 use pin_init::*;
23 #[allow(unused_attributes)]
24 #[path = "./linked_list.rs"]
25 pub mod linked_list;
26 use linked_list::*;
27
28 pub struct SpinLock {
29 inner: AtomicBool,
30 }
31
32 impl SpinLock {
33 #[inline]
acquire(&self) -> SpinLockGuard<'_>34 pub fn acquire(&self) -> SpinLockGuard<'_> {
35 while self
36 .inner
37 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
38 .is_err()
39 {
40 #[cfg(feature = "std")]
41 while self.inner.load(Ordering::Relaxed) {
42 thread::yield_now();
43 }
44 }
45 SpinLockGuard(self)
46 }
47
48 #[inline]
49 #[allow(clippy::new_without_default)]
new() -> Self50 pub const fn new() -> Self {
51 Self {
52 inner: AtomicBool::new(false),
53 }
54 }
55 }
56
57 pub struct SpinLockGuard<'a>(&'a SpinLock);
58
59 impl Drop for SpinLockGuard<'_> {
60 #[inline]
drop(&mut self)61 fn drop(&mut self) {
62 self.0.inner.store(false, Ordering::Release);
63 }
64 }
65
66 #[pin_data]
67 pub struct CMutex<T> {
68 #[pin]
69 wait_list: ListHead,
70 spin_lock: SpinLock,
71 locked: Cell<bool>,
72 #[pin]
73 data: UnsafeCell<T>,
74 }
75
76 impl<T> CMutex<T> {
77 #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>78 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
79 pin_init!(CMutex {
80 wait_list <- ListHead::new(),
81 spin_lock: SpinLock::new(),
82 locked: Cell::new(false),
83 data <- unsafe {
84 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
85 val.__pinned_init(slot.cast::<T>())
86 })
87 },
88 })
89 }
90
91 #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>92 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
93 let mut sguard = self.spin_lock.acquire();
94 if self.locked.get() {
95 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
96 // println!("wait list length: {}", self.wait_list.size());
97 while self.locked.get() {
98 drop(sguard);
99 #[cfg(feature = "std")]
100 thread::park();
101 sguard = self.spin_lock.acquire();
102 }
103 // This does have an effect, as the ListHead inside wait_entry implements Drop!
104 #[expect(clippy::drop_non_drop)]
105 drop(wait_entry);
106 }
107 self.locked.set(true);
108 unsafe {
109 Pin::new_unchecked(CMutexGuard {
110 mtx: self,
111 _pin: PhantomPinned,
112 })
113 }
114 }
115
116 #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T117 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
118 // SAFETY: we have an exclusive reference and thus nobody has access to data.
119 unsafe { &mut *self.data.get() }
120 }
121 }
122
123 unsafe impl<T: Send> Send for CMutex<T> {}
124 unsafe impl<T: Send> Sync for CMutex<T> {}
125
126 pub struct CMutexGuard<'a, T> {
127 mtx: &'a CMutex<T>,
128 _pin: PhantomPinned,
129 }
130
131 impl<T> Drop for CMutexGuard<'_, T> {
132 #[inline]
drop(&mut self)133 fn drop(&mut self) {
134 let sguard = self.mtx.spin_lock.acquire();
135 self.mtx.locked.set(false);
136 if let Some(list_field) = self.mtx.wait_list.next() {
137 let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
138 #[cfg(feature = "std")]
139 unsafe {
140 (*_wait_entry).thread.unpark()
141 };
142 }
143 drop(sguard);
144 }
145 }
146
147 impl<T> Deref for CMutexGuard<'_, T> {
148 type Target = T;
149
150 #[inline]
deref(&self) -> &Self::Target151 fn deref(&self) -> &Self::Target {
152 unsafe { &*self.mtx.data.get() }
153 }
154 }
155
156 impl<T> DerefMut for CMutexGuard<'_, T> {
157 #[inline]
deref_mut(&mut self) -> &mut Self::Target158 fn deref_mut(&mut self) -> &mut Self::Target {
159 unsafe { &mut *self.mtx.data.get() }
160 }
161 }
162
163 #[pin_data]
164 #[repr(C)]
165 struct WaitEntry {
166 #[pin]
167 wait_list: ListHead,
168 #[cfg(feature = "std")]
169 thread: Thread,
170 }
171
172 impl WaitEntry {
173 #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_174 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
175 #[cfg(feature = "std")]
176 {
177 pin_init!(Self {
178 thread: thread::current(),
179 wait_list <- ListHead::insert_prev(list),
180 })
181 }
182 #[cfg(not(feature = "std"))]
183 {
184 pin_init!(Self {
185 wait_list <- ListHead::insert_prev(list),
186 })
187 }
188 }
189 }
190
191 #[cfg_attr(test, test)]
192 #[allow(dead_code)]
main()193 fn main() {
194 #[cfg(feature = "std")]
195 {
196 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
197 let mut handles = vec![];
198 let thread_count = 20;
199 let workload = if cfg!(miri) { 100 } else { 1_000 };
200 for i in 0..thread_count {
201 let mtx = mtx.clone();
202 handles.push(
203 Builder::new()
204 .name(format!("worker #{i}"))
205 .spawn(move || {
206 for _ in 0..workload {
207 *mtx.lock() += 1;
208 }
209 println!("{i} halfway");
210 sleep(Duration::from_millis((i as u64) * 10));
211 for _ in 0..workload {
212 *mtx.lock() += 1;
213 }
214 println!("{i} finished");
215 })
216 .expect("should not fail"),
217 );
218 }
219 for h in handles {
220 h.join().expect("thread panicked");
221 }
222 println!("{:?}", &*mtx.lock());
223 assert_eq!(*mtx.lock(), workload * thread_count * 2);
224 }
225 }
226