Bugfix for overwriting array initialization functions, such that it can also handle complex numbers and different ordering

This commit is contained in:
Anne de Jong 2024-06-28 09:15:55 +02:00
parent 7cd3dcffa8
commit f6ea790071

View File

@ -16,13 +16,28 @@ else:
LASP_NUMPY_COMPLEX_TYPE = np.float64 LASP_NUMPY_COMPLEX_TYPE = np.float64
def zeros(shape): def zeros(shape, dtype=float, order='F'):
return np.zeros(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F') if dtype == float:
return np.zeros(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order)
elif dtype == complex:
return np.zeros(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order)
else:
raise RuntimeError(f"Unknown dtype: {dtype}")
def ones(shape): def ones(shape, dtype=float, order='F'):
return np.ones(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F') if dtype == float:
return np.ones(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order)
elif dtype == complex:
return np.ones(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order)
else:
raise RuntimeError(f"Unknown dtype: {dtype}")
def empty(shape, dtype=float, order='F'):
if dtype == float:
return np.empty(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order=order)
elif dtype == complex:
return np.empty(shape, dtype=LASP_NUMPY_COMPLEX_TYPE, order=order)
else:
raise RuntimeError(f"Unknown dtype: {dtype}")
def empty(shape):
return np.empty(shape, dtype=LASP_NUMPY_FLOAT_TYPE, order='F')