diff --git a/pyqtgraph/functions.py b/pyqtgraph/functions.py index 1aed6ace..7ad603f7 100644 --- a/pyqtgraph/functions.py +++ b/pyqtgraph/functions.py @@ -15,7 +15,7 @@ from .python2_3 import asUnicode, basestring from .Qt import QtGui, QtCore, USE_PYSIDE from . import getConfigOption, setConfigOptions from . import debug - +from .metaarray import MetaArray Colors = { @@ -417,7 +417,21 @@ def eq(a, b): """ if a is b: return True - + + # Avoid comparing large arrays against scalars; this is expensive and we know it should return False. + aIsArr = isinstance(a, (np.ndarray, MetaArray)) + bIsArr = isinstance(b, (np.ndarray, MetaArray)) + if (aIsArr or bIsArr) and type(a) != type(b): + return False + + # If both inputs are arrays, we can speeed up comparison if shapes / dtypes don't match + # NOTE: arrays of dissimilar type should be considered unequal even if they are numerically + # equal because they may behave differently when computed on. + if aIsArr and bIsArr and (a.shape != b.shape or a.dtype != b.dtype): + return False + + # Test for equivalence. + # If the test raises a recognized exception, then return Falase try: try: # Sometimes running catch_warnings(module=np) generates AttributeError ??? @@ -733,26 +747,17 @@ def subArray(data, offset, shape, stride): the input in the example above to have shape (10, 7) would cause the output to have shape (2, 3, 7). """ - #data = data.flatten() - data = data[offset:] + data = np.ascontiguousarray(data)[offset:] shape = tuple(shape) - stride = tuple(stride) extraShape = data.shape[1:] - #print data.shape, offset, shape, stride - for i in range(len(shape)): - mask = (slice(None),) * i + (slice(None, shape[i] * stride[i]),) - newShape = shape[:i+1] - if i < len(shape)-1: - newShape += (stride[i],) - newShape += extraShape - #print i, mask, newShape - #print "start:\n", data.shape, data - data = data[mask] - #print "mask:\n", data.shape, data - data = data.reshape(newShape) - #print "reshape:\n", data.shape, data + + strides = list(data.strides[::-1]) + itemsize = strides[-1] + for s in stride[1::-1]: + strides.append(itemsize * s) + strides = tuple(strides[::-1]) - return data + return np.ndarray(buffer=data, shape=shape+extraShape, strides=strides, dtype=data.dtype) def transformToArray(tr): diff --git a/pyqtgraph/tests/test_functions.py b/pyqtgraph/tests/test_functions.py index eff56635..68f3dc24 100644 --- a/pyqtgraph/tests/test_functions.py +++ b/pyqtgraph/tests/test_functions.py @@ -344,14 +344,14 @@ def test_eq(): a2 = a1 + 1 a3 = a2.astype('int') a4 = np.empty((0, 20)) - assert not eq(a1, a2) - assert not eq(a1, a3) - assert not eq(a1, a4) + assert not eq(a1, a2) # same shape/dtype, different values + assert not eq(a1, a3) # same shape, different dtype and values + assert not eq(a1, a4) # different shape (note: np.all gives True if one array has size 0) - assert eq(a2, a3) - assert not eq(a2, a4) + assert not eq(a2, a3) # same values, but different dtype + assert not eq(a2, a4) # different shape - assert not eq(a3, a4) + assert not eq(a3, a4) # different shape and dtype assert eq(a4, a4.copy()) assert not eq(a4, a4.T)