diff --git a/cvode-wrap/src/cvode.rs b/cvode-wrap/src/cvode.rs new file mode 100644 index 0000000..e226ba0 --- /dev/null +++ b/cvode-wrap/src/cvode.rs @@ -0,0 +1,205 @@ +use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}; + +use cvode_5_sys::{SUNLinearSolver, SUNMatrix}; + +use crate::{LinearMultistepMethod, NVectorSerialHeapAllocated, Realtype, Result, StepKind, c_wrapping, check_flag_is_succes, check_non_null}; + +#[repr(C)] +struct CvodeMemoryBlock { + _private: [u8; 0], +} + +#[repr(transparent)] +#[derive(Debug, Clone, Copy)] +struct CvodeMemoryBlockNonNullPtr { + ptr: NonNull, +} + +impl CvodeMemoryBlockNonNullPtr { + fn new(ptr: NonNull) -> Self { + Self { ptr } + } + + fn as_raw(self) -> *mut c_void { + self.ptr.as_ptr() as *mut c_void + } +} + +impl From> for CvodeMemoryBlockNonNullPtr { + fn from(x: NonNull) -> Self { + Self::new(x) + } +} + +/// An enum representing the choice between a scalar or vector absolute tolerance +pub enum AbsTolerance { + Scalar(Realtype), + Vector(NVectorSerialHeapAllocated), +} + +impl AbsTolerance { + pub fn scalar(atol: Realtype) -> Self { + AbsTolerance::Scalar(atol) + } + + pub fn vector(atol: &[Realtype; SIZE]) -> Self { + let atol = NVectorSerialHeapAllocated::new_from(atol); + AbsTolerance::Vector(atol) + } +} + +/// The main struct of the crate. Wraps a sundials solver. +/// +/// Args +/// ---- +/// `UserData` is the type of the supplementary arguments for the +/// right-hand-side. If unused, should be `()`. +/// +/// `N` is the "problem size", that is the dimension of the state space. +/// +/// See [crate-level](`crate`) documentation for more. +pub struct Solver { + mem: CvodeMemoryBlockNonNullPtr, + y0: NVectorSerialHeapAllocated, + sunmatrix: SUNMatrix, + linsolver: SUNLinearSolver, + atol: AbsTolerance, + user_data: Pin>, +} + + +impl Solver { + pub fn new( + method: LinearMultistepMethod, + f: c_wrapping::RhsFCtype, + t0: Realtype, + y0: &[Realtype; N], + rtol: Realtype, + atol: AbsTolerance, + user_data: UserData, + ) -> Result { + assert_eq!(y0.len(), N); + let mem: CvodeMemoryBlockNonNullPtr = { + let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) }; + check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() + }; + let y0 = NVectorSerialHeapAllocated::new_from(y0); + let matrix = { + let matrix = unsafe { + cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) + }; + check_non_null(matrix, "SUNDenseMatrix")? + }; + let linsolver = { + let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; + check_non_null(linsolver, "SUNDenseLinearSolver")? + }; + let user_data = Box::pin(user_data); + let res = Solver { + mem, + y0, + sunmatrix: matrix.as_ptr(), + linsolver: linsolver.as_ptr(), + atol, + user_data, + }; + { + let flag = unsafe { + cvode_5_sys::CVodeInit( + mem.as_raw(), + Some(std::mem::transmute(f)), + t0, + res.y0.as_raw(), + ) + }; + check_flag_is_succes(flag, "CVodeInit")?; + } + match &res.atol { + &AbsTolerance::Scalar(atol) => { + let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; + check_flag_is_succes(flag, "CVodeSStolerances")?; + } + AbsTolerance::Vector(atol) => { + let flag = + unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; + check_flag_is_succes(flag, "CVodeSVtolerances")?; + } + } + { + let flag = unsafe { + cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr()) + }; + check_flag_is_succes(flag, "CVodeSetLinearSolver")?; + } + { + let flag = unsafe { + cvode_5_sys::CVodeSetUserData( + mem.as_raw(), + std::mem::transmute(res.user_data.as_ref().get_ref()), + ) + }; + check_flag_is_succes(flag, "CVodeSetUserData")?; + } + Ok(res) + } + + pub fn step( + &mut self, + tout: Realtype, + step_kind: StepKind, + ) -> Result<(Realtype, &[Realtype; N])> { + let mut tret = 0.; + let flag = unsafe { + cvode_5_sys::CVode( + self.mem.as_raw(), + tout, + self.y0.as_raw(), + &mut tret, + step_kind as c_int, + ) + }; + check_flag_is_succes(flag, "CVode")?; + Ok((tret, self.y0.as_slice())) + } +} + +impl Drop for Solver { + fn drop(&mut self) { + unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) } + unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) }; + unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) }; + } +} + +#[cfg(test)] +mod tests { + use crate::{RhsResult,wrap, NVectorSerial}; + + use super::*; + + fn f( + _t: super::Realtype, + y: &[Realtype; 2], + ydot: &mut [Realtype; 2], + _data: &(), + ) -> RhsResult { + *ydot = [y[1], -y[0]]; + RhsResult::Ok + } + + wrap!(wrapped_f, f, (), 2); + + #[test] + fn create() { + let y0 = [0., 1.]; + let _solver = Solver::new( + LinearMultistepMethod::Adams, + wrapped_f, + 0., + &y0, + 1e-4, + AbsTolerance::Scalar(1e-4), + (), + ); + } +} \ No newline at end of file diff --git a/cvode-wrap/src/lib.rs b/cvode-wrap/src/lib.rs index 5cd8093..c01d26d 100644 --- a/cvode-wrap/src/lib.rs +++ b/cvode-wrap/src/lib.rs @@ -1,13 +1,14 @@ -use std::{convert::TryInto, pin::Pin}; -use std::{ffi::c_void, os::raw::c_int, ptr::NonNull}; +use std::{os::raw::c_int, ptr::NonNull}; -use cvode_5_sys::{realtype, SUNLinearSolver, SUNMatrix}; +use cvode_5_sys::realtype; mod nvector; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; pub mod c_wrapping; +pub mod cvode; + /// The floatting-point type sundials was compiled with pub type Realtype = realtype; @@ -21,52 +22,6 @@ pub enum LinearMultistepMethod { Bdf = cvode_5_sys::CV_BDF, } -#[repr(C)] -struct CvodeMemoryBlock { - _private: [u8; 0], -} - -#[repr(transparent)] -#[derive(Debug, Clone, Copy)] -struct CvodeMemoryBlockNonNullPtr { - ptr: NonNull, -} - -impl CvodeMemoryBlockNonNullPtr { - fn new(ptr: NonNull) -> Self { - Self { ptr } - } - - fn as_raw(self) -> *mut c_void { - self.ptr.as_ptr() as *mut c_void - } -} - -impl From> for CvodeMemoryBlockNonNullPtr { - fn from(x: NonNull) -> Self { - Self::new(x) - } -} - -/// The main struct of the crate. Wraps a sundials solver. -/// -/// Args -/// ---- -/// `UserData` is the type of the supplementary arguments for the -/// right-hand-side. If unused, should be `()`. -/// -/// `N` is the "problem size", that is the dimension of the state space. -/// -/// See [crate-level](`crate`) documentation for more. -pub struct Solver { - mem: CvodeMemoryBlockNonNullPtr, - y0: NVectorSerialHeapAllocated, - sunmatrix: SUNMatrix, - linsolver: SUNLinearSolver, - atol: AbsTolerance, - user_data: Pin>, -} - /// A return type for the right-hand-side rust function. /// /// Adapted from Sundials cv-ode guide version 5.7 (BSD Licensed), setcion 4.6.1 : @@ -133,154 +88,3 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> { Err(Error::ErrorCode { flag, func_id }) } } - -/// An enum representing the choice between a scalar or vector absolute tolerance -pub enum AbsTolerance { - Scalar(Realtype), - Vector(NVectorSerialHeapAllocated), -} - -impl AbsTolerance { - pub fn scalar(atol: Realtype) -> Self { - AbsTolerance::Scalar(atol) - } - - pub fn vector(atol: &[Realtype; SIZE]) -> Self { - let atol = NVectorSerialHeapAllocated::new_from(atol); - AbsTolerance::Vector(atol) - } -} - -impl Solver { - pub fn new( - method: LinearMultistepMethod, - f: c_wrapping::RhsFCtype, - t0: Realtype, - y0: &[Realtype; N], - rtol: Realtype, - atol: AbsTolerance, - user_data: UserData, - ) -> Result { - assert_eq!(y0.len(), N); - let mem: CvodeMemoryBlockNonNullPtr = { - let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) }; - check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() - }; - let y0 = NVectorSerialHeapAllocated::new_from(y0); - let matrix = { - let matrix = unsafe { - cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) - }; - check_non_null(matrix, "SUNDenseMatrix")? - }; - let linsolver = { - let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; - check_non_null(linsolver, "SUNDenseLinearSolver")? - }; - let user_data = Box::pin(user_data); - let res = Solver { - mem, - y0, - sunmatrix: matrix.as_ptr(), - linsolver: linsolver.as_ptr(), - atol, - user_data, - }; - { - let flag = unsafe { - cvode_5_sys::CVodeInit( - mem.as_raw(), - Some(std::mem::transmute(f)), - t0, - res.y0.as_raw(), - ) - }; - check_flag_is_succes(flag, "CVodeInit")?; - } - match &res.atol { - &AbsTolerance::Scalar(atol) => { - let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; - check_flag_is_succes(flag, "CVodeSStolerances")?; - } - AbsTolerance::Vector(atol) => { - let flag = - unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; - check_flag_is_succes(flag, "CVodeSVtolerances")?; - } - } - { - let flag = unsafe { - cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr()) - }; - check_flag_is_succes(flag, "CVodeSetLinearSolver")?; - } - { - let flag = unsafe { - cvode_5_sys::CVodeSetUserData( - mem.as_raw(), - std::mem::transmute(res.user_data.as_ref().get_ref()), - ) - }; - check_flag_is_succes(flag, "CVodeSetUserData")?; - } - Ok(res) - } - - pub fn step( - &mut self, - tout: Realtype, - step_kind: StepKind, - ) -> Result<(Realtype, &[Realtype; N])> { - let mut tret = 0.; - let flag = unsafe { - cvode_5_sys::CVode( - self.mem.as_raw(), - tout, - self.y0.as_raw(), - &mut tret, - step_kind as c_int, - ) - }; - check_flag_is_succes(flag, "CVode")?; - Ok((tret, self.y0.as_slice())) - } -} - -impl Drop for Solver { - fn drop(&mut self) { - unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) } - unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) }; - unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) }; - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn f( - _t: super::Realtype, - y: &[Realtype; 2], - ydot: &mut [Realtype; 2], - _data: &(), - ) -> RhsResult { - *ydot = [y[1], -y[0]]; - RhsResult::Ok - } - - wrap!(wrapped_f, f, (), 2); - - #[test] - fn create() { - let y0 = [0., 1.]; - let _solver = Solver::new( - LinearMultistepMethod::Adams, - wrapped_f, - 0., - &y0, - 1e-4, - AbsTolerance::Scalar(1e-4), - (), - ); - } -} diff --git a/example/src/main.rs b/example/src/main.rs index d71e3ec..94e7836 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -9,13 +9,13 @@ fn main() { } wrap!(wrapped_f, f, Realtype, 2); //initialize the solver - let mut solver = Solver::new( + let mut solver = cvode::Solver::new( LinearMultistepMethod::Adams, wrapped_f, 0., &y0, 1e-4, - AbsTolerance::scalar(1e-4), + cvode::AbsTolerance::scalar(1e-4), 1e-2, ) .unwrap();