Better type checking of problem size for C Rhs
This commit is contained in:
parent
b1472edb77
commit
63f08c286b
@ -10,10 +10,9 @@ use cvode_5_sys::{
|
||||
};
|
||||
|
||||
mod nvector;
|
||||
use nvector::NVectorSerialHeapAlloced;
|
||||
pub use nvector::{NVectorSerial, NVectorSerialHeapAlloced};
|
||||
|
||||
pub type Realtype = realtype;
|
||||
pub type CVector = cvode::N_Vector;
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug)]
|
||||
@ -74,15 +73,21 @@ pub enum RhsResult {
|
||||
NonRecoverableError(u8),
|
||||
}
|
||||
|
||||
pub type RhsF<UserData, const N: usize> = fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
|
||||
pub type RhsF<UserData, const N: usize> =
|
||||
fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
|
||||
|
||||
pub type RhsFCtype<UserData> = extern "C" fn(t: Realtype, y: CVector, ydot: CVector, user_data: *const UserData) -> c_int;
|
||||
pub type RhsFCtype<UserData, const N: usize> = extern "C" fn(
|
||||
t: Realtype,
|
||||
y: *const NVectorSerial<N>,
|
||||
ydot: *mut NVectorSerial<N>,
|
||||
user_data: *const UserData,
|
||||
) -> c_int;
|
||||
|
||||
pub fn wrap_f<UserData, const N: usize>(
|
||||
f: RhsF<UserData, N>,
|
||||
t: Realtype,
|
||||
y: CVector,
|
||||
ydot: CVector,
|
||||
y: *const NVectorSerial<N>,
|
||||
ydot: *mut NVectorSerial<N>,
|
||||
data: &UserData,
|
||||
) -> c_int {
|
||||
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
|
||||
@ -97,12 +102,12 @@ pub fn wrap_f<UserData, const N: usize>(
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! wrap {
|
||||
($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
|
||||
($wrapped_f_name: ident, $f_name: ident, $user_data: ty, $problem_size: expr) => {
|
||||
extern "C" fn $wrapped_f_name(
|
||||
t: Realtype,
|
||||
y: CVector,
|
||||
ydot: CVector,
|
||||
data: *const $user_data,
|
||||
y: *const NVectorSerial<$problem_size>,
|
||||
ydot: *mut NVectorSerial<$problem_size>,
|
||||
data: *const $user_data,
|
||||
) -> std::os::raw::c_int {
|
||||
let data = unsafe { std::mem::transmute(data) };
|
||||
wrap_f($f_name, t, y, ydot, data)
|
||||
@ -155,7 +160,7 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
|
||||
impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
pub fn new(
|
||||
method: LinearMultistepMethod,
|
||||
f: RhsFCtype<UserData>,
|
||||
f: RhsFCtype<UserData, N>,
|
||||
t0: Realtype,
|
||||
y0: &[Realtype; N],
|
||||
rtol: Realtype,
|
||||
@ -196,7 +201,14 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
user_data,
|
||||
};
|
||||
{
|
||||
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
|
||||
let flag = unsafe {
|
||||
cvode::CVodeInit(
|
||||
mem.as_raw(),
|
||||
Some(std::mem::transmute(f)),
|
||||
t0,
|
||||
res.y0.as_raw() as _,
|
||||
)
|
||||
};
|
||||
check_flag_is_succes(flag, "CVodeInit")?;
|
||||
}
|
||||
match &res.atol {
|
||||
@ -232,7 +244,11 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn step(&mut self, tout: Realtype, step_kind: StepKind) -> Result<(Realtype, &[Realtype; N])> {
|
||||
pub fn step(
|
||||
&mut self,
|
||||
tout: Realtype,
|
||||
step_kind: StepKind,
|
||||
) -> Result<(Realtype, &[Realtype; N])> {
|
||||
let mut tret = 0.;
|
||||
let flag = unsafe {
|
||||
cvode::CVode(
|
||||
@ -260,12 +276,17 @@ impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn f(_t: super::Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], _data: &()) -> RhsResult {
|
||||
fn f(
|
||||
_t: super::Realtype,
|
||||
y: &[Realtype; 2],
|
||||
ydot: &mut [Realtype; 2],
|
||||
_data: &(),
|
||||
) -> RhsResult {
|
||||
*ydot = [y[1], -y[0]];
|
||||
RhsResult::Ok
|
||||
}
|
||||
|
||||
wrap!(wrapped_f, f, ());
|
||||
wrap!(wrapped_f, f, (), 2);
|
||||
|
||||
#[test]
|
||||
fn create() {
|
||||
|
@ -1,23 +1,24 @@
|
||||
use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
|
||||
use std::{
|
||||
convert::TryInto,
|
||||
intrinsics::transmute,
|
||||
ops::{Deref, DerefMut},
|
||||
ptr::NonNull,
|
||||
};
|
||||
|
||||
use cvode_5_sys::{cvode::realtype, nvector_serial};
|
||||
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug)]
|
||||
pub struct NVectorSerialHeapAlloced<const SIZE: usize> {
|
||||
inner: NonNull<nvector_serial::_generic_N_Vector>,
|
||||
pub struct NVectorSerial<const SIZE: usize> {
|
||||
inner: nvector_serial::_generic_N_Vector,
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> Deref for NVectorSerialHeapAlloced<SIZE> {
|
||||
type Target = nvector_serial::_generic_N_Vector;
|
||||
impl<const SIZE: usize> NVectorSerial<SIZE> {
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe {self.inner.as_ref()}
|
||||
pub unsafe fn as_raw(&self) -> nvector_serial::N_Vector {
|
||||
std::mem::transmute(&self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
||||
pub fn as_slice(&self) -> &[realtype; SIZE] {
|
||||
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
||||
}
|
||||
@ -25,11 +26,33 @@ impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
||||
pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] {
|
||||
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug)]
|
||||
pub struct NVectorSerialHeapAlloced<const SIZE: usize> {
|
||||
inner: NonNull<NVectorSerial<SIZE>>,
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> Deref for NVectorSerialHeapAlloced<SIZE> {
|
||||
type Target = NVectorSerial<SIZE>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe { self.inner.as_ref() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> DerefMut for NVectorSerialHeapAlloced<SIZE> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
unsafe { self.inner.as_mut() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
||||
pub fn new() -> Self {
|
||||
let raw_c = unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) };
|
||||
Self {
|
||||
inner: NonNull::new(unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) })
|
||||
.unwrap(),
|
||||
inner: NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,10 +61,6 @@ impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
||||
res.as_slice_mut().copy_from_slice(data);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn as_raw(&self) -> nvector_serial::N_Vector {
|
||||
self.inner.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> Drop for NVectorSerialHeapAlloced<SIZE> {
|
||||
|
@ -7,7 +7,7 @@ fn main() {
|
||||
*ydot = [y[1], -y[0] * k];
|
||||
RhsResult::Ok
|
||||
}
|
||||
wrap!(wrapped_f, f, Realtype);
|
||||
wrap!(wrapped_f, f, Realtype, 2);
|
||||
//initialize the solver
|
||||
let mut solver = Solver::new(
|
||||
LinearMultistepMethod::ADAMS,
|
||||
|
Loading…
Reference in New Issue
Block a user