Updated to sundials-sys v0.6.1.

This commit is contained in:
Anne de Jong 2024-12-20 15:59:53 +01:00
parent be40717df5
commit a6db3259c5
5 changed files with 54 additions and 30 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "cvode-wrap"
version = "0.1.3"
version = "0.1.4"
authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
edition = "2018"
license = "BSD-3-Clause"
@ -13,9 +13,9 @@ categories=["science","simulation","api-bindings"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sundials-sys = {version="0.2.3", default-features=false, features=["cvodes"]}
sundials-sys = {version="0.6.1", default-features=false, features=["cvodes"]}
array-init = "2.0"
[package.metadata.docs.rs]
features = ["sundials-sys/build_libraries"]
features = ["sundials-sys/build_libraries"]

View File

@ -2,12 +2,12 @@
use std::{convert::TryInto, os::raw::c_int, pin::Pin};
use sundials_sys::{SUNLinearSolver, SUNMatrix};
use sundials_sys::{SUNComm, SUNContext, SUNLinearSolver, SUNMatrix};
use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
Realtype, Result, RhsResult, StepKind,
check_flag_is_succes, check_non_null, sundials_create_context, sundials_free_context,
AbsTolerance, CvodeMemoryBlock, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod,
NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind,
};
struct WrappingUserData<UserData, F> {
@ -32,6 +32,7 @@ pub struct Solver<UserData, F, const N: usize> {
linsolver: SUNLinearSolver,
atol: AbsTolerance<N>,
user_data: Pin<Box<WrappingUserData<UserData, F>>>,
context: SUNContext,
}
extern "C" fn wrap_f<UserData, F, const N: usize>(
@ -71,20 +72,24 @@ where
atol: AbsTolerance<N>,
user_data: UserData,
) -> Result<Self> {
// Create context, required from version 6 and above on
let context = sundials_create_context()?;
assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0);
let y0 = NVectorSerialHeapAllocated::new_from(y0, context);
let matrix = {
let matrix = unsafe {
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context)
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
let linsolver =
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {
@ -98,6 +103,7 @@ where
linsolver: linsolver.as_ptr(),
atol,
user_data,
context,
};
{
let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
@ -174,6 +180,7 @@ impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
sundials_free_context(self.context);
}
}

View File

@ -5,7 +5,7 @@ use std::{convert::TryInto, os::raw::c_int, pin::Pin};
use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED};
use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
check_flag_is_succes, check_non_null, sundials_create_context, AbsTolerance, CvodeMemoryBlock,
CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
Realtype, Result, RhsResult, SensiAbsTolerance, StepKind,
};
@ -139,26 +139,28 @@ where
user_data: UserData,
) -> Result<Self> {
assert_eq!(y0.len(), N);
let context = sundials_create_context()?;
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0);
let y0 = NVectorSerialHeapAllocated::new_from(y0, context);
let y_s0 = Box::new(
array_init::from_iter(
y_s0.iter()
.map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
.map(|arr| NVectorSerialHeapAllocated::new_from(arr, context)),
)
.unwrap(),
);
let matrix = {
let matrix = unsafe {
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context)
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
let linsolver =
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {
@ -175,7 +177,7 @@ where
atol,
atol_sens,
user_data,
sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new()),
sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new(context)),
};
{
let flag = unsafe {

View File

@ -105,7 +105,7 @@
//! ```
use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
use sundials_sys::realtype;
use sundials_sys::{realtype, SUNComm, SUNContext};
mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
@ -182,8 +182,8 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
AbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[Realtype; SIZE]) -> Self {
let atol = NVectorSerialHeapAllocated::new_from(atol);
pub fn vector(atol: &[Realtype; SIZE], context: SUNContext) -> Self {
let atol = NVectorSerialHeapAllocated::new_from(atol, context);
AbsTolerance::Vector(atol)
}
}
@ -200,11 +200,11 @@ impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
SensiAbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self {
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SUNContext) -> Self {
SensiAbsTolerance::Vector(
array_init::from_iter(
atol.iter()
.map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
.map(|arr| NVectorSerialHeapAllocated::new_from(arr, context)),
)
.unwrap(),
)
@ -226,6 +226,21 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
}
}
fn sundials_create_context() -> Result<SUNContext> {
let context = unsafe {
let mut context: SUNContext = std::ptr::null_mut();
let ompi_communicator_t: SUNComm = std::ptr::null_mut();
sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context);
check_non_null(context, "SUNContext_Create")?;
context
};
Ok(context)
}
fn sundials_free_context(mut context: SUNContext) -> Result<()> {
unsafe { sundials_sys::SUNContext_Free(&mut context) };
Ok(())
}
#[repr(C)]
struct CvodeMemoryBlock {
_private: [u8; 0],

View File

@ -4,7 +4,7 @@ use std::{
ptr::NonNull,
};
use sundials_sys::realtype;
use sundials_sys::{realtype, SUNContext};
/// A sundials `N_Vector_Serial`.
#[repr(transparent)]
@ -53,15 +53,15 @@ impl<const SIZE: usize> DerefMut for NVectorSerialHeapAllocated<SIZE> {
}
impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
unsafe fn new_inner_uninitialized() -> NonNull<NVectorSerial<SIZE>> {
let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap());
unsafe fn new_inner_uninitialized(context: SUNContext) -> NonNull<NVectorSerial<SIZE>> {
let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap(), context);
NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap()
}
/// Creates a new vector, filled with 0.
pub fn new() -> Self {
pub fn new(context: SUNContext) -> Self {
let inner = unsafe {
let x = Self::new_inner_uninitialized();
let x = Self::new_inner_uninitialized(context);
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
for off in 0..SIZE {
*ptr.add(off) = 0.;
@ -72,9 +72,9 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
}
/// Creates a new vector, filled with data from `data`.
pub fn new_from(data: &[realtype; SIZE]) -> Self {
pub fn new_from(data: &[realtype; SIZE], context: SUNContext) -> Self {
let inner = unsafe {
let x = Self::new_inner_uninitialized();
let x = Self::new_inner_uninitialized(context);
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE);
x