Clippy fixes. Store SUNContext in a NonNull wrapper

This commit is contained in:
Anne de Jong 2024-12-20 16:26:45 +01:00
parent 343cfbcb7e
commit bf132d928f
4 changed files with 35 additions and 23 deletions

View File

@ -2,12 +2,12 @@
use std::{convert::TryInto, os::raw::c_int, pin::Pin};
use sundials_sys::{SUNComm, SUNContext, SUNLinearSolver, SUNMatrix};
use sundials_sys::{SUNLinearSolver, SUNMatrix};
use crate::{
check_flag_is_succes, check_non_null, sundials_create_context, sundials_free_context,
AbsTolerance, CvodeMemoryBlock, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod,
NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind,
NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, SunContext,
};
struct WrappingUserData<UserData, F> {
@ -32,7 +32,7 @@ pub struct Solver<UserData, F, const N: usize> {
linsolver: SUNLinearSolver,
atol: AbsTolerance<N>,
user_data: Pin<Box<WrappingUserData<UserData, F>>>,
context: SUNContext,
context: SunContext,
}
extern "C" fn wrap_f<UserData, F, const N: usize>(
@ -77,19 +77,19 @@ where
assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) };
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context.as_ptr()) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0, context);
let matrix = {
let matrix = unsafe {
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context)
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context.as_ptr())
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver =
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) };
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context.as_ptr()) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {

View File

@ -125,7 +125,7 @@ where
) -> RhsResult,
{
/// Creates a new solver.
#[allow(clippy::clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub fn new(
method: LinearMultistepMethod,
f: F,
@ -141,7 +141,8 @@ where
assert_eq!(y0.len(), N);
let context = sundials_create_context()?;
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) };
let mem_maybenull =
unsafe { sundials_sys::CVodeCreate(method as c_int, context.as_ptr()) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0, context);
@ -154,13 +155,18 @@ where
);
let matrix = {
let matrix = unsafe {
sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context)
sundials_sys::SUNDenseMatrix(
N.try_into().unwrap(),
N.try_into().unwrap(),
context.as_ptr(),
)
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver =
unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) };
let linsolver = unsafe {
sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context.as_ptr())
};
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {
@ -177,7 +183,9 @@ where
atol,
atol_sens,
user_data,
sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new(context)),
sensi_out_buffer: array_init::array_init(|_| {
NVectorSerialHeapAllocated::new(context.as_ptr())
}),
};
{
let flag = unsafe {
@ -258,7 +266,7 @@ where
/// reached by the solver as dictated by `step_kind`, `y(t_out)` is an
/// array of the state variables at that time, and the i-th `dy_dp(tout)` is an array
/// of the sensitivities of all variables with respect to parameter i.
#[allow(clippy::clippy::type_complexity)]
#[allow(clippy::type_complexity)]
pub fn step(
&mut self,
tout: Realtype,

View File

@ -105,7 +105,7 @@
//! ```
use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
use sundials_sys::{realtype, SUNComm, SUNContext};
use sundials_sys::{realtype, SUNComm, SUNContext, _SUNContext};
mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
@ -182,7 +182,7 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
AbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[Realtype; SIZE], context: SUNContext) -> Self {
pub fn vector(atol: &[Realtype; SIZE], context: SunContext) -> Self {
let atol = NVectorSerialHeapAllocated::new_from(atol, context);
AbsTolerance::Vector(atol)
}
@ -200,7 +200,7 @@ impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
SensiAbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SUNContext) -> Self {
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SunContext) -> Self {
SensiAbsTolerance::Vector(
array_init::from_iter(
atol.iter()
@ -226,18 +226,22 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
}
}
fn sundials_create_context() -> Result<SUNContext> {
pub type SunContext = std::ptr::NonNull<_SUNContext>;
fn sundials_create_context() -> Result<SunContext> {
let context = unsafe {
let mut context: SUNContext = std::ptr::null_mut();
let ompi_communicator_t: SUNComm = std::ptr::null_mut();
sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context);
check_non_null(context, "SUNContext_Create")?;
context
std::ptr::NonNull::new(context).unwrap()
};
Ok(context)
}
fn sundials_free_context(mut context: SUNContext) -> Result<()> {
unsafe { sundials_sys::SUNContext_Free(&mut context) };
fn sundials_free_context(context: SunContext) -> Result<()> {
let mut ptr = context.as_ptr();
let ptr_ptr: *mut *mut _SUNContext = &mut ptr;
unsafe { sundials_sys::SUNContext_Free(ptr_ptr) };
Ok(())
}

View File

@ -4,7 +4,7 @@ use std::{
ptr::NonNull,
};
use sundials_sys::{realtype, SUNContext};
use sundials_sys::{realtype, SUNContext, _SUNContext};
/// A sundials `N_Vector_Serial`.
#[repr(transparent)]
@ -72,9 +72,9 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
}
/// Creates a new vector, filled with data from `data`.
pub fn new_from(data: &[realtype; SIZE], context: SUNContext) -> Self {
pub fn new_from(data: &[realtype; SIZE], context: std::ptr::NonNull<_SUNContext>) -> Self {
let inner = unsafe {
let x = Self::new_inner_uninitialized(context);
let x = Self::new_inner_uninitialized(context.as_ptr());
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE);
x