2022-10-11 07:50:15 +00:00
|
|
|
#pragma once
|
|
|
|
#include "lasp_mathtypes.h"
|
|
|
|
#include <cstring>
|
|
|
|
#include <pybind11/numpy.h>
|
|
|
|
#include <type_traits>
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
using pyarray = py::array_t<T, py::array::f_style | py::array::forcecast>;
|
|
|
|
using dpyarray = pyarray<d>;
|
2022-10-11 12:50:44 +00:00
|
|
|
using cpyarray = pyarray<c>;
|
2022-10-11 07:50:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Convert Armadillo column vector to Numpy 1D array
|
|
|
|
*
|
|
|
|
* @tparam T Type (double, d, c, etc)
|
|
|
|
* @param data Input Armadillo array
|
|
|
|
*
|
|
|
|
* @return Numpy array with copied data
|
|
|
|
*/
|
|
|
|
template <typename T> py::array_t<T> ColToNpy(const arma::Col<T> &data) {
|
|
|
|
|
|
|
|
const us Tsize = sizeof(T);
|
|
|
|
|
|
|
|
const auto nrows = static_cast<ssize_t>(data.n_rows);
|
|
|
|
|
|
|
|
const us size = nrows;
|
|
|
|
|
|
|
|
py::array_t<T> result({nrows}, {Tsize});
|
|
|
|
|
|
|
|
// Copy over memory
|
|
|
|
if (size > 0) {
|
|
|
|
memcpy(result.mutable_data(0), data.memptr(), sizeof(T) * nrows);
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Convert Armadillo 2D array Numpy 2D array
|
|
|
|
*
|
|
|
|
* @tparam T Type (double, d, c, etc)
|
|
|
|
* @param data Input Armadillo array
|
|
|
|
*
|
|
|
|
* @return Numpy array with copied data
|
|
|
|
*/
|
|
|
|
template <typename T> py::array_t<T> MatToNpy(const arma::Mat<T> &data) {
|
|
|
|
|
|
|
|
const us Tsize = sizeof(T);
|
|
|
|
|
|
|
|
const auto nrows = static_cast<ssize_t>(data.n_rows);
|
|
|
|
const auto ncols = static_cast<ssize_t>(data.n_cols);
|
|
|
|
|
|
|
|
const us size = nrows;
|
|
|
|
|
|
|
|
py::array_t<T> result({nrows, ncols}, {Tsize, nrows * Tsize});
|
|
|
|
|
|
|
|
// Copy over memory
|
|
|
|
if (size > 0) {
|
|
|
|
memcpy(result.mutable_data(0, 0), data.memptr(), sizeof(T) * nrows * ncols);
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Convert Armadillo Cube to Numpy 3D array
|
|
|
|
*
|
|
|
|
* @tparam T Type (double, d, c, etc)
|
|
|
|
* @param data Input Armadillo Array
|
|
|
|
*
|
|
|
|
* @return Numpy array with copied data
|
|
|
|
*/
|
|
|
|
template <typename T> py::array_t<T> CubeToNpy(const arma::Cube<T> &data) {
|
|
|
|
|
|
|
|
const us Tsize = sizeof(T);
|
|
|
|
|
|
|
|
const auto nrows = static_cast<ssize_t>(data.n_rows);
|
|
|
|
const auto ncols = static_cast<ssize_t>(data.n_cols);
|
|
|
|
const auto nslices = static_cast<ssize_t>(data.n_slices);
|
|
|
|
|
|
|
|
const us size = nrows * ncols * nslices;
|
|
|
|
|
|
|
|
py::array_t<T> result({nrows, ncols, nslices},
|
|
|
|
{Tsize, nrows * Tsize, nrows * ncols * Tsize});
|
|
|
|
|
|
|
|
// Copy over memory
|
|
|
|
if (size > 0) {
|
|
|
|
memcpy(result.mutable_data(0, 0, 0), data.memptr(),
|
|
|
|
sizeof(T) * nrows * ncols * nslices);
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// BACK converters
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Wrap Numpy array to 2D Armadillo Mat
|
|
|
|
*
|
|
|
|
* @tparam T Type
|
|
|
|
* @tparam copy Whether to copy the data. If true, creates a copy
|
|
|
|
* @param data The Numpy array to convert
|
|
|
|
*
|
|
|
|
* @return Mat instance
|
|
|
|
*/
|
|
|
|
template <typename T, bool copy = true>
|
|
|
|
arma::Mat<T> NpyToMat(pyarray<T> data) {
|
|
|
|
if (data.ndim() != 2) {
|
|
|
|
throw std::runtime_error("Expects a 2D array");
|
|
|
|
}
|
|
|
|
return arma::Mat<T>(data.mutable_data(0,0), data.shape(0), data.shape(1), copy);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Wrap Numpy array to 1D Armadillo column vector
|
|
|
|
*
|
|
|
|
* @tparam T Type
|
|
|
|
* @tparam copy Whether to copy the data. If true, creates a copy
|
|
|
|
* @param data The Numpy array to convert
|
|
|
|
*
|
|
|
|
* @return Armadillo column instance
|
|
|
|
*/
|
|
|
|
template <typename T, bool copy = true>
|
|
|
|
arma::Mat<T> NpyToCol(pyarray<T> data) {
|
|
|
|
if (data.ndim() != 1) {
|
|
|
|
throw std::runtime_error("Expects a 1D array");
|
|
|
|
}
|
2022-10-11 12:50:44 +00:00
|
|
|
return arma::Col<T>(data.mutable_data(0), data.size(), copy);
|
2022-10-11 07:50:15 +00:00
|
|
|
}
|
|
|
|
|