Add user_data
This commit is contained in:
parent
18c7c0ab02
commit
a1231f7d85
@ -1,4 +1,4 @@
|
|||||||
use std::convert::TryInto;
|
use std::{convert::TryInto, pin::Pin};
|
||||||
use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull};
|
use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull};
|
||||||
|
|
||||||
use cvode::SUNMatrix;
|
use cvode::SUNMatrix;
|
||||||
@ -47,12 +47,13 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Solver<const N: usize> {
|
pub struct Solver<UserData, const N: usize> {
|
||||||
mem: CvodeMemoryBlockNonNullPtr,
|
mem: CvodeMemoryBlockNonNullPtr,
|
||||||
y0: NVectorSerial<N>,
|
y0: NVectorSerial<N>,
|
||||||
sunmatrix: SUNMatrix,
|
sunmatrix: SUNMatrix,
|
||||||
linsolver: SUNLinearSolver,
|
linsolver: SUNLinearSolver,
|
||||||
_atol: AbsTolerance<N>,
|
atol: AbsTolerance<N>,
|
||||||
|
user_data: Pin<Box<UserData>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum RhsResult {
|
pub enum RhsResult {
|
||||||
@ -61,14 +62,14 @@ pub enum RhsResult {
|
|||||||
NonRecoverableError(u8),
|
NonRecoverableError(u8),
|
||||||
}
|
}
|
||||||
|
|
||||||
type RhsF<const N: usize> = fn(F, &[F; N], &mut [F; N], *mut c_void) -> RhsResult;
|
type RhsF<UserData, const N: usize> = fn(F, &[F; N], &mut [F; N], &UserData) -> RhsResult;
|
||||||
|
|
||||||
pub fn wrap_f<const N: usize>(
|
pub fn wrap_f<UserData, const N: usize>(
|
||||||
f: RhsF<N>,
|
f: RhsF<UserData, N>,
|
||||||
t: F,
|
t: F,
|
||||||
y: CVector,
|
y: CVector,
|
||||||
ydot: CVector,
|
ydot: CVector,
|
||||||
data: *mut c_void,
|
data: &UserData,
|
||||||
) -> c_int {
|
) -> c_int {
|
||||||
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
|
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
|
||||||
let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) };
|
let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) };
|
||||||
@ -82,19 +83,20 @@ pub fn wrap_f<const N: usize>(
|
|||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! wrap {
|
macro_rules! wrap {
|
||||||
($wrapped_f_name: ident, $f_name: ident) => {
|
($wrapped_f_name: ident, $f_name: ident, $user_data: ty) => {
|
||||||
extern "C" fn $wrapped_f_name(
|
extern "C" fn $wrapped_f_name(
|
||||||
t: F,
|
t: F,
|
||||||
y: CVector,
|
y: CVector,
|
||||||
ydot: CVector,
|
ydot: CVector,
|
||||||
data: *mut std::ffi::c_void,
|
data: *const $user_data,
|
||||||
) -> std::os::raw::c_int {
|
) -> std::os::raw::c_int {
|
||||||
|
let data = unsafe { std::mem::transmute(data) };
|
||||||
wrap_f($f_name, t, y, ydot, data)
|
wrap_f($f_name, t, y, ydot, data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
type RhsFCtype = extern "C" fn(F, CVector, CVector, *mut c_void) -> c_int;
|
type RhsFCtype<UserData> = extern "C" fn(F, CVector, CVector, *const UserData) -> c_int;
|
||||||
|
|
||||||
#[repr(u32)]
|
#[repr(u32)]
|
||||||
pub enum StepKind {
|
pub enum StepKind {
|
||||||
@ -138,23 +140,54 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const N: usize> Solver<N> {
|
impl<UserData, const N: usize> Solver<UserData, N> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
method: LinearMultistepMethod,
|
method: LinearMultistepMethod,
|
||||||
f: RhsFCtype,
|
f: RhsFCtype<UserData>,
|
||||||
t0: F,
|
t0: F,
|
||||||
y0: &[F; N],
|
y0: &[F; N],
|
||||||
rtol: F,
|
rtol: F,
|
||||||
atol: AbsTolerance<N>,
|
atol: AbsTolerance<N>,
|
||||||
|
user_data: UserData,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
assert_eq!(y0.len(), N);
|
assert_eq!(y0.len(), N);
|
||||||
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
|
let mem: CvodeMemoryBlockNonNullPtr = {
|
||||||
let mem: CvodeMemoryBlockNonNullPtr =
|
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
|
||||||
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into();
|
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
|
||||||
|
};
|
||||||
let y0 = NVectorSerial::new_from(y0);
|
let y0 = NVectorSerial::new_from(y0);
|
||||||
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(f), t0, y0.as_raw() as _) };
|
let matrix = {
|
||||||
check_flag_is_succes(flag, "CVodeInit")?;
|
let matrix = unsafe {
|
||||||
match &atol {
|
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
|
||||||
|
N.try_into().unwrap(),
|
||||||
|
N.try_into().unwrap(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
check_non_null(matrix, "SUNDenseMatrix")?
|
||||||
|
};
|
||||||
|
let linsolver = {
|
||||||
|
let linsolver = unsafe {
|
||||||
|
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(
|
||||||
|
y0.as_raw() as _,
|
||||||
|
matrix.as_ptr() as _,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
check_non_null(linsolver, "SUNDenseLinearSolver")?
|
||||||
|
};
|
||||||
|
let user_data = Box::pin(user_data);
|
||||||
|
let res = Solver {
|
||||||
|
mem,
|
||||||
|
y0,
|
||||||
|
sunmatrix: matrix.as_ptr() as _,
|
||||||
|
linsolver: linsolver.as_ptr() as _,
|
||||||
|
atol,
|
||||||
|
user_data,
|
||||||
|
};
|
||||||
|
{
|
||||||
|
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(std::mem::transmute(f)), t0, res.y0.as_raw() as _) };
|
||||||
|
check_flag_is_succes(flag, "CVodeInit")?;
|
||||||
|
}
|
||||||
|
match &res.atol {
|
||||||
&AbsTolerance::Scalar(atol) => {
|
&AbsTolerance::Scalar(atol) => {
|
||||||
let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) };
|
let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), rtol, atol) };
|
||||||
check_flag_is_succes(flag, "CVodeSStolerances")?;
|
check_flag_is_succes(flag, "CVodeSStolerances")?;
|
||||||
@ -165,27 +198,26 @@ impl<const N: usize> Solver<N> {
|
|||||||
check_flag_is_succes(flag, "CVodeSVtolerances")?;
|
check_flag_is_succes(flag, "CVodeSVtolerances")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let matrix = unsafe {
|
{
|
||||||
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
|
let flag = unsafe {
|
||||||
N.try_into().unwrap(),
|
cvode::CVodeSetLinearSolver(
|
||||||
N.try_into().unwrap(),
|
mem.as_raw(),
|
||||||
)
|
linsolver.as_ptr() as _,
|
||||||
};
|
matrix.as_ptr() as _,
|
||||||
check_non_null(matrix, "SUNDenseMatrix")?;
|
)
|
||||||
let linsolver = unsafe {
|
};
|
||||||
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(y0.as_raw() as _, matrix as _)
|
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
|
||||||
};
|
}
|
||||||
check_non_null(linsolver, "SUNDenseLinearSolver")?;
|
{
|
||||||
let flag =
|
let flag = unsafe {
|
||||||
unsafe { cvode::CVodeSetLinearSolver(mem.as_raw(), linsolver as _, matrix as _) };
|
cvode::CVodeSetUserData(
|
||||||
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
|
mem.as_raw(),
|
||||||
Ok(Solver {
|
std::mem::transmute(res.user_data.as_ref().get_ref()),
|
||||||
mem,
|
)
|
||||||
y0,
|
};
|
||||||
sunmatrix: matrix as _,
|
check_flag_is_succes(flag, "CVodeSetUserData")?;
|
||||||
linsolver: linsolver as _,
|
}
|
||||||
_atol: atol as _,
|
Ok(res)
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> {
|
pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> {
|
||||||
@ -204,7 +236,7 @@ impl<const N: usize> Solver<N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const N: usize> Drop for Solver<N> {
|
impl<UserData, const N: usize> Drop for Solver<UserData, N> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) }
|
unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) }
|
||||||
unsafe { cvode::SUNLinSolFree(self.linsolver) };
|
unsafe { cvode::SUNLinSolFree(self.linsolver) };
|
||||||
@ -216,16 +248,24 @@ impl<const N: usize> Drop for Solver<N> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
|
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: &()) -> RhsResult {
|
||||||
*ydot = [y[1], -y[0]];
|
*ydot = [y[1], -y[0]];
|
||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
wrap!(wrapped_f, f);
|
wrap!(wrapped_f, f, ());
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn create() {
|
fn create() {
|
||||||
let y0 = [0., 1.];
|
let y0 = [0., 1.];
|
||||||
let _solver = Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, AbsTolerance::Scalar(1e-4));
|
let _solver = Solver::new(
|
||||||
|
LinearMultistepMethod::ADAMS,
|
||||||
|
wrapped_f,
|
||||||
|
0.,
|
||||||
|
&y0,
|
||||||
|
1e-4,
|
||||||
|
AbsTolerance::Scalar(1e-4),
|
||||||
|
(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use std::{convert::TryInto, intrinsics::transmute, ptr::NonNull};
|
use std::{convert::TryInto, intrinsics::transmute, ops::Deref, ptr::NonNull};
|
||||||
|
|
||||||
use cvode_5_sys::{cvode::realtype, nvector_serial};
|
use cvode_5_sys::{cvode::realtype, nvector_serial};
|
||||||
|
|
||||||
@ -8,6 +8,14 @@ pub struct NVectorSerial<const SIZE: usize> {
|
|||||||
inner: NonNull<nvector_serial::_generic_N_Vector>,
|
inner: NonNull<nvector_serial::_generic_N_Vector>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<const SIZE: usize> Deref for NVectorSerial<SIZE> {
|
||||||
|
type Target = nvector_serial::_generic_N_Vector;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
unsafe {self.inner.as_ref()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<const SIZE: usize> NVectorSerial<SIZE> {
|
impl<const SIZE: usize> NVectorSerial<SIZE> {
|
||||||
pub fn as_ref(&self) -> &[realtype; SIZE] {
|
pub fn as_ref(&self) -> &[realtype; SIZE] {
|
||||||
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
|
||||||
|
@ -5,11 +5,11 @@ use cvode_wrap::*;
|
|||||||
fn main() {
|
fn main() {
|
||||||
let y0 = [0., 1.];
|
let y0 = [0., 1.];
|
||||||
//define the right-hand-side
|
//define the right-hand-side
|
||||||
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
|
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], k: &F) -> RhsResult {
|
||||||
*ydot = [y[1], -y[0] / 10.];
|
*ydot = [y[1], -y[0] * k];
|
||||||
RhsResult::Ok
|
RhsResult::Ok
|
||||||
}
|
}
|
||||||
wrap!(wrapped_f, f);
|
wrap!(wrapped_f, f, F);
|
||||||
//initialize the solver
|
//initialize the solver
|
||||||
let mut solver = Solver::new(
|
let mut solver = Solver::new(
|
||||||
LinearMultistepMethod::ADAMS,
|
LinearMultistepMethod::ADAMS,
|
||||||
@ -18,6 +18,7 @@ fn main() {
|
|||||||
&y0,
|
&y0,
|
||||||
1e-4,
|
1e-4,
|
||||||
AbsTolerance::scalar(1e-4),
|
AbsTolerance::scalar(1e-4),
|
||||||
|
1e-2
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
//and solve
|
//and solve
|
||||||
|
Loading…
Reference in New Issue
Block a user