From c4019b900d408469fa29fdc3e1f49825178d5473 Mon Sep 17 00:00:00 2001 From: Luke Campagnola <> Date: Thu, 18 Oct 2012 22:48:36 -0400 Subject: [PATCH] Overhaul of ScatterPlotItem to improve performance. (API should be mostly unchanged) Much more efficient at rapid updates. --- examples/ScatterPlotSpeedTest.py | 5 +- functions.py | 119 +++++++-- graphicsItems/PlotDataItem.py | 51 +++- graphicsItems/ScatterPlotItem.py | 441 +++++++++++++++++++++++-------- 4 files changed, 480 insertions(+), 136 deletions(-) diff --git a/examples/ScatterPlotSpeedTest.py b/examples/ScatterPlotSpeedTest.py index 386522d1..a44e58e3 100644 --- a/examples/ScatterPlotSpeedTest.py +++ b/examples/ScatterPlotSpeedTest.py @@ -26,7 +26,7 @@ win.show() p = ui.plot data = np.random.normal(size=(50,500), scale=100) -sizeArray = np.random.random(500) * 20. +sizeArray = (np.random.random(500) * 20.).astype(int) ptr = 0 lastTime = time() fps = None @@ -49,7 +49,8 @@ def update(): s = np.clip(dt*3., 0, 1) fps = fps * (1-s) + (1.0/dt) * s p.setTitle('%0.2f fps' % fps) - app.processEvents() ## force complete redraw for every plot + p.repaint() + #app.processEvents() ## force complete redraw for every plot timer = QtCore.QTimer() timer.timeout.connect(update) timer.start(0) diff --git a/functions.py b/functions.py index e70e72fc..4a318b4e 100644 --- a/functions.py +++ b/functions.py @@ -25,6 +25,7 @@ SI_PREFIXES_ASCII = 'yzafpnum kMGTPEZY' from .Qt import QtGui, QtCore import numpy as np import decimal, re +import ctypes try: import scipy.ndimage @@ -223,13 +224,15 @@ def mkColor(*args): return QtGui.QColor(*args) -def mkBrush(*args): +def mkBrush(*args, **kwds): """ | Convenience function for constructing Brush. | This function always constructs a solid brush and accepts the same arguments as :func:`mkColor() ` | Calling mkBrush(None) returns an invisible brush. """ - if len(args) == 1: + if 'color' in kwds: + color = kwds['color'] + elif len(args) == 1: arg = args[0] if arg is None: return QtGui.QBrush(QtCore.Qt.NoBrush) @@ -237,7 +240,7 @@ def mkBrush(*args): return QtGui.QBrush(arg) else: color = arg - if len(args) > 1: + elif len(args) > 1: color = args return QtGui.QBrush(mkColor(color)) @@ -579,7 +582,10 @@ def solveBilinearTransform(points1, points2): - +def makeRGBA(*args, **kwds): + """Equivalent to makeARGB(..., useRGBA=True)""" + kwds['useRGBA'] = True + return makeARGB(*args, **kwds) def makeARGB(data, lut=None, levels=None, useRGBA=False): """ @@ -605,7 +611,7 @@ def makeARGB(data, lut=None, levels=None, useRGBA=False): Lookup tables can be built using GradientWidget. levels - List [min, max]; optionally rescale data before converting through the lookup table. rescaled = (data-min) * len(lut) / (max-min) - useRGBA - If True, the data is returned in RGBA order. The default is + useRGBA - If True, the data is returned in RGBA order (useful for building OpenGL textures). The default is False, which returns in BGRA order for use with QImage. """ @@ -779,30 +785,107 @@ def makeARGB(data, lut=None, levels=None, useRGBA=False): return imgData, alpha -def makeQImage(imgData, alpha): - """Turn an ARGB array into QImage""" +def makeQImage(imgData, alpha=None, copy=True, transpose=True): + """ + Turn an ARGB array into QImage. + By default, the data is copied; changes to the array will not + be reflected in the image. The image will be given a 'data' attribute + pointing to the array which shares its data to prevent python + freeing that memory while the image is in use. + + =========== =================================================================== + Arguments: + imgData Array of data to convert. Must have shape (width, height, 3 or 4) + and dtype=ubyte. The order of values in the 3rd axis must be + (b, g, r, a). + alpha If True, the QImage returned will have format ARGB32. If False, + the format will be RGB32. By default, _alpha_ is True if + array.shape[2] == 4. + copy If True, the data is copied before converting to QImage. + If False, the new QImage points directly to the data in the array. + Note that the array must be contiguous for this to work. + transpose If True (the default), the array x/y axes are transposed before + creating the image. Note that Qt expects the axes to be in + (height, width) order whereas pyqtgraph usually prefers the + opposite. + =========== =================================================================== + """ ## create QImage from buffer prof = debug.Profiler('functions.makeQImage', disabled=True) + ## If we didn't explicitly specify alpha, check the array shape. + if alpha is None: + alpha = (imgData.shape[2] == 4) + + copied = False + if imgData.shape[2] == 3: ## need to make alpha channel (even if alpha==False; QImage requires 32 bpp) + if copy is True: + d2 = np.empty(imgData.shape[:2] + (4,), dtype=imgData.dtype) + d2[:,:,:3] = imgData + d2[:,:,3] = 255 + imgData = d2 + copied = True + else: + raise Exception('Array has only 3 channels; cannot make QImage without copying.') + if alpha: imgFormat = QtGui.QImage.Format_ARGB32 else: imgFormat = QtGui.QImage.Format_RGB32 - imgData = imgData.transpose((1, 0, 2)) ## QImage expects the row/column order to be opposite - try: - buf = imgData.data - except AttributeError: ## happens when image data is non-contiguous + if transpose: + imgData = imgData.transpose((1, 0, 2)) ## QImage expects the row/column order to be opposite + + if not imgData.flags['C_CONTIGUOUS']: + if copy is False: + extra = ' (try setting transpose=False)' if transpose else '' + raise Exception('Array is not contiguous; cannot make QImage without copying.'+extra) imgData = np.ascontiguousarray(imgData) - buf = imgData.data + copied = True - prof.mark('1') - qimage = QtGui.QImage(buf, imgData.shape[1], imgData.shape[0], imgFormat) - prof.mark('2') - qimage.data = imgData - prof.finish() - return qimage + if copy is True and copied is False: + imgData = imgData.copy() + + addr = ctypes.addressof(ctypes.c_char.from_buffer(imgData, 0)) + img = QtGui.QImage(addr, imgData.shape[1], imgData.shape[0], imgFormat) + img.data = imgData + return img + #try: + #buf = imgData.data + #except AttributeError: ## happens when image data is non-contiguous + #buf = imgData.data + + #prof.mark('1') + #qimage = QtGui.QImage(buf, imgData.shape[1], imgData.shape[0], imgFormat) + #prof.mark('2') + #qimage.data = imgData + #prof.finish() + #return qimage +def imageToArray(img, copy=False, transpose=True): + """ + Convert a QImage into numpy array. The image must have format RGB32, ARGB32, or ARGB32_Premultiplied. + By default, the image is not copied; changes made to the array will appear in the QImage as well (beware: if + the QImage is collected before the array, there may be trouble). + The array will have shape (width, height, (b,g,r,a)). + """ + ptr = img.bits() + ptr.setsize(img.byteCount()) + fmt = img.format() + if fmt == img.Format_RGB32: + arr = np.asarray(ptr).reshape(img.height(), img.width(), 3) + elif fmt == img.Format_ARGB32 or fmt == img.Format_ARGB32_Premultiplied: + arr = np.asarray(ptr).reshape(img.height(), img.width(), 4) + if copy: + arr = arr.copy() + + if transpose: + return arr.transpose((1,0,2)) + else: + return arr + + + def rescaleData(data, scale, offset): newData = np.empty((data.size,), dtype=np.int) diff --git a/graphicsItems/PlotDataItem.py b/graphicsItems/PlotDataItem.py index d30d737b..34af641a 100644 --- a/graphicsItems/PlotDataItem.py +++ b/graphicsItems/PlotDataItem.py @@ -130,6 +130,8 @@ class PlotDataItem(GraphicsObject): 'symbolBrush': (50, 50, 150), 'pxMode': True, + 'pointMode': None, + 'data': None, } self.setData(*args, **kargs) @@ -144,22 +146,30 @@ class PlotDataItem(GraphicsObject): return QtCore.QRectF() ## let child items handle this def setAlpha(self, alpha, auto): + if self.opts['alphaHint'] == alpha and self.opts['alphaMode'] == auto: + return self.opts['alphaHint'] = alpha self.opts['alphaMode'] = auto self.setOpacity(alpha) #self.update() def setFftMode(self, mode): + if self.opts['fftMode'] == mode: + return self.opts['fftMode'] = mode self.xDisp = self.yDisp = None self.updateItems() def setLogMode(self, xMode, yMode): - self.opts['logMode'] = (xMode, yMode) + if self.opts['logMode'] == [xMode, yMode]: + return + self.opts['logMode'] = [xMode, yMode] self.xDisp = self.yDisp = None self.updateItems() def setPointMode(self, mode): + if self.opts['pointMode'] == mode: + return self.opts['pointMode'] = mode self.update() @@ -193,6 +203,8 @@ class PlotDataItem(GraphicsObject): def setFillBrush(self, *args, **kargs): brush = fn.mkBrush(*args, **kargs) + if self.opts['fillBrush'] == brush: + return self.opts['fillBrush'] = brush self.updateItems() @@ -200,16 +212,22 @@ class PlotDataItem(GraphicsObject): return self.setFillBrush(*args, **kargs) def setFillLevel(self, level): + if self.opts['fillLevel'] == level: + return self.opts['fillLevel'] = level self.updateItems() def setSymbol(self, symbol): + if self.opts['symbol'] == symbol: + return self.opts['symbol'] = symbol #self.scatter.setSymbol(symbol) self.updateItems() def setSymbolPen(self, *args, **kargs): pen = fn.mkPen(*args, **kargs) + if self.opts['symbolPen'] == pen: + return self.opts['symbolPen'] = pen #self.scatter.setSymbolPen(pen) self.updateItems() @@ -218,21 +236,26 @@ class PlotDataItem(GraphicsObject): def setSymbolBrush(self, *args, **kargs): brush = fn.mkBrush(*args, **kargs) + if self.opts['symbolBrush'] == brush: + return self.opts['symbolBrush'] = brush #self.scatter.setSymbolBrush(brush) self.updateItems() def setSymbolSize(self, size): + if self.opts['symbolSize'] == size: + return self.opts['symbolSize'] = size #self.scatter.setSymbolSize(symbolSize) self.updateItems() def setDownsampling(self, ds): - if self.opts['downsample'] != ds: - self.opts['downsample'] = ds - self.xDisp = self.yDisp = None - self.updateItems() + if self.opts['downsample'] == ds: + return + self.opts['downsample'] = ds + self.xDisp = self.yDisp = None + self.updateItems() def setData(self, *args, **kargs): """ @@ -436,9 +459,12 @@ class PlotDataItem(GraphicsObject): and max) =============== ============================================================= """ + if frac <= 0.0: + raise Exception("Value for parameter 'frac' must be > 0. (got %s)" % str(frac)) + (x, y) = self.getData() if x is None or len(x) == 0: - return (0, 0) + return None if ax == 0: d = x @@ -450,14 +476,15 @@ class PlotDataItem(GraphicsObject): if orthoRange is not None: mask = (d2 >= orthoRange[0]) * (d2 <= orthoRange[1]) d = d[mask] - d2 = d2[mask] + #d2 = d2[mask] - if frac >= 1.0: - return (np.min(d), np.max(d)) - elif frac <= 0.0: - raise Exception("Value for parameter 'frac' must be > 0. (got %s)" % str(frac)) + if len(d) > 0: + if frac >= 1.0: + return (np.min(d), np.max(d)) + else: + return (scipy.stats.scoreatpercentile(d, 50 - (frac * 50)), scipy.stats.scoreatpercentile(d, 50 + (frac * 50))) else: - return (scipy.stats.scoreatpercentile(d, 50 - (frac * 50)), scipy.stats.scoreatpercentile(d, 50 + (frac * 50))) + return None def clear(self): diff --git a/graphicsItems/ScatterPlotItem.py b/graphicsItems/ScatterPlotItem.py index 625eb0b6..d816f2a6 100644 --- a/graphicsItems/ScatterPlotItem.py +++ b/graphicsItems/ScatterPlotItem.py @@ -32,26 +32,171 @@ for k, c in coords.items(): Symbols[k].lineTo(x, y) Symbols[k].closeSubpath() + +def drawSymbol(painter, symbol, size, pen, brush): + painter.scale(size, size) + painter.setPen(pen) + painter.setBrush(brush) + if isinstance(symbol, basestring): + symbol = Symbols[symbol] + if np.isscalar(symbol): + symbol = Symbols.values()[symbol % len(Symbols)] + painter.drawPath(symbol) -def makeSymbolPixmap(size, pen, brush, symbol): + +def renderSymbol(symbol, size, pen, brush, device=None): + """ + Render a symbol specification to QImage. + Symbol may be either a QPainterPath or one of the keys in the Symbols dict. + If *device* is None, a new QPixmap will be returned. Otherwise, + the symbol will be rendered into the device specified (See QPainter documentation + for more information). + """ + ## see if this pixmap is already cached + #global SymbolPixmapCache + #key = (symbol, size, fn.colorTuple(pen.color()), pen.width(), pen.style(), fn.colorTuple(brush.color())) + #if key in SymbolPixmapCache: + #return SymbolPixmapCache[key] + ## Render a spot with the given parameters to a pixmap penPxWidth = max(np.ceil(pen.width()), 1) - image = QtGui.QImage(int(size+penPxWidth), int(size+penPxWidth), QtGui.QImage.Format_ARGB32_Premultiplied) + image = QtGui.QImage(int(size+penPxWidth), int(size+penPxWidth), QtGui.QImage.Format_ARGB32) image.fill(0) p = QtGui.QPainter(image) p.setRenderHint(p.Antialiasing) p.translate(image.width()*0.5, image.height()*0.5) - p.scale(size, size) - p.setPen(pen) - p.setBrush(brush) - if isinstance(symbol, basestring): - symbol = Symbols[symbol] - p.drawPath(symbol) + drawSymbol(p, symbol, size, pen, brush) p.end() - return QtGui.QPixmap(image) + return image + #pixmap = QtGui.QPixmap(image) + #SymbolPixmapCache[key] = pixmap + #return pixmap +def makeSymbolPixmap(size, pen, brush, symbol): + ## deprecated + img = renderSymbol(symbol, size, pen, brush) + return QtGui.QPixmap(img) + +class SymbolAtlas: + """ + Used to efficiently construct a single QPixmap containing all rendered symbols + for a ScatterPlotItem. This is required for fragment rendering. + + Use example: + atlas = SymbolAtlas() + sc1 = atlas.getSymbolCoords('o', 5, QPen(..), QBrush(..)) + sc2 = atlas.getSymbolCoords('t', 10, QPen(..), QBrush(..)) + pm = atlas.getAtlas() + + """ + class SymbolCoords(list): ## needed because lists are not allowed in weak references. + pass + + def __init__(self): + # symbol key : [x, y, w, h] atlas coordinates + # note that the coordinate list will always be the same list object as + # long as the symbol is in the atlas, but the coordinates may + # change if the atlas is rebuilt. + # weak value; if all external refs to this list disappear, + # the symbol will be forgotten. + self.symbolMap = weakref.WeakValueDictionary() + + self.atlasData = None # numpy array of atlas image + self.atlas = None # atlas as QPixmap + self.atlasValid = False + + def getSymbolCoords(self, opts): + """ + Given a list of spot records, return an object representing the coordinates of that symbol within the atlas + """ + coords = np.empty(len(opts), dtype=object) + for i, rec in enumerate(opts): + symbol, size, pen, brush = rec['symbol'], rec['size'], rec['pen'], rec['brush'] + pen = fn.mkPen(pen) if not isinstance(pen, QtGui.QPen) else pen + brush = fn.mkBrush(brush) if not isinstance(pen, QtGui.QBrush) else brush + key = (symbol, size, fn.colorTuple(pen.color()), pen.width(), pen.style(), fn.colorTuple(brush.color())) + if key not in self.symbolMap: + newCoords = SymbolAtlas.SymbolCoords() + self.symbolMap[key] = newCoords + self.atlasValid = False + #try: + #self.addToAtlas(key) ## squeeze this into the atlas if there is room + #except: + #self.buildAtlas() ## otherwise, we need to rebuild + + coords[i] = self.symbolMap[key] + return coords + + def buildAtlas(self): + # get rendered array for all symbols, keep track of avg/max width + rendered = {} + avgWidth = 0.0 + maxWidth = 0 + images = [] + for key, coords in self.symbolMap.items(): + if len(coords) == 0: + pen = fn.mkPen(color=key[2], width=key[3], style=key[4]) + brush = fn.mkBrush(color=key[5]) + img = renderSymbol(key[0], key[1], pen, brush) + images.append(img) ## we only need this to prevent the images being garbage collected immediately + arr = fn.imageToArray(img, copy=False, transpose=False) + else: + (x,y,w,h) = self.symbolMap[key] + arr = self.atlasData[x:x+w, y:y+w] + rendered[key] = arr + w = arr.shape[0] + avgWidth += w + maxWidth = max(maxWidth, w) + + nSymbols = len(rendered) + if nSymbols > 0: + avgWidth /= nSymbols + width = max(maxWidth, avgWidth * (nSymbols**0.5)) + else: + avgWidth = 0 + width = 0 + + # sort symbols by height + symbols = sorted(rendered.keys(), key=lambda x: rendered[x].shape[1], reverse=True) + + self.atlasRows = [] + x = width + y = 0 + rowheight = 0 + for key in symbols: + arr = rendered[key] + w,h = arr.shape[:2] + if x+w > width: + y += rowheight + x = 0 + rowheight = h + self.atlasRows.append([y, rowheight, 0]) + self.symbolMap[key][:] = x, y, w, h + x += w + self.atlasRows[-1][2] = x + height = y + rowheight + self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte) + for key in symbols: + x, y, w, h = self.symbolMap[key] + self.atlasData[x:x+w, y:y+h] = rendered[key] + self.atlas = None + self.atlasValid = True + + def getAtlas(self): + if not self.atlasValid: + self.buildAtlas() + if self.atlas is None: + if len(self.atlasData) == 0: + return QtGui.QPixmap(0,0) + img = fn.makeQImage(self.atlasData, copy=False, transpose=False) + self.atlas = QtGui.QPixmap(img) + return self.atlas + + + + class ScatterPlotItem(GraphicsObject): """ Displays a set of x/y points. Instances of this class are created @@ -79,13 +224,16 @@ class ScatterPlotItem(GraphicsObject): """ prof = debug.Profiler('ScatterPlotItem.__init__', disabled=True) GraphicsObject.__init__(self) - self.setFlag(self.ItemHasNoContents, True) - self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('item', object), ('data', object)]) + + self.picture = None # QPicture used for rendering when pxmode==False + self.fragments = None # fragment specification for pxmode; updated every time the view changes. + self.fragmentAtlas = SymbolAtlas() + + self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('fragCoords', object), ('item', object)]) self.bounds = [None, None] ## caches data bounds self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots - self._spotPixmap = None - self.opts = {'pxMode': True} + self.opts = {'pxMode': True, 'useCache': True} ## If useCache is False, symbols are re-drawn on every paint. self.setPen(200,200,200, update=False) self.setBrush(100,100,150, update=False) @@ -96,6 +244,8 @@ class ScatterPlotItem(GraphicsObject): prof.mark('setData') prof.finish() + #self.setCacheMode(self.DeviceCoordinateCache) + def setData(self, *args, **kargs): """ **Ordered Arguments:** @@ -130,6 +280,7 @@ class ScatterPlotItem(GraphicsObject): *identical* *Deprecated*. This functionality is handled automatically now. ====================== =============================================================================================== """ + oldData = self.data ## this causes cached pixmaps to be preserved while new data is registered. self.clear() ## clear out all old data self.addPoints(*args, **kargs) @@ -183,8 +334,8 @@ class ScatterPlotItem(GraphicsObject): ## note that np.empty initializes object fields to None and string fields to '' self.data[:len(oldData)] = oldData - for i in range(len(oldData)): - oldData[i]['item']._data = self.data[i] ## Make sure items have proper reference to new array + #for i in range(len(oldData)): + #oldData[i]['item']._data = self.data[i] ## Make sure items have proper reference to new array newData = self.data[len(oldData):] newData['size'] = -1 ## indicates to use default size @@ -217,7 +368,7 @@ class ScatterPlotItem(GraphicsObject): newData['y'] = kargs['y'] if 'pxMode' in kargs: - self.setPxMode(kargs['pxMode'], update=False) + self.setPxMode(kargs['pxMode']) ## Set any extra parameters provided in keyword arguments for k in ['pen', 'brush', 'symbol', 'size']: @@ -228,12 +379,18 @@ class ScatterPlotItem(GraphicsObject): if 'data' in kargs: self.setPointData(kargs['data'], dataSet=newData) - #self.updateSpots() self.prepareGeometryChange() self.bounds = [None, None] - self.generateSpotItems() + self.invalidate() + self.updateSpots(newData) self.sigPlotChanged.emit(self) + def invalidate(self): + ## clear any cached drawing state + self.picture = None + self.fragments = None + self.update() + def getData(self): return self.data['x'], self.data['y'] @@ -263,8 +420,8 @@ class ScatterPlotItem(GraphicsObject): dataSet['pen'] = pens else: self.opts['pen'] = fn.mkPen(*args, **kargs) - self._spotPixmap = None + dataSet['fragCoords'] = None if update: self.updateSpots(dataSet) @@ -285,8 +442,9 @@ class ScatterPlotItem(GraphicsObject): dataSet['brush'] = brushes else: self.opts['brush'] = fn.mkBrush(*args, **kargs) - self._spotPixmap = None + #self._spotPixmap = None + dataSet['fragCoords'] = None if update: self.updateSpots(dataSet) @@ -307,6 +465,7 @@ class ScatterPlotItem(GraphicsObject): self.opts['symbol'] = symbol self._spotPixmap = None + dataSet['fragCoords'] = None if update: self.updateSpots(dataSet) @@ -327,6 +486,7 @@ class ScatterPlotItem(GraphicsObject): self.opts['size'] = size self._spotPixmap = None + dataSet['fragCoords'] = None if update: self.updateSpots(dataSet) @@ -346,34 +506,71 @@ class ScatterPlotItem(GraphicsObject): else: dataSet['data'] = data - def setPxMode(self, mode, update=True): + def setPxMode(self, mode): if self.opts['pxMode'] == mode: return self.opts['pxMode'] = mode - self.clearItems() - if update: - self.generateSpotItems() + self.invalidate() def updateSpots(self, dataSet=None): if dataSet is None: dataSet = self.data self._maxSpotWidth = 0 self._maxSpotPxWidth = 0 - for spot in dataSet['item']: - spot.updateItem() + invalidate = False self.measureSpotSizes(dataSet) + if self.opts['pxMode']: + mask = np.equal(dataSet['fragCoords'], None) + if np.any(mask): + invalidate = True + opts = self.getSpotOpts(dataSet[mask]) + coords = self.fragmentAtlas.getSymbolCoords(opts) + dataSet['fragCoords'][mask] = coords + + #for rec in dataSet: + #if rec['fragCoords'] is None: + #invalidate = True + #rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec)) + if invalidate: + self.invalidate() + def getSpotOpts(self, recs): + if recs.ndim == 0: + rec = recs + symbol = rec['symbol'] + if symbol is None: + symbol = self.opts['symbol'] + size = rec['size'] + if size < 0: + size = self.opts['size'] + pen = rec['pen'] + if pen is None: + pen = self.opts['pen'] + brush = rec['brush'] + if brush is None: + brush = self.opts['brush'] + return (symbol, size, fn.mkPen(pen), fn.mkBrush(brush)) + else: + recs = recs.copy() + recs['symbol'][np.equal(recs['symbol'], None)] = self.opts['symbol'] + recs['size'][np.equal(recs['size'], -1)] = self.opts['size'] + recs['pen'][np.equal(recs['pen'], None)] = fn.mkPen(self.opts['pen']) + recs['brush'][np.equal(recs['brush'], None)] = fn.mkBrush(self.opts['brush']) + return recs + + + def measureSpotSizes(self, dataSet): - for spot in dataSet['item']: + for rec in dataSet: ## keep track of the maximum spot size and pixel size + symbol, size, pen, brush = self.getSpotOpts(rec) width = 0 pxWidth = 0 - pen = spot.pen() if self.opts['pxMode']: - pxWidth = spot.size() + pen.width() + pxWidth = size + pen.width() else: - width = spot.size() + width = size if pen.isCosmetic(): pxWidth += pen.width() else: @@ -385,20 +582,11 @@ class ScatterPlotItem(GraphicsObject): def clear(self): """Remove all spots from the scatter plot""" - self.clearItems() + #self.clearItems() self.data = np.empty(0, dtype=self.data.dtype) self.bounds = [None, None] + self.invalidate() - def clearItems(self): - for i in self.data['item']: - if i is None: - continue - i.setParentItem(None) - s = i.scene() - if s is not None: - s.removeItem(i) - self.data['item'] = None - def dataBounds(self, ax, frac=1.0, orthoRange=None): if frac >= 1.0 and self.bounds[ax] is not None: return self.bounds[ax] @@ -436,28 +624,12 @@ class ScatterPlotItem(GraphicsObject): else: return (scipy.stats.scoreatpercentile(d, 50 - (frac * 50)), scipy.stats.scoreatpercentile(d, 50 + (frac * 50))) - - - - - def generateSpotItems(self): - if self.opts['pxMode']: - for rec in self.data: - if rec['item'] is None: - rec['item'] = PixmapSpotItem(rec, self) - else: - for rec in self.data: - if rec['item'] is None: - rec['item'] = PathSpotItem(rec, self) - self.measureSpotSizes(self.data) - self.sigPlotChanged.emit(self) - - def defaultSpotPixmap(self): - ## Return the default spot pixmap - if self._spotPixmap is None: - self._spotPixmap = makeSymbolPixmap(size=self.opts['size'], brush=self.opts['brush'], pen=self.opts['pen'], symbol=self.opts['symbol']) - return self._spotPixmap + #def defaultSpotPixmap(self): + ### Return the default spot pixmap + #if self._spotPixmap is None: + #self._spotPixmap = makeSymbolPixmap(size=self.opts['size'], brush=self.opts['brush'], pen=self.opts['pen'], symbol=self.opts['symbol']) + #return self._spotPixmap def boundingRect(self): (xmn, xmx) = self.dataBounds(ax=0) @@ -474,15 +646,68 @@ class ScatterPlotItem(GraphicsObject): self.prepareGeometryChange() GraphicsObject.viewRangeChanged(self) self.bounds = [None, None] + self.fragments = None + def generateFragments(self): + tr = self.deviceTransform() + if tr is None: + return + pts = np.empty((2,len(self.data['x']))) + pts[0] = self.data['x'] + pts[1] = self.data['y'] + pts = fn.transformCoordinates(tr, pts) + self.fragments = [] + for i in xrange(len(self.data)): + rec = self.data[i] + pos = QtCore.QPointF(pts[0,i], pts[1,i]) + x,y,w,h = rec['fragCoords'] + rect = QtCore.QRectF(y, x, h, w) + self.fragments.append(QtGui.QPainter.PixmapFragment.create(pos, rect)) + def paint(self, p, *args): - ## NOTE: self.paint is disabled by this line in __init__: - ## self.setFlag(self.ItemHasNoContents, True) - p.setPen(fn.mkPen('r')) - p.drawRect(self.boundingRect()) + #p.setPen(fn.mkPen('r')) + #p.drawRect(self.boundingRect()) + if self.opts['pxMode']: + atlas = self.fragmentAtlas.getAtlas() + #arr = fn.imageToArray(atlas.toImage(), copy=True) + #if hasattr(self, 'lastAtlas'): + #if np.any(self.lastAtlas != arr): + #print "Atlas changed:", arr + #self.lastAtlas = arr + + if self.fragments is None: + self.updateSpots() + self.generateFragments() + + p.resetTransform() + + if self.opts['useCache']: + p.drawPixmapFragments(self.fragments, atlas) + else: + for i in range(len(self.data)): + rec = self.data[i] + frag = self.fragments[i] + p.resetTransform() + p.translate(frag.x, frag.y) + drawSymbol(p, *self.getSpotOpts(rec)) + else: + if self.picture is None: + self.picture = QtGui.QPicture() + p2 = QtGui.QPainter(self.picture) + for rec in self.data: + + p2.resetTransform() + p2.translate(rec['x'], rec['y']) + drawSymbol(p2, *self.getSpotOpts(rec)) + p2.end() + + self.picture.play(p) def points(self): + for rec in self.data: + if rec['item'] is None: + rec['item'] = SpotItem(rec, self) return self.data['item'] def pointsAt(self, pos): @@ -506,8 +731,8 @@ class ScatterPlotItem(GraphicsObject): #else: #print "No hit:", (x, y), (sx, sy) #print " ", (sx-s2x, sy-s2y), (sx+s2x, sy+s2y) - pts.sort(lambda a,b: cmp(b.zValue(), a.zValue())) - return pts + #pts.sort(lambda a,b: cmp(b.zValue(), a.zValue())) + return pts[::-1] def mouseClickEvent(self, ev): @@ -524,7 +749,7 @@ class ScatterPlotItem(GraphicsObject): ev.ignore() -class SpotItem(GraphicsItem): +class SpotItem(object): """ Class referring to individual spots in a scatter plot. These can be retrieved by calling ScatterPlotItem.points() or @@ -532,14 +757,12 @@ class SpotItem(GraphicsItem): """ def __init__(self, data, plot): - GraphicsItem.__init__(self, register=False) + #GraphicsItem.__init__(self, register=False) self._data = data self._plot = plot - #self._viewBox = None - #self._viewWidget = None - self.setParentItem(plot) - self.setPos(QtCore.QPointF(data['x'], data['y'])) - self.updateItem() + #self.setParentItem(plot) + #self.setPos(QtCore.QPointF(data['x'], data['y'])) + #self.updateItem() def data(self): """Return the user data associated with this spot.""" @@ -553,6 +776,12 @@ class SpotItem(GraphicsItem): else: return self._data['size'] + def pos(self): + return Point(self._data['x'], self._data['y']) + + def viewPos(self): + return self._plot.mapToView(self.pos()) + def setSize(self, size): """Set the size of this spot. If the size is set to -1, then the ScatterPlotItem's default size @@ -618,37 +847,41 @@ class SpotItem(GraphicsItem): """Set the user-data associated with this spot""" self._data['data'] = data - -class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem): - def __init__(self, data, plot): - QtGui.QGraphicsPixmapItem.__init__(self) - self.setFlags(self.flags() | self.ItemIgnoresTransformations) - SpotItem.__init__(self, data, plot) - - def setPixmap(self, pixmap): - QtGui.QGraphicsPixmapItem.setPixmap(self, pixmap) - self.setOffset(-pixmap.width()/2.+0.5, -pixmap.height()/2.) - def updateItem(self): - symbolOpts = (self._data['pen'], self._data['brush'], self._data['size'], self._data['symbol']) + self._data['fragCoords'] = None + self._plot.updateSpots([self._data]) + self._plot.invalidate() + +#class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem): + #def __init__(self, data, plot): + #QtGui.QGraphicsPixmapItem.__init__(self) + #self.setFlags(self.flags() | self.ItemIgnoresTransformations) + #SpotItem.__init__(self, data, plot) + + #def setPixmap(self, pixmap): + #QtGui.QGraphicsPixmapItem.setPixmap(self, pixmap) + #self.setOffset(-pixmap.width()/2.+0.5, -pixmap.height()/2.) + + #def updateItem(self): + #symbolOpts = (self._data['pen'], self._data['brush'], self._data['size'], self._data['symbol']) - ## If all symbol options are default, use default pixmap - if symbolOpts == (None, None, -1, ''): - pixmap = self._plot.defaultSpotPixmap() - else: - pixmap = makeSymbolPixmap(size=self.size(), pen=self.pen(), brush=self.brush(), symbol=self.symbol()) - self.setPixmap(pixmap) + ### If all symbol options are default, use default pixmap + #if symbolOpts == (None, None, -1, ''): + #pixmap = self._plot.defaultSpotPixmap() + #else: + #pixmap = makeSymbolPixmap(size=self.size(), pen=self.pen(), brush=self.brush(), symbol=self.symbol()) + #self.setPixmap(pixmap) -class PathSpotItem(SpotItem, QtGui.QGraphicsPathItem): - def __init__(self, data, plot): - QtGui.QGraphicsPathItem.__init__(self) - SpotItem.__init__(self, data, plot) +#class PathSpotItem(SpotItem, QtGui.QGraphicsPathItem): + #def __init__(self, data, plot): + #QtGui.QGraphicsPathItem.__init__(self) + #SpotItem.__init__(self, data, plot) - def updateItem(self): - QtGui.QGraphicsPathItem.setPath(self, Symbols[self.symbol()]) - QtGui.QGraphicsPathItem.setPen(self, self.pen()) - QtGui.QGraphicsPathItem.setBrush(self, self.brush()) - size = self.size() - self.resetTransform() - self.scale(size, size) + #def updateItem(self): + #QtGui.QGraphicsPathItem.setPath(self, Symbols[self.symbol()]) + #QtGui.QGraphicsPathItem.setPen(self, self.pen()) + #QtGui.QGraphicsPathItem.setBrush(self, self.brush()) + #size = self.size() + #self.resetTransform() + #self.scale(size, size)