Better type checking of problem size for C Rhs
This commit is contained in:
parent
b1472edb77
commit
63f08c286b
@ -10,10 +10,9 @@ use cvode_5_sys::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
mod nvector;
|
mod nvector;
|
||||||
use nvector::NVectorSerialHeapAlloced;
|
pub use nvector::{NVectorSerial, NVectorSerialHeapAlloced};
|
||||||
|
|
||||||
pub type Realtype = realtype;
|
pub type Realtype = realtype;
|
||||||
pub type CVector = cvode::N_Vector;
|
|
||||||
|
|
||||||
#[repr(u32)]
|
#[repr(u32)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -74,15 +73,21 @@ pub enum RhsResult {
|
|||||||
NonRecoverableError(u8),
|
NonRecoverableError(u8),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type RhsF<UserData, const N: usize> = fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
|
pub type RhsF<UserData, const N: usize> =
|
||||||
|
fn(t: Realtype, y: &[Realtype; N], ydot: &mut [Realtype; N], user_data: &UserData) -> RhsResult;
|
||||||
|
|
||||||
pub type RhsFCtype<UserData> = extern "C" fn(t: Realtype, y: CVector, ydot: CVector, user_data: *const UserData) -> c_int;
|
pub type RhsFCtype<UserData, const N: usize> = extern "C" fn(
|
||||||
|
t: Realtype,
|
||||||
|
y: *const NVectorSerial<N>,
|
||||||
|
ydot: *mut NVectorSerial<N>,
|
||||||
|
user_data: *const UserData,
|
||||||
|
) -> c_int;
|
||||||
|
|
||||||
pub fn wrap_f<UserData, const N: usize>(
|
pub fn wrap_f<UserData, const N: usize>(
|
||||||
f: RhsF<UserData, N>,
|
f: RhsF<UserData, N>,
|
||||||
t: Realtype,
|
t: Realtype,
|
||||||
y: CVector,
|
y: *const NVectorSerial<N>,
|
||||||
ydot: CVector,
|
ydot: *mut NVectorSerial<N>,
|
||||||
data: &UserData,
|
data: &UserData,
|
||||||
) -> c_int {
|
) -> c_int {
|
||||||
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
|
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
|
||||||
@ -97,12 +102,12 @@ pub fn wrap_f<UserData, const N: usize>(
|
|||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! wrap {
|
macro_rules! wrap {
|
||||||
($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
|
($wrapped_f_name: ident, $f_name: ident, $user_data: ty, $problem_size: expr) => {
|
||||||
extern "C" fn $wrapped_f_name(
|
extern "C" fn $wrapped_f_name(
|
||||||
t: Realtype,
|
t: Realtype,
|
||||||
y: CVector,
|
y: *const NVectorSerial<$problem_size>,
|
||||||
ydot: CVector,
|
ydot: *mut NVectorSerial<$problem_size>,
|
||||||
data: *const $user_data,
|
data: *const $user_data,
|
||||||
) -> std::os::raw::c_int {
|
) -> std::os::raw::c_int {
|
||||||
let data = unsafe { std::mem::transmute(data) };
|
let data = unsafe { std::mem::transmute(data) };
|
||||||
wrap_f($f_name, t, y, ydot, data)
|
wrap_f($f_name, t, y, ydot, data)
|
||||||
@ -155,7 +160,7 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
|
|||||||
impl<UserData, const N: usize> Solver<UserData, N> {
|
impl<UserData, const N: usize> Solver<UserData, N> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
method: LinearMultistepMethod,
|
method: LinearMultistepMethod,
|
||||||
f: RhsFCtype<UserData>,
|
f: RhsFCtype<UserData, N>,
|
||||||
t0: Realtype,
|
t0: Realtype,
|
||||||
y0: &[Realtype; N],
|
y0: &[Realtype; N],
|
||||||
rtol: Realtype,
|
rtol: Realtype,
|
||||||
@ -196,7 +201,14 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
|||||||
user_data,
|
user_data,
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
|
let flag = unsafe {
|
||||||
|
cvode::CVodeInit(
|
||||||
|
mem.as_raw(),
|
||||||
|
Some(std::mem::transmute(f)),
|
||||||
|
t0,
|
||||||
|
res.y0.as_raw() as _,
|
||||||
|
)
|
||||||
|
};
|
||||||
check_flag_is_succes(flag, "CVodeInit")?;
|
check_flag_is_succes(flag, "CVodeInit")?;
|
||||||
}
|
}
|
||||||
match &res.atol {
|
match &res.atol {
|
||||||
@ -232,7 +244,11 @@ impl<UserData, const N: usize> Solver<UserData, N> {
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn step(&mut self, tout: Realtype, step_kind: StepKind) -> Result<(Realtype, &[Realtype; N])> {
|
pub fn step(
|
||||||
|
&mut self,
|
||||||
|
tout: Realtype,
|
||||||
|
step_kind: StepKind,
|
||||||
|
) -> Result<(Realtype, &[Realtype; N])> {
|
||||||
let mut tret = 0.;
|
let mut tret = 0.;
|
||||||
let flag = unsafe {
|
let flag = unsafe {
|
||||||
cvode::CVode(
|
cvode::CVode(
|
||||||
@ -260,12 +276,17 @@ impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn f(_t: super::Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], _data: &()) -> RhsResult {
|
fn f(
|
||||||
|
_t: super::Realtype,
|
||||||
|
y: &[Realtype; 2],
|
||||||
|
ydot: &mut [Realtype; 2],
|
||||||
|
_data: &(),
|
||||||
|
) -> RhsResult {
|
||||||
*ydot = [y[1], -y[0]];
|
*ydot = [y[1], -y[0]];
|
||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
wrap!(wrapped_f, f, ());
|
wrap!(wrapped_f, f, (), 2);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn create() {
|
fn create() {
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
|
use std::{
|
||||||
|
convert::TryInto,
|
||||||
|
intrinsics::transmute,
|
||||||
|
ops::{Deref, DerefMut},
|
||||||
|
ptr::NonNull,
|
||||||
|
};
|
||||||
|
|
||||||
use cvode_5_sys::{cvode::realtype, nvector_serial};
|
use cvode_5_sys::{cvode::realtype, nvector_serial};
|
||||||
|
|
||||||
|
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct NVectorSerialHeapAlloced<const SIZE: usize> {
|
pub struct NVectorSerial<const SIZE: usize> {
|
||||||
inner: NonNull<nvector_serial::_generic_N_Vector>,
|
inner: nvector_serial::_generic_N_Vector,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const SIZE: usize> Deref for NVectorSerialHeapAlloced<SIZE> {
|
impl<const SIZE: usize> NVectorSerial<SIZE> {
|
||||||
type Target = nvector_serial::_generic_N_Vector;
|
|
||||||
|
pub unsafe fn as_raw(&self) -> nvector_serial::N_Vector {
|
||||||
fn deref(&self) -> &Self::Target {
|
std::mem::transmute(&self.inner)
|
||||||
unsafe {self.inner.as_ref()}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
|
||||||
pub fn as_slice(&self) -> &[realtype; SIZE] {
|
pub fn as_slice(&self) -> &[realtype; SIZE] {
|
||||||
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
||||||
}
|
}
|
||||||
@ -25,11 +26,33 @@ impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
|||||||
pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] {
|
pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] {
|
||||||
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct NVectorSerialHeapAlloced<const SIZE: usize> {
|
||||||
|
inner: NonNull<NVectorSerial<SIZE>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const SIZE: usize> Deref for NVectorSerialHeapAlloced<SIZE> {
|
||||||
|
type Target = NVectorSerial<SIZE>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
unsafe { self.inner.as_ref() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const SIZE: usize> DerefMut for NVectorSerialHeapAlloced<SIZE> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
unsafe { self.inner.as_mut() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
let raw_c = unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) };
|
||||||
Self {
|
Self {
|
||||||
inner: NonNull::new(unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) })
|
inner: NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap(),
|
||||||
.unwrap(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,10 +61,6 @@ impl<const SIZE: usize> NVectorSerialHeapAlloced<SIZE> {
|
|||||||
res.as_slice_mut().copy_from_slice(data);
|
res.as_slice_mut().copy_from_slice(data);
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_raw(&self) -> nvector_serial::N_Vector {
|
|
||||||
self.inner.as_ptr()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const SIZE: usize> Drop for NVectorSerialHeapAlloced<SIZE> {
|
impl<const SIZE: usize> Drop for NVectorSerialHeapAlloced<SIZE> {
|
||||||
|
@ -7,7 +7,7 @@ fn main() {
|
|||||||
*ydot = [y[1], -y[0] * k];
|
*ydot = [y[1], -y[0] * k];
|
||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
wrap!(wrapped_f, f, Realtype);
|
wrap!(wrapped_f, f, Realtype, 2);
|
||||||
//initialize the solver
|
//initialize the solver
|
||||||
let mut solver = Solver::new(
|
let mut solver = Solver::new(
|
||||||
LinearMultistepMethod::ADAMS,
|
LinearMultistepMethod::ADAMS,
|
||||||
|
Loading…
Reference in New Issue
Block a user