Improve ScatterPlotItem.py

Add optimization for PySide, Plot only visible symbole, cache rectTarg
This commit is contained in:
Guillaume Poulin 2013-09-24 16:12:29 +08:00
parent bd43a7508a
commit 73a079a649

View File

@ -3,7 +3,7 @@ from pyqtgraph.Point import Point
import pyqtgraph.functions as fn import pyqtgraph.functions as fn
from .GraphicsItem import GraphicsItem from .GraphicsItem import GraphicsItem
from .GraphicsObject import GraphicsObject from .GraphicsObject import GraphicsObject
from itertools import starmap from itertools import starmap, repeat
try: try:
from itertools import imap from itertools import imap
except ImportError: except ImportError:
@ -102,7 +102,6 @@ class SymbolAtlas(object):
self.symbolPen = weakref.WeakValueDictionary() self.symbolPen = weakref.WeakValueDictionary()
self.symbolBrush = weakref.WeakValueDictionary() self.symbolBrush = weakref.WeakValueDictionary()
self.symbolRectSrc = weakref.WeakValueDictionary() self.symbolRectSrc = weakref.WeakValueDictionary()
self.symbolRectTarg = weakref.WeakValueDictionary()
self.atlasData = None # numpy array of atlas image self.atlasData = None # numpy array of atlas image
self.atlas = None # atlas as QPixmap self.atlas = None # atlas as QPixmap
@ -114,33 +113,25 @@ class SymbolAtlas(object):
Given a list of spot records, return an object representing the coordinates of that symbol within the atlas Given a list of spot records, return an object representing the coordinates of that symbol within the atlas
""" """
rectSrc = np.empty(len(opts), dtype=object) rectSrc = np.empty(len(opts), dtype=object)
rectTarg = np.empty(len(opts), dtype=object)
keyi = None keyi = None
rectSrci = None rectSrci = None
rectTargi = None
for i, rec in enumerate(opts): for i, rec in enumerate(opts):
key = (rec[3], rec[2], id(rec[4]), id(rec[5])) key = (rec[3], rec[2], id(rec[4]), id(rec[5]))
if key == keyi: if key == keyi:
rectSrc[i] = rectSrci rectSrc[i] = rectSrci
rectTarg[i] = rectTargi
else: else:
try: try:
rectSrc[i] = self.symbolRectSrc[key] rectSrc[i] = self.symbolRectSrc[key]
rectTarg[i] = self.symbolRectTarg[key]
except KeyError: except KeyError:
newRectSrc = QtCore.QRectF() newRectSrc = QtCore.QRectF()
newRectTarg = QtCore.QRectF()
self.symbolPen[key] = rec['pen'] self.symbolPen[key] = rec['pen']
self.symbolBrush[key] = rec['brush'] self.symbolBrush[key] = rec['brush']
self.symbolRectSrc[key] = newRectSrc self.symbolRectSrc[key] = newRectSrc
self.symbolRectTarg[key] = newRectTarg
self.atlasValid = False self.atlasValid = False
rectSrc[i] = self.symbolRectSrc[key] rectSrc[i] = self.symbolRectSrc[key]
rectTarg[i] = self.symbolRectTarg[key]
keyi = key keyi = key
rectSrci = self.symbolRectSrc[key] rectSrci = self.symbolRectSrc[key]
rectTargi = self.symbolRectTarg[key] return rectSrc
return rectSrc, rectTarg
def buildAtlas(self): def buildAtlas(self):
# get rendered array for all symbols, keep track of avg/max width # get rendered array for all symbols, keep track of avg/max width
@ -195,7 +186,6 @@ class SymbolAtlas(object):
self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte) self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte)
for key in symbols: for key in symbols:
y, x, h, w = self.symbolRectSrc[key].getRect() y, x, h, w = self.symbolRectSrc[key].getRect()
self.symbolRectTarg[key].setRect(-h/2, -w/2, h, w)
self.atlasData[x:x+w, y:y+h] = rendered[key] self.atlasData[x:x+w, y:y+h] = rendered[key]
self.atlas = None self.atlas = None
self.atlasValid = True self.atlasValid = True
@ -247,7 +237,7 @@ class ScatterPlotItem(GraphicsObject):
self.target = None self.target = None
self.fragmentAtlas = SymbolAtlas() self.fragmentAtlas = SymbolAtlas()
self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('rectSrc', object), ('rectTarg', object)]) self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('rectSrc', object), ('rectTarg', object), ('width', float)])
self.bounds = [None, None] ## caches data bounds self.bounds = [None, None] ## caches data bounds
self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots
self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots
@ -561,15 +551,17 @@ class ScatterPlotItem(GraphicsObject):
if np.any(mask): if np.any(mask):
invalidate = True invalidate = True
opts = self.getSpotOpts(dataSet[mask]) opts = self.getSpotOpts(dataSet[mask])
rectSrc, rectTarg = self.fragmentAtlas.getSymbolCoords(opts) rectSrc = self.fragmentAtlas.getSymbolCoords(opts)
dataSet['rectSrc'][mask] = rectSrc dataSet['rectSrc'][mask] = rectSrc
dataSet['rectTarg'][mask] = rectTarg
#for rec in dataSet: #for rec in dataSet:
#if rec['fragCoords'] is None: #if rec['fragCoords'] is None:
#invalidate = True #invalidate = True
#rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec)) #rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec))
self.fragmentAtlas.getAtlas() self.fragmentAtlas.getAtlas()
dataSet['width'] = np.array(list(imap(QtCore.QRectF.width, dataSet['rectSrc'])))/2
dataSet['rectTarg'] = list(imap(QtCore.QRectF, repeat(0), repeat(0), dataSet['width']*2, dataSet['width']*2))
self._maxSpotPxWidth=self.fragmentAtlas.max_width self._maxSpotPxWidth=self.fragmentAtlas.max_width
else: else:
self._maxSpotWidth = 0 self._maxSpotWidth = 0
@ -699,9 +691,15 @@ class ScatterPlotItem(GraphicsObject):
tr = self.deviceTransform() tr = self.deviceTransform()
if tr is None: if tr is None:
return return
pts = np.empty((2,len(self.data['x']))) mask = np.logical_and(
pts[0] = self.data['x'] np.logical_and(self.data['x'] - self.data['width'] > range[0][0],
pts[1] = self.data['y'] self.data['x'] + self.data['width'] < range[0][1]),
np.logical_and(self.data['y'] - self.data['width'] > range[1][0],
self.data['y'] + self.data['width'] < range[1][1])) ## remove out of view points
data = self.data[mask]
pts = np.empty((2,len(data['x'])))
pts[0] = data['x']
pts[1] = data['y']
pts = fn.transformCoordinates(tr, pts) pts = fn.transformCoordinates(tr, pts)
self.fragments = [] self.fragments = []
pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault. pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault.
@ -746,18 +744,39 @@ class ScatterPlotItem(GraphicsObject):
p.resetTransform() p.resetTransform()
if not USE_PYSIDE and self.opts['useCache'] and self._exportOpts is False: if self.opts['useCache'] and self._exportOpts is False:
tr = self.deviceTransform() tr = self.deviceTransform()
if tr is None: if tr is None:
return return
pts = np.empty((2,len(self.data['x']))) w = np.empty((2,len(self.data['width'])))
pts[0] = self.data['x'] w[0] = self.data['width']
pts[1] = self.data['y'] w[1] = self.data['width']
q, intv = tr.inverted()
if intv:
w = fn.transformCoordinates(q, w)
w=np.abs(w)
range = self.getViewBox().viewRange()
mask = np.logical_and(
np.logical_and(self.data['x'] + w[0,:] > range[0][0],
self.data['x'] - w[0,:] < range[0][1]),
np.logical_and(self.data['y'] + w[0,:] > range[1][0],
self.data['y'] - w[0,:] < range[1][1])) ## remove out of view points
data = self.data[mask]
else:
data = self.data
pts = np.empty((2,len(data['x'])))
pts[0] = data['x']
pts[1] = data['y']
pts = fn.transformCoordinates(tr, pts) pts = fn.transformCoordinates(tr, pts)
pts -= data['width']
pts = np.clip(pts, -2**30, 2**30) pts = np.clip(pts, -2**30, 2**30)
if self.target == None: if self.target == None:
self.target = list(imap(QtCore.QRectF.translated, self.data['rectTarg'], pts[0,:], pts[1,:])) list(imap(QtCore.QRectF.moveTo, data['rectTarg'], pts[0,:], pts[1,:]))
p.drawPixmapFragments(self.target, self.data['rectSrc'].tolist(), atlas) self.target=data['rectTarg']
if USE_PYSIDE:
list(imap(p.drawPixmap, self.target, repeat(atlas), data['rectSrc']))
else:
p.drawPixmapFragments(self.target.tolist(), data['rectSrc'].tolist(), atlas)
#p.drawPixmapFragments(self.fragments, atlas) #p.drawPixmapFragments(self.fragments, atlas)
else: else:
if self.fragments is None: if self.fragments is None: