Finished change around of wrapping SUNContext in a NonNull

This commit is contained in:
Anne de Jong 2024-12-25 10:38:07 +01:00
parent bf132d928f
commit 0f0098d473
3 changed files with 12 additions and 12 deletions

View File

@ -184,7 +184,7 @@ where
atol_sens, atol_sens,
user_data, user_data,
sensi_out_buffer: array_init::array_init(|_| { sensi_out_buffer: array_init::array_init(|_| {
NVectorSerialHeapAllocated::new(context.as_ptr()) NVectorSerialHeapAllocated::new(context)
}), }),
}; };
{ {

View File

@ -105,7 +105,7 @@
//! ``` //! ```
use std::{ffi::c_void, os::raw::c_int, ptr::NonNull}; use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
use sundials_sys::{realtype, SUNComm, SUNContext, _SUNContext}; use sundials_sys::{realtype, SUNComm, SUNContext, SUNContext_};
mod nvector; mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
@ -226,12 +226,12 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
} }
} }
pub type SunContext = std::ptr::NonNull<_SUNContext>; pub type SunContext = std::ptr::NonNull<SUNContext_>;
fn sundials_create_context() -> Result<SunContext> { fn sundials_create_context() -> Result<SunContext> {
let context = unsafe { let context = unsafe {
let mut context: SUNContext = std::ptr::null_mut(); let mut context: SUNContext = std::ptr::null_mut();
let ompi_communicator_t: SUNComm = std::ptr::null_mut(); let ompi_communicator_t: SUNComm = 0;
sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context); sundials_sys::SUNContext_Create(ompi_communicator_t, &mut context);
check_non_null(context, "SUNContext_Create")?; check_non_null(context, "SUNContext_Create")?;
std::ptr::NonNull::new(context).unwrap() std::ptr::NonNull::new(context).unwrap()
@ -240,7 +240,7 @@ fn sundials_create_context() -> Result<SunContext> {
} }
fn sundials_free_context(context: SunContext) -> Result<()> { fn sundials_free_context(context: SunContext) -> Result<()> {
let mut ptr = context.as_ptr(); let mut ptr = context.as_ptr();
let ptr_ptr: *mut *mut _SUNContext = &mut ptr; let ptr_ptr: *mut *mut SUNContext_ = &mut ptr;
unsafe { sundials_sys::SUNContext_Free(ptr_ptr) }; unsafe { sundials_sys::SUNContext_Free(ptr_ptr) };
Ok(()) Ok(())
} }

View File

@ -3,8 +3,8 @@ use std::{
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
ptr::NonNull, ptr::NonNull,
}; };
use sundials_sys::realtype;
use sundials_sys::{realtype, SUNContext, _SUNContext}; use crate::SunContext;
/// A sundials `N_Vector_Serial`. /// A sundials `N_Vector_Serial`.
#[repr(transparent)] #[repr(transparent)]
@ -53,13 +53,13 @@ impl<const SIZE: usize> DerefMut for NVectorSerialHeapAllocated<SIZE> {
} }
impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> { impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
unsafe fn new_inner_uninitialized(context: SUNContext) -> NonNull<NVectorSerial<SIZE>> { unsafe fn new_inner_uninitialized(context: SunContext) -> NonNull<NVectorSerial<SIZE>> {
let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap(), context); let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap(), context.as_ptr());
NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap() NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap()
} }
/// Creates a new vector, filled with 0. /// Creates a new vector, filled with 0.
pub fn new(context: SUNContext) -> Self { pub fn new(context: SunContext) -> Self {
let inner = unsafe { let inner = unsafe {
let x = Self::new_inner_uninitialized(context); let x = Self::new_inner_uninitialized(context);
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
@ -72,9 +72,9 @@ impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
} }
/// Creates a new vector, filled with data from `data`. /// Creates a new vector, filled with data from `data`.
pub fn new_from(data: &[realtype; SIZE], context: std::ptr::NonNull<_SUNContext>) -> Self { pub fn new_from(data: &[realtype; SIZE], context: SunContext) -> Self {
let inner = unsafe { let inner = unsafe {
let x = Self::new_inner_uninitialized(context.as_ptr()); let x = Self::new_inner_uninitialized(context);
let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw()); let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE); std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE);
x x