Add user_data

This commit is contained in:
Arthur Carcano 2021-05-09 10:13:03 +02:00
parent 18c7c0ab02
commit a1231f7d85
3 changed files with 96 additions and 47 deletions

View File

@ -1,4 +1,4 @@
use std::convert::TryInto; use std::{convert::TryInto, pin::Pin};
use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull}; use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull};
use cvode::SUNMatrix; use cvode::SUNMatrix;
@ -47,12 +47,13 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
} }
} }
pub struct Solver<const N: usize> { pub struct Solver<UserData, const N: usize> {
mem: CvodeMemoryBlockNonNullPtr, mem: CvodeMemoryBlockNonNullPtr,
y0: NVectorSerial<N>, y0: NVectorSerial<N>,
sunmatrix: SUNMatrix, sunmatrix: SUNMatrix,
linsolver: SUNLinearSolver, linsolver: SUNLinearSolver,
_atol: AbsTolerance<N>, atol: AbsTolerance<N>,
user_data: Pin<Box<UserData>>,
} }
pub enum RhsResult { pub enum RhsResult {
@ -61,14 +62,14 @@ pub enum RhsResult {
NonRecoverableError(u8), NonRecoverableError(u8),
} }
type RhsF<const N: usize> = fn(F, &[F; N], &mut [F; N], *mut c_void) -> RhsResult; type RhsF<UserData, const N: usize> = fn(F, &[F; N], &mut [F; N], &UserData) -> RhsResult;
pub fn wrap_f<const N: usize>( pub fn wrap_f<UserData, const N: usize>(
f: RhsF<N>, f: RhsF<UserData, N>,
t: F, t: F,
y: CVector, y: CVector,
ydot: CVector, ydot: CVector,
data: *mut c_void, data: &UserData,
) -> c_int { ) -> c_int {
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) }; let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) }; let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) };
@ -82,19 +83,20 @@ pub fn wrap_f<const N: usize>(
#[macro_export] #[macro_export]
macro_rules! wrap { macro_rules! wrap {
($wrapped_f_name: ident, $f_name: ident) => { ($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
extern "C" fn $wrapped_f_name( extern "C" fn $wrapped_f_name(
t: F, t: F,
y: CVector, y: CVector,
ydot: CVector, ydot: CVector,
data: *mut std::ffi::c_void, data: *const $user_data,
) -> std::os::raw::c_int { ) -> std::os::raw::c_int {
let data = unsafe { std::mem::transmute(data) };
wrap_f($f_name, t, y, ydot, data) wrap_f($f_name, t, y, ydot, data)
} }
}; };
} }
type RhsFCtype = extern "C" fn(F, CVector, CVector, *mut c_void) -> c_int; type RhsFCtype<UserData> = extern "C" fn(F, CVector, CVector, *const UserData) -> c_int;
#[repr(u32)] #[repr(u32)]
pub enum StepKind { pub enum StepKind {
@ -138,23 +140,54 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
} }
} }
impl<const N: usize> Solver<N> { impl<UserData, const N: usize> Solver<UserData, N> {
pub fn new( pub fn new(
method: LinearMultistepMethod, method: LinearMultistepMethod,
f: RhsFCtype, f: RhsFCtype<UserData>,
t0: F, t0: F,
y0: &[F; N], y0: &[F; N],
rtol: F, rtol: F,
atol: AbsTolerance<N>, atol: AbsTolerance<N>,
user_data: UserData,
) -> Result<Self> { ) -> Result<Self> {
assert_eq!(y0.len(), N); assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) }; let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
let mem: CvodeMemoryBlockNonNullPtr = check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into(); };
let y0 = NVectorSerial::new_from(y0); let y0 = NVectorSerial::new_from(y0);
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(f), t0, y0.as_raw() as _) }; let matrix = {
let matrix = unsafe {
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
N.try_into().unwrap(),
N.try_into().unwrap(),
)
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe {
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(
y0.as_raw() as _,
matrix.as_ptr() as _,
)
};
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(user_data);
let res = Solver {
mem,
y0,
sunmatrix: matrix.as_ptr() as _,
linsolver: linsolver.as_ptr() as _,
atol,
user_data,
};
{
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
check_flag_is_succes(flag, "CVodeInit")?; check_flag_is_succes(flag, "CVodeInit")?;
match &atol { }
match &res.atol {
&AbsTolerance::Scalar(atol) => { &AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) }; let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?; check_flag_is_succes(flag, "CVodeSStolerances")?;
@ -165,27 +198,26 @@ impl<const N: usize> Solver<N> {
check_flag_is_succes(flag, "CVodeSVtolerances")?; check_flag_is_succes(flag, "CVodeSVtolerances")?;
} }
} }
let matrix = unsafe { {
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix( let flag = unsafe {
N.try_into().unwrap(), cvode::CVodeSetLinearSolver(
N.try_into().unwrap(), mem.as_raw(),
linsolver.as_ptr() as _,
matrix.as_ptr() as _,
) )
}; };
check_non_null(matrix, "SUNDenseMatrix")?;
let linsolver = unsafe {
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(y0.as_raw() as _, matrix as _)
};
check_non_null(linsolver, "SUNDenseLinearSolver")?;
let flag =
unsafe { cvode::CVodeSetLinearSolver(mem.as_raw(), linsolver as _, matrix as _) };
check_flag_is_succes(flag, "CVodeSetLinearSolver")?; check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
Ok(Solver { }
mem, {
y0, let flag = unsafe {
sunmatrix: matrix as _, cvode::CVodeSetUserData(
linsolver: linsolver as _, mem.as_raw(),
_atol: atol as _, std::mem::transmute(res.user_data.as_ref().get_ref()),
}) )
};
check_flag_is_succes(flag, "CVodeSetUserData")?;
}
Ok(res)
} }
pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> { pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> {
@ -204,7 +236,7 @@ impl<const N: usize> Solver<N> {
} }
} }
impl<const N: usize> Drop for Solver<N> { impl<UserData, const N: usize> Drop for Solver<UserData, N> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) } unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode::SUNLinSolFree(self.linsolver) }; unsafe { cvode::SUNLinSolFree(self.linsolver) };
@ -216,16 +248,24 @@ impl<const N: usize> Drop for Solver<N> {
mod tests { mod tests {
use super::*; use super::*;
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult { fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: &()) -> RhsResult {
*ydot = [y[1], -y[0]]; *ydot = [y[1], -y[0]];
RhsResult::Ok RhsResult::Ok
} }
wrap!(wrapped_f, f); wrap!(wrapped_f, f, ());
#[test] #[test]
fn create() { fn create() {
let y0 = [0., 1.]; let y0 = [0., 1.];
let _solver = Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, AbsTolerance::Scalar(1e-4)); let _solver = Solver::new(
LinearMultistepMethod::ADAMS,
wrapped_f,
0.,
&y0,
1e-4,
AbsTolerance::Scalar(1e-4),
(),
);
} }
} }

View File

@ -1,4 +1,4 @@
use std::{convert::TryInto, intrinsics::transmute, ptr::NonNull}; use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
use cvode_5_sys::{cvode::realtype, nvector_serial}; use cvode_5_sys::{cvode::realtype, nvector_serial};
@ -8,6 +8,14 @@ pub struct NVectorSerial<const SIZE: usize> {
inner: NonNull<nvector_serial::_generic_N_Vector>, inner: NonNull<nvector_serial::_generic_N_Vector>,
} }
impl<const SIZE: usize> Deref for NVectorSerial<SIZE> {
type Target = nvector_serial::_generic_N_Vector;
fn deref(&self) -> &Self::Target {
unsafe {self.inner.as_ref()}
}
}
impl<const SIZE: usize> NVectorSerial<SIZE> { impl<const SIZE: usize> NVectorSerial<SIZE> {
pub fn as_ref(&self) -> &[realtype; SIZE] { pub fn as_ref(&self) -> &[realtype; SIZE] {
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) } unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }

View File

@ -5,11 +5,11 @@ use cvode_wrap::*;
fn main() { fn main() {
let y0 = [0., 1.]; let y0 = [0., 1.];
//define the right-hand-side //define the right-hand-side
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult { fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], k: &F) -> RhsResult {
*ydot = [y[1], -y[0] / 10.]; *ydot = [y[1], -y[0] * k];
RhsResult::Ok RhsResult::Ok
} }
wrap!(wrapped_f, f); wrap!(wrapped_f, f, F);
//initialize the solver //initialize the solver
let mut solver = Solver::new( let mut solver = Solver::new(
LinearMultistepMethod::ADAMS, LinearMultistepMethod::ADAMS,
@ -18,6 +18,7 @@ fn main() {
&y0, &y0,
1e-4, 1e-4,
AbsTolerance::scalar(1e-4), AbsTolerance::scalar(1e-4),
1e-2
) )
.unwrap(); .unwrap();
//and solve //and solve