Some few more optimization to ScatterPlotItem

This commit is contained in:
Guillaume Poulin 2013-09-23 16:45:43 +08:00
parent f5ee45ac28
commit c3576b1c09

View File

@ -92,9 +92,6 @@ class SymbolAtlas(object):
pm = atlas.getAtlas() pm = atlas.getAtlas()
""" """
class SymbolCoords(list): ## needed because lists are not allowed in weak references.
pass
def __init__(self): def __init__(self):
# symbol key : [x, y, w, h] atlas coordinates # symbol key : [x, y, w, h] atlas coordinates
# note that the coordinate list will always be the same list object as # note that the coordinate list will always be the same list object as
@ -102,9 +99,10 @@ class SymbolAtlas(object):
# change if the atlas is rebuilt. # change if the atlas is rebuilt.
# weak value; if all external refs to this list disappear, # weak value; if all external refs to this list disappear,
# the symbol will be forgotten. # the symbol will be forgotten.
self.symbolMap = weakref.WeakValueDictionary()
self.symbolPen = weakref.WeakValueDictionary() self.symbolPen = weakref.WeakValueDictionary()
self.symbolBrush = weakref.WeakValueDictionary() self.symbolBrush = weakref.WeakValueDictionary()
self.symbolRectSrc = weakref.WeakValueDictionary()
self.symbolRectTarg = weakref.WeakValueDictionary()
self.atlasData = None # numpy array of atlas image self.atlasData = None # numpy array of atlas image
self.atlas = None # atlas as QPixmap self.atlas = None # atlas as QPixmap
@ -115,30 +113,34 @@ class SymbolAtlas(object):
""" """
Given a list of spot records, return an object representing the coordinates of that symbol within the atlas Given a list of spot records, return an object representing the coordinates of that symbol within the atlas
""" """
coords = np.empty(len(opts), dtype=object) rectSrc = np.empty(len(opts), dtype=object)
rectTarg = np.empty(len(opts), dtype=object)
keyi = None keyi = None
coordi = None rectSrci = None
rectTargi = None
for i, rec in enumerate(opts): for i, rec in enumerate(opts):
key = (rec[3], rec[2], id(rec[4]), id(rec[5])) key = (rec[3], rec[2], id(rec[4]), id(rec[5]))
if key == keyi: if key == keyi:
coords[i]=coordi rectSrc[i] = rectSrci
rectTarg[i] = rectTargi
else: else:
try: try:
coords[i] = self.symbolMap[key] rectSrc[i] = self.symbolRectSrc[key]
rectTarg[i] = self.symbolRectTarg[key]
except KeyError: except KeyError:
newCoords = SymbolAtlas.SymbolCoords() newRectSrc = QtCore.QRectF()
self.symbolMap[key] = newCoords newRectTarg = QtCore.QRectF()
self.symbolPen[key] = rec['pen'] self.symbolPen[key] = rec['pen']
self.symbolBrush[key] = rec['brush'] self.symbolBrush[key] = rec['brush']
self.symbolRectSrc[key] = newRectSrc
self.symbolRectTarg[key] = newRectTarg
self.atlasValid = False self.atlasValid = False
#try: rectSrc[i] = self.symbolRectSrc[key]
#self.addToAtlas(key) ## squeeze this into the atlas if there is room rectTarg[i] = self.symbolRectTarg[key]
#except:
#self.buildAtlas() ## otherwise, we need to rebuild
coords[i] = newCoords
keyi = key keyi = key
coordi = newCoords rectSrci = self.symbolRectSrc[key]
return coords rectTargi = self.symbolRectTarg[key]
return rectSrc, rectTarg
def buildAtlas(self): def buildAtlas(self):
# get rendered array for all symbols, keep track of avg/max width # get rendered array for all symbols, keep track of avg/max width
@ -146,15 +148,15 @@ class SymbolAtlas(object):
avgWidth = 0.0 avgWidth = 0.0
maxWidth = 0 maxWidth = 0
images = [] images = []
for key, coords in self.symbolMap.items(): for key, rectSrc in self.symbolRectSrc.items():
if len(coords) == 0: if rectSrc.width() == 0:
pen = self.symbolPen[key] pen = self.symbolPen[key]
brush = self.symbolBrush[key] brush = self.symbolBrush[key]
img = renderSymbol(key[0], key[1], pen, brush) img = renderSymbol(key[0], key[1], pen, brush)
images.append(img) ## we only need this to prevent the images being garbage collected immediately images.append(img) ## we only need this to prevent the images being garbage collected immediately
arr = fn.imageToArray(img, copy=False, transpose=False) arr = fn.imageToArray(img, copy=False, transpose=False)
else: else:
(y,x,h,w) = self.symbolMap[key] (y,x,h,w) = rectSrc.getRect()
arr = self.atlasData[x:x+w, y:y+w] arr = self.atlasData[x:x+w, y:y+w]
rendered[key] = arr rendered[key] = arr
w = arr.shape[0] w = arr.shape[0]
@ -185,14 +187,15 @@ class SymbolAtlas(object):
x = 0 x = 0
rowheight = h rowheight = h
self.atlasRows.append([y, rowheight, 0]) self.atlasRows.append([y, rowheight, 0])
self.symbolMap[key][:] = y, x, h, w self.symbolRectSrc[key].setRect(y, x, h, w)
x += w x += w
self.atlasRows[-1][2] = x self.atlasRows[-1][2] = x
height = y + rowheight height = y + rowheight
self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte) self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte)
for key in symbols: for key in symbols:
y, x, h, w = self.symbolMap[key] y, x, h, w = self.symbolRectSrc[key].getRect()
self.symbolRectTarg[key].setRect(-h/2, -w/2, h, w)
self.atlasData[x:x+w, y:y+h] = rendered[key] self.atlasData[x:x+w, y:y+h] = rendered[key]
self.atlas = None self.atlas = None
self.atlasValid = True self.atlasValid = True
@ -241,9 +244,10 @@ class ScatterPlotItem(GraphicsObject):
self.picture = None # QPicture used for rendering when pxmode==False self.picture = None # QPicture used for rendering when pxmode==False
self.fragments = None # fragment specification for pxmode; updated every time the view changes. self.fragments = None # fragment specification for pxmode; updated every time the view changes.
self.tar = None
self.fragmentAtlas = SymbolAtlas() self.fragmentAtlas = SymbolAtlas()
self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('fragCoords', object), ('item', object)]) self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('rectSrc', object), ('rectTarg', object)])
self.bounds = [None, None] ## caches data bounds self.bounds = [None, None] ## caches data bounds
self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots
self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots
@ -412,6 +416,7 @@ class ScatterPlotItem(GraphicsObject):
## clear any cached drawing state ## clear any cached drawing state
self.picture = None self.picture = None
self.fragments = None self.fragments = None
self.tar = None
self.update() self.update()
def getData(self): def getData(self):
@ -446,7 +451,7 @@ class ScatterPlotItem(GraphicsObject):
else: else:
self.opts['pen'] = fn.mkPen(*args, **kargs) self.opts['pen'] = fn.mkPen(*args, **kargs)
dataSet['fragCoords'] = None dataSet['rectSrc'] = None
if update: if update:
self.updateSpots(dataSet) self.updateSpots(dataSet)
@ -471,7 +476,7 @@ class ScatterPlotItem(GraphicsObject):
self.opts['brush'] = fn.mkBrush(*args, **kargs) self.opts['brush'] = fn.mkBrush(*args, **kargs)
#self._spotPixmap = None #self._spotPixmap = None
dataSet['fragCoords'] = None dataSet['rectSrc'] = None
if update: if update:
self.updateSpots(dataSet) self.updateSpots(dataSet)
@ -494,7 +499,7 @@ class ScatterPlotItem(GraphicsObject):
self.opts['symbol'] = symbol self.opts['symbol'] = symbol
self._spotPixmap = None self._spotPixmap = None
dataSet['fragCoords'] = None dataSet['rectSrc'] = None
if update: if update:
self.updateSpots(dataSet) self.updateSpots(dataSet)
@ -517,7 +522,7 @@ class ScatterPlotItem(GraphicsObject):
self.opts['size'] = size self.opts['size'] = size
self._spotPixmap = None self._spotPixmap = None
dataSet['fragCoords'] = None dataSet['rectSrc'] = None
if update: if update:
self.updateSpots(dataSet) self.updateSpots(dataSet)
@ -552,12 +557,13 @@ class ScatterPlotItem(GraphicsObject):
invalidate = False invalidate = False
if self.opts['pxMode']: if self.opts['pxMode']:
mask = np.equal(dataSet['fragCoords'], None) mask = np.equal(dataSet['rectSrc'], None)
if np.any(mask): if np.any(mask):
invalidate = True invalidate = True
opts = self.getSpotOpts(dataSet[mask]) opts = self.getSpotOpts(dataSet[mask])
coords = self.fragmentAtlas.getSymbolCoords(opts) rectSrc, rectTarg = self.fragmentAtlas.getSymbolCoords(opts)
dataSet['fragCoords'][mask] = coords dataSet['rectSrc'][mask] = rectSrc
dataSet['rectTarg'][mask] = rectTarg
#for rec in dataSet: #for rec in dataSet:
#if rec['fragCoords'] is None: #if rec['fragCoords'] is None:
@ -687,6 +693,7 @@ class ScatterPlotItem(GraphicsObject):
GraphicsObject.viewTransformChanged(self) GraphicsObject.viewTransformChanged(self)
self.bounds = [None, None] self.bounds = [None, None]
self.fragments = None self.fragments = None
self.tar = None
def generateFragments(self): def generateFragments(self):
tr = self.deviceTransform() tr = self.deviceTransform()
@ -705,9 +712,8 @@ class ScatterPlotItem(GraphicsObject):
# x,y,w,h = rec['fragCoords'] # x,y,w,h = rec['fragCoords']
# rect = QtCore.QRectF(y, x, h, w) # rect = QtCore.QRectF(y, x, h, w)
# self.fragments.append(QtGui.QPainter.PixmapFragment.create(pos, rect)) # self.fragments.append(QtGui.QPainter.PixmapFragment.create(pos, rect))
rect = starmap(QtCore.QRectF, self.data['fragCoords'])
pos = imap(QtCore.QPointF, pts[0,:], pts[1,:]) pos = imap(QtCore.QPointF, pts[0,:], pts[1,:])
self.fragments = list(imap(QtGui.QPainter.PixmapFragment.create, pos, rect)) self.fragments = list(imap(QtGui.QPainter.PixmapFragment.create, pos, self.data['rectSrc']))
def setExportMode(self, *args, **kwds): def setExportMode(self, *args, **kwds):
GraphicsObject.setExportMode(self, *args, **kwds) GraphicsObject.setExportMode(self, *args, **kwds)
@ -734,15 +740,28 @@ class ScatterPlotItem(GraphicsObject):
#print "Atlas changed:", arr #print "Atlas changed:", arr
#self.lastAtlas = arr #self.lastAtlas = arr
if self.fragments is None: #if self.fragments is None:
#self.updateSpots() #self.updateSpots()
self.generateFragments() #self.generateFragments()
p.resetTransform() p.resetTransform()
if not USE_PYSIDE and self.opts['useCache'] and self._exportOpts is False: if not USE_PYSIDE and self.opts['useCache'] and self._exportOpts is False:
p.drawPixmapFragments(self.fragments, atlas) tr = self.deviceTransform()
if tr is None:
return
pts = np.empty((2,len(self.data['x'])))
pts[0] = self.data['x']
pts[1] = self.data['y']
pts = fn.transformCoordinates(tr, pts)
pts = np.clip(pts, -2**30, 2**30)
if self.tar == None:
self.tar = list(imap(QtCore.QRectF.translated, self.data['rectTarg'], pts[0,:], pts[1,:]))
p.drawPixmapFragments(self.tar, self.data['rectSrc'].tolist(), atlas)
#p.drawPixmapFragments(self.fragments, atlas)
else: else:
if self.fragments is None:
self.generateFragments()
p.setRenderHint(p.Antialiasing, aa) p.setRenderHint(p.Antialiasing, aa)
for i in range(len(self.data)): for i in range(len(self.data)):
@ -911,7 +930,7 @@ class SpotItem(object):
self._data['data'] = data self._data['data'] = data
def updateItem(self): def updateItem(self):
self._data['fragCoords'] = None self._data['rectSrc'] = None
self._plot.updateSpots(self._data.reshape(1)) self._plot.updateSpots(self._data.reshape(1))
self._plot.invalidate() self._plot.invalidate()