diff --git a/pyqtgraph/graphicsItems/ScatterPlotItem.py b/pyqtgraph/graphicsItems/ScatterPlotItem.py index 8de985fc..1c11fcf9 100644 --- a/pyqtgraph/graphicsItems/ScatterPlotItem.py +++ b/pyqtgraph/graphicsItems/ScatterPlotItem.py @@ -228,8 +228,6 @@ class ScatterPlotItem(GraphicsObject): GraphicsObject.__init__(self) self.picture = None # QPicture used for rendering when pxmode==False - self.fragments = None # fragment specification for pxmode; updated every time the view changes. - self.target = None self.fragmentAtlas = SymbolAtlas() self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('sourceRect', object), ('targetRect', object), ('width', float)]) @@ -394,6 +392,7 @@ class ScatterPlotItem(GraphicsObject): self.setPointData(kargs['data'], dataSet=newData) self.prepareGeometryChange() + self.informViewBoundsChanged() self.bounds = [None, None] self.invalidate() self.updateSpots(newData) @@ -402,13 +401,10 @@ class ScatterPlotItem(GraphicsObject): def invalidate(self): ## clear any cached drawing state self.picture = None - self.fragments = None - self.target = None self.update() def getData(self): - return self.data['x'], self.data['y'] - + return self.data['x'], self.data['y'] def setPoints(self, *args, **kargs): ##Deprecated; use setData @@ -554,14 +550,10 @@ class ScatterPlotItem(GraphicsObject): sourceRect = self.fragmentAtlas.getSymbolCoords(opts) dataSet['sourceRect'][mask] = sourceRect - - #for rec in dataSet: - #if rec['fragCoords'] is None: - #invalidate = True - #rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec)) - self.fragmentAtlas.getAtlas() + self.fragmentAtlas.getAtlas() # generate atlas so source widths are available. + dataSet['width'] = np.array(list(imap(QtCore.QRectF.width, dataSet['sourceRect'])))/2 - dataSet['targetRect'] = list(imap(QtCore.QRectF, repeat(0), repeat(0), dataSet['width']*2, dataSet['width']*2)) + dataSet['targetRect'] = None self._maxSpotPxWidth = self.fragmentAtlas.max_width else: self._maxSpotWidth = 0 @@ -684,40 +676,42 @@ class ScatterPlotItem(GraphicsObject): self.prepareGeometryChange() GraphicsObject.viewTransformChanged(self) self.bounds = [None, None] - self.fragments = None - self.target = None + self.data['targetRect'] = None def setExportMode(self, *args, **kwds): GraphicsObject.setExportMode(self, *args, **kwds) self.invalidate() - def getTransformedPoint(self): - # Map point locations to device - - vb = self.getViewBox() - if vb is None: - return None, None + def mapPointsToDevice(self, pts): + # Map point locations to device tr = self.deviceTransform() if tr is None: - return None, None + return None - pts = np.empty((2,len(self.data['x']))) - pts[0] = self.data['x'] - pts[1] = self.data['y'] + #pts = np.empty((2,len(self.data['x']))) + #pts[0] = self.data['x'] + #pts[1] = self.data['y'] pts = fn.transformCoordinates(tr, pts) pts -= self.data['width'] pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault. - ## Remove out of view points + return pts + + def getViewMask(self, pts): + # Return bool mask indicating all points that are within viewbox + # pts is expressed in *device coordiantes* + vb = self.getViewBox() + if vb is None: + return None viewBounds = vb.mapRectToDevice(vb.boundingRect()) w = self.data['width'] mask = ((pts[0] + w > viewBounds.left()) & (pts[0] - w < viewBounds.right()) & (pts[1] + w > viewBounds.top()) & (pts[1] - w < viewBounds.bottom())) ## remove out of view points - print np.sum(mask) - return self.data[mask], pts[:, mask] + return mask + @debug.warnOnException ## raising an exception here causes crash def paint(self, p, *args): @@ -733,27 +727,42 @@ class ScatterPlotItem(GraphicsObject): scale = 1.0 if self.opts['pxMode'] is True: - atlas = self.fragmentAtlas.getAtlas() p.resetTransform() - data, pts = self.getTransformedPoint() - if data is None: + # Map point coordinates to device + pts = np.vstack([self.data['x'], self.data['y']]) + pts = self.mapPointsToDevice(pts) + if pts is None: return + # Cull points that are outside view + viewMask = self.getViewMask(pts) + #pts = pts[:,mask] + #data = self.data[mask] + if self.opts['useCache'] and self._exportOpts is False: + # Draw symbols from pre-rendered atlas + atlas = self.fragmentAtlas.getAtlas() - if self.target == None: - list(imap(QtCore.QRectF.moveTo, data['targetRect'], pts[0,:], pts[1,:])) - self.target = data['targetRect'] + # Update targetRects if necessary + updateMask = viewMask & np.equal(self.data['targetRect'], None) + if np.any(updateMask): + updatePts = pts[:,updateMask] + width = self.data[updateMask]['width']*2 + self.data['targetRect'][updateMask] = list(imap(QtCore.QRectF, updatePts[0,:], updatePts[1,:], width, width)) + + data = self.data[viewMask] if USE_PYSIDE: - list(imap(p.drawPixmap, self.target, repeat(atlas), data['sourceRect'])) + list(imap(p.drawPixmap, data['targetRect'], repeat(atlas), data['sourceRect'])) else: - p.drawPixmapFragments(self.target.tolist(), data['sourceRect'].tolist(), atlas) + p.drawPixmapFragments(data['targetRect'].tolist(), data['sourceRect'].tolist(), atlas) else: + # render each symbol individually p.setRenderHint(p.Antialiasing, aa) - for i in range(len(self.data)): - rec = data[i] + data = self.data[viewMask] + pts = pts[:,viewMask] + for i, rec in enumerate(data): p.resetTransform() p.translate(pts[0,i] + rec['width'], pts[1,i] + rec['width']) drawSymbol(p, *self.getSpotOpts(rec, scale)) diff --git a/pyqtgraph/graphicsItems/tests/ScatterPlotItem.py b/pyqtgraph/graphicsItems/tests/ScatterPlotItem.py new file mode 100644 index 00000000..ef8271bf --- /dev/null +++ b/pyqtgraph/graphicsItems/tests/ScatterPlotItem.py @@ -0,0 +1,23 @@ +import pyqtgraph as pg +import numpy as np +app = pg.mkQApp() +plot = pg.plot() +app.processEvents() + +# set view range equal to its bounding rect. +# This causes plots to look the same regardless of pxMode. +plot.setRange(rect=plot.boundingRect()) + + +def test_modes(): + for i, pxMode in enumerate([True, False]): + for j, useCache in enumerate([True, False]): + s = pg.ScatterPlotItem() + s.opts['useCache'] = useCache + plot.addItem(s) + s.setData(x=np.array([10,40,20,30])+i*100, y=np.array([40,60,10,30])+j*100, pxMode=pxMode) + s.addPoints(x=np.array([60, 70])+i*100, y=np.array([60, 70])+j*100, size=[20, 30]) + + +if __name__ == '__main__': + test_modes()