Fix interpolateArray for order=0

This commit is contained in:
Luke Campagnola 2016-10-12 10:26:54 -07:00
parent 92fc9dbe2f
commit e35f59fcb7
2 changed files with 40 additions and 33 deletions

View File

@ -602,6 +602,17 @@ def interpolateArray(data, x, default=0.0, order=1):
if md > nd: if md > nd:
raise TypeError("x.shape[-1] must be less than or equal to data.ndim") raise TypeError("x.shape[-1] must be less than or equal to data.ndim")
totalMask = np.ones(x.shape[:-1], dtype=bool) # keep track of out-of-bound indexes
if order == 0:
xinds = np.round(x).astype(int) # NOTE: for 0.5 this rounds to the nearest *even* number
for ax in range(md):
mask = (xinds[...,ax] >= 0) & (xinds[...,ax] <= data.shape[ax]-1)
xinds[...,ax][~mask] = 0
# keep track of points that need to be set to default
totalMask &= mask
result = data[tuple([xinds[...,i] for i in range(xinds.shape[-1])])]
elif order == 1:
# First we generate arrays of indexes that are needed to # First we generate arrays of indexes that are needed to
# extract the data surrounding each point # extract the data surrounding each point
fields = np.mgrid[(slice(0,order+1),) * md] fields = np.mgrid[(slice(0,order+1),) * md]
@ -609,7 +620,6 @@ def interpolateArray(data, x, default=0.0, order=1):
xmax = xmin + 1 xmax = xmin + 1
indexes = np.concatenate([xmin[np.newaxis, ...], xmax[np.newaxis, ...]]) indexes = np.concatenate([xmin[np.newaxis, ...], xmax[np.newaxis, ...]])
fieldInds = [] fieldInds = []
totalMask = np.ones(x.shape[:-1], dtype=bool) # keep track of out-of-bound indexes
for ax in range(md): for ax in range(md):
mask = (xmin[...,ax] >= 0) & (x[...,ax] <= data.shape[ax]-1) mask = (xmin[...,ax] >= 0) & (x[...,ax] <= data.shape[ax]-1)
# keep track of points that need to be set to default # keep track of points that need to be set to default
@ -630,9 +640,6 @@ def interpolateArray(data, x, default=0.0, order=1):
prof() prof()
## Interpolate ## Interpolate
if order == 0:
result = fieldData[0,0]
else:
s = np.empty((md,) + fieldData.shape, dtype=float) s = np.empty((md,) + fieldData.shape, dtype=float)
dx = x - xmin dx = x - xmin
# reshape fields for arithmetic against dx # reshape fields for arithmetic against dx

View File

@ -58,8 +58,8 @@ def check_interpolateArray(order):
x = np.array([[ 0.3, 0.6], x = np.array([[ 0.3, 0.6],
[ 1. , 1. ], [ 1. , 1. ],
[ 0.5, 1. ], [ 0.501, 1. ], # NOTE: testing at exactly 0.5 can yield different results from map_coordinates
[ 0.5, 2.5], [ 0.501, 2.501], # due to differences in rounding
[ 10. , 10. ]]) [ 10. , 10. ]])
result = interpolateArray(data, x) result = interpolateArray(data, x)
@ -82,8 +82,8 @@ def check_interpolateArray(order):
# test mapping 2D array of locations # test mapping 2D array of locations
x = np.array([[[0.5, 0.5], [0.5, 1.0], [0.5, 1.5]], x = np.array([[[0.501, 0.501], [0.501, 1.0], [0.501, 1.501]],
[[1.5, 0.5], [1.5, 1.0], [1.5, 1.5]]]) [[1.501, 0.501], [1.501, 1.0], [1.501, 1.501]]])
r1 = interpolateArray(data, x) r1 = interpolateArray(data, x)
r2 = scipy.ndimage.map_coordinates(data, x.transpose(2,0,1), order=order) r2 = scipy.ndimage.map_coordinates(data, x.transpose(2,0,1), order=order)