From c1e0281898378c5def1656382c2f4cdc9f58cdcf Mon Sep 17 00:00:00 2001 From: Will Date: Mon, 9 Sep 2024 06:06:02 -0700 Subject: [PATCH] Catch panics and C<->Rust boundary (#201) * Catch panics in callbacks and mark callbacks as invalid. * Mark callbacks as invalid if they return Quit. --- Cargo.toml | 2 +- src/client/async_client.rs | 2 + src/client/callbacks.rs | 293 ++++++++++++++++++++++++++++++------- src/client/client_impl.rs | 21 ++- src/properties.rs | 26 ++-- 5 files changed, 274 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 53f24450..b1e56189 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT" name = "jack" readme = "README.md" repository = "https://github.com/RustAudio/rust-jack" -version = "0.12.0" +version = "0.12.1" [dependencies] bitflags = "1" diff --git a/src/client/async_client.rs b/src/client/async_client.rs index c019a750..ed57aad2 100644 --- a/src/client/async_client.rs +++ b/src/client/async_client.rs @@ -2,6 +2,7 @@ use jack_sys as j; use std::fmt; use std::fmt::Debug; use std::mem; +use std::sync::atomic::AtomicBool; use super::callbacks::clear_callbacks; use super::callbacks::{CallbackContext, NotificationHandler, ProcessHandler}; @@ -58,6 +59,7 @@ where client, notification: notification_handler, process: process_handler, + is_valid: AtomicBool::new(true), }); CallbackContext::register_callbacks(&mut callback_context)?; sleep_on_test(); diff --git a/src/client/callbacks.rs b/src/client/callbacks.rs index 04c25ce5..f6e95434 100644 --- a/src/client/callbacks.rs +++ b/src/client/callbacks.rs @@ -1,5 +1,9 @@ use jack_sys as j; -use std::ffi; +use std::{ + ffi, + panic::catch_unwind, + sync::atomic::{AtomicBool, Ordering}, +}; use crate::{Client, ClientStatus, Control, Error, Frames, PortId, ProcessScope}; @@ -123,8 +127,17 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - ctx.notification.thread_init(&ctx.client) + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return; + }; + ctx.notification.thread_init(&ctx.client); + }); + if let Err(err) = res { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + } } unsafe extern "C" fn shutdown( @@ -135,13 +148,22 @@ unsafe extern "C" fn shutdown( N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - let cstr = ffi::CStr::from_ptr(reason); - let reason_str = cstr.to_str().unwrap_or("Failed to interpret error."); - ctx.notification.shutdown( - ClientStatus::from_bits(code).unwrap_or_else(ClientStatus::empty), - reason_str, - ) + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return; + }; + let cstr = ffi::CStr::from_ptr(reason); + let reason_str = cstr.to_str().unwrap_or("Failed to interpret error."); + ctx.notification.shutdown( + ClientStatus::from_bits(code).unwrap_or_else(ClientStatus::empty), + reason_str, + ); + }); + if let Err(err) = res { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + } } unsafe extern "C" fn process(n_frames: Frames, data: *mut libc::c_void) -> libc::c_int @@ -149,15 +171,23 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let res = std::panic::catch_unwind(|| { - let ctx = CallbackContext::::from_raw(data); + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return Control::Quit; + }; let scope = ProcessScope::from_raw(n_frames, ctx.client.raw()); - ctx.process.process(&ctx.client, &scope) + let c = ctx.process.process(&ctx.client, &scope); + if c == Control::Quit { + ctx.mark_invalid(); + } + c }); match res { Ok(res) => res.to_ffi(), Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); eprintln!("{err:?}"); + std::mem::forget(err); Control::Quit.to_ffi() } } @@ -172,14 +202,29 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - match ctx.process.sync( - &ctx.client, - crate::Transport::state_from_ffi(state), - &*(pos as *mut crate::TransportPosition), - ) { - true => 1, - false => 0, + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return false; + }; + let is_ready = ctx.process.sync( + &ctx.client, + crate::Transport::state_from_ffi(state), + &*(pos as *mut crate::TransportPosition), + ); + if !is_ready { + ctx.mark_invalid(); + } + is_ready + }); + match res { + Ok(true) => 1, + Ok(false) => 0, + Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + 0 + } } } @@ -188,9 +233,18 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - let is_starting = !matches!(starting, 0); - ctx.notification.freewheel(&ctx.client, is_starting) + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return; + }; + let is_starting = !matches!(starting, 0); + ctx.notification.freewheel(&ctx.client, is_starting); + }); + if let Err(err) = res { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + } } unsafe extern "C" fn buffer_size(n_frames: Frames, data: *mut libc::c_void) -> libc::c_int @@ -198,8 +252,25 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - ctx.process.buffer_size(&ctx.client, n_frames).to_ffi() + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return Control::Quit; + }; + let c = ctx.process.buffer_size(&ctx.client, n_frames); + if c == Control::Quit { + ctx.mark_invalid(); + } + c + }); + match res { + Ok(c) => c.to_ffi(), + Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + Control::Quit.to_ffi() + } + } } unsafe extern "C" fn sample_rate(n_frames: Frames, data: *mut libc::c_void) -> libc::c_int @@ -207,8 +278,25 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - ctx.notification.sample_rate(&ctx.client, n_frames).to_ffi() + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return Control::Quit; + }; + let c = ctx.notification.sample_rate(&ctx.client, n_frames); + if c == Control::Quit { + ctx.mark_invalid(); + } + c + }); + match res { + Ok(c) => c.to_ffi(), + Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + Control::Quit.to_ffi() + } + } } unsafe extern "C" fn client_registration( @@ -219,11 +307,20 @@ unsafe extern "C" fn client_registration( N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - let name = ffi::CStr::from_ptr(name).to_str().unwrap(); - let register = !matches!(register, 0); - ctx.notification - .client_registration(&ctx.client, name, register) + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return; + }; + let name = ffi::CStr::from_ptr(name).to_str().unwrap(); + let register = !matches!(register, 0); + ctx.notification + .client_registration(&ctx.client, name, register); + }); + if let Err(err) = res { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + } } unsafe extern "C" fn port_registration( @@ -234,10 +331,19 @@ unsafe extern "C" fn port_registration( N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - let register = !matches!(register, 0); - ctx.notification - .port_registration(&ctx.client, port_id, register) + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return; + }; + let register = !matches!(register, 0); + ctx.notification + .port_registration(&ctx.client, port_id, register); + }); + if let Err(err) = res { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + } } #[allow(dead_code)] // TODO: remove once it can be registered @@ -251,12 +357,29 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - let old_name = ffi::CStr::from_ptr(old_name).to_str().unwrap(); - let new_name = ffi::CStr::from_ptr(new_name).to_str().unwrap(); - ctx.notification - .port_rename(&ctx.client, port_id, old_name, new_name) - .to_ffi() + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return Control::Quit; + }; + let old_name = ffi::CStr::from_ptr(old_name).to_str().unwrap(); + let new_name = ffi::CStr::from_ptr(new_name).to_str().unwrap(); + let c = ctx + .notification + .port_rename(&ctx.client, port_id, old_name, new_name); + if c == Control::Quit { + ctx.mark_invalid(); + } + c + }); + match res { + Ok(c) => c.to_ffi(), + Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + Control::Quit.to_ffi() + } + } } unsafe extern "C" fn port_connect( @@ -268,10 +391,19 @@ unsafe extern "C" fn port_connect( N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - let are_connected = !matches!(connect, 0); - ctx.notification - .ports_connected(&ctx.client, port_id_a, port_id_b, are_connected) + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return; + }; + let are_connected = !matches!(connect, 0); + ctx.notification + .ports_connected(&ctx.client, port_id_a, port_id_b, are_connected); + }); + if let Err(err) = res { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + } } unsafe extern "C" fn graph_order(data: *mut libc::c_void) -> libc::c_int @@ -279,8 +411,25 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - ctx.notification.graph_reorder(&ctx.client).to_ffi() + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return Control::Quit; + }; + let c = ctx.notification.graph_reorder(&ctx.client); + if c == Control::Quit { + ctx.mark_invalid(); + } + c + }); + match res { + Ok(c) => c.to_ffi(), + Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + Control::Quit.to_ffi() + } + } } unsafe extern "C" fn xrun(data: *mut libc::c_void) -> libc::c_int @@ -288,8 +437,25 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - let ctx = CallbackContext::::from_raw(data); - ctx.notification.xrun(&ctx.client).to_ffi() + let res = catch_unwind(|| { + let Some(ctx) = CallbackContext::::from_raw(data) else { + return Control::Quit; + }; + let c = ctx.notification.xrun(&ctx.client); + if c == Control::Quit { + ctx.mark_invalid(); + } + c + }); + match res { + Ok(c) => c.to_ffi(), + Err(err) => { + CallbackContext::::from_raw(data).map(CallbackContext::mark_invalid); + eprintln!("{err:?}"); + std::mem::forget(err); + Control::Quit.to_ffi() + } + } } /// Unsafe ffi wrapper that clears the callbacks registered to `client`. @@ -313,10 +479,18 @@ pub unsafe fn clear_callbacks(client: *mut j::jack_client_t) -> Result<(), Error Ok(()) } +/// The information used by JACK to process data. pub struct CallbackContext { + /// The underlying JACK client. pub client: Client, + /// The handler for notifications. pub notification: N, + /// The handler for processing. pub process: P, + /// True if the callback is valid. + /// + /// This becomes false after a panic. + pub is_valid: AtomicBool, } impl CallbackContext @@ -324,10 +498,23 @@ where N: 'static + Send + Sync + NotificationHandler, P: 'static + Send + ProcessHandler, { - pub unsafe fn from_raw<'a>(ptr: *mut libc::c_void) -> &'a mut CallbackContext { + pub unsafe fn from_raw<'a>(ptr: *mut libc::c_void) -> Option<&'a mut CallbackContext> { debug_assert!(!ptr.is_null()); let obj_ptr = ptr as *mut CallbackContext; - &mut *obj_ptr + let obj_ref = &mut *obj_ptr; + if obj_ref.is_valid.load(Ordering::Relaxed) { + Some(obj_ref) + } else { + None + } + } + + /// Mark the callback context as invalid. + /// + /// This usually happens after a panic. + #[cold] + pub fn mark_invalid(&mut self) { + self.is_valid.store(true, Ordering::Relaxed); } fn raw(b: &mut Box) -> *mut libc::c_void { diff --git a/src/client/client_impl.rs b/src/client/client_impl.rs index 387c90f4..609a3a24 100644 --- a/src/client/client_impl.rs +++ b/src/client/client_impl.rs @@ -1,5 +1,6 @@ use jack_sys as j; use std::fmt::Debug; +use std::panic::catch_unwind; use std::sync::Arc; use std::{ffi, fmt, ptr}; @@ -791,19 +792,25 @@ pub struct CycleTimes { } unsafe extern "C" fn error_handler(msg: *const libc::c_char) { - match std::ffi::CStr::from_ptr(msg).to_str() { + let res = catch_unwind(|| match std::ffi::CStr::from_ptr(msg).to_str() { Ok(msg) => log::error!("{}", msg), - Err(err) => log::error!("failed to parse JACK error: {:?}", err), + Err(err) => log::error!("failed to log to JACK error: {:?}", err), + }); + if let Err(err) = res { + eprintln!("{err:?}"); + std::mem::forget(err); } } unsafe extern "C" fn info_handler(msg: *const libc::c_char) { - match std::ffi::CStr::from_ptr(msg).to_str() { + let res = catch_unwind(|| match std::ffi::CStr::from_ptr(msg).to_str() { Ok(msg) => log::info!("{}", msg), - Err(err) => log::error!("failed to parse JACK error: {:?}", err), + Err(err) => log::error!("failed to log to JACK info: {:?}", err), + }); + if let Err(err) = res { + eprintln!("{err:?}"); + std::mem::forget(err); } } -unsafe extern "C" fn silent_handler(_msg: *const libc::c_char) { - //silent -} +unsafe extern "C" fn silent_handler(_msg: *const libc::c_char) {} diff --git a/src/properties.rs b/src/properties.rs index c341aa2b..74687c39 100644 --- a/src/properties.rs +++ b/src/properties.rs @@ -1,5 +1,7 @@ //! Properties, AKA [Meta Data](https://jackaudio.org/api/group__Metadata.html) //! +use std::panic::catch_unwind; + use j::jack_uuid_t as uuid; use jack_sys as j; @@ -30,15 +32,21 @@ pub(crate) unsafe extern "C" fn property_changed

( ) where P: PropertyChangeHandler, { - let h: &mut P = &mut *(arg as *mut P); - let key_c = std::ffi::CStr::from_ptr(key); - let key = key_c.to_str().expect("to convert key to valid str"); - let c = match change { - j::PropertyCreated => PropertyChange::Created { subject, key }, - j::PropertyDeleted => PropertyChange::Deleted { subject, key }, - _ => PropertyChange::Changed { subject, key }, - }; - h.property_changed(&c); + let res = catch_unwind(|| { + let h: &mut P = &mut *(arg as *mut P); + let key_c = std::ffi::CStr::from_ptr(key); + let key = key_c.to_str().expect("to convert key to valid str"); + let c = match change { + j::PropertyCreated => PropertyChange::Created { subject, key }, + j::PropertyDeleted => PropertyChange::Deleted { subject, key }, + _ => PropertyChange::Changed { subject, key }, + }; + h.property_changed(&c); + }); + if let Err(err) = res { + eprintln!("{err:?}"); + std::mem::forget(err); + } } #[cfg(feature = "metadata")]