This commit is contained in:
Arthur Carcano 2021-05-07 18:29:56 +02:00
commit 1ea23c85a2
15 changed files with 22462 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
target/
Cargo.lock

6
Cargo.toml Normal file
View File

@ -0,0 +1,6 @@
[workspace]
members = [
"cvode-5-sys",
"cvode-wrap",
"test-solver"
]

9
cvode-5-sys/Cargo.toml Normal file
View File

@ -0,0 +1,9 @@
[package]
name = "cvode-5-sys"
version = "0.1.0"
authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

5
cvode-5-sys/build.rs Normal file
View File

@ -0,0 +1,5 @@
fn main() -> () {
for lib in &["sundials_cvodes", "sundials_nvecserial"] {
println!("cargo:rustc-link-lib={}", lib);
}
}

6591
cvode-5-sys/src/cvode.rs Normal file

File diff suppressed because it is too large Load Diff

6
cvode-5-sys/src/lib.rs Normal file
View File

@ -0,0 +1,6 @@
#![allow(non_upper_case_globals, non_camel_case_types, non_snake_case, improper_ctypes)]
pub mod cvode;
pub mod nvector_serial;
pub mod sunmatrix_dense;
pub mod sunlinsol_dense;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

10
cvode-wrap/Cargo.toml Normal file
View File

@ -0,0 +1,10 @@
[package]
name = "cvode-wrap"
version = "0.1.0"
authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
cvode-5-sys = {path = "../cvode-5-sys"}

204
cvode-wrap/src/lib.rs Normal file
View File

@ -0,0 +1,204 @@
use std::convert::TryInto;
use std::{ffi::c_void, intrinsics::transmute, os::raw::c_int, ptr::NonNull};
use cvode::SUNMatrix;
use cvode_5_sys::{
cvode::{self, realtype, SUNLinearSolver},
nvector_serial::N_VGetArrayPointer,
};
mod nvector;
pub use nvector::NVectorSerial;
pub type F = realtype;
pub type CVector = cvode::N_Vector;
#[repr(u32)]
#[derive(Debug)]
pub enum LinearMultistepMethod {
ADAMS = cvode::CV_ADAMS,
BDF = cvode::CV_BDF,
}
#[repr(C)]
struct CvodeMemoryBlock {
_private: [u8; 0],
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy)]
struct CvodeMemoryBlockNonNullPtr {
ptr: NonNull<CvodeMemoryBlock>,
}
impl CvodeMemoryBlockNonNullPtr {
fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
Self { ptr }
}
fn as_raw(self) -> *mut c_void {
self.ptr.as_ptr() as *mut c_void
}
}
impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
Self::new(x)
}
}
pub struct Solver<const N: usize> {
mem: CvodeMemoryBlockNonNullPtr,
_y0: NVectorSerial<N>,
_sunmatrix: SUNMatrix,
_linsolver: SUNLinearSolver,
}
pub enum RhsResult {
Ok,
RecoverableError(u8),
NonRecoverableError(u8),
}
type RhsF<const N: usize> = fn(F, &[F; N], &mut [F; N], *mut c_void) -> RhsResult;
pub fn wrap_f<const N: usize>(
f: RhsF<N>,
t: F,
y: CVector,
ydot: CVector,
data: *mut c_void,
) -> c_int {
let y = unsafe { transmute(N_VGetArrayPointer(y as _)) };
let ydot = unsafe { transmute(N_VGetArrayPointer(ydot as _)) };
let res = f(t, y, ydot, data);
match res {
RhsResult::Ok => 0,
RhsResult::RecoverableError(e) => e as c_int,
RhsResult::NonRecoverableError(e) => -(e as c_int),
}
}
#[macro_export]
macro_rules! wrap {
($wrapped_f_name: ident, $f_name: ident) => {
extern "C" fn $wrapped_f_name(
t: F,
y: CVector,
ydot: CVector,
data: *mut std::ffi::c_void,
) -> std::os::raw::c_int {
wrap_f($f_name, t, y, ydot, data)
}
};
}
type RhsFCtype = extern "C" fn(F, CVector, CVector, *mut c_void) -> c_int;
#[repr(u32)]
pub enum StepKind {
Normal = cvode::CV_NORMAL,
OneStep = cvode::CV_ONE_STEP,
}
#[derive(Debug)]
pub enum Error {
NullPointerError { func_id: &'static str },
ErrorCode { func_id: &'static str, flag: c_int },
}
pub type Result<T> = std::result::Result<T, Error>;
fn check_non_null<T>(ptr: *mut T, func_id: &'static str) -> Result<NonNull<T>> {
NonNull::new(ptr).ok_or_else(|| Error::NullPointerError { func_id })
}
fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
if flag == cvode::CV_SUCCESS as i32 {
Ok(())
} else {
Err(Error::ErrorCode { flag, func_id })
}
}
impl<const N: usize> Solver<N> {
pub fn new(
method: LinearMultistepMethod,
f: RhsFCtype,
t0: F,
y0: &[F; N],
atol: F,
rtol: F,
) -> Result<Self> {
assert_eq!(y0.len(), N);
let mem_maybenull = unsafe { cvode::CVodeCreate(method as c_int) };
let mem: CvodeMemoryBlockNonNullPtr =
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into();
let y0 = NVectorSerial::new_from(y0);
let flag = unsafe { cvode::CVodeInit(mem.as_raw(), Some(f), t0, y0.as_raw() as _) };
check_flag_is_succes(flag, "CVodeInit")?;
let flag = unsafe { cvode::CVodeSStolerances(mem.as_raw(), atol, rtol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
let matrix = unsafe {
cvode_5_sys::sunmatrix_dense::SUNDenseMatrix(
N.try_into().unwrap(),
N.try_into().unwrap(),
)
};
check_non_null(matrix, "SUNDenseMatrix")?;
let linsolver = unsafe {
cvode_5_sys::sunlinsol_dense::SUNDenseLinearSolver(y0.as_raw() as _, matrix as _)
};
check_non_null(linsolver, "SUNDenseLinearSolver")?;
let flag =
unsafe { cvode::CVodeSetLinearSolver(mem.as_raw(), linsolver as _, matrix as _) };
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
Ok(Solver {
mem,
_y0: y0,
_sunmatrix: matrix as _,
_linsolver: linsolver as _,
})
}
pub fn step(&mut self, tout: F, step_kind: StepKind) -> Result<(F, &[F; N])> {
let mut tret = 0.;
let flag = unsafe {
cvode::CVode(
self.mem.as_raw(),
tout,
self._y0.as_raw() as _,
&mut tret,
step_kind as c_int,
)
};
check_flag_is_succes(flag, "CVode")?;
Ok((tret, self._y0.as_ref()))
}
}
impl<const N: usize> Drop for Solver<N> {
fn drop(&mut self) {
unsafe { cvode::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode::SUNLinSolFree(self._linsolver) };
unsafe { cvode::SUNMatDestroy(self._sunmatrix) };
}
}
#[cfg(test)]
mod tests {
use super::*;
fn f(_t: super::F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
*ydot = [y[1], -y[0]];
RhsResult::Ok
}
wrap!(wrapped_f, f);
#[test]
fn create() {
let y0 = [0., 1.];
let _solver = Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, 1e-4);
}
}

51
cvode-wrap/src/nvector.rs Normal file
View File

@ -0,0 +1,51 @@
use std::{convert::TryInto, intrinsics::transmute, ptr::NonNull};
use cvode_5_sys::{cvode::realtype, nvector_serial};
#[repr(transparent)]
#[derive(Debug)]
pub struct NVectorSerial<const SIZE: usize> {
inner: NonNull<nvector_serial::_generic_N_Vector>,
}
impl<const SIZE: usize> NVectorSerial<SIZE> {
pub fn as_ref(&self) -> &[realtype; SIZE] {
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
}
pub fn as_mut(&mut self) -> &mut [realtype; SIZE] {
unsafe { transmute(nvector_serial::N_VGetArrayPointer_Serial(self.as_raw())) }
}
pub fn new() -> Self {
Self {
inner: NonNull::new(unsafe { nvector_serial::N_VNew_Serial(SIZE.try_into().unwrap()) })
.unwrap(),
}
}
pub fn new_from(data: &[realtype; SIZE]) -> Self {
let mut res = Self::new();
res.as_mut().copy_from_slice(data);
res
}
pub fn make(data: &mut [realtype; SIZE]) -> Self {
Self {
inner: NonNull::new(unsafe {
nvector_serial::N_VMake_Serial(SIZE.try_into().unwrap(), data.as_mut_ptr())
})
.unwrap(),
}
}
pub fn as_raw(&self) -> nvector_serial::N_Vector {
self.inner.as_ptr()
}
}
impl<const SIZE: usize> Drop for NVectorSerial<SIZE> {
fn drop(&mut self) {
unsafe { nvector_serial::N_VDestroy(self.as_raw()) }
}
}

10
test-solver/Cargo.toml Normal file
View File

@ -0,0 +1,10 @@
[package]
name = "test-solver"
version = "0.1.0"
authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
cvode-wrap = {path = "../cvode-wrap"}

7
test-solver/plot.py Normal file
View File

@ -0,0 +1,7 @@
import pandas as pd
import sys
import matplotlib.pyplot as plt
df = pd.read_csv(sys.stdin,names=['t','x',r'\dot{x}'],index_col='t')
ax = df.plot()
plt.show()

23
test-solver/src/main.rs Normal file
View File

@ -0,0 +1,23 @@
use std::ffi::c_void;
use cvode_wrap::*;
fn main() {
let y0 = [0., 1.];
//define the right-hand-side
fn f(_t: F, y: &[F; 2], ydot: &mut [F; 2], _data: *mut c_void) -> RhsResult {
*ydot = [y[1], -y[0] / 10.];
RhsResult::Ok
}
wrap!(wrapped_f, f);
//initialize the solver
let mut solver =
Solver::new(LinearMultistepMethod::ADAMS, wrapped_f, 0., &y0, 1e-4, 1e-4).unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
println!("{},{},{}", t, x, xdot);
}
}