eq(): better performance by avoiding array comparison when shapes do not match
This commit is contained in:
parent
f627a6a447
commit
aad1c737c3
@ -418,6 +418,20 @@ def eq(a, b):
|
|||||||
if a is b:
|
if a is b:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Avoid comparing large arrays against scalars; this is expensive and we know it should return False.
|
||||||
|
aIsArr = isinstance(a, (np.ndarray, MetaArray))
|
||||||
|
bIsArr = isinstance(b, (np.ndarray, MetaArray))
|
||||||
|
if (aIsArr or bIsArr) and type(a) != type(b):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If both inputs are arrays, we can speeed up comparison if shapes / dtypes don't match
|
||||||
|
# NOTE: arrays of dissimilar type should be considered unequal even if they are numerically
|
||||||
|
# equal because they may behave differently when computed on.
|
||||||
|
if aIsArr and bIsArr and (a.shape != b.shape or a.dtype != b.dtype):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test for equivalence.
|
||||||
|
# If the test raises a recognized exception, then return Falase
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
# Sometimes running catch_warnings(module=np) generates AttributeError ???
|
# Sometimes running catch_warnings(module=np) generates AttributeError ???
|
||||||
|
Loading…
Reference in New Issue
Block a user