lasp/src/lasp/pybind11/lasp_dsp_pybind.cpp

147 lines
4.7 KiB
C++

#include "arma_npy.h"
#include "lasp_avpowerspectra.h"
#include "lasp_biquadbank.h"
#include "lasp_fft.h"
#include "lasp_filter.h"
#include "lasp_slm.h"
#include "lasp_streammgr.h"
#include "lasp_window.h"
#include <iostream>
#include <pybind11/pybind11.h>
using std::cerr;
using std::endl;
namespace py = pybind11;
using rte = std::runtime_error;
/**
* \ingroup pybind
* @{
*
*/
/**
* @brief Initialize DSP code
*
* @param m The Python module to add classes and methods to
*/
void init_dsp(py::module &m) {
py::class_<Fft> fft(m, "Fft");
fft.def(py::init<us>());
fft.def("fft", [](Fft &f, dpyarray dat) {
if (dat.ndim() == 1) {
return ColToNpy<c>(f.fft(NpyToCol<d, false>(dat)));
} else if (dat.ndim() == 2) {
return MatToNpy<c>(f.fft(NpyToMat<d, false>(dat)));
} else {
throw rte("Invalid dimensions of array");
}
});
fft.def("ifft", [](Fft &f, cpyarray dat) {
if (dat.ndim() == 1) {
return ColToNpy<d>(f.ifft(NpyToCol<c, false>(dat)));
} else if (dat.ndim() == 2) {
return MatToNpy<d>(f.ifft(NpyToMat<c, false>(dat)));
} else {
throw rte("Invalid dimensions of array");
}
});
fft.def_static("load_fft_wisdom", &Fft::load_fft_wisdom);
fft.def_static("store_fft_wisdom", &Fft::store_fft_wisdom);
/// Window
py::class_<Window> w(m, "Window");
py::enum_<Window::WindowType>(w, "WindowType")
.value("Hann", Window::WindowType::Hann)
.value("Hamming", Window::WindowType::Hamming)
.value("Bartlett", Window::WindowType::Bartlett)
.value("Blackman", Window::WindowType::Bartlett)
.value("Rectangular", Window::WindowType::Rectangular);
w.def_static("toTxt", &Window::toText);
py::class_<Filter, std::shared_ptr<Filter>> filter(m, "Filter");
/// SeriesBiquad
py::class_<SeriesBiquad, std::shared_ptr<SeriesBiquad>> sbq(m, "SeriesBiquad",
filter);
sbq.def(py::init([](dpyarray filter) {
return std::make_shared<SeriesBiquad>(NpyToCol<d, false>(filter));
}));
sbq.def("filter", [](SeriesBiquad &s, dpyarray input) {
vd res = NpyToCol<d, true>(input);
s.filter(res);
return ColToNpy<d>(res);
});
/// BiquadBank
py::class_<BiquadBank, std::shared_ptr<BiquadBank>> bqb(m, "BiquadBank");
bqb.def(py::init<const dmat &, const vd *>());
bqb.def("setGains",
[](BiquadBank &b, dpyarray gains) { b.setGains(NpyToCol(gains)); });
bqb.def("filter", [](BiquadBank &b, dpyarray input) {
vd inout = NpyToCol<d, true>(input);
b.filter(inout);
return ColToNpy(inout);
});
/// PowerSpectra
py::class_<PowerSpectra> ps(m, "PowerSpectra");
ps.def(py::init<const us, const Window::WindowType>());
ps.def("compute", [](PowerSpectra &ps, dpyarray input) {
return CubeToNpy<c>(ps.compute(NpyToMat<d, false>(input)));
});
/// AvPowerSpectra
py::class_<AvPowerSpectra> aps(m, "AvPowerSpectra");
aps.def(py::init<const us, const Window::WindowType, const d, const d>(),
py::arg("nfft") = 2048,
py::arg("windowType") = Window::WindowType::Hann,
py::arg("overlap_percentage") = 50.0, py::arg("time_constant") = -1);
aps.def("compute", [](AvPowerSpectra &aps, dpyarray timedata) {
std::optional<ccube> res;
{
py::gil_scoped_release release;
res = aps.compute(NpyToMat<d, false>(timedata));
}
return CubeToNpy<c>(res.value_or(ccube(0, 0, 0)));
});
aps.def("get_est", [](const AvPowerSpectra &ps) {
auto est = ps.get_est();
return CubeToNpy<c>(est.value_or(ccube(0, 0, 0)));
});
py::class_<SLM> slm(m, "cppSLM");
slm.def_static("fromBiquads", [](const d fs, const d Lref, const us ds,
const d tau, dpyarray bandpass) {
return SLM::fromBiquads(fs, Lref, ds, tau, NpyToMat<d, false>(bandpass));
});
slm.def_static("fromBiquads", [](const d fs, const d Lref, const us ds,
const d tau, dpyarray prefilter,
py::array_t<d> bandpass) {
return SLM::fromBiquads(fs, Lref, ds, tau, NpyToCol<d, false>(prefilter),
NpyToMat<d, false>(bandpass));
});
slm.def("run", [](SLM &slm, dpyarray in) {
return MatToNpy<d>(slm.run(NpyToCol<d, false>(in)));
});
slm.def("Pm", [](const SLM &slm) { return ColToNpy<d>(slm.Pm); });
slm.def("Pmax", [](const SLM &slm) { return ColToNpy<d>(slm.Pmax); });
slm.def("Ppeak", [](const SLM &slm) { return ColToNpy<d>(slm.Ppeak); });
slm.def("Leq", [](const SLM &slm) { return ColToNpy<d>(slm.Leq()); });
slm.def("Lmax", [](const SLM &slm) { return ColToNpy<d>(slm.Lmax()); });
slm.def("Lpeak", [](const SLM &slm) { return ColToNpy<d>(slm.Lpeak()); });
slm.def_static("suggestedDownSamplingFac", &SLM::suggestedDownSamplingFac);
}
/** @} */