diff --git a/python_src/lasp/lasp_config.py b/python_src/lasp/lasp_config.py index de94229..f148be6 100644 --- a/python_src/lasp/lasp_config.py +++ b/python_src/lasp/lasp_config.py @@ -16,13 +16,28 @@ else: LASP_NUMPY_COMPLEX_TYPE = np.float64 -def zeros(shape): - return np.zeros(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F') +def zeros(shape, dtype=float, order='F'): + if dtype == float: + return np.zeros(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order) + elif dtype == complex: + return np.zeros(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order) + else: + raise RuntimeError(f"Unknown dtype: {dtype}") -def ones(shape): - return np.ones(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F') +def ones(shape, dtype=float, order='F'): + if dtype == float: + return np.ones(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order) + elif dtype == complex: + return np.ones(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order) + else: + raise RuntimeError(f"Unknown dtype: {dtype}") +def empty(shape, dtype=float, order='F'): + if dtype == float: + return np.empty(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order) + elif dtype == complex: + return np.empty(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order) + else: + raise RuntimeError(f"Unknown dtype: {dtype}") -def empty(shape): - return np.empty(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F')