Syntaxic switch, bugs remain

This commit is contained in:
Arthur Carcano 2021-06-10 17:03:00 +02:00
parent 29f503861f
commit e2ee2b8a93
6 changed files with 63 additions and 55 deletions

View File

@ -7,5 +7,6 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
cvode-5-sys = {path = "../cvode-5-sys"}
#cvode-5-sys = {path = "../cvode-5-sys"}
sundials-sys = {path = "../../sundials-sys", default-features=false, features=["cvodes"]}
array-init = "2.0"

View File

@ -1,6 +1,6 @@
use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull};
use cvode_5_sys::{SUNLinearSolver, SUNMatrix};
use sundials_sys::{SUNLinearSolver, SUNMatrix};
use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial,
@ -99,18 +99,18 @@ where
) -> Result<Self> {
assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) };
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0);
let matrix = {
let matrix = unsafe {
cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {
@ -128,7 +128,7 @@ where
{
let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
let flag = unsafe {
cvode_5_sys::CVodeInit(
sundials_sys::CVodeInit(
mem.as_raw(),
Some(std::mem::transmute(fn_ptr)),
t0,
@ -139,24 +139,28 @@ where
}
match &res.atol {
&AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
}
AbsTolerance::Vector(atol) => {
let flag =
unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
{
let flag = unsafe {
cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr())
sundials_sys::CVodeSetLinearSolver(
mem.as_raw(),
linsolver.as_ptr(),
matrix.as_ptr(),
)
};
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
}
{
let flag = unsafe {
cvode_5_sys::CVodeSetUserData(
sundials_sys::CVodeSetUserData(
mem.as_raw(),
std::mem::transmute(res.user_data.as_ref().get_ref()),
)
@ -173,7 +177,7 @@ where
) -> Result<(Realtype, &[Realtype; N])> {
let mut tret = 0.;
let flag = unsafe {
cvode_5_sys::CVode(
sundials_sys::CVode(
self.mem.as_raw(),
tout,
self.y0.as_raw(),
@ -188,9 +192,9 @@ where
impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
fn drop(&mut self) {
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) };
unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
}
}

View File

@ -1,6 +1,6 @@
use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull};
use cvode_5_sys::{N_VPrint, SUNLinearSolver, SUNMatrix, CV_STAGGERED};
use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED};
use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial,
@ -183,7 +183,7 @@ where
) -> Result<Self> {
assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) };
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0);
@ -196,12 +196,12 @@ where
);
let matrix = {
let matrix = unsafe {
cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {
@ -222,20 +222,17 @@ where
};
{
let flag = unsafe {
cvode_5_sys::CVodeSetUserData(
sundials_sys::CVodeSetUserData(
mem.as_raw(),
res.user_data.as_ref().get_ref() as *const _ as _,
)
};
check_flag_is_succes(flag, "CVodeSetUserData")?;
}
for v in res.y_s0.as_ref() {
unsafe { N_VPrint(v.as_raw()) }
}
{
let fn_ptr = wrap_f::<UserData, F, FS, N> as extern "C" fn(_, _, _, _) -> _;
let flag = unsafe {
cvode_5_sys::CVodeInit(
sundials_sys::CVodeInit(
mem.as_raw(),
Some(std::mem::transmute(fn_ptr)),
t0,
@ -248,7 +245,7 @@ where
let fn_ptr = wrap_f_sens::<UserData, F, FS, N, N_SENSI>
as extern "C" fn(_, _, _, _, _, _, _, _, _) -> _;
let flag = unsafe {
cvode_5_sys::CVodeSensInit(
sundials_sys::CVodeSensInit(
mem.as_raw(),
N_SENSI as c_int,
CV_STAGGERED as _,
@ -256,47 +253,51 @@ where
res.y_s0.as_ptr() as _,
)
};
check_flag_is_succes(flag, "CVodeInit")?;
check_flag_is_succes(flag, "CVodeSensInit")?;
}
match &res.atol {
&AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
}
AbsTolerance::Vector(atol) => {
let flag =
unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
match &res.atol {
&AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
}
AbsTolerance::Vector(atol) => {
let flag =
unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
match &res.atol_sens {
SensiAbsTolerance::Scalar(atol) => {
let flag = unsafe {
cvode_5_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
sundials_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
};
check_flag_is_succes(flag, "CVodeSensSStolerances")?;
}
SensiAbsTolerance::Vector(atol) => {
let flag = unsafe {
cvode_5_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
sundials_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
};
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
{
let flag = unsafe {
cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr())
sundials_sys::CVodeSetLinearSolver(
mem.as_raw(),
linsolver.as_ptr(),
matrix.as_ptr(),
)
};
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
}
@ -311,7 +312,7 @@ where
) -> Result<(Realtype, &[Realtype; N], [&[Realtype; N]; N_SENSI])> {
let mut tret = 0.;
let flag = unsafe {
cvode_5_sys::CVode(
sundials_sys::CVode(
self.mem.as_raw(),
tout,
self.y0.as_raw(),
@ -321,7 +322,7 @@ where
};
check_flag_is_succes(flag, "CVode")?;
let flag = unsafe {
cvode_5_sys::CVodeGetSens(
sundials_sys::CVodeGetSens(
self.mem.as_raw(),
&mut tret,
self.sensi_out_buffer.as_mut_ptr() as _,
@ -338,9 +339,9 @@ impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Drop
for Solver<UserData, F, FS, N, N_SENSI>
{
fn drop(&mut self) {
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) };
unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
}
}

View File

@ -1,6 +1,6 @@
use std::{os::raw::c_int, ptr::NonNull};
use cvode_5_sys::realtype;
use sundials_sys::realtype;
mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
@ -11,14 +11,14 @@ pub mod cvode_sens;
/// The floatting-point type sundials was compiled with
pub type Realtype = realtype;
#[repr(u32)]
#[repr(i32)]
#[derive(Debug)]
/// An integration method.
pub enum LinearMultistepMethod {
/// Recomended for non-stiff problems.
Adams = cvode_5_sys::CV_ADAMS,
Adams = sundials_sys::CV_ADAMS,
/// Recommended for stiff problems.
Bdf = cvode_5_sys::CV_BDF,
Bdf = sundials_sys::CV_BDF,
}
/// A return type for the right-hand-side rust function.
@ -43,17 +43,17 @@ pub enum RhsResult {
}
/// Type of integration step
#[repr(u32)]
#[repr(i32)]
pub enum StepKind {
/// The `NORMAL`option causes the solver to take internal steps
/// until it has reached or just passed the user-specified time.
/// The solver then interpolates in order to return an approximate
/// value of y at the desired time.
Normal = cvode_5_sys::CV_NORMAL,
Normal = sundials_sys::CV_NORMAL,
/// The `CV_ONE_STEP` option tells the solver to take just one
/// internal step and then return thesolution at the point reached
/// by that step.
OneStep = cvode_5_sys::CV_ONE_STEP,
OneStep = sundials_sys::CV_ONE_STEP,
}
/// The error type for this crate
@ -88,7 +88,7 @@ fn check_non_null<T>(ptr: *mut T, func_id: &'static str) -> Result<NonNull<T>> {
}
fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
if flag == cvode_5_sys::CV_SUCCESS as i32 {
if flag == sundials_sys::CV_SUCCESS {
Ok(())
} else {
Err(Error::ErrorCode { flag, func_id })

View File

@ -4,28 +4,30 @@ use std::{
ptr::NonNull,
};
use cvode_5_sys::realtype;
use sundials_sys::realtype;
/// A sundials `N_Vector_Serial`.
#[repr(transparent)]
#[derive(Debug)]
pub struct NVectorSerial<const SIZE: usize> {
inner: cvode_5_sys::_generic_N_Vector,
inner: sundials_sys::_generic_N_Vector,
}
impl<const SIZE: usize> NVectorSerial<SIZE> {
pub(crate) unsafe fn as_raw(&self) -> cvode_5_sys::N_Vector {
pub(crate) unsafe fn as_raw(&self) -> sundials_sys::N_Vector {
std::mem::transmute(&self.inner)
}
/// Returns a reference to the inner slice of the vector.
pub fn as_slice(&self) -> &[realtype; SIZE] {
unsafe { &*(cvode_5_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *const [f64; SIZE]) }
unsafe { &*(sundials_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *const [f64; SIZE]) }
}
/// Returns a mutable reference to the inner slice of the vector.
pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] {
unsafe { &mut *(cvode_5_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *mut [f64; SIZE]) }
unsafe {
&mut *(sundials_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *mut [f64; SIZE])
}
}
}
@ -52,7 +54,7 @@ impl<const SIZE: usize> DerefMut for NVectorSerialHeapAllocated<SIZE> {
impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
unsafe fn new_inner_uninitialized() -> NonNull<NVectorSerial<SIZE>> {
let raw_c = cvode_5_sys::N_VNew_Serial(SIZE.try_into().unwrap());
let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap());
NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap()
}
@ -60,7 +62,7 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
pub fn new() -> Self {
let inner = unsafe {
let x = Self::new_inner_uninitialized();
let ptr = cvode_5_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
for off in 0..SIZE {
*ptr.add(off) = 0.;
}
@ -73,7 +75,7 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
pub fn new_from(data: &[realtype; SIZE]) -> Self {
let inner = unsafe {
let x = Self::new_inner_uninitialized();
let ptr = cvode_5_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE);
x
};
@ -83,6 +85,6 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
impl<const SIZE: usize> Drop for NVectorSerialHeapAllocated<SIZE> {
fn drop(&mut self) {
unsafe { cvode_5_sys::N_VDestroy(self.as_raw()) }
unsafe { sundials_sys::N_VDestroy(self.as_raw()) }
}
}

View File

@ -10,7 +10,7 @@ fn main() {
RhsResult::Ok
}
// If there is any command line argument compute the sensitivities, else don't.
if args().nth(1).is_none() {
if false && args().nth(1).is_none() {
//initialize the solver
let mut solver = cvode::Solver::new(
LinearMultistepMethod::Adams,