diff --git a/cvode-wrap/src/lib.rs b/cvode-wrap/src/lib.rs index 1aa421d..e8f3e07 100644 --- a/cvode-wrap/src/lib.rs +++ b/cvode-wrap/src/lib.rs @@ -10,10 +10,9 @@ use cvode_5_sys::{ }; mod nvector; -use nvector::NVectorSerialHeapAlloced; +pub use nvector::{NVectorSerial, NVectorSerialHeapAlloced}; pub type Realtype = realtype; -pub type CVector = cvode::N_Vector; #[repr(u32)] #[derive(Debug)] @@ -74,15 +73,21 @@ pub enum RhsResult { NonRecoverableError(u8), } -pub type RhsF = fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult; +pub type RhsF = + fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult; -pub type RhsFCtype = extern "C" fn(t: Realtype, y: CVector, ydot: CVector, user_data: *const UserData) -> c_int; +pub type RhsFCtype = extern "C" fn( + t: Realtype, + y: *const NVectorSerial, + ydot: *mut NVectorSerial, + user_data: *const UserData, +) -> c_int; pub fn wrap_f( f: RhsF, t: Realtype, - y: CVector, - ydot: CVector, + y: *const NVectorSerial, + ydot: *mut NVectorSerial, data: &UserData, ) -> c_int { let y = unsafe { transmute(N_VGetArrayPointer(y as _)) }; @@ -97,12 +102,12 @@ pub fn wrap_f( #[macro_export] macro_rules! wrap { - ($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => { + ($wrapped_f_name: ident, $f_name: ident, $user_data: ty, $problem_size: expr) => { extern "C" fn $wrapped_f_name( t: Realtype, - y: CVector, - ydot: CVector, - data: *const $user_data, + y: *const NVectorSerial<$problem_size>, + ydot: *mut NVectorSerial<$problem_size>, + data: *const $user_data, ) -> std::os::raw::c_int { let data = unsafe { std::mem::transmute(data) }; wrap_f($f_name, t, y, ydot, data) @@ -155,7 +160,7 @@ impl AbsTolerance { impl Solver { pub fn new( method: LinearMultistepMethod, - f: RhsFCtype, + f: RhsFCtype, t0: Realtype, y0: &[Realtype; N], rtol: Realtype, @@ -196,7 +201,14 @@ impl Solver { user_data, }; { - let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) }; + let flag = unsafe { + cvode::CVodeInit( + mem.as_raw(), + Some(std::mem::transmute(f)), + t0, + res.y0.as_raw() as _, + ) + }; check_flag_is_succes(flag, "CVodeInit")?; } match &res.atol { @@ -232,7 +244,11 @@ impl Solver { Ok(res) } - pub fn step(&mut self, tout: Realtype, step_kind: StepKind) -> Result<(Realtype, &[Realtype; N])> { + pub fn step( + &mut self, + tout: Realtype, + step_kind: StepKind, + ) -> Result<(Realtype, &[Realtype; N])> { let mut tret = 0.; let flag = unsafe { cvode::CVode( @@ -260,12 +276,17 @@ impl Drop for Solver { mod tests { use super::*; - fn f(_t: super::Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], _data: &()) -> RhsResult { + fn f( + _t: super::Realtype, + y: &[Realtype; 2], + ydot: &mut [Realtype; 2], + _data: &(), + ) -> RhsResult { *ydot = [y[1], -y[0]]; RhsResult::Ok } - wrap!(wrapped_f, f, ()); + wrap!(wrapped_f, f, (), 2); #[test] fn create() { diff --git a/cvode-wrap/src/nvector.rs b/cvode-wrap/src/nvector.rs index f3e363e..089a631 100644 --- a/cvode-wrap/src/nvector.rs +++ b/cvode-wrap/src/nvector.rs @@ -1,23 +1,24 @@ -use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull}; +use std::{ + convert::TryInto, + intrinsics::transmute, + ops::{Deref, DerefMut}, + ptr::NonNull, +}; use cvode_5_sys::{cvode::realtype, nvector_serial}; - #[repr(transparent)] #[derive(Debug)] -pub struct NVectorSerialHeapAlloced { - inner: NonNull, +pub struct NVectorSerial { + inner: nvector_serial::_generic_N_Vector, } -impl Deref for NVectorSerialHeapAlloced { - type Target = nvector_serial::_generic_N_Vector; - - fn deref(&self) -> &Self::Target { - unsafe {self.inner.as_ref()} +impl NVectorSerial { + + pub unsafe fn as_raw(&self) -> nvector_serial::N_Vector { + std::mem::transmute(&self.inner) } - } -impl NVectorSerialHeapAlloced { pub fn as_slice(&self) -> &[realtype; SIZE] { unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) } } @@ -25,11 +26,33 @@ impl NVectorSerialHeapAlloced { pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] { unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) } } +} +#[repr(transparent)] +#[derive(Debug)] +pub struct NVectorSerialHeapAlloced { + inner: NonNull>, +} + +impl Deref for NVectorSerialHeapAlloced { + type Target = NVectorSerial; + + fn deref(&self) -> &Self::Target { + unsafe { self.inner.as_ref() } + } +} + +impl DerefMut for NVectorSerialHeapAlloced { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { self.inner.as_mut() } + } +} + +impl NVectorSerialHeapAlloced { pub fn new() -> Self { + let raw_c = unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) }; Self { - inner: NonNull::new(unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) }) - .unwrap(), + inner: NonNull::new(raw_c as *mut NVectorSerial).unwrap(), } } @@ -38,10 +61,6 @@ impl NVectorSerialHeapAlloced { res.as_slice_mut().copy_from_slice(data); res } - - pub fn as_raw(&self) -> nvector_serial::N_Vector { - self.inner.as_ptr() - } } impl Drop for NVectorSerialHeapAlloced { diff --git a/test-solver/src/main.rs b/test-solver/src/main.rs index 7dfa0c0..2399681 100644 --- a/test-solver/src/main.rs +++ b/test-solver/src/main.rs @@ -7,7 +7,7 @@ fn main() { *ydot = [y[1], -y[0] * k]; RhsResult::Ok } - wrap!(wrapped_f, f, Realtype); + wrap!(wrapped_f, f, Realtype, 2); //initialize the solver let mut solver = Solver::new( LinearMultistepMethod::ADAMS,