Add user_data

This commit is contained in:
Arthur Carcano 2021-05-09 10:13:03 +02:00
parent 18c7c0ab02
commit a1231f7d85
3 changed files with 96 additions and 47 deletions

View File

@ -1,4 +1,4 @@
use std::convert::TryInto;
use std::{convert::TryInto, pin::Pin};
use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull};
use cvode::SUNMatrix;
@ -47,12 +47,13 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
}
}
pub struct Solver<const N: usize> {
pub struct Solver<UserData, const N: usize> {
mem: CvodeMemoryBlockNonNullPtr,
y0: NVectorSerial<N>,
sunmatrix: SUNMatrix,
linsolver: SUNLinearSolver,
_atol: AbsTolerance<N>,
atol: AbsTolerance<N>,
user_data: Pin<Box<UserData>>,
}
pub enum RhsResult {
@ -61,14 +62,14 @@ pub enum RhsResult {
NonRecoverableError(u8),
}
type RhsF<const N: usize> = fn(F, &[F; N], &mut [F; N], *mut c_void) -> RhsResult;
type RhsF<UserData, const N: usize> = fn(F, &[F; N], &mut [F; N], &UserData) -> RhsResult;
pub fn wrap_f<const N: usize>(
f: RhsF<N>,
pub fn wrap_f<UserData, const N: usize>(
f: RhsF<UserData, N>,
t: F,
y: CVector,
ydot: CVector,
data: *mut c_void,
data: &UserData,
) -> c_int {
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) };
@ -82,19 +83,20 @@ pub fn wrap_f<const N: usize>(
#[macro_export]
macro_rules! wrap {
($wrapped_f_name: ident, $f_name: ident) => {
($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
extern "C" fn $wrapped_f_name(
t: F,
y: CVector,
ydot: CVector,
data: *mut std::ffi::c_void,
data: *const $user_data,
) -> std::os::raw::c_int {
let data = unsafe { std::mem::transmute(data) };
wrap_f($f_name, t, y, ydot, data)
}
};
}
type RhsFCtype = extern "C" fn(F, CVector, CVector, *mut c_void) -> c_int;
type RhsFCtype<UserData> = extern "C" fn(F, CVector, CVector, *const UserData) -> c_int;
#[repr(u32)]
pub enum StepKind {
@ -138,23 +140,54 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
}
}
impl<const N: usize> Solver<N> {
impl<UserData, const N: usize> Solver<UserData, N> {
pub fn new(
method: LinearMultistepMethod,
f: RhsFCtype,
f: RhsFCtype<UserData>,
t0: F,
y0: &[F; N],
rtol: F,
atol: AbsTolerance<N>,
user_data: UserData,
) -> Result<Self> {
assert_eq!(y0.len(), N);
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
let mem: CvodeMemoryBlockNonNullPtr =
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into();
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerial::new_from(y0);
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(f), t0, y0.as_raw() as _) };
check_flag_is_succes(flag, "CVodeInit")?;
match &atol {
let matrix = {
let matrix = unsafe {
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
N.try_into().unwrap(),
N.try_into().unwrap(),
)
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe {
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(
y0.as_raw() as _,
matrix.as_ptr() as _,
)
};
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(user_data);
let res = Solver {
mem,
y0,
sunmatrix: matrix.as_ptr() as _,
linsolver: linsolver.as_ptr() as _,
atol,
user_data,
};
{
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
check_flag_is_succes(flag, "CVodeInit")?;
}
match &res.atol {
&AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
@ -165,27 +198,26 @@ impl<const N: usize> Solver<N> {
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
let matrix = unsafe {
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
N.try_into().unwrap(),
N.try_into().unwrap(),
)
};
check_non_null(matrix, "SUNDenseMatrix")?;
let linsolver = unsafe {
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(y0.as_raw() as _, matrix as _)
};
check_non_null(linsolver, "SUNDenseLinearSolver")?;
let flag =
unsafe { cvode::CVodeSetLinearSolver(mem.as_raw(), linsolver as _, matrix as _) };
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
Ok(Solver {
mem,
y0,
sunmatrix: matrix as _,
linsolver: linsolver as _,
_atol: atol as _,
})
{
let flag = unsafe {
cvode::CVodeSetLinearSolver(
mem.as_raw(),
linsolver.as_ptr() as _,
matrix.as_ptr() as _,
)
};
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
}
{
let flag = unsafe {
cvode::CVodeSetUserData(
mem.as_raw(),
std::mem::transmute(res.user_data.as_ref().get_ref()),
)
};
check_flag_is_succes(flag, "CVodeSetUserData")?;
}
Ok(res)
}
pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> {
@ -204,7 +236,7 @@ impl<const N: usize> Solver<N> {
}
}
impl<const N: usize> Drop for Solver<N> {
impl<UserData, const N: usize> Drop for Solver<UserData, N> {
fn drop(&mut self) {
unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode::SUNLinSolFree(self.linsolver) };
@ -216,16 +248,24 @@ impl<const N: usize> Drop for Solver<N> {
mod tests {
use super::*;
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: &()) -> RhsResult {
*ydot = [y[1], -y[0]];
RhsResult::Ok
}
wrap!(wrapped_f, f);
wrap!(wrapped_f, f, ());
#[test]
fn create() {
let y0 = [0., 1.];
let _solver = Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, AbsTolerance::Scalar(1e-4));
let _solver = Solver::new(
LinearMultistepMethod::ADAMS,
wrapped_f,
0.,
&y0,
1e-4,
AbsTolerance::Scalar(1e-4),
(),
);
}
}

View File

@ -1,4 +1,4 @@
use std::{convert::TryInto, intrinsics::transmute, ptr::NonNull};
use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
use cvode_5_sys::{cvode::realtype, nvector_serial};
@ -8,6 +8,14 @@ pub struct NVectorSerial<const SIZE: usize> {
inner: NonNull<nvector_serial::_generic_N_Vector>,
}
impl<const SIZE: usize> Deref for NVectorSerial<SIZE> {
type Target = nvector_serial::_generic_N_Vector;
fn deref(&self) -> &Self::Target {
unsafe {self.inner.as_ref()}
}
}
impl<const SIZE: usize> NVectorSerial<SIZE> {
pub fn as_ref(&self) -> &[realtype; SIZE] {
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }

View File

@ -5,11 +5,11 @@ use cvode_wrap::*;
fn main() {
let y0 = [0., 1.];
//define the right-hand-side
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
*ydot = [y[1], -y[0] / 10.];
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], k: &F) -> RhsResult {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
wrap!(wrapped_f, f);
wrap!(wrapped_f, f, F);
//initialize the solver
let mut solver = Solver::new(
LinearMultistepMethod::ADAMS,
@ -18,6 +18,7 @@ fn main() {
&y0,
1e-4,
AbsTolerance::scalar(1e-4),
1e-2
)
.unwrap();
//and solve