Clippy fixes. Store SUNContext in a NonNull wrapper

This commit is contained in:
Anne de Jong 2024-12-20 16:26:45 +01:00
parent 343cfbcb7e
commit bf132d928f
4 changed files with 35 additions and 23 deletions

View File

@ -2,12 +2,12 @@
use std::{convert::TryInto, os::raw::c_int, pin::Pin}; use std::{convert::TryInto, os::raw::c_int, pin::Pin};
use sundials_sys::{SUNComm, SUNContext, SUNLinearSolver, SUNMatrix}; use sundials_sys::{SUNLinearSolver, SUNMatrix};
use crate::{ use crate::{
check_flag_is_succes, check_non_null, sundials_create_context, sundials_free_context, check_flag_is_succes, check_non_null, sundials_create_context, sundials_free_context,
AbsTolerance, CvodeMemoryBlock, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, AbsTolerance, CvodeMemoryBlock, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod,
NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, SunContext,
}; };
struct WrappingUserData<UserData, F> { struct WrappingUserData<UserData, F> {
@ -32,7 +32,7 @@ pub struct Solver<UserData, F, const N: usize> {
linsolver: SUNLinearSolver, linsolver: SUNLinearSolver,
atol: AbsTolerance<N>, atol: AbsTolerance<N>,
user_data: Pin<Box<WrappingUserData<UserData, F>>>, user_data: Pin<Box<WrappingUserData<UserData, F>>>,
context: SUNContext, context: SunContext,
} }
extern "C" fn wrap_f<UserData, F, const N: usize>( extern "C" fn wrap_f<UserData, F, const N: usize>(
@ -77,19 +77,19 @@ where
assert_eq!(y0.len(), N); assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = { let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) }; let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context.as_ptr()) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
}; };
let y0 = NVectorSerialHeapAllocated::new_from(y0, context); let y0 = NVectorSerialHeapAllocated::new_from(y0, context);
let matrix = { let matrix = {
let matrix = unsafe { let matrix = unsafe {
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context) sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context.as_ptr())
}; };
check_non_null(matrix, "SUNDenseMatrix")? check_non_null(matrix, "SUNDenseMatrix")?
}; };
let linsolver = { let linsolver = {
let linsolver = let linsolver =
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) }; unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context.as_ptr()) };
check_non_null(linsolver, "SUNDenseLinearSolver")? check_non_null(linsolver, "SUNDenseLinearSolver")?
}; };
let user_data = Box::pin(WrappingUserData { let user_data = Box::pin(WrappingUserData {

View File

@ -125,7 +125,7 @@ where
) -> RhsResult, ) -> RhsResult,
{ {
/// Creates a new solver. /// Creates a new solver.
#[allow(clippy::clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
method: LinearMultistepMethod, method: LinearMultistepMethod,
f: F, f: F,
@ -141,7 +141,8 @@ where
assert_eq!(y0.len(), N); assert_eq!(y0.len(), N);
let context = sundials_create_context()?; let context = sundials_create_context()?;
let mem: CvodeMemoryBlockNonNullPtr = { let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) }; let mem_maybenull =
unsafe { sundials_sys::CVodeCreate(method as c_int, context.as_ptr()) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
}; };
let y0 = NVectorSerialHeapAllocated::new_from(y0, context); let y0 = NVectorSerialHeapAllocated::new_from(y0, context);
@ -154,13 +155,18 @@ where
); );
let matrix = { let matrix = {
let matrix = unsafe { let matrix = unsafe {
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context) sundials_sys::SUNDenseMatrix(
N.try_into().unwrap(),
N.try_into().unwrap(),
context.as_ptr(),
)
}; };
check_non_null(matrix, "SUNDenseMatrix")? check_non_null(matrix, "SUNDenseMatrix")?
}; };
let linsolver = { let linsolver = {
let linsolver = let linsolver = unsafe {
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) }; sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context.as_ptr())
};
check_non_null(linsolver, "SUNDenseLinearSolver")? check_non_null(linsolver, "SUNDenseLinearSolver")?
}; };
let user_data = Box::pin(WrappingUserData { let user_data = Box::pin(WrappingUserData {
@ -177,7 +183,9 @@ where
atol, atol,
atol_sens, atol_sens,
user_data, user_data,
sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new(context)), sensi_out_buffer: array_init::array_init(|_| {
NVectorSerialHeapAllocated::new(context.as_ptr())
}),
}; };
{ {
let flag = unsafe { let flag = unsafe {
@ -258,7 +266,7 @@ where
/// reached by the solver as dictated by `step_kind`, `y(t_out)` is an /// reached by the solver as dictated by `step_kind`, `y(t_out)` is an
/// array of the state variables at that time, and the i-th `dy_dp(tout)` is an array /// array of the state variables at that time, and the i-th `dy_dp(tout)` is an array
/// of the sensitivities of all variables with respect to parameter i. /// of the sensitivities of all variables with respect to parameter i.
#[allow(clippy::clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub fn step( pub fn step(
&mut self, &mut self,
tout: Realtype, tout: Realtype,

View File

@ -105,7 +105,7 @@
//! ``` //! ```
use std::{ffi::c_void, os::raw::c_int, ptr::NonNull}; use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
use sundials_sys::{realtype, SUNComm, SUNContext}; use sundials_sys::{realtype, SUNComm, SUNContext, _SUNContext};
mod nvector; mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
@ -182,7 +182,7 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
AbsTolerance::Scalar(atol) AbsTolerance::Scalar(atol)
} }
pub fn vector(atol: &[Realtype; SIZE], context: SUNContext) -> Self { pub fn vector(atol: &[Realtype; SIZE], context: SunContext) -> Self {
let atol = NVectorSerialHeapAllocated::new_from(atol, context); let atol = NVectorSerialHeapAllocated::new_from(atol, context);
AbsTolerance::Vector(atol) AbsTolerance::Vector(atol)
} }
@ -200,7 +200,7 @@ impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
SensiAbsTolerance::Scalar(atol) SensiAbsTolerance::Scalar(atol)
} }
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SUNContext) -> Self { pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SunContext) -> Self {
SensiAbsTolerance::Vector( SensiAbsTolerance::Vector(
array_init::from_iter( array_init::from_iter(
atol.iter() atol.iter()
@ -226,18 +226,22 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
} }
} }
fn sundials_create_context() -> Result<SUNContext> { pub type SunContext = std::ptr::NonNull<_SUNContext>;
fn sundials_create_context() -> Result<SunContext> {
let context = unsafe { let context = unsafe {
let mut context: SUNContext = std::ptr::null_mut(); let mut context: SUNContext = std::ptr::null_mut();
let ompi_communicator_t: SUNComm = std::ptr::null_mut(); let ompi_communicator_t: SUNComm = std::ptr::null_mut();
sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context); sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context);
check_non_null(context, "SUNContext_Create")?; check_non_null(context, "SUNContext_Create")?;
context std::ptr::NonNull::new(context).unwrap()
}; };
Ok(context) Ok(context)
} }
fn sundials_free_context(mut context: SUNContext) -> Result<()> { fn sundials_free_context(context: SunContext) -> Result<()> {
unsafe { sundials_sys::SUNContext_Free(&mut context) }; let mut ptr = context.as_ptr();
let ptr_ptr: *mut *mut _SUNContext = &mut ptr;
unsafe { sundials_sys::SUNContext_Free(ptr_ptr) };
Ok(()) Ok(())
} }

View File

@ -4,7 +4,7 @@ use std::{
ptr::NonNull, ptr::NonNull,
}; };
use sundials_sys::{realtype, SUNContext}; use sundials_sys::{realtype, SUNContext, _SUNContext};
/// A sundials `N_Vector_Serial`. /// A sundials `N_Vector_Serial`.
#[repr(transparent)] #[repr(transparent)]
@ -72,9 +72,9 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
} }
/// Creates a new vector, filled with data from `data`. /// Creates a new vector, filled with data from `data`.
pub fn new_from(data: &[realtype; SIZE], context: SUNContext) -> Self { pub fn new_from(data: &[realtype; SIZE], context: std::ptr::NonNull<_SUNContext>) -> Self {
let inner = unsafe { let inner = unsafe {
let x = Self::new_inner_uninitialized(context); let x = Self::new_inner_uninitialized(context.as_ptr());
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE); std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE);
x x