Clean up and add doc

This commit is contained in:
Arthur Carcano 2021-06-10 19:34:57 +02:00
parent b48444ef6c
commit 66f0d5fbf8
12 changed files with 211 additions and 248 deletions

View File

@ -1,5 +1,17 @@
[workspace] [package]
members = [ name = "cvode-wrap"
"cvode-wrap", version = "0.1.0"
"example" authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
] edition = "2018"
license = "BSD-3"
description="A wrapper around cvode and cvodeS from sundials, allowing to solve ordinary differential equations (ODEs) with or without their sensitivities."
repository="https://gitlab.inria.fr/InBio/Public/cvode-rust-wrap/"
readme="Readme.md"
keywords=["sundials","cvode","cvodes","ode","sensitivities"]
categories=["science","simulation","api-bindings"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sundials-sys = {version="0.2.1", default-features=false, features=["cvodes"]}
array-init = "2.0"

View File

@ -1,39 +1,5 @@
A wrapper around the sundials ODE solver. A wrapper around the sundials ODE solver.
# Example # Examples
An oscillatory 2-D system. Examples computing the behavior of an oscillatory system defined by `x'' = -k * x` are included in the examples/ directory. In the example computing the sensitivities, sensitivities are computed with respect to `x(0)`, `x'(0)` and `k`.
```rust
use cvode_wrap::*;
let y0 = [0., 1.];
// define the right-hand-side as a rust function of type RhsF<Realtype, 2>
fn f(
_t: Realtype,
y: &[Realtype; 2],
ydot: &mut [Realtype; 2],
k: &Realtype,
) -> RhsResult {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
//initialize the solver
let mut solver = cvode::Solver::new(
LinearMultistepMethod::Adams,
wrapped_f,
0.,
&y0,
1e-4,
AbsTolerance::scalar(1e-4),
1e-2,
)
.unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
println!("{},{},{}", t, x, xdot);
}
```

View File

@ -1,12 +0,0 @@
[package]
name = "cvode-wrap"
version = "0.1.0"
authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
#cvode-5-sys = {path = "../cvode-5-sys"}
sundials-sys = {path = "../../sundials-sys", default-features=false, features=["cvodes"]}
array-init = "2.0"

View File

@ -1,10 +0,0 @@
[package]
name = "test-solver"
version = "0.1.0"
authors = ["Arthur Carcano <arthur.carcano@inria.fr>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
cvode-wrap = {path = "../cvode-wrap"}

View File

@ -1,80 +0,0 @@
use std::env::args;
use cvode_wrap::*;
fn main() {
let y0 = [0., 1.];
//define the right-hand-side
fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
// If there is any command line argument compute the sensitivities, else don't.
if false && args().nth(1).is_none() {
//initialize the solver
let mut solver = cvode::Solver::new(
LinearMultistepMethod::Adams,
f,
0.,
&y0,
1e-4,
AbsTolerance::scalar(1e-4),
1e-2,
)
.unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
println!("{},{},{}", t, x, xdot);
}
} else {
const N_SENSI: usize = 3;
// the sensitivities in order are d/dy0[0], d/dy0[1] and d/dk
let ys0 = [[1., 0.], [0., 1.], [0., 0.]];
fn fs(
_t: Realtype,
y: &[Realtype; 2],
_ydot: &[Realtype; 2],
ys: [&[Realtype; 2]; N_SENSI],
ysdot: [&mut [Realtype; 2]; N_SENSI],
k: &Realtype,
) -> RhsResult {
*ysdot[0] = [ys[0][1], -ys[0][0] * k];
*ysdot[1] = [ys[1][1], -ys[1][0] * k];
*ysdot[2] = [ys[2][1], -ys[2][0] * k - y[0]];
RhsResult::Ok
}
//initialize the solver
let mut solver = cvode_sens::Solver::new(
LinearMultistepMethod::Adams,
f,
fs,
0.,
&y0,
&ys0,
1e-4,
AbsTolerance::scalar(1e-4),
cvode_sens::SensiAbsTolerance::scalar([1e-4; N_SENSI]),
1e-2,
)
.unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (
_tret,
&[x, xdot],
[&[dy0_dy00, dy1_dy00], &[dy0_dy01, dy1_dy01], &[dy0_dk, dy1_dk]],
) = solver.step(t as _, StepKind::Normal).unwrap();
println!(
"{},{},{},{},{},{},{},{},{}",
t, x, xdot, dy0_dy00, dy1_dy00, dy0_dy01, dy1_dy01, dy0_dk, dy1_dk
);
}
}
}

View File

@ -0,0 +1,28 @@
use cvode_wrap::*;
fn main() {
let y0 = [0., 1.];
//define the right-hand-side
fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
//initialize the solver
let mut solver = SolverNoSensi::new(
LinearMultistepMethod::Adams,
f,
0.,
&y0,
1e-4,
AbsTolerance::scalar(1e-4),
1e-2,
)
.unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
println!("{},{},{}", t, x, xdot);
}
}

View File

@ -0,0 +1,58 @@
use cvode_wrap::*;
fn main() {
let y0 = [0., 1.];
//define the right-hand-side
fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
*ydot = [y[1], -y[0] * k];
RhsResult::Ok
}
//define the sensitivity function for the right hand side
fn fs(
_t: Realtype,
y: &[Realtype; 2],
_ydot: &[Realtype; 2],
ys: [&[Realtype; 2]; N_SENSI],
ysdot: [&mut [Realtype; 2]; N_SENSI],
k: &Realtype,
) -> RhsResult {
// Mind that when indexing sensitivities, the first index
// is the parameter index, and the second the state variable
// index
*ysdot[0] = [ys[0][1], -ys[0][0] * k];
*ysdot[1] = [ys[1][1], -ys[1][0] * k];
*ysdot[2] = [ys[2][1], -ys[2][0] * k - y[0]];
RhsResult::Ok
}
const N_SENSI: usize = 3;
// the sensitivities in order are d/dy0[0], d/dy0[1] and d/dk
let ys0 = [[1., 0.], [0., 1.], [0., 0.]];
//initialize the solver
let mut solver = SolverSensi::new(
LinearMultistepMethod::Adams,
f,
fs,
0.,
&y0,
&ys0,
1e-4,
AbsTolerance::scalar(1e-4),
SensiAbsTolerance::scalar([1e-4; N_SENSI]),
1e-2,
)
.unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (_tret, &[x, xdot], [&[dy0_dy00, dy1_dy00], &[dy0_dy01, dy1_dy01], &[dy0_dk, dy1_dk]]) =
solver.step(t as _, StepKind::Normal).unwrap();
println!(
"{},{},{},{},{},{},{},{},{}",
t, x, xdot, dy0_dy00, dy1_dy00, dy0_dy01, dy1_dy01, dy0_dk, dy1_dk
);
}
}

View File

@ -1,54 +1,30 @@
use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}; //! Wrapper around cvode, without sensitivities
use std::{convert::TryInto, os::raw::c_int, pin::Pin};
use sundials_sys::{SUNLinearSolver, SUNMatrix}; use sundials_sys::{SUNLinearSolver, SUNMatrix};
use crate::{ use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial, check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
Realtype, Result, RhsResult, StepKind,
}; };
#[repr(C)]
struct CvodeMemoryBlock {
_private: [u8; 0],
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy)]
struct CvodeMemoryBlockNonNullPtr {
ptr: NonNull<CvodeMemoryBlock>,
}
impl CvodeMemoryBlockNonNullPtr {
fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
Self { ptr }
}
fn as_raw(self) -> *mut c_void {
self.ptr.as_ptr() as *mut c_void
}
}
impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
Self::new(x)
}
}
struct WrappingUserData<UserData, F> { struct WrappingUserData<UserData, F> {
actual_user_data: UserData, actual_user_data: UserData,
f: F, f: F,
} }
/// The main struct of the crate. Wraps a sundials solver. /// The ODE solver without sensitivities.
/// ///
/// Args /// # Type Arguments
/// ---- ///
/// `UserData` is the type of the supplementary arguments for the /// - `F` is the type of the right-hand side function
///
/// - `UserData` is the type of the supplementary arguments for the
/// right-hand-side. If unused, should be `()`. /// right-hand-side. If unused, should be `()`.
/// ///
/// `N` is the "problem size", that is the dimension of the state space. /// - `N` is the "problem size", that is the dimension of the state space.
///
/// See [crate-level](`crate`) documentation for more.
pub struct Solver<UserData, F, const N: usize> { pub struct Solver<UserData, F, const N: usize> {
mem: CvodeMemoryBlockNonNullPtr, mem: CvodeMemoryBlockNonNullPtr,
y0: NVectorSerialHeapAllocated<N>, y0: NVectorSerialHeapAllocated<N>,
@ -58,9 +34,6 @@ pub struct Solver<UserData, F, const N: usize> {
user_data: Pin<Box<WrappingUserData<UserData, F>>>, user_data: Pin<Box<WrappingUserData<UserData, F>>>,
} }
/// The wrapping function.
///
/// Internally used in [`wrap`].
extern "C" fn wrap_f<UserData, F, const N: usize>( extern "C" fn wrap_f<UserData, F, const N: usize>(
t: Realtype, t: Realtype,
y: *const NVectorSerial<N>, y: *const NVectorSerial<N>,
@ -88,6 +61,7 @@ impl<UserData, F, const N: usize> Solver<UserData, F, N>
where where
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult, F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
{ {
/// Create a new solver.
pub fn new( pub fn new(
method: LinearMultistepMethod, method: LinearMultistepMethod,
f: F, f: F,
@ -170,6 +144,11 @@ where
Ok(res) Ok(res)
} }
/// Takes a step according to `step_kind` (see [`StepKind`]).
///
/// Returns a tuple `(t_out,&y(t_out))` where `t_out` is the time
/// reached by the solver as dictated by `step_kind`, and `y(t_out)` is an
/// array of the state variables at that time.
pub fn step( pub fn step(
&mut self, &mut self,
tout: Realtype, tout: Realtype,
@ -225,6 +204,7 @@ mod tests {
1e-4, 1e-4,
AbsTolerance::Scalar(1e-4), AbsTolerance::Scalar(1e-4),
(), (),
).unwrap(); )
.unwrap();
} }
} }

View File

@ -1,76 +1,35 @@
use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}; //! Wrapper around cvodeS, with sensitivities
use std::{convert::TryInto, os::raw::c_int, pin::Pin};
use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED}; use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED};
use crate::{ use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial, check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
Realtype, Result, RhsResult, SensiAbsTolerance, StepKind,
}; };
#[repr(C)]
struct CvodeMemoryBlock {
_private: [u8; 0],
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy)]
struct CvodeMemoryBlockNonNullPtr {
ptr: NonNull<CvodeMemoryBlock>,
}
impl CvodeMemoryBlockNonNullPtr {
fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
Self { ptr }
}
fn as_raw(self) -> *mut c_void {
self.ptr.as_ptr() as *mut c_void
}
}
pub enum SensiAbsTolerance<const SIZE: usize, const N_SENSI: usize> {
Scalar([Realtype; N_SENSI]),
Vector([NVectorSerialHeapAllocated<SIZE>; N_SENSI]),
}
impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
pub fn scalar(atol: [Realtype; N_SENSI]) -> Self {
SensiAbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self {
SensiAbsTolerance::Vector(
array_init::from_iter(
atol.iter()
.map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
)
.unwrap(),
)
}
}
impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
Self::new(x)
}
}
struct WrappingUserData<UserData, F, FS> { struct WrappingUserData<UserData, F, FS> {
actual_user_data: UserData, actual_user_data: UserData,
f: F, f: F,
fs: FS, fs: FS,
} }
/// The main struct of the crate. Wraps a sundials solver. /// The ODE solver with sensitivities.
/// ///
/// Args /// # Type Arguments
/// ---- ///
/// `UserData` is the type of the supplementary arguments for the /// - `F` is the type of the right-hand side function
///
/// - `FS` is the type of the sensitivities right-hand side function
///
/// - `UserData` is the type of the supplementary arguments for the
/// right-hand-side. If unused, should be `()`. /// right-hand-side. If unused, should be `()`.
/// ///
/// `N` is the "problem size", that is the dimension of the state space. /// - `N` is the "problem size", that is the dimension of the state space.
/// ///
/// See [crate-level](`crate`) documentation for more. /// - `N_SENSI` is the number of sensitivities computed
pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> { pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> {
mem: CvodeMemoryBlockNonNullPtr, mem: CvodeMemoryBlockNonNullPtr,
y0: NVectorSerialHeapAllocated<N>, y0: NVectorSerialHeapAllocated<N>,
@ -83,9 +42,6 @@ pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> {
sensi_out_buffer: [NVectorSerialHeapAllocated<N>; N_SENSI], sensi_out_buffer: [NVectorSerialHeapAllocated<N>; N_SENSI],
} }
/// The wrapping function.
///
/// Internally used in [`wrap`].
extern "C" fn wrap_f<UserData, F, FS, const N: usize>( extern "C" fn wrap_f<UserData, F, FS, const N: usize>(
t: Realtype, t: Realtype,
y: *const NVectorSerial<N>, y: *const NVectorSerial<N>,
@ -168,6 +124,7 @@ where
&UserData, &UserData,
) -> RhsResult, ) -> RhsResult,
{ {
/// Creates a new solver.
#[allow(clippy::clippy::too_many_arguments)] #[allow(clippy::clippy::too_many_arguments)]
pub fn new( pub fn new(
method: LinearMultistepMethod, method: LinearMultistepMethod,
@ -293,6 +250,12 @@ where
Ok(res) Ok(res)
} }
/// Takes a step according to `step_kind` (see [`StepKind`]).
///
/// Returns a tuple `(t_out,&y(t_out),[&dy_dp(tout)])` where `t_out` is the time
/// reached by the solver as dictated by `step_kind`, `y(t_out)` is an
/// array of the state variables at that time, and the i-th `dy_dp(tout)` is an array
/// of the sensitivities of all variables with respect to parameter i.
#[allow(clippy::clippy::type_complexity)] #[allow(clippy::clippy::type_complexity)]
pub fn step( pub fn step(
&mut self, &mut self,
@ -379,6 +342,7 @@ mod tests {
AbsTolerance::scalar(1e-4), AbsTolerance::scalar(1e-4),
SensiAbsTolerance::scalar([1e-4; 4]), SensiAbsTolerance::scalar([1e-4; 4]),
(), (),
).unwrap(); )
.unwrap();
} }
} }

View File

@ -1,12 +1,19 @@
use std::{os::raw::c_int, ptr::NonNull}; //! A wrapper around cvode and cvodes from the sundials tool suite.
//!
//! Users should be mostly interested in [`SolverSensi`] and [`SolverNoSensi`].
use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
use sundials_sys::realtype; use sundials_sys::realtype;
mod nvector; mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
pub mod cvode; mod cvode;
pub mod cvode_sens; mod cvode_sens;
pub use cvode::Solver as SolverNoSensi;
pub use cvode_sens::Solver as SolverSensi;
/// The floatting-point type sundials was compiled with /// The floatting-point type sundials was compiled with
pub type Realtype = realtype; pub type Realtype = realtype;
@ -80,6 +87,29 @@ impl<const SIZE: usize> AbsTolerance<SIZE> {
} }
} }
/// An enum representing the choice between scalars or vectors absolute tolerances
/// for sensitivities.
pub enum SensiAbsTolerance<const SIZE: usize, const N_SENSI: usize> {
Scalar([Realtype; N_SENSI]),
Vector([NVectorSerialHeapAllocated<SIZE>; N_SENSI]),
}
impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
pub fn scalar(atol: [Realtype; N_SENSI]) -> Self {
SensiAbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self {
SensiAbsTolerance::Vector(
array_init::from_iter(
atol.iter()
.map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
)
.unwrap(),
)
}
}
/// A short-hand for `std::result::Result<T, crate::Error>` /// A short-hand for `std::result::Result<T, crate::Error>`
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -94,3 +124,30 @@ fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
Err(Error::ErrorCode { flag, func_id }) Err(Error::ErrorCode { flag, func_id })
} }
} }
#[repr(C)]
struct CvodeMemoryBlock {
_private: [u8; 0],
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy)]
struct CvodeMemoryBlockNonNullPtr {
ptr: NonNull<CvodeMemoryBlock>,
}
impl CvodeMemoryBlockNonNullPtr {
fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
Self { ptr }
}
fn as_raw(self) -> *mut c_void {
self.ptr.as_ptr() as *mut c_void
}
}
impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
Self::new(x)
}
}