From e2ee2b8a93363a9467ed702777514b8d00d7bf5d Mon Sep 17 00:00:00 2001 From: Arthur Carcano Date: Thu, 10 Jun 2021 17:03:00 +0200 Subject: [PATCH] Syntaxic switch, bugs remain --- cvode-wrap/Cargo.toml | 3 ++- cvode-wrap/src/cvode.rs | 30 +++++++++++++---------- cvode-wrap/src/cvode_sens.rs | 47 ++++++++++++++++++------------------ cvode-wrap/src/lib.rs | 16 ++++++------ cvode-wrap/src/nvector.rs | 20 ++++++++------- example/src/main.rs | 2 +- 6 files changed, 63 insertions(+), 55 deletions(-) diff --git a/cvode-wrap/Cargo.toml b/cvode-wrap/Cargo.toml index 7142786..d4aba2e 100644 --- a/cvode-wrap/Cargo.toml +++ b/cvode-wrap/Cargo.toml @@ -7,5 +7,6 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -cvode-5-sys = {path = "../cvode-5-sys"} +#cvode-5-sys = {path = "../cvode-5-sys"} +sundials-sys = {path = "../../sundials-sys", default-features=false, features=["cvodes"]} array-init = "2.0" \ No newline at end of file diff --git a/cvode-wrap/src/cvode.rs b/cvode-wrap/src/cvode.rs index 1bd9aa1..5fab90f 100644 --- a/cvode-wrap/src/cvode.rs +++ b/cvode-wrap/src/cvode.rs @@ -1,6 +1,6 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}; -use cvode_5_sys::{SUNLinearSolver, SUNMatrix}; +use sundials_sys::{SUNLinearSolver, SUNMatrix}; use crate::{ check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial, @@ -99,18 +99,18 @@ where ) -> Result { assert_eq!(y0.len(), N); let mem: CvodeMemoryBlockNonNullPtr = { - let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) }; + let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) }; check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() }; let y0 = NVectorSerialHeapAllocated::new_from(y0); let matrix = { let matrix = unsafe { - cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) + sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) }; check_non_null(matrix, "SUNDenseMatrix")? }; let linsolver = { - let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; + let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; check_non_null(linsolver, "SUNDenseLinearSolver")? }; let user_data = Box::pin(WrappingUserData { @@ -128,7 +128,7 @@ where { let fn_ptr = wrap_f:: as extern "C" fn(_, _, _, _) -> _; let flag = unsafe { - cvode_5_sys::CVodeInit( + sundials_sys::CVodeInit( mem.as_raw(), Some(std::mem::transmute(fn_ptr)), t0, @@ -139,24 +139,28 @@ where } match &res.atol { &AbsTolerance::Scalar(atol) => { - let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; + let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; check_flag_is_succes(flag, "CVodeSStolerances")?; } AbsTolerance::Vector(atol) => { let flag = - unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; + unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; check_flag_is_succes(flag, "CVodeSVtolerances")?; } } { let flag = unsafe { - cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr()) + sundials_sys::CVodeSetLinearSolver( + mem.as_raw(), + linsolver.as_ptr(), + matrix.as_ptr(), + ) }; check_flag_is_succes(flag, "CVodeSetLinearSolver")?; } { let flag = unsafe { - cvode_5_sys::CVodeSetUserData( + sundials_sys::CVodeSetUserData( mem.as_raw(), std::mem::transmute(res.user_data.as_ref().get_ref()), ) @@ -173,7 +177,7 @@ where ) -> Result<(Realtype, &[Realtype; N])> { let mut tret = 0.; let flag = unsafe { - cvode_5_sys::CVode( + sundials_sys::CVode( self.mem.as_raw(), tout, self.y0.as_raw(), @@ -188,9 +192,9 @@ where impl Drop for Solver { fn drop(&mut self) { - unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) } - unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) }; - unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) }; + unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) } + unsafe { sundials_sys::SUNLinSolFree(self.linsolver) }; + unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) }; } } diff --git a/cvode-wrap/src/cvode_sens.rs b/cvode-wrap/src/cvode_sens.rs index bb1f830..dda0a0d 100644 --- a/cvode-wrap/src/cvode_sens.rs +++ b/cvode-wrap/src/cvode_sens.rs @@ -1,6 +1,6 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}; -use cvode_5_sys::{N_VPrint, SUNLinearSolver, SUNMatrix, CV_STAGGERED}; +use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED}; use crate::{ check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial, @@ -183,7 +183,7 @@ where ) -> Result { assert_eq!(y0.len(), N); let mem: CvodeMemoryBlockNonNullPtr = { - let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) }; + let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) }; check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() }; let y0 = NVectorSerialHeapAllocated::new_from(y0); @@ -196,12 +196,12 @@ where ); let matrix = { let matrix = unsafe { - cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) + sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) }; check_non_null(matrix, "SUNDenseMatrix")? }; let linsolver = { - let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; + let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; check_non_null(linsolver, "SUNDenseLinearSolver")? }; let user_data = Box::pin(WrappingUserData { @@ -222,20 +222,17 @@ where }; { let flag = unsafe { - cvode_5_sys::CVodeSetUserData( + sundials_sys::CVodeSetUserData( mem.as_raw(), res.user_data.as_ref().get_ref() as *const _ as _, ) }; check_flag_is_succes(flag, "CVodeSetUserData")?; } - for v in res.y_s0.as_ref() { - unsafe { N_VPrint(v.as_raw()) } - } { let fn_ptr = wrap_f:: as extern "C" fn(_, _, _, _) -> _; let flag = unsafe { - cvode_5_sys::CVodeInit( + sundials_sys::CVodeInit( mem.as_raw(), Some(std::mem::transmute(fn_ptr)), t0, @@ -248,7 +245,7 @@ where let fn_ptr = wrap_f_sens:: as extern "C" fn(_, _, _, _, _, _, _, _, _) -> _; let flag = unsafe { - cvode_5_sys::CVodeSensInit( + sundials_sys::CVodeSensInit( mem.as_raw(), N_SENSI as c_int, CV_STAGGERED as _, @@ -256,47 +253,51 @@ where res.y_s0.as_ptr() as _, ) }; - check_flag_is_succes(flag, "CVodeInit")?; + check_flag_is_succes(flag, "CVodeSensInit")?; } match &res.atol { &AbsTolerance::Scalar(atol) => { - let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; + let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; check_flag_is_succes(flag, "CVodeSStolerances")?; } AbsTolerance::Vector(atol) => { let flag = - unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; + unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; check_flag_is_succes(flag, "CVodeSVtolerances")?; } } match &res.atol { &AbsTolerance::Scalar(atol) => { - let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; + let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; check_flag_is_succes(flag, "CVodeSStolerances")?; } AbsTolerance::Vector(atol) => { let flag = - unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; + unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; check_flag_is_succes(flag, "CVodeSVtolerances")?; } } match &res.atol_sens { SensiAbsTolerance::Scalar(atol) => { let flag = unsafe { - cvode_5_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _) + sundials_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _) }; check_flag_is_succes(flag, "CVodeSensSStolerances")?; } SensiAbsTolerance::Vector(atol) => { let flag = unsafe { - cvode_5_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _) + sundials_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _) }; check_flag_is_succes(flag, "CVodeSVtolerances")?; } } { let flag = unsafe { - cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr()) + sundials_sys::CVodeSetLinearSolver( + mem.as_raw(), + linsolver.as_ptr(), + matrix.as_ptr(), + ) }; check_flag_is_succes(flag, "CVodeSetLinearSolver")?; } @@ -311,7 +312,7 @@ where ) -> Result<(Realtype, &[Realtype; N], [&[Realtype; N]; N_SENSI])> { let mut tret = 0.; let flag = unsafe { - cvode_5_sys::CVode( + sundials_sys::CVode( self.mem.as_raw(), tout, self.y0.as_raw(), @@ -321,7 +322,7 @@ where }; check_flag_is_succes(flag, "CVode")?; let flag = unsafe { - cvode_5_sys::CVodeGetSens( + sundials_sys::CVodeGetSens( self.mem.as_raw(), &mut tret, self.sensi_out_buffer.as_mut_ptr() as _, @@ -338,9 +339,9 @@ impl Drop for Solver { fn drop(&mut self) { - unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) } - unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) }; - unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) }; + unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) } + unsafe { sundials_sys::SUNLinSolFree(self.linsolver) }; + unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) }; } } diff --git a/cvode-wrap/src/lib.rs b/cvode-wrap/src/lib.rs index 81416ec..ab828a9 100644 --- a/cvode-wrap/src/lib.rs +++ b/cvode-wrap/src/lib.rs @@ -1,6 +1,6 @@ use std::{os::raw::c_int, ptr::NonNull}; -use cvode_5_sys::realtype; +use sundials_sys::realtype; mod nvector; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; @@ -11,14 +11,14 @@ pub mod cvode_sens; /// The floatting-point type sundials was compiled with pub type Realtype = realtype; -#[repr(u32)] +#[repr(i32)] #[derive(Debug)] /// An integration method. pub enum LinearMultistepMethod { /// Recomended for non-stiff problems. - Adams = cvode_5_sys::CV_ADAMS, + Adams = sundials_sys::CV_ADAMS, /// Recommended for stiff problems. - Bdf = cvode_5_sys::CV_BDF, + Bdf = sundials_sys::CV_BDF, } /// A return type for the right-hand-side rust function. @@ -43,17 +43,17 @@ pub enum RhsResult { } /// Type of integration step -#[repr(u32)] +#[repr(i32)] pub enum StepKind { /// The `NORMAL`option causes the solver to take internal steps /// until it has reached or just passed the user-specified time. /// The solver then interpolates in order to return an approximate /// value of y at the desired time. - Normal = cvode_5_sys::CV_NORMAL, + Normal = sundials_sys::CV_NORMAL, /// The `CV_ONE_STEP` option tells the solver to take just one /// internal step and then return thesolution at the point reached /// by that step. - OneStep = cvode_5_sys::CV_ONE_STEP, + OneStep = sundials_sys::CV_ONE_STEP, } /// The error type for this crate @@ -88,7 +88,7 @@ fn check_non_null(ptr: *mut T, func_id: &'static str) -> Result> { } fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> { - if flag == cvode_5_sys::CV_SUCCESS as i32 { + if flag == sundials_sys::CV_SUCCESS { Ok(()) } else { Err(Error::ErrorCode { flag, func_id }) diff --git a/cvode-wrap/src/nvector.rs b/cvode-wrap/src/nvector.rs index e7f1e98..67724b5 100644 --- a/cvode-wrap/src/nvector.rs +++ b/cvode-wrap/src/nvector.rs @@ -4,28 +4,30 @@ use std::{ ptr::NonNull, }; -use cvode_5_sys::realtype; +use sundials_sys::realtype; /// A sundials `N_Vector_Serial`. #[repr(transparent)] #[derive(Debug)] pub struct NVectorSerial { - inner: cvode_5_sys::_generic_N_Vector, + inner: sundials_sys::_generic_N_Vector, } impl NVectorSerial { - pub(crate) unsafe fn as_raw(&self) -> cvode_5_sys::N_Vector { + pub(crate) unsafe fn as_raw(&self) -> sundials_sys::N_Vector { std::mem::transmute(&self.inner) } /// Returns a reference to the inner slice of the vector. pub fn as_slice(&self) -> &[realtype; SIZE] { - unsafe { &*(cvode_5_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *const [f64; SIZE]) } + unsafe { &*(sundials_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *const [f64; SIZE]) } } /// Returns a mutable reference to the inner slice of the vector. pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] { - unsafe { &mut *(cvode_5_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *mut [f64; SIZE]) } + unsafe { + &mut *(sundials_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *mut [f64; SIZE]) + } } } @@ -52,7 +54,7 @@ impl DerefMut for NVectorSerialHeapAllocated { impl NVectorSerialHeapAllocated { unsafe fn new_inner_uninitialized() -> NonNull> { - let raw_c = cvode_5_sys::N_VNew_Serial(SIZE.try_into().unwrap()); + let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap()); NonNull::new(raw_c as *mut NVectorSerial).unwrap() } @@ -60,7 +62,7 @@ impl NVectorSerialHeapAllocated { pub fn new() -> Self { let inner = unsafe { let x = Self::new_inner_uninitialized(); - let ptr = cvode_5_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); + let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); for off in 0..SIZE { *ptr.add(off) = 0.; } @@ -73,7 +75,7 @@ impl NVectorSerialHeapAllocated { pub fn new_from(data: &[realtype; SIZE]) -> Self { let inner = unsafe { let x = Self::new_inner_uninitialized(); - let ptr = cvode_5_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); + let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE); x }; @@ -83,6 +85,6 @@ impl NVectorSerialHeapAllocated { impl Drop for NVectorSerialHeapAllocated { fn drop(&mut self) { - unsafe { cvode_5_sys::N_VDestroy(self.as_raw()) } + unsafe { sundials_sys::N_VDestroy(self.as_raw()) } } } diff --git a/example/src/main.rs b/example/src/main.rs index 7af6696..bc9099b 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -10,7 +10,7 @@ fn main() { RhsResult::Ok } // If there is any command line argument compute the sensitivities, else don't. - if args().nth(1).is_none() { + if false && args().nth(1).is_none() { //initialize the solver let mut solver = cvode::Solver::new( LinearMultistepMethod::Adams,