More exposure of classes and enums to Python

This commit is contained in:
Anne de Jong 2024-09-26 20:09:11 +02:00
parent d9fbe25dc1
commit c843c089dd
4 changed files with 109 additions and 45 deletions

View File

@ -62,6 +62,7 @@ fn lasprs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<ps::FreqWeighting>()?;
m.add_class::<slm::SLMSettings>()?;
m.add_class::<slm::SLM>()?;
m.add_class::<ps::WindowType>()?;
Ok(())
}

View File

@ -1,7 +1,7 @@
use super::timebuffer::TimeBuffer;
use super::CrossPowerSpecra;
use super::*;
use crate::config::*;
use crate::{config::*, TransferFunction, ZPKModel};
use anyhow::{bail, Error, Result};
use derive_builder::Builder;
use freqweighting::FreqWeighting;
@ -70,35 +70,24 @@ impl ApsSettings {
fn get_overlap_keep(&self) -> usize {
self.validate_get_overlap_keep().unwrap()
}
/// Unpack all, returns parts in tuple
pub fn get(self) -> (ApsMode, Overlap, WindowType, FreqWeighting, usize, Flt) {
(
self.mode,
self.overlap,
self.windowType,
self.freqWeightingType,
self.nfft,
self.fs,
)
}
/// Returns the amount of samples to `keep` in the time buffer when
/// overlapping time segments using [TimeBuffer].
pub fn validate_get_overlap_keep(&self) -> Result<usize> {
fn validate_get_overlap_keep(&self) -> Result<usize> {
let nfft = self.nfft;
let overlap_keep = match self.overlap {
Overlap::Number(i) if i >= nfft => {
Overlap::Number { N } if N >= nfft => {
bail!("Invalid overlap number of samples. Should be < nfft, which is {nfft}.")
}
// Keep 1 sample, if overlap is 1 sample etc.
Overlap::Number(i) if i < nfft => i,
Overlap::Number { N } if N < nfft => N,
// If overlap percentage is >= 100, or < 0.0 its an error
Overlap::Percentage(p) if !(0.0..100.).contains(&p) => {
Overlap::Percentage { pct } if !(0.0..100.).contains(&pct) => {
bail!("Invalid overlap percentage. Should be >= 0. And < 100.")
}
// If overlap percentage is 0, this gives
Overlap::Percentage(p) => ((p * nfft as Flt) / 100.) as usize,
Overlap::NoOverlap => 0,
Overlap::Percentage { pct } => ((pct * nfft as Flt) / 100.) as usize,
Overlap::NoOverlap {} => 0,
_ => unreachable!(),
};
if overlap_keep >= nfft {
@ -160,31 +149,47 @@ impl ApsSettings {
/// Provide the overlap of blocks for computing averaged (cross) power spectra.
/// Can be provided as a percentage of the block size, or as a number of
/// samples.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "python-bindings", pyclass)]
#[derive(Clone, Debug, PartialEq)]
pub enum Overlap {
/// Overlap specified as a percentage of the total FFT length. Value should
/// be 0<=pct<100.
Percentage(Flt),
Percentage {
/// Percentage
pct: Flt,
},
/// Number of samples to overlap
Number(usize),
Number {
/// N: Number of samples
N: usize,
},
/// No overlap at all, which is the same as Overlap::Number(0)
NoOverlap,
NoOverlap {},
}
impl Default for Overlap {
fn default() -> Self {
Overlap::Percentage(50.)
Overlap::Percentage { pct: 50. }
}
}
#[cfg(feature = "python-bindings")]
#[cfg_attr(feature = "python-bindings", pymethods)]
impl Overlap {
#[inline]
fn __eq__(&self, other: &Self) -> bool {
self == other
}
}
/// The 'mode' used in computing averaged power spectra. When providing data in
/// blocks to the [AvPowerSpectra] the resulting 'current estimate' responds
/// differently, depending on the model.
#[derive(Default, Copy, Clone)]
#[derive(Copy, Clone, PartialEq)]
#[cfg_attr(feature = "python-bindings", pyclass)]
pub enum ApsMode {
/// Averaged over all data provided. New averages can be created by calling
/// `AvPowerSpectra::reset()`
#[default]
AllAveraging,
AllAveraging {},
/// In this mode, the `AvPowerSpectra` works a bit like a sound level meter,
/// where new data is weighted with old data, and old data exponentially
/// backs off. This mode only makes sense when `tau >> nfft/fs`
@ -195,7 +200,20 @@ pub enum ApsMode {
tau: Flt,
},
/// Spectrogram mode. Only returns the latest estimate(s).
Spectrogram,
Spectrogram {},
}
impl Default for ApsMode {
fn default() -> Self {
ApsMode::AllAveraging {}
}
}
#[cfg(feature = "python-bindings")]
#[cfg_attr(feature = "python-bindings", pymethods)]
impl ApsMode {
#[inline]
fn __eq__(&self, other: &Self) -> bool {
self == other
}
}
/// Averaged power spectra computing engine
@ -221,6 +239,9 @@ pub struct AvPowerSpectra {
/// Storage for sample data.
timebuf: TimeBuffer,
/// Power scaling for of applied frequency weighting.
freqWeighting_pwr: Option<Dcol>,
// Current estimation of the power spectra
cur_est: CPSResult,
}
@ -254,8 +275,7 @@ impl AvPowerSpectra {
/// - When providing invalid sampling frequencies
///
pub fn new_simple_all_averaging(fs: Flt, nfft: usize) -> AvPowerSpectra {
let mut settings =
ApsSettings::reasonableAcousticDefault(fs, ApsMode::AllAveraging).unwrap();
let mut settings = ApsSettings::reasonableAcousticDefault(fs, ApsMode::default()).unwrap();
settings.nfft = nfft;
AvPowerSpectra::new(settings)
}
@ -285,18 +305,38 @@ impl AvPowerSpectra {
let ps = PowerSpectra::newFromWindow(window);
let freq = settings.getFreq();
let freqWeighting_pwr = match settings.freqWeightingType {
FreqWeighting::Z => None,
_ => {
let fw_pwr = ZPKModel::freqWeightingFilter(settings.freqWeightingType)
.tf(0., &freq)
.mapv(|a| a.abs() * a.abs());
Some(fw_pwr)
}
};
AvPowerSpectra {
ps,
overlap_keep,
settings,
N: 0,
freqWeighting_pwr,
cur_est: CPSResult::default((0, 0, 0)),
timebuf: TimeBuffer::new(),
}
}
// Update result for single block
fn update_singleblock(&mut self, timedata: ArrayView2<Flt>) {
let Cpsnew = self.ps.compute(timedata);
let Cpsnew = {
let mut Cpsnew = self.ps.compute(timedata);
if let Some(fw_pwr) = &self.freqWeighting_pwr {
Zip::from(Cpsnew.)
}
Cpsnew
};
// println!("Cpsnew: {:?}", Cpsnew[[0, 0, 0]]);
// Initialize to zero
@ -310,7 +350,7 @@ impl AvPowerSpectra {
// Apply operation based on mode
match self.settings.mode {
ApsMode::AllAveraging => {
ApsMode::AllAveraging {} => {
let Nf = Cflt {
re: self.N as Flt,
im: 0.,
@ -359,7 +399,7 @@ impl AvPowerSpectra {
}
}
ApsMode::Spectrogram => {
ApsMode::Spectrogram {} => {
self.cur_est = Cpsnew;
}
}
@ -395,9 +435,10 @@ impl AvPowerSpectra {
computed_single = true;
}
if computed_single {
return Some(&self.cur_est);
Some(&self.cur_est)
} else {
None
}
None
}
/// Computes average (cross)power spectra, and returns all intermediate
@ -450,11 +491,11 @@ mod test {
#[test]
fn test_overlap_keep() {
let ol = [
Overlap::NoOverlap,
Percentage(50.),
Percentage(50.),
Percentage(25.),
Overlap::Number(10),
Overlap::NoOverlap {},
Percentage { pct: 50. },
Percentage { pct: 50. },
Percentage { pct: 25. },
Overlap::Number { N: 10 },
];
let nffts = [10, 10, 1024, 10];
let expected_keep = [0, 5, 512, 2, 10];
@ -481,7 +522,7 @@ mod test {
let settings = ApsSettingsBuilder::default()
.fs(fs)
.nfft(nfft)
.overlap(Overlap::NoOverlap)
.overlap(Overlap::NoOverlap {})
.mode(ApsMode::ExponentialWeighting { tau })
.build()
.unwrap();

View File

@ -1,12 +1,13 @@
use crate::config::*;
use strum_macros::{Display, EnumMessage};
use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter, EnumMessage};
/// Sound level frequency weighting type (A, C, Z)
// Do the following when Pyo3 0.22 can finally be used combined with rust-numpy:
// #[cfg_attr(feature = "python-bindings", pyclass(eq, eq_int))]
// For now:
#[cfg_attr(feature = "python-bindings", pyclass)]
#[derive(Display, Debug, EnumMessage, Default, Clone, PartialEq)]
#[derive(Copy, Display, Debug, EnumMessage, Default, Clone, PartialEq, EnumIter)]
pub enum FreqWeighting {
/// A-weighting
A,
@ -16,6 +17,16 @@ pub enum FreqWeighting {
#[default]
Z,
}
#[cfg(feature = "python-bindings")]
#[cfg_attr(feature = "python-bindings", pymethods)]
impl FreqWeighting {
#[staticmethod]
fn all() -> Vec<Self> {
Self::iter().collect()
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@ -1,7 +1,8 @@
#![allow(non_snake_case)]
use crate::config::*;
use strum_macros::Display;
use strum::IntoEnumIterator;
use strum_macros::{Display, EnumMessage, EnumIter};
/// Von Hann window, often misnamed as the 'Hanning' window.
fn hann(nfft: usize) -> Dcol {
@ -72,11 +73,11 @@ fn hamming(N: usize) -> Dcol {
/// * Blackman
///
/// The [WindowType::default] is [WindowType::Hann].
#[derive(Display,Default, Copy, Clone, Debug, PartialEq)]
#[derive(Display, Default, Copy, Clone, Debug, PartialEq, EnumMessage, EnumIter)]
// Do the following when Pyo3 0.22 can finally be used combined with rust-numpy:
// #[cfg_attr(feature = "python-bindings", pyclass(eq))]
// For now:
// #[cfg_attr(feature = "python-bindings", pyclass(eq))]
#[cfg_attr(feature = "python-bindings", pyclass)]
pub enum WindowType {
/// Von Hann window
#[default]
@ -91,6 +92,16 @@ pub enum WindowType {
Blackman = 4,
}
#[cfg(feature = "python-bindings")]
#[cfg_attr(feature = "python-bindings", pymethods)]
impl WindowType {
#[staticmethod]
fn all() -> Vec<WindowType> {
WindowType::iter().collect()
}
}
/// Window (taper) computed from specified window type.
#[derive(Clone)]
pub struct Window {