Better type checking of problem size for C Rhs

This commit is contained in:
Arthur Carcano 2021-05-09 12:35:27 +02:00
parent b1472edb77
commit 63f08c286b
3 changed files with 73 additions and 33 deletions

View File

@ -10,10 +10,9 @@ use cvode_5_sys::{
};
mod nvector;
use nvector::NVectorSerialHeapAlloced;
pub use nvector::{NVectorSerial, NVectorSerialHeapAlloced};
pub type Realtype = realtype;
pub type CVector = cvode::N_Vector;
#[repr(u32)]
#[derive(Debug)]
@ -74,15 +73,21 @@ pub enum RhsResult {
NonRecoverableError(u8),
}
pub type RhsF<UserData, const N: usize> = fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
pub type RhsF<UserData, const N: usize> =
fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
pub type RhsFCtype<UserData> = extern "C" fn(t: Realtype, y: CVector, ydot: CVector, user_data: *const UserData) -> c_int;
pub type RhsFCtype<UserData, const N: usize> = extern "C" fn(
t: Realtype,
y: *const NVectorSerial<N>,
ydot: *mut NVectorSerial<N>,
user_data: *const UserData,
) -> c_int;
pub fn wrap_f<UserData, const N: usize>(
f: RhsF<UserData, N>,
t: Realtype,
y: CVector,
ydot: CVector,
y: *const NVectorSerial<N>,
ydot: *mut NVectorSerial<N>,
data: &UserData,
) -> c_int {
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
@ -97,11 +102,11 @@ pub fn wrap_f<UserData, const N: usize>(
#[macro_export]
macro_rules! wrap {
($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
($wrapped_f_name: ident, $f_name: ident, $user_data: ty, $problem_size: expr) => {
extern "C" fn $wrapped_f_name(
t: Realtype,
y: CVector,
ydot: CVector,
y: *const NVectorSerial<$problem_size>,
ydot: *mut NVectorSerial<$problem_size>,
data: *const $user_data,
) -> std::os::raw::c_int {
let data = unsafe { std::mem::transmute(data) };
@ -155,7 +160,7 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
impl<UserData, const N: usize> Solver<UserData, N> {
pub fn new(
method: LinearMultistepMethod,
f: RhsFCtype<UserData>,
f: RhsFCtype<UserData, N>,
t0: Realtype,
y0: &[Realtype; N],
rtol: Realtype,
@ -196,7 +201,14 @@ impl<UserData, const N: usize> Solver<UserData, N> {
user_data,
};
{
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
let flag = unsafe {
cvode::CVodeInit(
mem.as_raw(),
Some(std::mem::transmute(f)),
t0,
res.y0.as_raw() as _,
)
};
check_flag_is_succes(flag, "CVodeInit")?;
}
match &res.atol {
@ -232,7 +244,11 @@ impl<UserData, const N: usize> Solver<UserData, N> {
Ok(res)
}
pub fn step(&mut self, tout: Realtype, step_kind: StepKind) -> Result<(Realtype, &[Realtype; N])> {
pub fn step(
&mut self,
tout: Realtype,
step_kind: StepKind,
) -> Result<(Realtype, &[Realtype; N])> {
let mut tret = 0.;
let flag = unsafe {
cvode::CVode(
@ -260,12 +276,17 @@ impl<UserData, const N: usize> Drop for Solver<UserData, N> {
mod tests {
use super::*;
fn f(_t: super::Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], _data: &()) -> RhsResult {
fn f(
_t: super::Realtype,
y: &[Realtype; 2],
ydot: &mut [Realtype; 2],
_data: &(),
) -> RhsResult {
*ydot = [y[1], -y[0]];
RhsResult::Ok
}
wrap!(wrapped_f, f, ());
wrap!(wrapped_f, f, (), 2);
#[test]
fn create() {

View File

@ -1,23 +1,24 @@
use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
use std::{
convert::TryInto,
intrinsics::transmute,
ops::{Deref, DerefMut},
ptr::NonNull,
};
use cvode_5_sys::{cvode::realtype, nvector_serial};
#[repr(transparent)]
#[derive(Debug)]
pub struct NVectorSerialHeapAlloced<const SIZE: usize> {
inner: NonNull<nvector_serial::_generic_N_Vector>,
pub struct NVectorSerial<const SIZE: usize> {
inner: nvector_serial::_generic_N_Vector,
}
impl<const SIZE: usize> Deref for NVectorSerialHeapAlloced<SIZE> {
type Target = nvector_serial::_generic_N_Vector;
impl<const SIZE: usize> NVectorSerial<SIZE> {
fn deref(&self) -> &Self::Target {
unsafe {self.inner.as_ref()}
}
pub unsafe fn as_raw(&self) -> nvector_serial::N_Vector {
std::mem::transmute(&self.inner)
}
impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
pub fn as_slice(&self) -> &[realtype; SIZE] {
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
}
@ -25,11 +26,33 @@ impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] {
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct NVectorSerialHeapAlloced<const SIZE: usize> {
inner: NonNull<NVectorSerial<SIZE>>,
}
impl<const SIZE: usize> Deref for NVectorSerialHeapAlloced<SIZE> {
type Target = NVectorSerial<SIZE>;
fn deref(&self) -> &Self::Target {
unsafe { self.inner.as_ref() }
}
}
impl<const SIZE: usize> DerefMut for NVectorSerialHeapAlloced<SIZE> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.inner.as_mut() }
}
}
impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
pub fn new() -> Self {
let raw_c = unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) };
Self {
inner: NonNull::new(unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) })
.unwrap(),
inner: NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap(),
}
}
@ -38,10 +61,6 @@ impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
res.as_slice_mut().copy_from_slice(data);
res
}
pub fn as_raw(&self) -> nvector_serial::N_Vector {
self.inner.as_ptr()
}
}
impl<const SIZE: usize> Drop for NVectorSerialHeapAlloced<SIZE> {

View File

@ -7,7 +7,7 @@ fn main() {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
wrap!(wrapped_f, f, Realtype);
wrap!(wrapped_f, f, Realtype, 2);
//initialize the solver
let mut solver = Solver::new(
LinearMultistepMethod::ADAMS,