From f6ea790071463606a0e30c0df66e569b026042ba Mon Sep 17 00:00:00 2001 From: "J.A. de Jong - Redu-Sone B.V., ASCEE V.O.F" Date: Fri, 28 Jun 2024 09:15:55 +0200 Subject: [PATCH] Bugfix for overwriting array initialization functions, such that it can also handle complex numbers and different ordering --- python_src/lasp/lasp_config.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/python_src/lasp/lasp_config.py b/python_src/lasp/lasp_config.py index de94229..f148be6 100644 --- a/python_src/lasp/lasp_config.py +++ b/python_src/lasp/lasp_config.py @@ -16,13 +16,28 @@ else: LASP_NUMPY_COMPLEX_TYPE = np.float64 -def zeros(shape): - return np.zeros(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F') +def zeros(shape, dtype=float, order='F'): + if dtype == float: + return np.zeros(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order) + elif dtype == complex: + return np.zeros(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order) + else: + raise RuntimeError(f"Unknown dtype: {dtype}") -def ones(shape): - return np.ones(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F') +def ones(shape, dtype=float, order='F'): + if dtype == float: + return np.ones(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order) + elif dtype == complex: + return np.ones(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order) + else: + raise RuntimeError(f"Unknown dtype: {dtype}") +def empty(shape, dtype=float, order='F'): + if dtype == float: + return np.empty(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order) + elif dtype == complex: + return np.empty(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order) + else: + raise RuntimeError(f"Unknown dtype: {dtype}") -def empty(shape): - return np.empty(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F')