Simplify by removing wrapping at the cost of adding one more indirection
This commit is contained in:
parent
5a24ba0ee5
commit
0b20e6320c
@ -18,11 +18,8 @@ fn f(
|
|||||||
*ydot = [y[1], -y[0] * k];
|
*ydot = [y[1], -y[0] * k];
|
||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
// Use the `wrap!` macro to define a `wrapped_f` function callable
|
|
||||||
// from C (of type `c_wrapping::RhsFCtype`) that wraps `f`.
|
|
||||||
wrap!(wrapped_f, f, Realtype, 2);
|
|
||||||
//initialize the solver
|
//initialize the solver
|
||||||
let mut solver = Solver::new(
|
let mut solver = cvode::Solver::new(
|
||||||
LinearMultistepMethod::Adams,
|
LinearMultistepMethod::Adams,
|
||||||
wrapped_f,
|
wrapped_f,
|
||||||
0.,
|
0.,
|
||||||
|
@ -1,57 +0,0 @@
|
|||||||
//! Code used to wrap the RHS function for C.
|
|
||||||
//!
|
|
||||||
//! The most user friendly way to interact with this module is through the
|
|
||||||
//! [`wrap`] macro.
|
|
||||||
|
|
||||||
use std::os::raw::c_int;
|
|
||||||
|
|
||||||
use cvode_5_sys::N_VGetArrayPointer;
|
|
||||||
|
|
||||||
use crate::{NVectorSerial, Realtype, RhsF, RhsResult};
|
|
||||||
|
|
||||||
/// The type of the function pointer for the right hand side that is passed to C.
|
|
||||||
///
|
|
||||||
/// The advised method to declare such a function is to use the [`wrap`] macro.
|
|
||||||
pub type RhsFCtype<UserData, const N: usize> = extern "C" fn(
|
|
||||||
t: Realtype,
|
|
||||||
y: *const NVectorSerial<N>,
|
|
||||||
ydot: *mut NVectorSerial<N>,
|
|
||||||
user_data: *const UserData,
|
|
||||||
) -> c_int;
|
|
||||||
|
|
||||||
/// The wrapping function.
|
|
||||||
///
|
|
||||||
/// Internally used in [`wrap`].
|
|
||||||
pub fn wrap_f<UserData, const N: usize>(
|
|
||||||
f: RhsF<UserData, N>,
|
|
||||||
t: Realtype,
|
|
||||||
y: *const NVectorSerial<N>,
|
|
||||||
ydot: *mut NVectorSerial<N>,
|
|
||||||
data: &UserData,
|
|
||||||
) -> c_int {
|
|
||||||
let y = unsafe { &*(N_VGetArrayPointer(y as _) as *const [f64; N]) };
|
|
||||||
let ydot = unsafe { &mut *(N_VGetArrayPointer(ydot as _) as *mut [f64; N]) };
|
|
||||||
let res = f(t, y, ydot, data);
|
|
||||||
match res {
|
|
||||||
RhsResult::Ok => 0,
|
|
||||||
RhsResult::RecoverableError(e) => e as c_int,
|
|
||||||
RhsResult::NonRecoverableError(e) => -(e as c_int),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Declares an `extern "C"` function of type [`RhsFCtype`] that wraps a
|
|
||||||
/// normal Rust `fn` of type [`RhsF`]
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! wrap {
|
|
||||||
($wrapped_f_name: ident, $f_name: ident, $user_data_type: ty, $problem_size: expr) => {
|
|
||||||
extern "C" fn $wrapped_f_name(
|
|
||||||
t: Realtype,
|
|
||||||
y: *const NVectorSerial<$problem_size>,
|
|
||||||
ydot: *mut NVectorSerial<$problem_size>,
|
|
||||||
data: *const $user_data_type,
|
|
||||||
) -> std::os::raw::c_int {
|
|
||||||
let data = unsafe { &*data };
|
|
||||||
c_wrapping::wrap_f($f_name, t, y, ydot, data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
@ -2,7 +2,10 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}
|
|||||||
|
|
||||||
use cvode_5_sys::{SUNLinearSolver, SUNMatrix};
|
use cvode_5_sys::{SUNLinearSolver, SUNMatrix};
|
||||||
|
|
||||||
use crate::{LinearMultistepMethod, NVectorSerialHeapAllocated, Realtype, Result, StepKind, c_wrapping, check_flag_is_succes, check_non_null};
|
use crate::{
|
||||||
|
check_flag_is_succes, check_non_null, LinearMultistepMethod, NVectorSerial,
|
||||||
|
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, WrappingUserData,
|
||||||
|
};
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
struct CvodeMemoryBlock {
|
struct CvodeMemoryBlock {
|
||||||
@ -58,20 +61,48 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
|
|||||||
/// `N` is the "problem size", that is the dimension of the state space.
|
/// `N` is the "problem size", that is the dimension of the state space.
|
||||||
///
|
///
|
||||||
/// See [crate-level](`crate`) documentation for more.
|
/// See [crate-level](`crate`) documentation for more.
|
||||||
pub struct Solver<UserData, const N: usize> {
|
pub struct Solver<UserData, F, const N: usize> {
|
||||||
mem: CvodeMemoryBlockNonNullPtr,
|
mem: CvodeMemoryBlockNonNullPtr,
|
||||||
y0: NVectorSerialHeapAllocated<N>,
|
y0: NVectorSerialHeapAllocated<N>,
|
||||||
sunmatrix: SUNMatrix,
|
sunmatrix: SUNMatrix,
|
||||||
linsolver: SUNLinearSolver,
|
linsolver: SUNLinearSolver,
|
||||||
atol: AbsTolerance<N>,
|
atol: AbsTolerance<N>,
|
||||||
user_data: Pin<Box<UserData>>,
|
user_data: Pin<Box<WrappingUserData<UserData, F>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The wrapping function.
|
||||||
|
///
|
||||||
|
/// Internally used in [`wrap`].
|
||||||
|
extern "C" fn wrap_f<UserData, F, const N: usize>(
|
||||||
|
t: Realtype,
|
||||||
|
y: *const NVectorSerial<N>,
|
||||||
|
ydot: *mut NVectorSerial<N>,
|
||||||
|
data: *const WrappingUserData<UserData, F>,
|
||||||
|
) -> c_int
|
||||||
|
where
|
||||||
|
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
|
||||||
|
{
|
||||||
|
let y = unsafe { &*y }.as_slice();
|
||||||
|
let ydot = unsafe { &mut *ydot }.as_slice_mut();
|
||||||
|
let WrappingUserData {
|
||||||
|
actual_user_data: data,
|
||||||
|
f,
|
||||||
|
} = unsafe { &*data };
|
||||||
|
let res = f(t, y, ydot, data);
|
||||||
|
match res {
|
||||||
|
RhsResult::Ok => 0,
|
||||||
|
RhsResult::RecoverableError(e) => e as c_int,
|
||||||
|
RhsResult::NonRecoverableError(e) => -(e as c_int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<UserData, const N: usize> Solver<UserData, N> {
|
impl<UserData, F, const N: usize> Solver<UserData, F, N>
|
||||||
|
where
|
||||||
|
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
|
||||||
|
{
|
||||||
pub fn new(
|
pub fn new(
|
||||||
method: LinearMultistepMethod,
|
method: LinearMultistepMethod,
|
||||||
f: c_wrapping::RhsFCtype<UserData, N>,
|
f: F,
|
||||||
t0: Realtype,
|
t0: Realtype,
|
||||||
y0: &[Realtype; N],
|
y0: &[Realtype; N],
|
||||||
rtol: Realtype,
|
rtol: Realtype,
|
||||||
@ -94,7 +125,10 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
|||||||
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
|
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
|
||||||
check_non_null(linsolver, "SUNDenseLinearSolver")?
|
check_non_null(linsolver, "SUNDenseLinearSolver")?
|
||||||
};
|
};
|
||||||
let user_data = Box::pin(user_data);
|
let user_data = Box::pin(WrappingUserData {
|
||||||
|
actual_user_data: user_data,
|
||||||
|
f,
|
||||||
|
});
|
||||||
let res = Solver {
|
let res = Solver {
|
||||||
mem,
|
mem,
|
||||||
y0,
|
y0,
|
||||||
@ -104,10 +138,11 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
|||||||
user_data,
|
user_data,
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
|
let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
|
||||||
let flag = unsafe {
|
let flag = unsafe {
|
||||||
cvode_5_sys::CVodeInit(
|
cvode_5_sys::CVodeInit(
|
||||||
mem.as_raw(),
|
mem.as_raw(),
|
||||||
Some(std::mem::transmute(f)),
|
Some(std::mem::transmute(fn_ptr)),
|
||||||
t0,
|
t0,
|
||||||
res.y0.as_raw(),
|
res.y0.as_raw(),
|
||||||
)
|
)
|
||||||
@ -163,7 +198,7 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
|
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
|
||||||
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
|
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
|
||||||
@ -173,7 +208,7 @@ impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{RhsResult,wrap, NVectorSerial};
|
use crate::RhsResult;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@ -187,14 +222,12 @@ mod tests {
|
|||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
wrap!(wrapped_f, f, (), 2);
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn create() {
|
fn create() {
|
||||||
let y0 = [0., 1.];
|
let y0 = [0., 1.];
|
||||||
let _solver = Solver::new(
|
let _solver = Solver::new(
|
||||||
LinearMultistepMethod::Adams,
|
LinearMultistepMethod::Adams,
|
||||||
wrapped_f,
|
f,
|
||||||
0.,
|
0.,
|
||||||
&y0,
|
&y0,
|
||||||
1e-4,
|
1e-4,
|
||||||
|
@ -5,8 +5,6 @@ use cvode_5_sys::realtype;
|
|||||||
mod nvector;
|
mod nvector;
|
||||||
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
|
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
|
||||||
|
|
||||||
pub mod c_wrapping;
|
|
||||||
|
|
||||||
pub mod cvode;
|
pub mod cvode;
|
||||||
|
|
||||||
/// The floatting-point type sundials was compiled with
|
/// The floatting-point type sundials was compiled with
|
||||||
@ -43,16 +41,6 @@ pub enum RhsResult {
|
|||||||
NonRecoverableError(u8),
|
NonRecoverableError(u8),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The type of the "rust" Rhs function that can be then wrapped with [`wrap`].
|
|
||||||
///
|
|
||||||
/// # Type arguments
|
|
||||||
/// - `UserData` is any stuct representing "parameters" of the system, that is data
|
|
||||||
/// that doesn't change during the evolution of the state, but is needed to compute
|
|
||||||
/// the right-hand side.
|
|
||||||
/// - `N` is the dimension of the system
|
|
||||||
pub type RhsF<UserData, const N: usize> =
|
|
||||||
fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
|
|
||||||
|
|
||||||
/// Type of integration step
|
/// Type of integration step
|
||||||
#[repr(u32)]
|
#[repr(u32)]
|
||||||
pub enum StepKind {
|
pub enum StepKind {
|
||||||
@ -67,6 +55,11 @@ pub enum StepKind {
|
|||||||
OneStep = cvode_5_sys::CV_ONE_STEP,
|
OneStep = cvode_5_sys::CV_ONE_STEP,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct WrappingUserData<UserData, F> {
|
||||||
|
actual_user_data: UserData,
|
||||||
|
f: F,
|
||||||
|
}
|
||||||
|
|
||||||
/// The error type for this crate
|
/// The error type for this crate
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
|
@ -7,11 +7,10 @@ fn main() {
|
|||||||
*ydot = [y[1], -y[0] * k];
|
*ydot = [y[1], -y[0] * k];
|
||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
wrap!(wrapped_f, f, Realtype, 2);
|
|
||||||
//initialize the solver
|
//initialize the solver
|
||||||
let mut solver = cvode::Solver::new(
|
let mut solver = cvode::Solver::new(
|
||||||
LinearMultistepMethod::Adams,
|
LinearMultistepMethod::Adams,
|
||||||
wrapped_f,
|
f,
|
||||||
0.,
|
0.,
|
||||||
&y0,
|
&y0,
|
||||||
1e-4,
|
1e-4,
|
||||||
|
Loading…
Reference in New Issue
Block a user