Merge pull request #583 from acq4/function-updates

Function updates
This commit is contained in:
Luke Campagnola 2017-10-04 10:30:42 -07:00 committed by GitHub
commit 4880b5849b
2 changed files with 30 additions and 25 deletions

View File

@ -15,7 +15,7 @@ from .python2_3 import asUnicode, basestring
from .Qt import QtGui, QtCore, USE_PYSIDE
from . import getConfigOption, setConfigOptions
from . import debug
from .metaarray import MetaArray
Colors = {
@ -417,7 +417,21 @@ def eq(a, b):
"""
if a is b:
return True
# Avoid comparing large arrays against scalars; this is expensive and we know it should return False.
aIsArr = isinstance(a, (np.ndarray, MetaArray))
bIsArr = isinstance(b, (np.ndarray, MetaArray))
if (aIsArr or bIsArr) and type(a) != type(b):
return False
# If both inputs are arrays, we can speeed up comparison if shapes / dtypes don't match
# NOTE: arrays of dissimilar type should be considered unequal even if they are numerically
# equal because they may behave differently when computed on.
if aIsArr and bIsArr and (a.shape != b.shape or a.dtype != b.dtype):
return False
# Test for equivalence.
# If the test raises a recognized exception, then return Falase
try:
try:
# Sometimes running catch_warnings(module=np) generates AttributeError ???
@ -733,26 +747,17 @@ def subArray(data, offset, shape, stride):
the input in the example above to have shape (10, 7) would cause the
output to have shape (2, 3, 7).
"""
#data = data.flatten()
data = data[offset:]
data = np.ascontiguousarray(data)[offset:]
shape = tuple(shape)
stride = tuple(stride)
extraShape = data.shape[1:]
#print data.shape, offset, shape, stride
for i in range(len(shape)):
mask = (slice(None),) * i + (slice(None, shape[i] * stride[i]),)
newShape = shape[:i+1]
if i < len(shape)-1:
newShape += (stride[i],)
newShape += extraShape
#print i, mask, newShape
#print "start:\n", data.shape, data
data = data[mask]
#print "mask:\n", data.shape, data
data = data.reshape(newShape)
#print "reshape:\n", data.shape, data
strides = list(data.strides[::-1])
itemsize = strides[-1]
for s in stride[1::-1]:
strides.append(itemsize * s)
strides = tuple(strides[::-1])
return data
return np.ndarray(buffer=data, shape=shape+extraShape, strides=strides, dtype=data.dtype)
def transformToArray(tr):

View File

@ -344,14 +344,14 @@ def test_eq():
a2 = a1 + 1
a3 = a2.astype('int')
a4 = np.empty((0, 20))
assert not eq(a1, a2)
assert not eq(a1, a3)
assert not eq(a1, a4)
assert not eq(a1, a2) # same shape/dtype, different values
assert not eq(a1, a3) # same shape, different dtype and values
assert not eq(a1, a4) # different shape (note: np.all gives True if one array has size 0)
assert eq(a2, a3)
assert not eq(a2, a4)
assert not eq(a2, a3) # same values, but different dtype
assert not eq(a2, a4) # different shape
assert not eq(a3, a4)
assert not eq(a3, a4) # different shape and dtype
assert eq(a4, a4.copy())
assert not eq(a4, a4.T)