diff --git a/Readme.md b/Readme.md index 6c0be14..9b67df0 100644 --- a/Readme.md +++ b/Readme.md @@ -18,11 +18,8 @@ fn f( *ydot = [y[1], -y[0] * k]; RhsResult::Ok } -// Use the `wrap!` macro to define a `wrapped_f` function callable -// from C (of type `c_wrapping::RhsFCtype`) that wraps `f`. -wrap!(wrapped_f, f, Realtype, 2); //initialize the solver -let mut solver = Solver::new( +let mut solver = cvode::Solver::new( LinearMultistepMethod::Adams, wrapped_f, 0., diff --git a/cvode-wrap/src/c_wrapping.rs b/cvode-wrap/src/c_wrapping.rs deleted file mode 100644 index 681e9cb..0000000 --- a/cvode-wrap/src/c_wrapping.rs +++ /dev/null @@ -1,57 +0,0 @@ -//! Code used to wrap the RHS function for C. -//! -//! The most user friendly way to interact with this module is through the -//! [`wrap`] macro. - -use std::os::raw::c_int; - -use cvode_5_sys::N_VGetArrayPointer; - -use crate::{NVectorSerial, Realtype, RhsF, RhsResult}; - -/// The type of the function pointer for the right hand side that is passed to C. -/// -/// The advised method to declare such a function is to use the [`wrap`] macro. -pub type RhsFCtype = extern "C" fn( - t: Realtype, - y: *const NVectorSerial, - ydot: *mut NVectorSerial, - user_data: *const UserData, -) -> c_int; - -/// The wrapping function. -/// -/// Internally used in [`wrap`]. -pub fn wrap_f( - f: RhsF, - t: Realtype, - y: *const NVectorSerial, - ydot: *mut NVectorSerial, - data: &UserData, -) -> c_int { - let y = unsafe { &*(N_VGetArrayPointer(y as _) as *const [f64; N]) }; - let ydot = unsafe { &mut *(N_VGetArrayPointer(ydot as _) as *mut [f64; N]) }; - let res = f(t, y, ydot, data); - match res { - RhsResult::Ok => 0, - RhsResult::RecoverableError(e) => e as c_int, - RhsResult::NonRecoverableError(e) => -(e as c_int), - } -} - -/// Declares an `extern "C"` function of type [`RhsFCtype`] that wraps a -/// normal Rust `fn` of type [`RhsF`] -#[macro_export] -macro_rules! wrap { - ($wrapped_f_name: ident, $f_name: ident, $user_data_type: ty, $problem_size: expr) => { - extern "C" fn $wrapped_f_name( - t: Realtype, - y: *const NVectorSerial<$problem_size>, - ydot: *mut NVectorSerial<$problem_size>, - data: *const $user_data_type, - ) -> std::os::raw::c_int { - let data = unsafe { &*data }; - c_wrapping::wrap_f($f_name, t, y, ydot, data) - } - }; -} diff --git a/cvode-wrap/src/cvode.rs b/cvode-wrap/src/cvode.rs index e226ba0..de90330 100644 --- a/cvode-wrap/src/cvode.rs +++ b/cvode-wrap/src/cvode.rs @@ -2,7 +2,10 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull} use cvode_5_sys::{SUNLinearSolver, SUNMatrix}; -use crate::{LinearMultistepMethod, NVectorSerialHeapAllocated, Realtype, Result, StepKind, c_wrapping, check_flag_is_succes, check_non_null}; +use crate::{ + check_flag_is_succes, check_non_null, LinearMultistepMethod, NVectorSerial, + NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, WrappingUserData, +}; #[repr(C)] struct CvodeMemoryBlock { @@ -58,20 +61,48 @@ impl AbsTolerance { /// `N` is the "problem size", that is the dimension of the state space. /// /// See [crate-level](`crate`) documentation for more. -pub struct Solver { +pub struct Solver { mem: CvodeMemoryBlockNonNullPtr, y0: NVectorSerialHeapAllocated, sunmatrix: SUNMatrix, linsolver: SUNLinearSolver, atol: AbsTolerance, - user_data: Pin>, + user_data: Pin>>, } +/// The wrapping function. +/// +/// Internally used in [`wrap`]. +extern "C" fn wrap_f( + t: Realtype, + y: *const NVectorSerial, + ydot: *mut NVectorSerial, + data: *const WrappingUserData, +) -> c_int +where + F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult, +{ + let y = unsafe { &*y }.as_slice(); + let ydot = unsafe { &mut *ydot }.as_slice_mut(); + let WrappingUserData { + actual_user_data: data, + f, + } = unsafe { &*data }; + let res = f(t, y, ydot, data); + match res { + RhsResult::Ok => 0, + RhsResult::RecoverableError(e) => e as c_int, + RhsResult::NonRecoverableError(e) => -(e as c_int), + } +} -impl Solver { +impl Solver +where + F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult, +{ pub fn new( method: LinearMultistepMethod, - f: c_wrapping::RhsFCtype, + f: F, t0: Realtype, y0: &[Realtype; N], rtol: Realtype, @@ -94,7 +125,10 @@ impl Solver { let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; check_non_null(linsolver, "SUNDenseLinearSolver")? }; - let user_data = Box::pin(user_data); + let user_data = Box::pin(WrappingUserData { + actual_user_data: user_data, + f, + }); let res = Solver { mem, y0, @@ -104,10 +138,11 @@ impl Solver { user_data, }; { + let fn_ptr = wrap_f:: as extern "C" fn(_, _, _, _) -> _; let flag = unsafe { cvode_5_sys::CVodeInit( mem.as_raw(), - Some(std::mem::transmute(f)), + Some(std::mem::transmute(fn_ptr)), t0, res.y0.as_raw(), ) @@ -163,7 +198,7 @@ impl Solver { } } -impl Drop for Solver { +impl Drop for Solver { fn drop(&mut self) { unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) } unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) }; @@ -173,7 +208,7 @@ impl Drop for Solver { #[cfg(test)] mod tests { - use crate::{RhsResult,wrap, NVectorSerial}; + use crate::RhsResult; use super::*; @@ -187,14 +222,12 @@ mod tests { RhsResult::Ok } - wrap!(wrapped_f, f, (), 2); - #[test] fn create() { let y0 = [0., 1.]; let _solver = Solver::new( LinearMultistepMethod::Adams, - wrapped_f, + f, 0., &y0, 1e-4, @@ -202,4 +235,4 @@ mod tests { (), ); } -} \ No newline at end of file +} diff --git a/cvode-wrap/src/lib.rs b/cvode-wrap/src/lib.rs index c01d26d..b9515f2 100644 --- a/cvode-wrap/src/lib.rs +++ b/cvode-wrap/src/lib.rs @@ -5,8 +5,6 @@ use cvode_5_sys::realtype; mod nvector; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; -pub mod c_wrapping; - pub mod cvode; /// The floatting-point type sundials was compiled with @@ -43,16 +41,6 @@ pub enum RhsResult { NonRecoverableError(u8), } -/// The type of the "rust" Rhs function that can be then wrapped with [`wrap`]. -/// -/// # Type arguments -/// - `UserData` is any stuct representing "parameters" of the system, that is data -/// that doesn't change during the evolution of the state, but is needed to compute -/// the right-hand side. -/// - `N` is the dimension of the system -pub type RhsF = - fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult; - /// Type of integration step #[repr(u32)] pub enum StepKind { @@ -67,6 +55,11 @@ pub enum StepKind { OneStep = cvode_5_sys::CV_ONE_STEP, } +struct WrappingUserData { + actual_user_data: UserData, + f: F, +} + /// The error type for this crate #[derive(Debug)] pub enum Error { diff --git a/example/src/main.rs b/example/src/main.rs index 94e7836..d611d01 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -7,11 +7,10 @@ fn main() { *ydot = [y[1], -y[0] * k]; RhsResult::Ok } - wrap!(wrapped_f, f, Realtype, 2); //initialize the solver let mut solver = cvode::Solver::new( LinearMultistepMethod::Adams, - wrapped_f, + f, 0., &y0, 1e-4,