diff --git a/pyqtgraph/functions_numba.py b/pyqtgraph/functions_numba.py index d9e2f232..3df568fb 100644 --- a/pyqtgraph/functions_numba.py +++ b/pyqtgraph/functions_numba.py @@ -20,3 +20,24 @@ def rescaleData(data, scale, offset, dtype, clip): rescale_functions[key] = func func(data, scale, offset, clip[0], clip[1], out=data_out) return data_out + +@numba.jit(nopython=True) +def _rescale_and_lookup1d_function(data, scale, offset, lut, out): + vmin, vmax = 0, lut.shape[0] - 1 + for r in range(data.shape[0]): + for c in range(data.shape[1]): + val = (data[r, c] - offset) * scale + val = min(max(val, vmin), vmax) + out[r, c] = lut[int(val)] + +def rescale_and_lookup1d(data, scale, offset, lut): + # data should be floating point and 2d + # lut is 1d + data_out = np.empty_like(data, dtype=lut.dtype) + _rescale_and_lookup1d_function(data, float(scale), float(offset), lut, data_out) + return data_out + +@numba.jit(nopython=True) +def numba_take(lut, data): + # numba supports only the 1st two arguments of np.take + return np.take(lut, data) diff --git a/pyqtgraph/graphicsItems/ImageItem.py b/pyqtgraph/graphicsItems/ImageItem.py index 71e1de3a..01c5a605 100644 --- a/pyqtgraph/graphicsItems/ImageItem.py +++ b/pyqtgraph/graphicsItems/ImageItem.py @@ -603,16 +603,25 @@ class ImageItem(GraphicsObject): maxVal = xp.nextafter(maxVal, 2*maxVal) rng = maxVal - minVal rng = 1 if rng == 0 else rng - image = fn.rescaleData(image, scale/rng, offset=minVal, dtype=dtype, clip=(0, num_colors-1)) - levels = None + fn_numba = fn.getNumbaFunctions() + if xp == numpy and image.flags.c_contiguous and dtype == xp.uint16 and fn_numba is not None: + lut, augmented_alpha = self._convert_2dlut_to_1dlut(lut) + image = fn_numba.rescale_and_lookup1d(image, scale/rng, minVal, lut) + if image.dtype == xp.uint32: + image = image[..., xp.newaxis].view(xp.uint8) + return image, None, None, augmented_alpha + else: + image = fn.rescaleData(image, scale/rng, offset=minVal, dtype=dtype, clip=(0, num_colors-1)) - if image.dtype == xp.uint16 and image.ndim == 2: - image, augmented_alpha = self._apply_lut_for_uint16_mono(image, lut) - lut = None + levels = None - # image is now of type uint8 - return image, levels, lut, augmented_alpha + if image.dtype == xp.uint16 and image.ndim == 2: + image, augmented_alpha = self._apply_lut_for_uint16_mono(image, lut) + lut = None + + # image is now of type uint8 + return image, levels, lut, augmented_alpha def _try_combine_lut(self, image, levels, lut): augmented_alpha = False @@ -722,22 +731,40 @@ class ImageItem(GraphicsObject): # if we are contiguous, we can take a faster codepath where we # ensure that the lut is 1d - if lut.ndim == 2: - if lut.shape[1] == 3: # rgb - # convert rgb lut to rgba so that it is 32-bits - lut = xp.column_stack([lut, xp.full(lut.shape[0], 255, dtype=xp.uint8)]) - augmented_alpha = True - if lut.shape[1] == 4: # rgba - lut = lut.view(xp.uint32) + lut, augmented_alpha = self._convert_2dlut_to_1dlut(lut) + + fn_numba = fn.getNumbaFunctions() + if xp == numpy and fn_numba is not None: + image = fn_numba.numba_take(lut, image) + else: + image = lut[image] - image = lut.ravel()[image] - lut = None - # now both levels and lut are None if image.dtype == xp.uint32: - image = image.view(xp.uint8).reshape(image.shape + (4,)) + image = image[..., xp.newaxis].view(xp.uint8) return image, augmented_alpha + def _convert_2dlut_to_1dlut(self, lut): + # converts: + # - uint8 (N, 1) to uint8 (N,) + # - uint8 (N, 3) or (N, 4) to uint32 (N,) + # this allows faster lookup as 1d lookup is faster + xp = self._xp + augmented_alpha = False + + if lut.ndim == 1: + return lut, augmented_alpha + + if lut.shape[1] == 3: # rgb + # convert rgb lut to rgba so that it is 32-bits + lut = xp.column_stack([lut, xp.full(lut.shape[0], 255, dtype=xp.uint8)]) + augmented_alpha = True + if lut.shape[1] == 4: # rgba + lut = lut.view(xp.uint32) + lut = lut.ravel() + + return lut, augmented_alpha + def _try_make_qimage(self, image, levels, lut, augmented_alpha): xp = self._xp