Add user_data
This commit is contained in:
parent
18c7c0ab02
commit
a1231f7d85
@ -1,4 +1,4 @@
|
||||
use std::convert::TryInto;
|
||||
use std::{convert::TryInto, pin::Pin};
|
||||
use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull};
|
||||
|
||||
use cvode::SUNMatrix;
|
||||
@ -47,12 +47,13 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Solver<const N: usize> {
|
||||
pub struct Solver<UserData, const N: usize> {
|
||||
mem: CvodeMemoryBlockNonNullPtr,
|
||||
y0: NVectorSerial<N>,
|
||||
sunmatrix: SUNMatrix,
|
||||
linsolver: SUNLinearSolver,
|
||||
_atol: AbsTolerance<N>,
|
||||
atol: AbsTolerance<N>,
|
||||
user_data: Pin<Box<UserData>>,
|
||||
}
|
||||
|
||||
pub enum RhsResult {
|
||||
@ -61,14 +62,14 @@ pub enum RhsResult {
|
||||
NonRecoverableError(u8),
|
||||
}
|
||||
|
||||
type RhsF<const N: usize> = fn(F, &[F; N], &mut [F; N], *mut c_void) -> RhsResult;
|
||||
type RhsF<UserData, const N: usize> = fn(F, &[F; N], &mut [F; N], &UserData) -> RhsResult;
|
||||
|
||||
pub fn wrap_f<const N: usize>(
|
||||
f: RhsF<N>,
|
||||
pub fn wrap_f<UserData, const N: usize>(
|
||||
f: RhsF<UserData, N>,
|
||||
t: F,
|
||||
y: CVector,
|
||||
ydot: CVector,
|
||||
data: *mut c_void,
|
||||
data: &UserData,
|
||||
) -> c_int {
|
||||
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
|
||||
let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) };
|
||||
@ -82,19 +83,20 @@ pub fn wrap_f<const N: usize>(
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! wrap {
|
||||
($wrapped_f_name: ident, $f_name: ident) => {
|
||||
($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
|
||||
extern "C" fn $wrapped_f_name(
|
||||
t: F,
|
||||
y: CVector,
|
||||
ydot: CVector,
|
||||
data: *mut std::ffi::c_void,
|
||||
data: *const $user_data,
|
||||
) -> std::os::raw::c_int {
|
||||
let data = unsafe { std::mem::transmute(data) };
|
||||
wrap_f($f_name, t, y, ydot, data)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
type RhsFCtype = extern "C" fn(F, CVector, CVector, *mut c_void) -> c_int;
|
||||
type RhsFCtype<UserData> = extern "C" fn(F, CVector, CVector, *const UserData) -> c_int;
|
||||
|
||||
#[repr(u32)]
|
||||
pub enum StepKind {
|
||||
@ -138,23 +140,54 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Solver<N> {
|
||||
impl<UserData, const N: usize> Solver<UserData, N> {
|
||||
pub fn new(
|
||||
method: LinearMultistepMethod,
|
||||
f: RhsFCtype,
|
||||
f: RhsFCtype<UserData>,
|
||||
t0: F,
|
||||
y0: &[F; N],
|
||||
rtol: F,
|
||||
atol: AbsTolerance<N>,
|
||||
user_data: UserData,
|
||||
) -> Result<Self> {
|
||||
assert_eq!(y0.len(), N);
|
||||
let mem: CvodeMemoryBlockNonNullPtr = {
|
||||
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
|
||||
let mem: CvodeMemoryBlockNonNullPtr =
|
||||
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into();
|
||||
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
|
||||
};
|
||||
let y0 = NVectorSerial::new_from(y0);
|
||||
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(f), t0, y0.as_raw() as _) };
|
||||
let matrix = {
|
||||
let matrix = unsafe {
|
||||
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
|
||||
N.try_into().unwrap(),
|
||||
N.try_into().unwrap(),
|
||||
)
|
||||
};
|
||||
check_non_null(matrix, "SUNDenseMatrix")?
|
||||
};
|
||||
let linsolver = {
|
||||
let linsolver = unsafe {
|
||||
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(
|
||||
y0.as_raw() as _,
|
||||
matrix.as_ptr() as _,
|
||||
)
|
||||
};
|
||||
check_non_null(linsolver, "SUNDenseLinearSolver")?
|
||||
};
|
||||
let user_data = Box::pin(user_data);
|
||||
let res = Solver {
|
||||
mem,
|
||||
y0,
|
||||
sunmatrix: matrix.as_ptr() as _,
|
||||
linsolver: linsolver.as_ptr() as _,
|
||||
atol,
|
||||
user_data,
|
||||
};
|
||||
{
|
||||
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
|
||||
check_flag_is_succes(flag, "CVodeInit")?;
|
||||
match &atol {
|
||||
}
|
||||
match &res.atol {
|
||||
&AbsTolerance::Scalar(atol) => {
|
||||
let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) };
|
||||
check_flag_is_succes(flag, "CVodeSStolerances")?;
|
||||
@ -165,27 +198,26 @@ impl<const N: usize> Solver<N> {
|
||||
check_flag_is_succes(flag, "CVodeSVtolerances")?;
|
||||
}
|
||||
}
|
||||
let matrix = unsafe {
|
||||
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
|
||||
N.try_into().unwrap(),
|
||||
N.try_into().unwrap(),
|
||||
{
|
||||
let flag = unsafe {
|
||||
cvode::CVodeSetLinearSolver(
|
||||
mem.as_raw(),
|
||||
linsolver.as_ptr() as _,
|
||||
matrix.as_ptr() as _,
|
||||
)
|
||||
};
|
||||
check_non_null(matrix, "SUNDenseMatrix")?;
|
||||
let linsolver = unsafe {
|
||||
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(y0.as_raw() as _, matrix as _)
|
||||
};
|
||||
check_non_null(linsolver, "SUNDenseLinearSolver")?;
|
||||
let flag =
|
||||
unsafe { cvode::CVodeSetLinearSolver(mem.as_raw(), linsolver as _, matrix as _) };
|
||||
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
|
||||
Ok(Solver {
|
||||
mem,
|
||||
y0,
|
||||
sunmatrix: matrix as _,
|
||||
linsolver: linsolver as _,
|
||||
_atol: atol as _,
|
||||
})
|
||||
}
|
||||
{
|
||||
let flag = unsafe {
|
||||
cvode::CVodeSetUserData(
|
||||
mem.as_raw(),
|
||||
std::mem::transmute(res.user_data.as_ref().get_ref()),
|
||||
)
|
||||
};
|
||||
check_flag_is_succes(flag, "CVodeSetUserData")?;
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> {
|
||||
@ -204,7 +236,7 @@ impl<const N: usize> Solver<N> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Drop for Solver<N> {
|
||||
impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) }
|
||||
unsafe { cvode::SUNLinSolFree(self.linsolver) };
|
||||
@ -216,16 +248,24 @@ impl<const N: usize> Drop for Solver<N> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
|
||||
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: &()) -> RhsResult {
|
||||
*ydot = [y[1], -y[0]];
|
||||
RhsResult::Ok
|
||||
}
|
||||
|
||||
wrap!(wrapped_f, f);
|
||||
wrap!(wrapped_f, f, ());
|
||||
|
||||
#[test]
|
||||
fn create() {
|
||||
let y0 = [0., 1.];
|
||||
let _solver = Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, AbsTolerance::Scalar(1e-4));
|
||||
let _solver = Solver::new(
|
||||
LinearMultistepMethod::ADAMS,
|
||||
wrapped_f,
|
||||
0.,
|
||||
&y0,
|
||||
1e-4,
|
||||
AbsTolerance::Scalar(1e-4),
|
||||
(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{convert::TryInto, intrinsics::transmute, ptr::NonNull};
|
||||
use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
|
||||
|
||||
use cvode_5_sys::{cvode::realtype, nvector_serial};
|
||||
|
||||
@ -8,6 +8,14 @@ pub struct NVectorSerial<const SIZE: usize> {
|
||||
inner: NonNull<nvector_serial::_generic_N_Vector>,
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> Deref for NVectorSerial<SIZE> {
|
||||
type Target = nvector_serial::_generic_N_Vector;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe {self.inner.as_ref()}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> NVectorSerial<SIZE> {
|
||||
pub fn as_ref(&self) -> &[realtype; SIZE] {
|
||||
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
||||
|
@ -5,11 +5,11 @@ use cvode_wrap::*;
|
||||
fn main() {
|
||||
let y0 = [0., 1.];
|
||||
//define the right-hand-side
|
||||
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
|
||||
*ydot = [y[1], -y[0] / 10.];
|
||||
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], k: &F) -> RhsResult {
|
||||
*ydot = [y[1], -y[0] * k];
|
||||
RhsResult::Ok
|
||||
}
|
||||
wrap!(wrapped_f, f);
|
||||
wrap!(wrapped_f, f, F);
|
||||
//initialize the solver
|
||||
let mut solver = Solver::new(
|
||||
LinearMultistepMethod::ADAMS,
|
||||
@ -18,6 +18,7 @@ fn main() {
|
||||
&y0,
|
||||
1e-4,
|
||||
AbsTolerance::scalar(1e-4),
|
||||
1e-2
|
||||
)
|
||||
.unwrap();
|
||||
//and solve
|
||||
|
Loading…
Reference in New Issue
Block a user