1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
7
8 #[cfg(not(windows))]
9 mod pthread_mtx {
10 #[cfg(feature = "alloc")]
11 use core::alloc::AllocError;
12 use core::{
13 cell::UnsafeCell,
14 marker::PhantomPinned,
15 mem::MaybeUninit,
16 ops::{Deref, DerefMut},
17 pin::Pin,
18 };
19 use pin_init::*;
20 use std::convert::Infallible;
21
22 #[pin_data(PinnedDrop)]
23 pub struct PThreadMutex<T> {
24 #[pin]
25 raw: UnsafeCell<libc::pthread_mutex_t>,
26 data: UnsafeCell<T>,
27 #[pin]
28 pin: PhantomPinned,
29 }
30
31 unsafe impl<T: Send> Send for PThreadMutex<T> {}
32 unsafe impl<T: Send> Sync for PThreadMutex<T> {}
33
34 #[pinned_drop]
35 impl<T> PinnedDrop for PThreadMutex<T> {
drop(self: Pin<&mut Self>)36 fn drop(self: Pin<&mut Self>) {
37 unsafe {
38 libc::pthread_mutex_destroy(self.raw.get());
39 }
40 }
41 }
42
43 #[derive(Debug)]
44 pub enum Error {
45 #[allow(dead_code)]
46 IO(std::io::Error),
47 #[allow(dead_code)]
48 Alloc,
49 }
50
51 impl From<Infallible> for Error {
from(e: Infallible) -> Self52 fn from(e: Infallible) -> Self {
53 match e {}
54 }
55 }
56
57 #[cfg(feature = "alloc")]
58 impl From<AllocError> for Error {
from(_: AllocError) -> Self59 fn from(_: AllocError) -> Self {
60 Self::Alloc
61 }
62 }
63
64 impl<T> PThreadMutex<T> {
65 #[allow(dead_code)]
new(data: T) -> impl PinInit<Self, Error>66 pub fn new(data: T) -> impl PinInit<Self, Error> {
67 fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
68 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
69 // we can cast, because `UnsafeCell` has the same layout as T.
70 let slot: *mut libc::pthread_mutex_t = slot.cast();
71 let mut attr = MaybeUninit::uninit();
72 let attr = attr.as_mut_ptr();
73 // SAFETY: ptr is valid
74 let ret = unsafe { libc::pthread_mutexattr_init(attr) };
75 if ret != 0 {
76 return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
77 }
78 // SAFETY: attr is initialized
79 let ret = unsafe {
80 libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
81 };
82 if ret != 0 {
83 // SAFETY: attr is initialized
84 unsafe { libc::pthread_mutexattr_destroy(attr) };
85 return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
86 }
87 // SAFETY: slot is valid
88 unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
89 // SAFETY: attr and slot are valid ptrs and attr is initialized
90 let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
91 // SAFETY: attr was initialized
92 unsafe { libc::pthread_mutexattr_destroy(attr) };
93 if ret != 0 {
94 return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
95 }
96 Ok(())
97 };
98 // SAFETY: mutex has been initialized
99 unsafe { pin_init_from_closure(init) }
100 }
101 try_pin_init!(Self {
102 data: UnsafeCell::new(data),
103 raw <- init_raw(),
104 pin: PhantomPinned,
105 }? Error)
106 }
107
108 #[allow(dead_code)]
lock(&self) -> PThreadMutexGuard<'_, T>109 pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
110 // SAFETY: raw is always initialized
111 unsafe { libc::pthread_mutex_lock(self.raw.get()) };
112 PThreadMutexGuard { mtx: self }
113 }
114 }
115
116 pub struct PThreadMutexGuard<'a, T> {
117 mtx: &'a PThreadMutex<T>,
118 }
119
120 impl<T> Drop for PThreadMutexGuard<'_, T> {
drop(&mut self)121 fn drop(&mut self) {
122 // SAFETY: raw is always initialized
123 unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
124 }
125 }
126
127 impl<T> Deref for PThreadMutexGuard<'_, T> {
128 type Target = T;
129
deref(&self) -> &Self::Target130 fn deref(&self) -> &Self::Target {
131 unsafe { &*self.mtx.data.get() }
132 }
133 }
134
135 impl<T> DerefMut for PThreadMutexGuard<'_, T> {
deref_mut(&mut self) -> &mut Self::Target136 fn deref_mut(&mut self) -> &mut Self::Target {
137 unsafe { &mut *self.mtx.data.get() }
138 }
139 }
140 }
141
142 #[cfg_attr(test, test)]
143 #[cfg_attr(all(test, miri), ignore)]
main()144 fn main() {
145 #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
146 {
147 use core::pin::Pin;
148 use pin_init::*;
149 use pthread_mtx::*;
150 use std::{
151 sync::Arc,
152 thread::{sleep, Builder},
153 time::Duration,
154 };
155 let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
156 let mut handles = vec![];
157 let thread_count = 20;
158 let workload = 1_000_000;
159 for i in 0..thread_count {
160 let mtx = mtx.clone();
161 handles.push(
162 Builder::new()
163 .name(format!("worker #{i}"))
164 .spawn(move || {
165 for _ in 0..workload {
166 *mtx.lock() += 1;
167 }
168 println!("{i} halfway");
169 sleep(Duration::from_millis((i as u64) * 10));
170 for _ in 0..workload {
171 *mtx.lock() += 1;
172 }
173 println!("{i} finished");
174 })
175 .expect("should not fail"),
176 );
177 }
178 for h in handles {
179 h.join().expect("thread panicked");
180 }
181 println!("{:?}", &*mtx.lock());
182 assert_eq!(*mtx.lock(), workload * thread_count * 2);
183 }
184 }
185