Fix eq() bug where calling catch_warnings raised an AttributeError, which would cause eq() to return False

Add unit test coverage
This commit is contained in:
Luke Campagnola 2017-08-02 15:02:38 -07:00
parent b4e722f07b
commit 9094261c54
2 changed files with 90 additions and 10 deletions

View File

@ -200,7 +200,7 @@ def mkColor(*args):
try:
return Colors[c]
except KeyError:
raise Exception('No color named "%s"' % c)
raise ValueError('No color named "%s"' % c)
if len(c) == 3:
r = int(c[0]*2, 16)
g = int(c[1]*2, 16)
@ -235,18 +235,18 @@ def mkColor(*args):
elif len(args[0]) == 2:
return intColor(*args[0])
else:
raise Exception(err)
raise TypeError(err)
elif type(args[0]) == int:
return intColor(args[0])
else:
raise Exception(err)
raise TypeError(err)
elif len(args) == 3:
(r, g, b) = args
a = 255
elif len(args) == 4:
(r, g, b, a) = args
else:
raise Exception(err)
raise TypeError(err)
args = [r,g,b,a]
args = [0 if np.isnan(a) or np.isinf(a) else a for a in args]
@ -404,22 +404,39 @@ def makeArrowPath(headLen=20, tipAngle=20, tailLen=20, tailWidth=3, baseAngle=0)
def eq(a, b):
"""The great missing equivalence function: Guaranteed evaluation to a single bool value."""
"""The great missing equivalence function: Guaranteed evaluation to a single bool value.
This function has some important differences from the == operator:
1. Returns True if a IS b, even if a==b still evaluates to False, such as with nan values.
2. Tests for equivalence using ==, but silently ignores some common exceptions that can occur
(AtrtibuteError, ValueError).
3. When comparing arrays, returns False if the array shapes are not the same.
4. When comparing arrays of the same shape, returns True only if all elements are equal (whereas
the == operator would return a boolean array).
"""
if a is b:
return True
try:
with warnings.catch_warnings(module=np): # ignore numpy futurewarning (numpy v. 1.10)
e = a==b
except ValueError:
return False
except AttributeError:
try:
# Sometimes running catch_warnings(module=np) generates AttributeError ???
catcher = warnings.catch_warnings(module=np) # ignore numpy futurewarning (numpy v. 1.10)
catcher.__enter__()
except Exception:
catcher = None
e = a==b
except (ValueError, AttributeError):
return False
except:
print('failed to evaluate equivalence for:')
print(" a:", str(type(a)), str(a))
print(" b:", str(type(b)), str(b))
raise
finally:
if catcher is not None:
catcher.__exit__(None, None, None)
t = type(e)
if t is bool:
return e

View File

@ -1,5 +1,6 @@
import pyqtgraph as pg
import numpy as np
import sys
from numpy.testing import assert_array_almost_equal, assert_almost_equal
import pytest
@ -293,6 +294,68 @@ def test_makeARGB():
with AssertExc(): # 3d levels not allowed
pg.makeARGB(np.zeros((2,2,3), dtype='float'), levels=np.zeros([3, 2, 2]))
def test_eq():
eq = pg.functions.eq
zeros = [0, 0.0, np.float(0), np.int(0)]
if sys.version[0] < '3':
zeros.append(long(0))
for i,x in enumerate(zeros):
for y in zeros[i:]:
assert eq(x, y)
assert eq(y, x)
assert eq(np.nan, np.nan)
# test
class NotEq(object):
def __eq__(self, x):
return False
noteq = NotEq()
assert eq(noteq, noteq) # passes because they are the same object
assert not eq(noteq, NotEq())
# Should be able to test for equivalence even if the test raises certain
# exceptions
class NoEq(object):
def __init__(self, err):
self.err = err
def __eq__(self, x):
raise self.err
noeq1 = NoEq(AttributeError())
noeq2 = NoEq(ValueError())
noeq3 = NoEq(Exception())
assert eq(noeq1, noeq1)
assert not eq(noeq1, noeq2)
assert not eq(noeq2, noeq1)
with pytest.raises(Exception):
eq(noeq3, noeq2)
# test array equivalence
# note that numpy has a weird behavior here--np.all() always returns True
# if one of the arrays has size=0; eq() will only return True if both arrays
# have the same shape.
a1 = np.zeros((10, 20)).astype('float')
a2 = a1 + 1
a3 = a2.astype('int')
a4 = np.empty((0, 20))
assert not eq(a1, a2)
assert not eq(a1, a3)
assert not eq(a1, a4)
assert eq(a2, a3)
assert not eq(a2, a4)
assert not eq(a3, a4)
assert eq(a4, a4.copy())
assert not eq(a4, a4.T)
if __name__ == '__main__':
test_interpolateArray()