Merge pull request #583 from acq4/function-updates

Function updates
This commit is contained in:
Luke Campagnola 2017-10-04 10:30:42 -07:00 committed by GitHub
commit 4880b5849b
2 changed files with 30 additions and 25 deletions

View File

@ -15,7 +15,7 @@ from .python2_3 import asUnicode, basestring
from .Qt import QtGui, QtCore, USE_PYSIDE from .Qt import QtGui, QtCore, USE_PYSIDE
from . import getConfigOption, setConfigOptions from . import getConfigOption, setConfigOptions
from . import debug from . import debug
from .metaarray import MetaArray
Colors = { Colors = {
@ -417,7 +417,21 @@ def eq(a, b):
""" """
if a is b: if a is b:
return True return True
# Avoid comparing large arrays against scalars; this is expensive and we know it should return False.
aIsArr = isinstance(a, (np.ndarray, MetaArray))
bIsArr = isinstance(b, (np.ndarray, MetaArray))
if (aIsArr or bIsArr) and type(a) != type(b):
return False
# If both inputs are arrays, we can speeed up comparison if shapes / dtypes don't match
# NOTE: arrays of dissimilar type should be considered unequal even if they are numerically
# equal because they may behave differently when computed on.
if aIsArr and bIsArr and (a.shape != b.shape or a.dtype != b.dtype):
return False
# Test for equivalence.
# If the test raises a recognized exception, then return Falase
try: try:
try: try:
# Sometimes running catch_warnings(module=np) generates AttributeError ??? # Sometimes running catch_warnings(module=np) generates AttributeError ???
@ -733,26 +747,17 @@ def subArray(data, offset, shape, stride):
the input in the example above to have shape (10, 7) would cause the the input in the example above to have shape (10, 7) would cause the
output to have shape (2, 3, 7). output to have shape (2, 3, 7).
""" """
#data = data.flatten() data = np.ascontiguousarray(data)[offset:]
data = data[offset:]
shape = tuple(shape) shape = tuple(shape)
stride = tuple(stride)
extraShape = data.shape[1:] extraShape = data.shape[1:]
#print data.shape, offset, shape, stride
for i in range(len(shape)): strides = list(data.strides[::-1])
mask = (slice(None),) * i + (slice(None, shape[i] * stride[i]),) itemsize = strides[-1]
newShape = shape[:i+1] for s in stride[1::-1]:
if i < len(shape)-1: strides.append(itemsize * s)
newShape += (stride[i],) strides = tuple(strides[::-1])
newShape += extraShape
#print i, mask, newShape
#print "start:\n", data.shape, data
data = data[mask]
#print "mask:\n", data.shape, data
data = data.reshape(newShape)
#print "reshape:\n", data.shape, data
return data return np.ndarray(buffer=data, shape=shape+extraShape, strides=strides, dtype=data.dtype)
def transformToArray(tr): def transformToArray(tr):

View File

@ -344,14 +344,14 @@ def test_eq():
a2 = a1 + 1 a2 = a1 + 1
a3 = a2.astype('int') a3 = a2.astype('int')
a4 = np.empty((0, 20)) a4 = np.empty((0, 20))
assert not eq(a1, a2) assert not eq(a1, a2) # same shape/dtype, different values
assert not eq(a1, a3) assert not eq(a1, a3) # same shape, different dtype and values
assert not eq(a1, a4) assert not eq(a1, a4) # different shape (note: np.all gives True if one array has size 0)
assert eq(a2, a3) assert not eq(a2, a3) # same values, but different dtype
assert not eq(a2, a4) assert not eq(a2, a4) # different shape
assert not eq(a3, a4) assert not eq(a3, a4) # different shape and dtype
assert eq(a4, a4.copy()) assert eq(a4, a4.copy())
assert not eq(a4, a4.T) assert not eq(a4, a4.T)