reload tests pass in python 3

This commit is contained in:
Luke Campagnola 2018-04-26 13:22:47 -07:00
parent bc2c3232e2
commit 60a48ed2e4
3 changed files with 79 additions and 34 deletions

View File

@ -2444,7 +2444,7 @@ def disconnect(signal, slot):
try: try:
signal.disconnect(slot) signal.disconnect(slot)
return True return True
except TypeError, RuntimeError: except (TypeError, RuntimeError):
slot = reload.getPreviousVersion(slot) slot = reload.getPreviousVersion(slot)
if slot is None: if slot is None:
return False return False

View File

@ -21,15 +21,18 @@ Does NOT:
print module.someObject print module.someObject
""" """
from __future__ import print_function
import inspect, os, sys, gc, traceback, types import inspect, os, sys, gc, traceback, types
try: try:
import __builtin__ as builtins from builtins import reload as orig_reload
except ImportError: except ImportError:
import builtins from importlib import reload as orig_reload
from .debug import printExc from .debug import printExc
py3 = sys.version_info >= (3,)
def reloadAll(prefix=None, debug=False): def reloadAll(prefix=None, debug=False):
"""Automatically reload everything whose __file__ begins with prefix. """Automatically reload everything whose __file__ begins with prefix.
- Skips reload if the file has not been updated (if .pyc is newer than .py) - Skips reload if the file has not been updated (if .pyc is newer than .py)
@ -80,7 +83,7 @@ def reload(module, debug=False, lists=False, dicts=False):
## make a copy of the old module dictionary, reload, then grab the new module dictionary for comparison ## make a copy of the old module dictionary, reload, then grab the new module dictionary for comparison
oldDict = module.__dict__.copy() oldDict = module.__dict__.copy()
builtins.reload(module) orig_reload(module)
newDict = module.__dict__ newDict = module.__dict__
## Allow modules access to the old dictionary after they reload ## Allow modules access to the old dictionary after they reload
@ -130,6 +133,9 @@ def updateFunction(old, new, debug, depth=0, visited=None):
old.__code__ = new.__code__ old.__code__ = new.__code__
old.__defaults__ = new.__defaults__ old.__defaults__ = new.__defaults__
if hasattr(old, '__kwdefaults'):
old.__kwdefaults__ = new.__kwdefaults__
old.__doc__ = new.__doc__
if visited is None: if visited is None:
visited = [] visited = []
@ -154,6 +160,8 @@ def updateFunction(old, new, debug, depth=0, visited=None):
## For classes: ## For classes:
## 1) find all instances of the old class and set instance.__class__ to the new class ## 1) find all instances of the old class and set instance.__class__ to the new class
## 2) update all old class methods to use code from the new class methods ## 2) update all old class methods to use code from the new class methods
def updateClass(old, new, debug): def updateClass(old, new, debug):
## Track town all instances and subclasses of old ## Track town all instances and subclasses of old
refs = gc.get_referrers(old) refs = gc.get_referrers(old)
@ -198,7 +206,8 @@ def updateClass(old, new, debug):
## but it fixes a few specific cases (pyqt signals, for one) ## but it fixes a few specific cases (pyqt signals, for one)
for attr in dir(old): for attr in dir(old):
oa = getattr(old, attr) oa = getattr(old, attr)
if inspect.ismethod(oa): if (py3 and inspect.isfunction(oa)) or inspect.ismethod(oa):
# note python2 has unbound methods, whereas python3 just uses plain functions
try: try:
na = getattr(new, attr) na = getattr(new, attr)
except AttributeError: except AttributeError:
@ -206,11 +215,14 @@ def updateClass(old, new, debug):
print(" Skipping method update for %s; new class does not have this attribute" % attr) print(" Skipping method update for %s; new class does not have this attribute" % attr)
continue continue
if hasattr(oa, 'im_func') and hasattr(na, 'im_func') and oa.__func__ is not na.__func__: ofunc = getattr(oa, '__func__', oa) # in py2 we have to get the __func__ from unbound method,
depth = updateFunction(oa.__func__, na.__func__, debug) nfunc = getattr(na, '__func__', na) # in py3 the attribute IS the function
if not hasattr(na.__func__, '__previous_reload_method__'):
na.__func__.__previous_reload_method__ = oa # important for managing signal connection if ofunc is not nfunc:
#oa.im_class = new ## bind old method to new class ## not allowed depth = updateFunction(ofunc, nfunc, debug)
if not hasattr(nfunc, '__previous_reload_method__'):
nfunc.__previous_reload_method__ = oa # important for managing signal connection
#oa.__class__ = new ## bind old method to new class ## not allowed
if debug: if debug:
extra = "" extra = ""
if depth > 0: if depth > 0:
@ -251,16 +263,22 @@ def getPreviousVersion(obj):
if isinstance(obj, type) or inspect.isfunction(obj): if isinstance(obj, type) or inspect.isfunction(obj):
return getattr(obj, '__previous_reload_version__', None) return getattr(obj, '__previous_reload_version__', None)
elif inspect.ismethod(obj): elif inspect.ismethod(obj):
if obj.im_self is None: if obj.__self__ is None:
# unbound method # unbound method
return getattr(obj.__func__, '__previous_reload_method__', None) return getattr(obj.__func__, '__previous_reload_method__', None)
else: else:
oldmethod = getattr(obj.__func__, '__previous_reload_method__', None) oldmethod = getattr(obj.__func__, '__previous_reload_method__', None)
if oldmethod is None: if oldmethod is None:
return None return None
self = obj.im_self self = obj.__self__
cls = oldmethod.im_class oldfunc = getattr(oldmethod, '__func__', oldmethod)
return types.MethodType(oldmethod.__func__, self, cls) if hasattr(oldmethod, 'im_class'):
# python 2
cls = oldmethod.im_class
return types.MethodType(oldfunc, self, cls)
else:
# python 3
return types.MethodType(oldfunc, self)

View File

@ -1,14 +1,23 @@
import tempfile, os, sys, shutil, atexit import tempfile, os, sys, shutil
import pyqtgraph as pg import pyqtgraph as pg
import pyqtgraph.reload import pyqtgraph.reload
pgpath = os.path.join(os.path.dirname(pg.__file__), '..') pgpath = os.path.join(os.path.dirname(pg.__file__), '..')
# make temporary directory to write module code # make temporary directory to write module code
path = tempfile.mkdtemp() path = None
sys.path.insert(0, path)
def cleanup(): def setup_module():
# make temporary directory to write module code
global path
path = tempfile.mkdtemp()
sys.path.insert(0, path)
def teardown_module():
global path
shutil.rmtree(path) shutil.rmtree(path)
atexit.register(cleanup) sys.path.remove(path)
code = """ code = """
@ -33,6 +42,8 @@ def remove_cache(mod):
def test_reload(): def test_reload():
py3 = sys.version_info >= (3,)
# write a module # write a module
mod = os.path.join(path, 'reload_test.py') mod = os.path.join(path, 'reload_test.py')
open(mod, 'w').write(code.format(path=path, msg="C.fn() Version1")) open(mod, 'w').write(code.format(path=path, msg="C.fn() Version1"))
@ -42,7 +53,10 @@ def test_reload():
c = reload_test.C() c = reload_test.C()
c.sig.connect(c.fn) c.sig.connect(c.fn)
v1 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, reload_test.C.fn.__func__, c.sig, c.fn, c.fn.__func__) if py3:
v1 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, c.sig, c.fn, c.fn.__func__)
else:
v1 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, reload_test.C.fn.__func__, c.sig, c.fn, c.fn.__func__)
@ -50,25 +64,34 @@ def test_reload():
open(mod, 'w').write(code.format(path=path, msg="C.fn() Version2")) open(mod, 'w').write(code.format(path=path, msg="C.fn() Version2"))
remove_cache(mod) remove_cache(mod)
pg.reload.reloadAll(path, debug=True) pg.reload.reloadAll(path, debug=True)
v2 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, reload_test.C.fn.__func__, c.sig, c.fn, c.fn.__func__) if py3:
v2 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, c.sig, c.fn, c.fn.__func__)
else:
v2 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, reload_test.C.fn.__func__, c.sig, c.fn, c.fn.__func__)
assert c.fn.im_class is v2[0] if not py3:
assert c.fn.im_class is v2[0]
oldcfn = pg.reload.getPreviousVersion(c.fn) oldcfn = pg.reload.getPreviousVersion(c.fn)
if oldcfn is None: if oldcfn is None:
# Function did not reload; are we using pytest's assertion rewriting? # Function did not reload; are we using pytest's assertion rewriting?
raise Exception("Function did not reload. (This can happen when using py.test" raise Exception("Function did not reload. (This can happen when using py.test"
" with assertion rewriting; use --assert=plain for this test.)") " with assertion rewriting; use --assert=plain for this test.)")
assert oldcfn.im_class is v1[0] if py3:
assert oldcfn.im_func is v1[2].im_func assert oldcfn.__func__ is v1[2]
assert oldcfn.im_self is c else:
assert oldcfn.im_class is v1[0]
assert oldcfn.__func__ is v1[2].__func__
assert oldcfn.__self__ is c
# write again and reload # write again and reload
open(mod, 'w').write(code.format(path=path, msg="C.fn() Version2")) open(mod, 'w').write(code.format(path=path, msg="C.fn() Version2"))
remove_cache(mod) remove_cache(mod)
pg.reload.reloadAll(path, debug=True) pg.reload.reloadAll(path, debug=True)
v3 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, reload_test.C.fn.__func__, c.sig, c.fn, c.fn.__func__) if py3:
v3 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, c.sig, c.fn, c.fn.__func__)
else:
v3 = (reload_test.C, reload_test.C.sig, reload_test.C.fn, reload_test.C.fn.__func__, c.sig, c.fn, c.fn.__func__)
#for i in range(len(old)): #for i in range(len(old)):
#print id(old[i]), id(new1[i]), id(new2[i]), old[i], new1[i] #print id(old[i]), id(new1[i]), id(new2[i]), old[i], new1[i]
@ -76,11 +99,15 @@ def test_reload():
cfn1 = pg.reload.getPreviousVersion(c.fn) cfn1 = pg.reload.getPreviousVersion(c.fn)
cfn2 = pg.reload.getPreviousVersion(cfn1) cfn2 = pg.reload.getPreviousVersion(cfn1)
assert cfn1.im_class is v2[0] if py3:
assert cfn1.im_func is v2[2].im_func assert cfn1.__func__ is v2[2]
assert cfn1.im_self is c assert cfn2.__func__ is v1[2]
assert cfn2.im_class is v1[0] else:
assert cfn2.im_func is v1[2].im_func assert cfn1.__func__ is v2[2].__func__
assert cfn2.im_self is c assert cfn2.__func__ is v1[2].__func__
assert cfn1.im_class is v2[0]
assert cfn2.im_class is v1[0]
assert cfn1.__self__ is c
assert cfn2.__self__ is c
c.sig.disconnect(cfn2) c.sig.disconnect(cfn2)