Simplify by removing wrapping at the cost of adding one more indirection
This commit is contained in:
parent
5a24ba0ee5
commit
0b20e6320c
@ -18,11 +18,8 @@ fn f(
|
||||
*ydot = [y[1], -y[0] * k];
|
||||
RhsResult::Ok
|
||||
}
|
||||
// Use the `wrap!` macro to define a `wrapped_f` function callable
|
||||
// from C (of type `c_wrapping::RhsFCtype`) that wraps `f`.
|
||||
wrap!(wrapped_f, f, Realtype, 2);
|
||||
//initialize the solver
|
||||
let mut solver = Solver::new(
|
||||
let mut solver = cvode::Solver::new(
|
||||
LinearMultistepMethod::Adams,
|
||||
wrapped_f,
|
||||
0.,
|
||||
|
@ -1,57 +0,0 @@
|
||||
//! Code used to wrap the RHS function for C.
|
||||
//!
|
||||
//! The most user friendly way to interact with this module is through the
|
||||
//! [`wrap`] macro.
|
||||
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use cvode_5_sys::N_VGetArrayPointer;
|
||||
|
||||
use crate::{NVectorSerial, Realtype, RhsF, RhsResult};
|
||||
|
||||
/// The type of the function pointer for the right hand side that is passed to C.
|
||||
///
|
||||
/// The advised method to declare such a function is to use the [`wrap`] macro.
|
||||
pub type RhsFCtype<UserData, const N: usize> = extern "C" fn(
|
||||
t: Realtype,
|
||||
y: *const NVectorSerial<N>,
|
||||
ydot: *mut NVectorSerial<N>,
|
||||
user_data: *const UserData,
|
||||
) -> c_int;
|
||||
|
||||
/// The wrapping function.
|
||||
///
|
||||
/// Internally used in [`wrap`].
|
||||
pub fn wrap_f<UserData, const N: usize>(
|
||||
f: RhsF<UserData, N>,
|
||||
t: Realtype,
|
||||
y: *const NVectorSerial<N>,
|
||||
ydot: *mut NVectorSerial<N>,
|
||||
data: &UserData,
|
||||
) -> c_int {
|
||||
let y = unsafe { &*(N_VGetArrayPointer(y as _) as *const [f64; N]) };
|
||||
let ydot = unsafe { &mut *(N_VGetArrayPointer(ydot as _) as *mut [f64; N]) };
|
||||
let res = f(t, y, ydot, data);
|
||||
match res {
|
||||
RhsResult::Ok => 0,
|
||||
RhsResult::RecoverableError(e) => e as c_int,
|
||||
RhsResult::NonRecoverableError(e) => -(e as c_int),
|
||||
}
|
||||
}
|
||||
|
||||
/// Declares an `extern "C"` function of type [`RhsFCtype`] that wraps a
|
||||
/// normal Rust `fn` of type [`RhsF`]
|
||||
#[macro_export]
|
||||
macro_rules! wrap {
|
||||
($wrapped_f_name: ident, $f_name: ident, $user_data_type: ty, $problem_size: expr) => {
|
||||
extern "C" fn $wrapped_f_name(
|
||||
t: Realtype,
|
||||
y: *const NVectorSerial<$problem_size>,
|
||||
ydot: *mut NVectorSerial<$problem_size>,
|
||||
data: *const $user_data_type,
|
||||
) -> std::os::raw::c_int {
|
||||
let data = unsafe { &*data };
|
||||
c_wrapping::wrap_f($f_name, t, y, ydot, data)
|
||||
}
|
||||
};
|
||||
}
|
@ -2,7 +2,10 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}
|
||||
|
||||
use cvode_5_sys::{SUNLinearSolver, SUNMatrix};
|
||||
|
||||
use crate::{LinearMultistepMethod, NVectorSerialHeapAllocated, Realtype, Result, StepKind, c_wrapping, check_flag_is_succes, check_non_null};
|
||||
use crate::{
|
||||
check_flag_is_succes, check_non_null, LinearMultistepMethod, NVectorSerial,
|
||||
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, WrappingUserData,
|
||||
};
|
||||
|
||||
#[repr(C)]
|
||||
struct CvodeMemoryBlock {
|
||||
@ -58,20 +61,48 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
|
||||
/// `N` is the "problem size", that is the dimension of the state space.
|
||||
///
|
||||
/// See [crate-level](`crate`) documentation for more.
|
||||
pub struct Solver<UserData, const N: usize> {
|
||||
pub struct Solver<UserData, F, const N: usize> {
|
||||
mem: CvodeMemoryBlockNonNullPtr,
|
||||
y0: NVectorSerialHeapAllocated<N>,
|
||||
sunmatrix: SUNMatrix,
|
||||
linsolver: SUNLinearSolver,
|
||||
atol: AbsTolerance<N>,
|
||||
user_data: Pin<Box<UserData>>,
|
||||
user_data: Pin<Box<WrappingUserData<UserData, F>>>,
|
||||
}
|
||||
|
||||
/// The wrapping function.
|
||||
///
|
||||
/// Internally used in [`wrap`].
|
||||
extern "C" fn wrap_f<UserData, F, const N: usize>(
|
||||
t: Realtype,
|
||||
y: *const NVectorSerial<N>,
|
||||
ydot: *mut NVectorSerial<N>,
|
||||
data: *const WrappingUserData<UserData, F>,
|
||||
) -> c_int
|
||||
where
|
||||
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
|
||||
{
|
||||
let y = unsafe { &*y }.as_slice();
|
||||
let ydot = unsafe { &mut *ydot }.as_slice_mut();
|
||||
let WrappingUserData {
|
||||
actual_user_data: data,
|
||||
f,
|
||||
} = unsafe { &*data };
|
||||
let res = f(t, y, ydot, data);
|
||||
match res {
|
||||
RhsResult::Ok => 0,
|
||||
RhsResult::RecoverableError(e) => e as c_int,
|
||||
RhsResult::NonRecoverableError(e) => -(e as c_int),
|
||||
}
|
||||
}
|
||||
|
||||
impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
impl<UserData, F, const N: usize> Solver<UserData, F, N>
|
||||
where
|
||||
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
|
||||
{
|
||||
pub fn new(
|
||||
method: LinearMultistepMethod,
|
||||
f: c_wrapping::RhsFCtype<UserData, N>,
|
||||
f: F,
|
||||
t0: Realtype,
|
||||
y0: &[Realtype; N],
|
||||
rtol: Realtype,
|
||||
@ -94,7 +125,10 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
|
||||
check_non_null(linsolver, "SUNDenseLinearSolver")?
|
||||
};
|
||||
let user_data = Box::pin(user_data);
|
||||
let user_data = Box::pin(WrappingUserData {
|
||||
actual_user_data: user_data,
|
||||
f,
|
||||
});
|
||||
let res = Solver {
|
||||
mem,
|
||||
y0,
|
||||
@ -104,10 +138,11 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
user_data,
|
||||
};
|
||||
{
|
||||
let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
|
||||
let flag = unsafe {
|
||||
cvode_5_sys::CVodeInit(
|
||||
mem.as_raw(),
|
||||
Some(std::mem::transmute(f)),
|
||||
Some(std::mem::transmute(fn_ptr)),
|
||||
t0,
|
||||
res.y0.as_raw(),
|
||||
)
|
||||
@ -163,7 +198,7 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
||||
impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
|
||||
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
|
||||
@ -173,7 +208,7 @@ impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{RhsResult,wrap, NVectorSerial};
|
||||
use crate::RhsResult;
|
||||
|
||||
use super::*;
|
||||
|
||||
@ -187,14 +222,12 @@ mod tests {
|
||||
RhsResult::Ok
|
||||
}
|
||||
|
||||
wrap!(wrapped_f, f, (), 2);
|
||||
|
||||
#[test]
|
||||
fn create() {
|
||||
let y0 = [0., 1.];
|
||||
let _solver = Solver::new(
|
||||
LinearMultistepMethod::Adams,
|
||||
wrapped_f,
|
||||
f,
|
||||
0.,
|
||||
&y0,
|
||||
1e-4,
|
||||
|
@ -5,8 +5,6 @@ use cvode_5_sys::realtype;
|
||||
mod nvector;
|
||||
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
|
||||
|
||||
pub mod c_wrapping;
|
||||
|
||||
pub mod cvode;
|
||||
|
||||
/// The floatting-point type sundials was compiled with
|
||||
@ -43,16 +41,6 @@ pub enum RhsResult {
|
||||
NonRecoverableError(u8),
|
||||
}
|
||||
|
||||
/// The type of the "rust" Rhs function that can be then wrapped with [`wrap`].
|
||||
///
|
||||
/// # Type arguments
|
||||
/// - `UserData` is any stuct representing "parameters" of the system, that is data
|
||||
/// that doesn't change during the evolution of the state, but is needed to compute
|
||||
/// the right-hand side.
|
||||
/// - `N` is the dimension of the system
|
||||
pub type RhsF<UserData, const N: usize> =
|
||||
fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
|
||||
|
||||
/// Type of integration step
|
||||
#[repr(u32)]
|
||||
pub enum StepKind {
|
||||
@ -67,6 +55,11 @@ pub enum StepKind {
|
||||
OneStep = cvode_5_sys::CV_ONE_STEP,
|
||||
}
|
||||
|
||||
struct WrappingUserData<UserData, F> {
|
||||
actual_user_data: UserData,
|
||||
f: F,
|
||||
}
|
||||
|
||||
/// The error type for this crate
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
|
@ -7,11 +7,10 @@ fn main() {
|
||||
*ydot = [y[1], -y[0] * k];
|
||||
RhsResult::Ok
|
||||
}
|
||||
wrap!(wrapped_f, f, Realtype, 2);
|
||||
//initialize the solver
|
||||
let mut solver = cvode::Solver::new(
|
||||
LinearMultistepMethod::Adams,
|
||||
wrapped_f,
|
||||
f,
|
||||
0.,
|
||||
&y0,
|
||||
1e-4,
|
||||
|
Loading…
Reference in New Issue
Block a user