Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a backend for waitable handles #68

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ jobs:
if: startsWith(matrix.rust, 'nightly')
run: cargo check -Z features=dev_dep
- run: cargo test
- run: cargo test
env:
RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg async_process_force_signal_backend

msrv:
runs-on: ${{ matrix.os }}
Expand Down
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,20 @@ async-lock = "3.0.0"
cfg-if = "1.0"
event-listener = "5.1.0"
futures-lite = "2.0.0"
tracing = { version = "0.1.40", default-features = false }

[target.'cfg(unix)'.dependencies]
async-io = "2.1.0"
async-signal = "0.2.3"
rustix = { version = "0.38", default-features = false, features = ["std", "fs"] }

[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies]
async-channel = "2.0.0"
async-task = "4.7.0"

[target.'cfg(all(unix, not(any(target_os = "linux", target_os = "android"))))'.dependencies]
rustix = { version = "0.38", default-features = false, features = ["std", "fs", "process"] }

[target.'cfg(windows)'.dependencies]
async-channel = "2.0.0"
blocking = "1.0.0"
Expand Down
220 changes: 28 additions & 192 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
use std::convert::Infallible;
use std::ffi::OsStr;
use std::fmt;
use std::mem;
use std::path::Path;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand All @@ -75,8 +74,7 @@ use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
#[cfg(windows)]
use blocking::Unblock;

use async_lock::{Mutex as AsyncMutex, OnceCell};
use event_listener::Event;
use async_lock::OnceCell;
use futures_lite::{future, io, prelude::*};

#[doc(no_inline)]
Expand All @@ -87,6 +85,8 @@ pub mod unix;
#[cfg(windows)]
pub mod windows;

mod reaper;

mod sealed {
pub trait Sealed {}
}
Expand All @@ -99,17 +99,8 @@ static DRIVER_THREAD_SPAWNED: std::sync::atomic::AtomicBool =
///
/// This structure reaps zombie processes and emits the `SIGCHLD` signal.
struct Reaper {
/// An event delivered every time the SIGCHLD signal occurs.
sigchld: Event,

/// The list of zombie processes.
zombies: Mutex<Vec<std::process::Child>>,

/// The pipe that delivers signal notifications.
pipe: Pipe,

/// Locking this mutex indicates that we are polling the SIGCHLD event.
driver_guard: AsyncMutex<()>,
/// Underlying system reaper.
sys: reaper::Reaper,

/// The number of tasks polling the SIGCHLD event.
///
Expand All @@ -129,10 +120,7 @@ impl Reaper {
static REAPER: OnceCell<Reaper> = OnceCell::new();

REAPER.get_or_init_blocking(|| Reaper {
sigchld: Event::new(),
zombies: Mutex::new(Vec::new()),
pipe: Pipe::new().expect("cannot create SIGCHLD pipe"),
driver_guard: AsyncMutex::new(()),
sys: reaper::Reaper::new(),
drivers: AtomicUsize::new(0),
child_count: AtomicUsize::new(0),
})
Expand Down Expand Up @@ -165,8 +153,8 @@ impl Reaper {
.spawn(move || {
let driver = async move {
// No need to bump self.drivers, it was already bumped in ensure_driven.
let guard = self.driver_guard.lock().await;
self.reap(guard).await
let guard = self.sys.lock().await;
self.sys.reap(guard).await
};

#[cfg(unix)]
Expand All @@ -178,147 +166,20 @@ impl Reaper {
.expect("cannot spawn async-process thread");
}

/// Reap zombie processes forever.
async fn reap(&'static self, _driver_guard: async_lock::MutexGuard<'_, ()>) -> ! {
loop {
// Wait for the next SIGCHLD signal.
self.pipe.wait().await;

// Notify all listeners waiting on the SIGCHLD event.
self.sigchld.notify(std::usize::MAX);

// Reap zombie processes, but make sure we don't hold onto the lock for too long!
let mut zombies = mem::take(&mut *self.zombies.lock().unwrap());
let mut i = 0;
'reap_zombies: loop {
for _ in 0..50 {
if i >= zombies.len() {
break 'reap_zombies;
}

if let Ok(None) = zombies[i].try_wait() {
i += 1;
} else {
zombies.swap_remove(i);
}
}

// Be a good citizen; yield if there are a lot of processes.
//
// After we yield, check if there are more zombie processes.
future::yield_now().await;
zombies.append(&mut self.zombies.lock().unwrap());
}

// Put zombie processes back.
self.zombies.lock().unwrap().append(&mut zombies);
}
}

/// Register a process with this reaper.
fn register(&'static self, child: &std::process::Child) -> io::Result<()> {
fn register(&'static self, child: std::process::Child) -> io::Result<reaper::ChildGuard> {
self.ensure_driven();
self.pipe.register(child)
self.sys.register(child)
}
}

cfg_if::cfg_if! {
if #[cfg(windows)] {
use async_channel::{Sender, Receiver, bounded};
use std::ffi::c_void;
use std::os::windows::io::AsRawHandle;

use windows_sys::Win32::{
Foundation::{BOOLEAN, HANDLE},
System::Threading::{
RegisterWaitForSingleObject, INFINITE, WT_EXECUTEINWAITTHREAD, WT_EXECUTEONLYONCE,
},
};

/// Waits for the next SIGCHLD signal.
struct Pipe {
/// The sender channel for the SIGCHLD signal.
sender: Sender<()>,

/// The receiver channel for the SIGCHLD signal.
receiver: Receiver<()>,
}

impl Pipe {
/// Creates a new pipe.
fn new() -> io::Result<Pipe> {
let (sender, receiver) = bounded(1);
Ok(Pipe {
sender,
receiver
})
}

/// Waits for the next SIGCHLD signal.
async fn wait(&self) {
self.receiver.recv().await.ok();
}

/// Register a process object into this pipe.
fn register(&self, child: &std::process::Child) -> io::Result<()> {
// Called when a child exits.
unsafe extern "system" fn callback(_: *mut c_void, _: BOOLEAN) {
Reaper::get().pipe.sender.try_send(()).ok();
}

// Register this child process to invoke `callback` on exit.
let mut wait_object = 0;
let ret = unsafe {
RegisterWaitForSingleObject(
&mut wait_object,
child.as_raw_handle() as HANDLE,
Some(callback),
std::ptr::null_mut(),
INFINITE,
WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE,
)
};

if ret == 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
}

// Wraps a sync I/O type into an async I/O type.
fn wrap<T>(io: T) -> io::Result<Unblock<T>> {
Ok(Unblock::new(io))
}
} else if #[cfg(unix)] {
use async_signal::{Signal, Signals};

/// Waits for the next SIGCHLD signal.
struct Pipe {
/// The iterator over SIGCHLD signals.
signals: Signals,
}

impl Pipe {
/// Creates a new pipe.
fn new() -> io::Result<Pipe> {
Ok(Pipe {
signals: Signals::new(Some(Signal::Child))?,
})
}

/// Waits for the next SIGCHLD signal.
async fn wait(&self) {
(&self.signals).next().await;
}

/// Register a process object into this pipe.
fn register(&self, _child: &std::process::Child) -> io::Result<()> {
Ok(())
}
}

/// Wrap a file descriptor into a non-blocking I/O type.
fn wrap<T: std::os::unix::io::AsFd>(io: T) -> io::Result<Async<T>> {
Async::new(io)
Expand All @@ -328,15 +189,15 @@ cfg_if::cfg_if! {

/// A guard that can kill child processes, or push them into the zombie list.
struct ChildGuard {
inner: Option<std::process::Child>,
inner: reaper::ChildGuard,
reap_on_drop: bool,
kill_on_drop: bool,
reaper: &'static Reaper,
}

impl ChildGuard {
fn get_mut(&mut self) -> &mut std::process::Child {
self.inner.as_mut().unwrap()
self.inner.get_mut()
}
}

Expand All @@ -347,10 +208,7 @@ impl Drop for ChildGuard {
self.get_mut().kill().ok();
}
if self.reap_on_drop {
let mut zombies = self.reaper.zombies.lock().unwrap();
if let Ok(None) = self.get_mut().try_wait() {
zombies.push(self.inner.take().unwrap());
}
self.inner.reap(&self.reaper.sys);
}

// Decrement number of children.
Expand Down Expand Up @@ -409,14 +267,14 @@ impl Child {
reaper.child_count.fetch_add(1, Ordering::Relaxed);

// Register the child process in the global list.
reaper.register(&child)?;
let inner = reaper.register(child)?;

Ok(Child {
stdin,
stdout,
stderr,
child: Arc::new(Mutex::new(ChildGuard {
inner: Some(child),
inner,
reap_on_drop: cmd.reap_on_drop,
kill_on_drop: cmd.kill_on_drop,
reaper,
Expand Down Expand Up @@ -509,25 +367,7 @@ impl Child {
self.stdin.take();
let child = self.child.clone();

async move {
loop {
// Wait on the child process.
if let Some(status) = child.lock().unwrap().get_mut().try_wait()? {
return Ok(status);
}

// Start listening.
event_listener::listener!(Reaper::get().sigchld => listener);

// Try again.
if let Some(status) = child.lock().unwrap().get_mut().try_wait()? {
return Ok(status);
}

// Wait on the listener.
listener.await;
}
}
async move { Reaper::get().sys.status(&child).await }
}

/// Drops the stdin handle and collects the output of the process.
Expand Down Expand Up @@ -872,16 +712,9 @@ impl TryFrom<ChildStderr> for OwnedFd {
/// }).await;
/// # });
/// ```
#[allow(clippy::manual_async_fn)]
#[inline]
pub fn driver() -> impl Future<Output = Infallible> + Send + 'static {
struct CallOnDrop<F: FnMut()>(F);

impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}

async {
// Get the reaper.
let reaper = Reaper::get();
Expand All @@ -896,20 +729,15 @@ pub fn driver() -> impl Future<Output = Infallible> + Send + 'static {
// If this was the last driver, and there are still resources actively using the
// reaper, make sure that there is a thread driving the reaper.
if prev_count == 1
&& reaper.child_count.load(Ordering::SeqCst) > 0
&& !reaper
.zombies
.lock()
.unwrap_or_else(|x| x.into_inner())
.is_empty()
&& (reaper.child_count.load(Ordering::SeqCst) > 0 || reaper.sys.has_zombies())
{
reaper.ensure_driven();
}
});

// Acquire the reaper lock and start polling the SIGCHLD event.
let guard = reaper.driver_guard.lock().await;
reaper.reap(guard).await
let guard = reaper.sys.lock().await;
reaper.sys.reap(guard).await
}
}

Expand Down Expand Up @@ -1307,6 +1135,14 @@ fn blocking_fd(fd: rustix::fd::BorrowedFd<'_>) -> io::Result<()> {
Ok(())
}

struct CallOnDrop<F: FnMut()>(F);

impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}

#[cfg(test)]
mod test {
#[test]
Expand Down
Loading
Loading