From bf132d928f13eb7f52b6127d737bdb546e80629b Mon Sep 17 00:00:00 2001 From: "J.A. de Jong - Redu-Sone B.V., ASCEE V.O.F." Date: Fri, 20 Dec 2024 16:26:45 +0100 Subject: [PATCH] Clippy fixes. Store SUNContext in a NonNull wrapper --- src/cvode.rs | 12 ++++++------ src/cvode_sens.rs | 22 +++++++++++++++------- src/lib.rs | 18 +++++++++++------- src/nvector.rs | 6 +++--- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/cvode.rs b/src/cvode.rs index 52f6cf5..3277d49 100644 --- a/src/cvode.rs +++ b/src/cvode.rs @@ -2,12 +2,12 @@ use std::{convert::TryInto, os::raw::c_int, pin::Pin}; -use sundials_sys::{SUNComm, SUNContext, SUNLinearSolver, SUNMatrix}; +use sundials_sys::{SUNLinearSolver, SUNMatrix}; use crate::{ check_flag_is_succes, check_non_null, sundials_create_context, sundials_free_context, AbsTolerance, CvodeMemoryBlock, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, - NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, + NVectorSerial, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, SunContext, }; struct WrappingUserData { @@ -32,7 +32,7 @@ pub struct Solver { linsolver: SUNLinearSolver, atol: AbsTolerance, user_data: Pin>>, - context: SUNContext, + context: SunContext, } extern "C" fn wrap_f( @@ -77,19 +77,19 @@ where assert_eq!(y0.len(), N); let mem: CvodeMemoryBlockNonNullPtr = { - let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) }; + let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context.as_ptr()) }; check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() }; let y0 = NVectorSerialHeapAllocated::new_from(y0, context); let matrix = { let matrix = unsafe { - sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context) + sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context.as_ptr()) }; check_non_null(matrix, "SUNDenseMatrix")? }; let linsolver = { let linsolver = - unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) }; + unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context.as_ptr()) }; check_non_null(linsolver, "SUNDenseLinearSolver")? }; let user_data = Box::pin(WrappingUserData { diff --git a/src/cvode_sens.rs b/src/cvode_sens.rs index 9557399..8e868c9 100644 --- a/src/cvode_sens.rs +++ b/src/cvode_sens.rs @@ -125,7 +125,7 @@ where ) -> RhsResult, { /// Creates a new solver. - #[allow(clippy::clippy::too_many_arguments)] + #[allow(clippy::too_many_arguments)] pub fn new( method: LinearMultistepMethod, f: F, @@ -141,7 +141,8 @@ where assert_eq!(y0.len(), N); let context = sundials_create_context()?; let mem: CvodeMemoryBlockNonNullPtr = { - let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int, context) }; + let mem_maybenull = + unsafe { sundials_sys::CVodeCreate(method as c_int, context.as_ptr()) }; check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() }; let y0 = NVectorSerialHeapAllocated::new_from(y0, context); @@ -154,13 +155,18 @@ where ); let matrix = { let matrix = unsafe { - sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap(), context) + sundials_sys::SUNDenseMatrix( + N.try_into().unwrap(), + N.try_into().unwrap(), + context.as_ptr(), + ) }; check_non_null(matrix, "SUNDenseMatrix")? }; let linsolver = { - let linsolver = - unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context) }; + let linsolver = unsafe { + sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr(), context.as_ptr()) + }; check_non_null(linsolver, "SUNDenseLinearSolver")? }; let user_data = Box::pin(WrappingUserData { @@ -177,7 +183,9 @@ where atol, atol_sens, user_data, - sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new(context)), + sensi_out_buffer: array_init::array_init(|_| { + NVectorSerialHeapAllocated::new(context.as_ptr()) + }), }; { let flag = unsafe { @@ -258,7 +266,7 @@ where /// reached by the solver as dictated by `step_kind`, `y(t_out)` is an /// array of the state variables at that time, and the i-th `dy_dp(tout)` is an array /// of the sensitivities of all variables with respect to parameter i. - #[allow(clippy::clippy::type_complexity)] + #[allow(clippy::type_complexity)] pub fn step( &mut self, tout: Realtype, diff --git a/src/lib.rs b/src/lib.rs index c73b0ec..d195750 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,7 +105,7 @@ //! ``` use std::{ffi::c_void, os::raw::c_int, ptr::NonNull}; -use sundials_sys::{realtype, SUNComm, SUNContext}; +use sundials_sys::{realtype, SUNComm, SUNContext, _SUNContext}; mod nvector; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; @@ -182,7 +182,7 @@ impl AbsTolerance { AbsTolerance::Scalar(atol) } - pub fn vector(atol: &[Realtype; SIZE], context: SUNContext) -> Self { + pub fn vector(atol: &[Realtype; SIZE], context: SunContext) -> Self { let atol = NVectorSerialHeapAllocated::new_from(atol, context); AbsTolerance::Vector(atol) } @@ -200,7 +200,7 @@ impl SensiAbsTolerance { SensiAbsTolerance::Scalar(atol) } - pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SUNContext) -> Self { + pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI], context: SunContext) -> Self { SensiAbsTolerance::Vector( array_init::from_iter( atol.iter() @@ -226,18 +226,22 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> { } } -fn sundials_create_context() -> Result { +pub type SunContext = std::ptr::NonNull<_SUNContext>; + +fn sundials_create_context() -> Result { let context = unsafe { let mut context: SUNContext = std::ptr::null_mut(); let ompi_communicator_t: SUNComm = std::ptr::null_mut(); sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context); check_non_null(context, "SUNContext_Create")?; - context + std::ptr::NonNull::new(context).unwrap() }; Ok(context) } -fn sundials_free_context(mut context: SUNContext) -> Result<()> { - unsafe { sundials_sys::SUNContext_Free(&mut context) }; +fn sundials_free_context(context: SunContext) -> Result<()> { + let mut ptr = context.as_ptr(); + let ptr_ptr: *mut *mut _SUNContext = &mut ptr; + unsafe { sundials_sys::SUNContext_Free(ptr_ptr) }; Ok(()) } diff --git a/src/nvector.rs b/src/nvector.rs index 43b4547..24d0989 100644 --- a/src/nvector.rs +++ b/src/nvector.rs @@ -4,7 +4,7 @@ use std::{ ptr::NonNull, }; -use sundials_sys::{realtype, SUNContext}; +use sundials_sys::{realtype, SUNContext, _SUNContext}; /// A sundials `N_Vector_Serial`. #[repr(transparent)] @@ -72,9 +72,9 @@ impl NVectorSerialHeapAllocated { } /// Creates a new vector, filled with data from `data`. - pub fn new_from(data: &[realtype; SIZE], context: SUNContext) -> Self { + pub fn new_from(data: &[realtype; SIZE], context: std::ptr::NonNull<_SUNContext>) -> Self { let inner = unsafe { - let x = Self::new_inner_uninitialized(context); + let x = Self::new_inner_uninitialized(context.as_ptr()); let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE); x