From e35f59fcb77d50ab493d7c9db847c24c4fbd6006 Mon Sep 17 00:00:00 2001 From: Luke Campagnola Date: Wed, 12 Oct 2016 10:26:54 -0700 Subject: [PATCH] Fix interpolateArray for order=0 --- pyqtgraph/functions.py | 65 +++++++++++++++++-------------- pyqtgraph/tests/test_functions.py | 8 ++-- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/pyqtgraph/functions.py b/pyqtgraph/functions.py index 32d9f2bf..8593241e 100644 --- a/pyqtgraph/functions.py +++ b/pyqtgraph/functions.py @@ -602,37 +602,44 @@ def interpolateArray(data, x, default=0.0, order=1): if md > nd: raise TypeError("x.shape[-1] must be less than or equal to data.ndim") - # First we generate arrays of indexes that are needed to - # extract the data surrounding each point - fields = np.mgrid[(slice(0,order+1),) * md] - xmin = np.floor(x).astype(int) - xmax = xmin + 1 - indexes = np.concatenate([xmin[np.newaxis, ...], xmax[np.newaxis, ...]]) - fieldInds = [] totalMask = np.ones(x.shape[:-1], dtype=bool) # keep track of out-of-bound indexes - for ax in range(md): - mask = (xmin[...,ax] >= 0) & (x[...,ax] <= data.shape[ax]-1) - # keep track of points that need to be set to default - totalMask &= mask - - # ..and keep track of indexes that are out of bounds - # (note that when x[...,ax] == data.shape[ax], then xmax[...,ax] will be out - # of bounds, but the interpolation will work anyway) - mask &= (xmax[...,ax] < data.shape[ax]) - axisIndex = indexes[...,ax][fields[ax]] - axisIndex[axisIndex < 0] = 0 - axisIndex[axisIndex >= data.shape[ax]] = 0 - fieldInds.append(axisIndex) - prof() - - # Get data values surrounding each requested point - fieldData = data[tuple(fieldInds)] - prof() - - ## Interpolate if order == 0: - result = fieldData[0,0] - else: + xinds = np.round(x).astype(int) # NOTE: for 0.5 this rounds to the nearest *even* number + for ax in range(md): + mask = (xinds[...,ax] >= 0) & (xinds[...,ax] <= data.shape[ax]-1) + xinds[...,ax][~mask] = 0 + # keep track of points that need to be set to default + totalMask &= mask + result = data[tuple([xinds[...,i] for i in range(xinds.shape[-1])])] + + elif order == 1: + # First we generate arrays of indexes that are needed to + # extract the data surrounding each point + fields = np.mgrid[(slice(0,order+1),) * md] + xmin = np.floor(x).astype(int) + xmax = xmin + 1 + indexes = np.concatenate([xmin[np.newaxis, ...], xmax[np.newaxis, ...]]) + fieldInds = [] + for ax in range(md): + mask = (xmin[...,ax] >= 0) & (x[...,ax] <= data.shape[ax]-1) + # keep track of points that need to be set to default + totalMask &= mask + + # ..and keep track of indexes that are out of bounds + # (note that when x[...,ax] == data.shape[ax], then xmax[...,ax] will be out + # of bounds, but the interpolation will work anyway) + mask &= (xmax[...,ax] < data.shape[ax]) + axisIndex = indexes[...,ax][fields[ax]] + axisIndex[axisIndex < 0] = 0 + axisIndex[axisIndex >= data.shape[ax]] = 0 + fieldInds.append(axisIndex) + prof() + + # Get data values surrounding each requested point + fieldData = data[tuple(fieldInds)] + prof() + + ## Interpolate s = np.empty((md,) + fieldData.shape, dtype=float) dx = x - xmin # reshape fields for arithmetic against dx diff --git a/pyqtgraph/tests/test_functions.py b/pyqtgraph/tests/test_functions.py index 4c9cabfe..7ad3bf91 100644 --- a/pyqtgraph/tests/test_functions.py +++ b/pyqtgraph/tests/test_functions.py @@ -58,8 +58,8 @@ def check_interpolateArray(order): x = np.array([[ 0.3, 0.6], [ 1. , 1. ], - [ 0.5, 1. ], - [ 0.5, 2.5], + [ 0.501, 1. ], # NOTE: testing at exactly 0.5 can yield different results from map_coordinates + [ 0.501, 2.501], # due to differences in rounding [ 10. , 10. ]]) result = interpolateArray(data, x) @@ -82,8 +82,8 @@ def check_interpolateArray(order): # test mapping 2D array of locations - x = np.array([[[0.5, 0.5], [0.5, 1.0], [0.5, 1.5]], - [[1.5, 0.5], [1.5, 1.0], [1.5, 1.5]]]) + x = np.array([[[0.501, 0.501], [0.501, 1.0], [0.501, 1.501]], + [[1.501, 0.501], [1.501, 1.0], [1.501, 1.501]]]) r1 = interpolateArray(data, x) r2 = scipy.ndimage.map_coordinates(data, x.transpose(2,0,1), order=order)