From 9094261c542684146a1813470012cd1fbfcabe12 Mon Sep 17 00:00:00 2001 From: Luke Campagnola Date: Wed, 2 Aug 2017 15:02:38 -0700 Subject: [PATCH] Fix eq() bug where calling catch_warnings raised an AttributeError, which would cause eq() to return False Add unit test coverage --- pyqtgraph/functions.py | 37 +++++++++++++----- pyqtgraph/tests/test_functions.py | 63 +++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 10 deletions(-) diff --git a/pyqtgraph/functions.py b/pyqtgraph/functions.py index bdbf6d87..1aed6ace 100644 --- a/pyqtgraph/functions.py +++ b/pyqtgraph/functions.py @@ -200,7 +200,7 @@ def mkColor(*args): try: return Colors[c] except KeyError: - raise Exception('No color named "%s"' % c) + raise ValueError('No color named "%s"' % c) if len(c) == 3: r = int(c[0]*2, 16) g = int(c[1]*2, 16) @@ -235,18 +235,18 @@ def mkColor(*args): elif len(args[0]) == 2: return intColor(*args[0]) else: - raise Exception(err) + raise TypeError(err) elif type(args[0]) == int: return intColor(args[0]) else: - raise Exception(err) + raise TypeError(err) elif len(args) == 3: (r, g, b) = args a = 255 elif len(args) == 4: (r, g, b, a) = args else: - raise Exception(err) + raise TypeError(err) args = [r,g,b,a] args = [0 if np.isnan(a) or np.isinf(a) else a for a in args] @@ -404,22 +404,39 @@ def makeArrowPath(headLen=20, tipAngle=20, tailLen=20, tailWidth=3, baseAngle=0) def eq(a, b): - """The great missing equivalence function: Guaranteed evaluation to a single bool value.""" + """The great missing equivalence function: Guaranteed evaluation to a single bool value. + + This function has some important differences from the == operator: + + 1. Returns True if a IS b, even if a==b still evaluates to False, such as with nan values. + 2. Tests for equivalence using ==, but silently ignores some common exceptions that can occur + (AtrtibuteError, ValueError). + 3. When comparing arrays, returns False if the array shapes are not the same. + 4. When comparing arrays of the same shape, returns True only if all elements are equal (whereas + the == operator would return a boolean array). + """ if a is b: return True try: - with warnings.catch_warnings(module=np): # ignore numpy futurewarning (numpy v. 1.10) - e = a==b - except ValueError: - return False - except AttributeError: + try: + # Sometimes running catch_warnings(module=np) generates AttributeError ??? + catcher = warnings.catch_warnings(module=np) # ignore numpy futurewarning (numpy v. 1.10) + catcher.__enter__() + except Exception: + catcher = None + e = a==b + except (ValueError, AttributeError): return False except: print('failed to evaluate equivalence for:') print(" a:", str(type(a)), str(a)) print(" b:", str(type(b)), str(b)) raise + finally: + if catcher is not None: + catcher.__exit__(None, None, None) + t = type(e) if t is bool: return e diff --git a/pyqtgraph/tests/test_functions.py b/pyqtgraph/tests/test_functions.py index 7ad3bf91..eff56635 100644 --- a/pyqtgraph/tests/test_functions.py +++ b/pyqtgraph/tests/test_functions.py @@ -1,5 +1,6 @@ import pyqtgraph as pg import numpy as np +import sys from numpy.testing import assert_array_almost_equal, assert_almost_equal import pytest @@ -293,6 +294,68 @@ def test_makeARGB(): with AssertExc(): # 3d levels not allowed pg.makeARGB(np.zeros((2,2,3), dtype='float'), levels=np.zeros([3, 2, 2])) + +def test_eq(): + eq = pg.functions.eq + + zeros = [0, 0.0, np.float(0), np.int(0)] + if sys.version[0] < '3': + zeros.append(long(0)) + for i,x in enumerate(zeros): + for y in zeros[i:]: + assert eq(x, y) + assert eq(y, x) + + assert eq(np.nan, np.nan) + + # test + class NotEq(object): + def __eq__(self, x): + return False + + noteq = NotEq() + assert eq(noteq, noteq) # passes because they are the same object + assert not eq(noteq, NotEq()) + + + # Should be able to test for equivalence even if the test raises certain + # exceptions + class NoEq(object): + def __init__(self, err): + self.err = err + def __eq__(self, x): + raise self.err + + noeq1 = NoEq(AttributeError()) + noeq2 = NoEq(ValueError()) + noeq3 = NoEq(Exception()) + + assert eq(noeq1, noeq1) + assert not eq(noeq1, noeq2) + assert not eq(noeq2, noeq1) + with pytest.raises(Exception): + eq(noeq3, noeq2) + + # test array equivalence + # note that numpy has a weird behavior here--np.all() always returns True + # if one of the arrays has size=0; eq() will only return True if both arrays + # have the same shape. + a1 = np.zeros((10, 20)).astype('float') + a2 = a1 + 1 + a3 = a2.astype('int') + a4 = np.empty((0, 20)) + assert not eq(a1, a2) + assert not eq(a1, a3) + assert not eq(a1, a4) + + assert eq(a2, a3) + assert not eq(a2, a4) + + assert not eq(a3, a4) + + assert eq(a4, a4.copy()) + assert not eq(a4, a4.T) + if __name__ == '__main__': test_interpolateArray() \ No newline at end of file