From a1231f7d85e4a9026ffaf5a965569de4a6477c62 Mon Sep 17 00:00:00 2001 From: Arthur Carcano <53921575+krtab@users.noreply.github.com> Date: Sun, 9 May 2021 10:13:03 +0200 Subject: [PATCH] Add user_data --- cvode-wrap/src/lib.rs | 126 +++++++++++++++++++++++++------------- cvode-wrap/src/nvector.rs | 10 ++- test-solver/src/main.rs | 7 ++- 3 files changed, 96 insertions(+), 47 deletions(-) diff --git a/cvode-wrap/src/lib.rs b/cvode-wrap/src/lib.rs index 82d4a97..0dd4387 100644 --- a/cvode-wrap/src/lib.rs +++ b/cvode-wrap/src/lib.rs @@ -1,4 +1,4 @@ -use std::convert::TryInto; +use std::{convert::TryInto, pin::Pin}; use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull}; use cvode::SUNMatrix; @@ -47,12 +47,13 @@ impl From> for CvodeMemoryBlockNonNullPtr { } } -pub struct Solver { +pub struct Solver { mem: CvodeMemoryBlockNonNullPtr, y0: NVectorSerial, sunmatrix: SUNMatrix, linsolver: SUNLinearSolver, - _atol: AbsTolerance, + atol: AbsTolerance, + user_data: Pin>, } pub enum RhsResult { @@ -61,14 +62,14 @@ pub enum RhsResult { NonRecoverableError(u8), } -type RhsF = fn(F, &[F; N], &mut [F; N], *mut c_void) -> RhsResult; +type RhsF = fn(F, &[F; N], &mut [F; N], &UserData) -> RhsResult; -pub fn wrap_f( - f: RhsF, +pub fn wrap_f( + f: RhsF, t: F, y: CVector, ydot: CVector, - data: *mut c_void, + data: &UserData, ) -> c_int { let y = unsafe { transmute(N_VGetArrayPointer(y as _)) }; let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) }; @@ -82,19 +83,20 @@ pub fn wrap_f( #[macro_export] macro_rules! wrap { - ($wrapped_f_name: ident, $f_name: ident) => { + ($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => { extern "C" fn $wrapped_f_name( t: F, y: CVector, ydot: CVector, - data: *mut std::ffi::c_void, + data: *const $user_data, ) -> std::os::raw::c_int { + let data = unsafe { std::mem::transmute(data) }; wrap_f($f_name, t, y, ydot, data) } }; } -type RhsFCtype = extern "C" fn(F, CVector, CVector, *mut c_void) -> c_int; +type RhsFCtype = extern "C" fn(F, CVector, CVector, *const UserData) -> c_int; #[repr(u32)] pub enum StepKind { @@ -138,23 +140,54 @@ impl AbsTolerance { } } -impl Solver { +impl Solver { pub fn new( method: LinearMultistepMethod, - f: RhsFCtype, + f: RhsFCtype, t0: F, y0: &[F; N], rtol: F, atol: AbsTolerance, + user_data: UserData, ) -> Result { assert_eq!(y0.len(), N); - let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) }; - let mem: CvodeMemoryBlockNonNullPtr = - check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into(); + let mem: CvodeMemoryBlockNonNullPtr = { + let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) }; + check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() + }; let y0 = NVectorSerial::new_from(y0); - let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(f), t0, y0.as_raw() as _) }; - check_flag_is_succes(flag, "CVodeInit")?; - match &atol { + let matrix = { + let matrix = unsafe { + cvode_5_sys::sunmatrix_dense::SUNDenseMatrix( + N.try_into().unwrap(), + N.try_into().unwrap(), + ) + }; + check_non_null(matrix, "SUNDenseMatrix")? + }; + let linsolver = { + let linsolver = unsafe { + cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver( + y0.as_raw() as _, + matrix.as_ptr() as _, + ) + }; + check_non_null(linsolver, "SUNDenseLinearSolver")? + }; + let user_data = Box::pin(user_data); + let res = Solver { + mem, + y0, + sunmatrix: matrix.as_ptr() as _, + linsolver: linsolver.as_ptr() as _, + atol, + user_data, + }; + { + let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) }; + check_flag_is_succes(flag, "CVodeInit")?; + } + match &res.atol { &AbsTolerance::Scalar(atol) => { let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) }; check_flag_is_succes(flag, "CVodeSStolerances")?; @@ -165,27 +198,26 @@ impl Solver { check_flag_is_succes(flag, "CVodeSVtolerances")?; } } - let matrix = unsafe { - cvode_5_sys::sunmatrix_dense::SUNDenseMatrix( - N.try_into().unwrap(), - N.try_into().unwrap(), - ) - }; - check_non_null(matrix, "SUNDenseMatrix")?; - let linsolver = unsafe { - cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(y0.as_raw() as _, matrix as _) - }; - check_non_null(linsolver, "SUNDenseLinearSolver")?; - let flag = - unsafe { cvode::CVodeSetLinearSolver(mem.as_raw(), linsolver as _, matrix as _) }; - check_flag_is_succes(flag, "CVodeSetLinearSolver")?; - Ok(Solver { - mem, - y0, - sunmatrix: matrix as _, - linsolver: linsolver as _, - _atol: atol as _, - }) + { + let flag = unsafe { + cvode::CVodeSetLinearSolver( + mem.as_raw(), + linsolver.as_ptr() as _, + matrix.as_ptr() as _, + ) + }; + check_flag_is_succes(flag, "CVodeSetLinearSolver")?; + } + { + let flag = unsafe { + cvode::CVodeSetUserData( + mem.as_raw(), + std::mem::transmute(res.user_data.as_ref().get_ref()), + ) + }; + check_flag_is_succes(flag, "CVodeSetUserData")?; + } + Ok(res) } pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> { @@ -204,7 +236,7 @@ impl Solver { } } -impl Drop for Solver { +impl Drop for Solver { fn drop(&mut self) { unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) } unsafe { cvode::SUNLinSolFree(self.linsolver) }; @@ -216,16 +248,24 @@ impl Drop for Solver { mod tests { use super::*; - fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult { + fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: &()) -> RhsResult { *ydot = [y[1], -y[0]]; RhsResult::Ok } - wrap!(wrapped_f, f); + wrap!(wrapped_f, f, ()); #[test] fn create() { let y0 = [0., 1.]; - let _solver = Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, AbsTolerance::Scalar(1e-4)); + let _solver = Solver::new( + LinearMultistepMethod::ADAMS, + wrapped_f, + 0., + &y0, + 1e-4, + AbsTolerance::Scalar(1e-4), + (), + ); } } diff --git a/cvode-wrap/src/nvector.rs b/cvode-wrap/src/nvector.rs index ec0bfed..a5c28fc 100644 --- a/cvode-wrap/src/nvector.rs +++ b/cvode-wrap/src/nvector.rs @@ -1,4 +1,4 @@ -use std::{convert::TryInto, intrinsics::transmute, ptr::NonNull}; +use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull}; use cvode_5_sys::{cvode::realtype, nvector_serial}; @@ -8,6 +8,14 @@ pub struct NVectorSerial { inner: NonNull, } +impl Deref for NVectorSerial { + type Target = nvector_serial::_generic_N_Vector; + + fn deref(&self) -> &Self::Target { + unsafe {self.inner.as_ref()} + } + } + impl NVectorSerial { pub fn as_ref(&self) -> &[realtype; SIZE] { unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) } diff --git a/test-solver/src/main.rs b/test-solver/src/main.rs index 8a41a2a..61d119e 100644 --- a/test-solver/src/main.rs +++ b/test-solver/src/main.rs @@ -5,11 +5,11 @@ use cvode_wrap::*; fn main() { let y0 = [0., 1.]; //define the right-hand-side - fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult { - *ydot = [y[1], -y[0] / 10.]; + fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], k: &F) -> RhsResult { + *ydot = [y[1], -y[0] * k]; RhsResult::Ok } - wrap!(wrapped_f, f); + wrap!(wrapped_f, f, F); //initialize the solver let mut solver = Solver::new( LinearMultistepMethod::ADAMS, @@ -18,6 +18,7 @@ fn main() { &y0, 1e-4, AbsTolerance::scalar(1e-4), + 1e-2 ) .unwrap(); //and solve