Simplify by removing wrapping at the cost of adding one more indirection

This commit is contained in:
Arthur Carcano 2021-06-09 16:06:15 +02:00
parent 5a24ba0ee5
commit 0b20e6320c
5 changed files with 53 additions and 88 deletions

View File

@ -18,11 +18,8 @@ fn f(
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
// Use the `wrap!` macro to define a `wrapped_f` function callable
// from C (of type `c_wrapping::RhsFCtype`) that wraps `f`.
wrap!(wrapped_f, f, Realtype, 2);
//initialize the solver
let mut solver = Solver::new(
let mut solver = cvode::Solver::new(
LinearMultistepMethod::Adams,
wrapped_f,
0.,

View File

@ -1,57 +0,0 @@
//! Code used to wrap the RHS function for C.
//!
//! The most user friendly way to interact with this module is through the
//! [`wrap`] macro.
use std::os::raw::c_int;
use cvode_5_sys::N_VGetArrayPointer;
use crate::{NVectorSerial, Realtype, RhsF, RhsResult};
/// The type of the function pointer for the right hand side that is passed to C.
///
/// The advised method to declare such a function is to use the [`wrap`] macro.
pub type RhsFCtype<UserData, const N: usize> = extern "C" fn(
t: Realtype,
y: *const NVectorSerial<N>,
ydot: *mut NVectorSerial<N>,
user_data: *const UserData,
) -> c_int;
/// The wrapping function.
///
/// Internally used in [`wrap`].
pub fn wrap_f<UserData, const N: usize>(
f: RhsF<UserData, N>,
t: Realtype,
y: *const NVectorSerial<N>,
ydot: *mut NVectorSerial<N>,
data: &UserData,
) -> c_int {
let y = unsafe { &*(N_VGetArrayPointer(y as _) as *const [f64; N]) };
let ydot = unsafe { &mut *(N_VGetArrayPointer(ydot as _) as *mut [f64; N]) };
let res = f(t, y, ydot, data);
match res {
RhsResult::Ok => 0,
RhsResult::RecoverableError(e) => e as c_int,
RhsResult::NonRecoverableError(e) => -(e as c_int),
}
}
/// Declares an `extern "C"` function of type [`RhsFCtype`] that wraps a
/// normal Rust `fn` of type [`RhsF`]
#[macro_export]
macro_rules! wrap {
($wrapped_f_name: ident, $f_name: ident, $user_data_type: ty, $problem_size: expr) => {
extern "C" fn $wrapped_f_name(
t: Realtype,
y: *const NVectorSerial<$problem_size>,
ydot: *mut NVectorSerial<$problem_size>,
data: *const $user_data_type,
) -> std::os::raw::c_int {
let data = unsafe { &*data };
c_wrapping::wrap_f($f_name, t, y, ydot, data)
}
};
}

View File

@ -2,7 +2,10 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}
use cvode_5_sys::{SUNLinearSolver, SUNMatrix};
use crate::{LinearMultistepMethod, NVectorSerialHeapAllocated, Realtype, Result, StepKind, c_wrapping, check_flag_is_succes, check_non_null};
use crate::{
check_flag_is_succes, check_non_null, LinearMultistepMethod, NVectorSerial,
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, WrappingUserData,
};
#[repr(C)]
struct CvodeMemoryBlock {
@ -58,20 +61,48 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
/// `N` is the "problem size", that is the dimension of the state space.
///
/// See [crate-level](`crate`) documentation for more.
pub struct Solver<UserData, const N: usize> {
pub struct Solver<UserData, F, const N: usize> {
mem: CvodeMemoryBlockNonNullPtr,
y0: NVectorSerialHeapAllocated<N>,
sunmatrix: SUNMatrix,
linsolver: SUNLinearSolver,
atol: AbsTolerance<N>,
user_data: Pin<Box<UserData>>,
user_data: Pin<Box<WrappingUserData<UserData, F>>>,
}
/// The wrapping function.
///
/// Internally used in [`wrap`].
extern "C" fn wrap_f<UserData, F, const N: usize>(
t: Realtype,
y: *const NVectorSerial<N>,
ydot: *mut NVectorSerial<N>,
data: *const WrappingUserData<UserData, F>,
) -> c_int
where
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
{
let y = unsafe { &*y }.as_slice();
let ydot = unsafe { &mut *ydot }.as_slice_mut();
let WrappingUserData {
actual_user_data: data,
f,
} = unsafe { &*data };
let res = f(t, y, ydot, data);
match res {
RhsResult::Ok => 0,
RhsResult::RecoverableError(e) => e as c_int,
RhsResult::NonRecoverableError(e) => -(e as c_int),
}
}
impl<UserData, const N: usize> Solver<UserData, N> {
impl<UserData, F, const N: usize> Solver<UserData, F, N>
where
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
{
pub fn new(
method: LinearMultistepMethod,
f: c_wrapping::RhsFCtype<UserData, N>,
f: F,
t0: Realtype,
y0: &[Realtype; N],
rtol: Realtype,
@ -94,7 +125,10 @@ impl<UserData, const N: usize> Solver<UserData, N> {
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(user_data);
let user_data = Box::pin(WrappingUserData {
actual_user_data: user_data,
f,
});
let res = Solver {
mem,
y0,
@ -104,10 +138,11 @@ impl<UserData, const N: usize> Solver<UserData, N> {
user_data,
};
{
let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
let flag = unsafe {
cvode_5_sys::CVodeInit(
mem.as_raw(),
Some(std::mem::transmute(f)),
Some(std::mem::transmute(fn_ptr)),
t0,
res.y0.as_raw(),
)
@ -163,7 +198,7 @@ impl<UserData, const N: usize> Solver<UserData, N> {
}
}
impl<UserData, const N: usize> Drop for Solver<UserData, N> {
impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
fn drop(&mut self) {
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
@ -173,7 +208,7 @@ impl<UserData, const N: usize> Drop for Solver<UserData, N> {
#[cfg(test)]
mod tests {
use crate::{RhsResult,wrap, NVectorSerial};
use crate::RhsResult;
use super::*;
@ -187,14 +222,12 @@ mod tests {
RhsResult::Ok
}
wrap!(wrapped_f, f, (), 2);
#[test]
fn create() {
let y0 = [0., 1.];
let _solver = Solver::new(
LinearMultistepMethod::Adams,
wrapped_f,
f,
0.,
&y0,
1e-4,
@ -202,4 +235,4 @@ mod tests {
(),
);
}
}
}

View File

@ -5,8 +5,6 @@ use cvode_5_sys::realtype;
mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
pub mod c_wrapping;
pub mod cvode;
/// The floatting-point type sundials was compiled with
@ -43,16 +41,6 @@ pub enum RhsResult {
NonRecoverableError(u8),
}
/// The type of the "rust" Rhs function that can be then wrapped with [`wrap`].
///
/// # Type arguments
/// - `UserData` is any stuct representing "parameters" of the system, that is data
/// that doesn't change during the evolution of the state, but is needed to compute
/// the right-hand side.
/// - `N` is the dimension of the system
pub type RhsF<UserData, const N: usize> =
fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
/// Type of integration step
#[repr(u32)]
pub enum StepKind {
@ -67,6 +55,11 @@ pub enum StepKind {
OneStep = cvode_5_sys::CV_ONE_STEP,
}
struct WrappingUserData<UserData, F> {
actual_user_data: UserData,
f: F,
}
/// The error type for this crate
#[derive(Debug)]
pub enum Error {

View File

@ -7,11 +7,10 @@ fn main() {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
wrap!(wrapped_f, f, Realtype, 2);
//initialize the solver
let mut solver = cvode::Solver::new(
LinearMultistepMethod::Adams,
wrapped_f,
f,
0.,
&y0,
1e-4,