improve SymbolAtlas.getSymbolCoords and ScatterPlotItem.plot performance (#1198)
This commit is contained in:
parent
99c43613f3
commit
6194245322
@ -15,6 +15,7 @@ from ..pgcollections import OrderedDict
|
|||||||
from .. import debug
|
from .. import debug
|
||||||
from ..python2_3 import basestring
|
from ..python2_3 import basestring
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ScatterPlotItem', 'SpotItem']
|
__all__ = ['ScatterPlotItem', 'SpotItem']
|
||||||
|
|
||||||
|
|
||||||
@ -128,8 +129,12 @@ class SymbolAtlas(object):
|
|||||||
sourceRecti = None
|
sourceRecti = None
|
||||||
symbol_map = self.symbolMap
|
symbol_map = self.symbolMap
|
||||||
|
|
||||||
for i, rec in enumerate(opts.tolist()):
|
symbols = opts['symbol'].tolist()
|
||||||
size, symbol, pen, brush = rec[2: 6]
|
sizes = opts['size'].tolist()
|
||||||
|
pens = opts['pen'].tolist()
|
||||||
|
brushes = opts['brush'].tolist()
|
||||||
|
|
||||||
|
for symbol, size, pen, brush in zip(symbols, sizes, pens, brushes):
|
||||||
|
|
||||||
key = id(symbol), size, id(pen), id(brush)
|
key = id(symbol), size, id(pen), id(brush)
|
||||||
if key == keyi:
|
if key == keyi:
|
||||||
@ -560,6 +565,7 @@ class ScatterPlotItem(GraphicsObject):
|
|||||||
self.invalidate()
|
self.invalidate()
|
||||||
|
|
||||||
def updateSpots(self, dataSet=None):
|
def updateSpots(self, dataSet=None):
|
||||||
|
|
||||||
if dataSet is None:
|
if dataSet is None:
|
||||||
dataSet = self.data
|
dataSet = self.data
|
||||||
|
|
||||||
@ -610,8 +616,6 @@ class ScatterPlotItem(GraphicsObject):
|
|||||||
recs['brush'][np.equal(recs['brush'], None)] = fn.mkBrush(self.opts['brush'])
|
recs['brush'][np.equal(recs['brush'], None)] = fn.mkBrush(self.opts['brush'])
|
||||||
return recs
|
return recs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def measureSpotSizes(self, dataSet):
|
def measureSpotSizes(self, dataSet):
|
||||||
for rec in dataSet:
|
for rec in dataSet:
|
||||||
## keep track of the maximum spot size and pixel size
|
## keep track of the maximum spot size and pixel size
|
||||||
@ -630,7 +634,6 @@ class ScatterPlotItem(GraphicsObject):
|
|||||||
self._maxSpotPxWidth = max(self._maxSpotPxWidth, pxWidth)
|
self._maxSpotPxWidth = max(self._maxSpotPxWidth, pxWidth)
|
||||||
self.bounds = [None, None]
|
self.bounds = [None, None]
|
||||||
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Remove all spots from the scatter plot"""
|
"""Remove all spots from the scatter plot"""
|
||||||
#self.clearItems()
|
#self.clearItems()
|
||||||
@ -757,8 +760,10 @@ class ScatterPlotItem(GraphicsObject):
|
|||||||
if self.opts['pxMode'] is True:
|
if self.opts['pxMode'] is True:
|
||||||
p.resetTransform()
|
p.resetTransform()
|
||||||
|
|
||||||
|
data = self.data
|
||||||
|
|
||||||
# Map point coordinates to device
|
# Map point coordinates to device
|
||||||
pts = np.vstack([self.data['x'], self.data['y']])
|
pts = np.vstack([data['x'], data['y']])
|
||||||
pts = self.mapPointsToDevice(pts)
|
pts = self.mapPointsToDevice(pts)
|
||||||
if pts is None:
|
if pts is None:
|
||||||
return
|
return
|
||||||
@ -770,25 +775,31 @@ class ScatterPlotItem(GraphicsObject):
|
|||||||
# Draw symbols from pre-rendered atlas
|
# Draw symbols from pre-rendered atlas
|
||||||
atlas = self.fragmentAtlas.getAtlas()
|
atlas = self.fragmentAtlas.getAtlas()
|
||||||
|
|
||||||
|
target_rect = data['targetRect']
|
||||||
|
source_rect = data['sourceRect']
|
||||||
|
widths = data['width']
|
||||||
|
|
||||||
# Update targetRects if necessary
|
# Update targetRects if necessary
|
||||||
updateMask = viewMask & np.equal(self.data['targetRect'], None)
|
updateMask = viewMask & np.equal(target_rect, None)
|
||||||
if np.any(updateMask):
|
if np.any(updateMask):
|
||||||
updatePts = pts[:,updateMask]
|
updatePts = pts[:,updateMask]
|
||||||
width = self.data[updateMask]['width']*2
|
width = widths[updateMask] * 2
|
||||||
self.data['targetRect'][updateMask] = list(imap(QtCore.QRectF, updatePts[0,:], updatePts[1,:], width, width))
|
target_rect[updateMask] = list(imap(QtCore.QRectF, updatePts[0,:], updatePts[1,:], width, width))
|
||||||
|
|
||||||
data = self.data[viewMask]
|
|
||||||
if QT_LIB == 'PyQt4':
|
if QT_LIB == 'PyQt4':
|
||||||
p.drawPixmapFragments(data['targetRect'].tolist(), data['sourceRect'].tolist(), atlas)
|
p.drawPixmapFragments(
|
||||||
|
target_rect[viewMask].tolist(),
|
||||||
|
source_rect[viewMask].tolist(),
|
||||||
|
atlas
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
list(imap(p.drawPixmap, data['targetRect'], repeat(atlas), data['sourceRect']))
|
list(imap(p.drawPixmap, target_rect[viewMask].tolist(), repeat(atlas), source_rect[viewMask].tolist()))
|
||||||
else:
|
else:
|
||||||
# render each symbol individually
|
# render each symbol individually
|
||||||
p.setRenderHint(p.Antialiasing, aa)
|
p.setRenderHint(p.Antialiasing, aa)
|
||||||
|
|
||||||
data = self.data[viewMask]
|
|
||||||
pts = pts[:,viewMask]
|
pts = pts[:,viewMask]
|
||||||
for i, rec in enumerate(data):
|
for i, rec in enumerate(data[viewMask]):
|
||||||
p.resetTransform()
|
p.resetTransform()
|
||||||
p.translate(pts[0,i] + rec['width']/2, pts[1,i] + rec['width']/2)
|
p.translate(pts[0,i] + rec['width']/2, pts[1,i] + rec['width']/2)
|
||||||
drawSymbol(p, *self.getSpotOpts(rec, scale))
|
drawSymbol(p, *self.getSpotOpts(rec, scale))
|
||||||
|
Loading…
Reference in New Issue
Block a user