lasp/cpp_src/pybind11/arma_npy.h

134 lines
3.2 KiB
C
Raw Permalink Normal View History

#pragma once
#include "lasp_mathtypes.h"
#include <cstring>
#include <pybind11/numpy.h>
#include <type_traits>
namespace py = pybind11;
template<typename T>
using pyarray = py::array_t<T, py::array::f_style | py::array::forcecast>;
using dpyarray = pyarray<d>;
using cpyarray = pyarray<c>;
/**
* @brief Convert Armadillo column vector to Numpy 1D array
*
* @tparam T Type (double, d, c, etc)
* @param data Input Armadillo array
*
* @return Numpy array with copied data
*/
template <typename T> py::array_t<T> ColToNpy(const arma::Col<T> &data) {
const us Tsize = sizeof(T);
const auto nrows = static_cast<ssize_t>(data.n_rows);
const us size = nrows;
py::array_t<T> result({nrows}, {Tsize});
// Copy over memory
if (size > 0) {
memcpy(result.mutable_data(0), data.memptr(), sizeof(T) * nrows);
}
return result;
}
/**
* @brief Convert Armadillo 2D array Numpy 2D array
*
* @tparam T Type (double, d, c, etc)
* @param data Input Armadillo array
*
* @return Numpy array with copied data
*/
template <typename T> py::array_t<T> MatToNpy(const arma::Mat<T> &data) {
const us Tsize = sizeof(T);
const auto nrows = static_cast<ssize_t>(data.n_rows);
const auto ncols = static_cast<ssize_t>(data.n_cols);
const us size = nrows;
py::array_t<T> result({nrows, ncols}, {Tsize, nrows * Tsize});
// Copy over memory
if (size > 0) {
memcpy(result.mutable_data(0, 0), data.memptr(), sizeof(T) * nrows * ncols);
}
return result;
}
/**
* @brief Convert Armadillo Cube to Numpy 3D array
*
* @tparam T Type (double, d, c, etc)
* @param data Input Armadillo Array
*
* @return Numpy array with copied data
*/
template <typename T> py::array_t<T> CubeToNpy(const arma::Cube<T> &data) {
const us Tsize = sizeof(T);
const auto nrows = static_cast<ssize_t>(data.n_rows);
const auto ncols = static_cast<ssize_t>(data.n_cols);
const auto nslices = static_cast<ssize_t>(data.n_slices);
const us size = nrows * ncols * nslices;
py::array_t<T> result({nrows, ncols, nslices},
{Tsize, nrows * Tsize, nrows * ncols * Tsize});
// Copy over memory
if (size > 0) {
memcpy(result.mutable_data(0, 0, 0), data.memptr(),
sizeof(T) * nrows * ncols * nslices);
}
return result;
}
/// BACK converters
/**
* @brief Wrap Numpy array to 2D Armadillo Mat
*
* @tparam T Type
* @tparam copy Whether to copy the data. If true, creates a copy
* @param data The Numpy array to convert
*
* @return Mat instance
*/
template <typename T, bool copy = true>
arma::Mat<T> NpyToMat(pyarray<T> data) {
if (data.ndim() != 2) {
throw std::runtime_error("Expects a 2D array");
}
return arma::Mat<T>(data.mutable_data(0,0), data.shape(0), data.shape(1), copy);
}
/**
* @brief Wrap Numpy array to 1D Armadillo column vector
*
* @tparam T Type
* @tparam copy Whether to copy the data. If true, creates a copy
* @param data The Numpy array to convert
*
* @return Armadillo column instance
*/
template <typename T, bool copy = true>
arma::Mat<T> NpyToCol(pyarray<T> data) {
if (data.ndim() != 1) {
throw std::runtime_error("Expects a 1D array");
}
return arma::Col<T>(data.mutable_data(0), data.size(), copy);
}