Merge pull request #670 from campagnola/scatterplotwidget-updates

Scatterplotwidget updates
This commit is contained in:
Luke Campagnola 2018-04-25 11:56:40 -07:00 committed by GitHub
commit 66c96edbd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 58 deletions

View File

@ -701,16 +701,12 @@ class ScatterPlotItem(GraphicsObject):
GraphicsObject.setExportMode(self, *args, **kwds) GraphicsObject.setExportMode(self, *args, **kwds)
self.invalidate() self.invalidate()
def mapPointsToDevice(self, pts): def mapPointsToDevice(self, pts):
# Map point locations to device # Map point locations to device
tr = self.deviceTransform() tr = self.deviceTransform()
if tr is None: if tr is None:
return None return None
#pts = np.empty((2,len(self.data['x'])))
#pts[0] = self.data['x']
#pts[1] = self.data['y']
pts = fn.transformCoordinates(tr, pts) pts = fn.transformCoordinates(tr, pts)
pts -= self.data['width'] pts -= self.data['width']
pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault. pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault.
@ -731,7 +727,6 @@ class ScatterPlotItem(GraphicsObject):
(pts[1] - w < viewBounds.bottom())) ## remove out of view points (pts[1] - w < viewBounds.bottom())) ## remove out of view points
return mask return mask
@debug.warnOnException ## raising an exception here causes crash @debug.warnOnException ## raising an exception here causes crash
def paint(self, p, *args): def paint(self, p, *args):
cmode = self.opts.get('compositionMode', None) cmode = self.opts.get('compositionMode', None)
@ -758,8 +753,6 @@ class ScatterPlotItem(GraphicsObject):
# Cull points that are outside view # Cull points that are outside view
viewMask = self.getViewMask(pts) viewMask = self.getViewMask(pts)
#pts = pts[:,mask]
#data = self.data[mask]
if self.opts['useCache'] and self._exportOpts is False: if self.opts['useCache'] and self._exportOpts is False:
# Draw symbols from pre-rendered atlas # Draw symbols from pre-rendered atlas
@ -804,9 +797,9 @@ class ScatterPlotItem(GraphicsObject):
self.picture.play(p) self.picture.play(p)
def points(self): def points(self):
for rec in self.data: for i,rec in enumerate(self.data):
if rec['item'] is None: if rec['item'] is None:
rec['item'] = SpotItem(rec, self) rec['item'] = SpotItem(rec, self, i)
return self.data['item'] return self.data['item']
def pointsAt(self, pos): def pointsAt(self, pos):
@ -854,16 +847,13 @@ class SpotItem(object):
by connecting to the ScatterPlotItem's click signals. by connecting to the ScatterPlotItem's click signals.
""" """
def __init__(self, data, plot): def __init__(self, data, plot, index):
#GraphicsItem.__init__(self, register=False)
self._data = data self._data = data
self._index = index
# SpotItems are kept in plot.data["items"] numpy object array which # SpotItems are kept in plot.data["items"] numpy object array which
# does not support cyclic garbage collection (numpy issue 6581). # does not support cyclic garbage collection (numpy issue 6581).
# Keeping a strong ref to plot here would leak the cycle # Keeping a strong ref to plot here would leak the cycle
self.__plot_ref = weakref.ref(plot) self.__plot_ref = weakref.ref(plot)
#self.setParentItem(plot)
#self.setPos(QtCore.QPointF(data['x'], data['y']))
#self.updateItem()
@property @property
def _plot(self): def _plot(self):
@ -873,6 +863,10 @@ class SpotItem(object):
"""Return the user data associated with this spot.""" """Return the user data associated with this spot."""
return self._data['data'] return self._data['data']
def index(self):
"""Return the index of this point as given in the scatter plot data."""
return self._index
def size(self): def size(self):
"""Return the size of this spot. """Return the size of this spot.
If the spot has no explicit size set, then return the ScatterPlotItem's default size instead.""" If the spot has no explicit size set, then return the ScatterPlotItem's default size instead."""
@ -956,37 +950,3 @@ class SpotItem(object):
self._data['sourceRect'] = None self._data['sourceRect'] = None
self._plot.updateSpots(self._data.reshape(1)) self._plot.updateSpots(self._data.reshape(1))
self._plot.invalidate() self._plot.invalidate()
#class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem):
#def __init__(self, data, plot):
#QtGui.QGraphicsPixmapItem.__init__(self)
#self.setFlags(self.flags() | self.ItemIgnoresTransformations)
#SpotItem.__init__(self, data, plot)
#def setPixmap(self, pixmap):
#QtGui.QGraphicsPixmapItem.setPixmap(self, pixmap)
#self.setOffset(-pixmap.width()/2.+0.5, -pixmap.height()/2.)
#def updateItem(self):
#symbolOpts = (self._data['pen'], self._data['brush'], self._data['size'], self._data['symbol'])
### If all symbol options are default, use default pixmap
#if symbolOpts == (None, None, -1, ''):
#pixmap = self._plot.defaultSpotPixmap()
#else:
#pixmap = makeSymbolPixmap(size=self.size(), pen=self.pen(), brush=self.brush(), symbol=self.symbol())
#self.setPixmap(pixmap)
#class PathSpotItem(SpotItem, QtGui.QGraphicsPathItem):
#def __init__(self, data, plot):
#QtGui.QGraphicsPathItem.__init__(self)
#SpotItem.__init__(self, data, plot)
#def updateItem(self):
#QtGui.QGraphicsPathItem.setPath(self, Symbols[self.symbol()])
#QtGui.QGraphicsPathItem.setPen(self, self.pen())
#QtGui.QGraphicsPathItem.setBrush(self, self.brush())
#size = self.size()
#self.resetTransform()
#self.scale(size, size)

View File

@ -52,16 +52,23 @@ class ScatterPlotWidget(QtGui.QSplitter):
self.ctrlPanel.addWidget(self.ptree) self.ctrlPanel.addWidget(self.ptree)
self.addWidget(self.plot) self.addWidget(self.plot)
bg = fn.mkColor(getConfigOption('background')) fg = fn.mkColor(getConfigOption('foreground'))
bg.setAlpha(150) fg.setAlpha(150)
self.filterText = TextItem(border=getConfigOption('foreground'), color=bg) self.filterText = TextItem(border=getConfigOption('foreground'), color=fg)
self.filterText.setPos(60,20) self.filterText.setPos(60,20)
self.filterText.setParentItem(self.plot.plotItem) self.filterText.setParentItem(self.plot.plotItem)
self.data = None self.data = None
self.indices = None
self.mouseOverField = None self.mouseOverField = None
self.scatterPlot = None self.scatterPlot = None
self.selectionScatter = None
self.selectedIndices = []
self.style = dict(pen=None, symbol='o') self.style = dict(pen=None, symbol='o')
self._visibleXY = None # currently plotted points
self._visibleData = None # currently plotted records
self._visibleIndices = None
self._indexMap = None
self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged) self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged)
self.filter.sigFilterChanged.connect(self.filterChanged) self.filter.sigFilterChanged.connect(self.filterChanged)
@ -102,9 +109,26 @@ class ScatterPlotWidget(QtGui.QSplitter):
Argument must be a numpy record array. Argument must be a numpy record array.
""" """
self.data = data self.data = data
self.indices = np.arange(len(data))
self.filtered = None self.filtered = None
self.filteredIndices = None
self.updatePlot() self.updatePlot()
def setSelectedIndices(self, inds):
"""Mark the specified indices as selected.
Must be a sequence of integers that index into the array given in setData().
"""
self.selectedIndices = inds
self.updateSelected()
def setSelectedPoints(self, points):
"""Mark the specified points as selected.
Must be a list of points as generated by the sigScatterPlotClicked signal.
"""
self.setSelectedIndices([pt.originalIndex for pt in points])
def fieldSelectionChanged(self): def fieldSelectionChanged(self):
sel = self.fieldList.selectedItems() sel = self.fieldList.selectedItems()
if len(sel) > 2: if len(sel) > 2:
@ -129,11 +153,13 @@ class ScatterPlotWidget(QtGui.QSplitter):
def updatePlot(self): def updatePlot(self):
self.plot.clear() self.plot.clear()
if self.data is None: if self.data is None or len(self.data) == 0:
return return
if self.filtered is None: if self.filtered is None:
self.filtered = self.filter.filterData(self.data) mask = self.filter.generateMask(self.data)
self.filtered = self.data[mask]
self.filteredIndices = self.indices[mask]
data = self.filtered data = self.filtered
if len(data) == 0: if len(data) == 0:
return return
@ -194,6 +220,8 @@ class ScatterPlotWidget(QtGui.QSplitter):
xy[0] = xy[0][mask] xy[0] = xy[0][mask]
style['symbolBrush'] = colors[mask] style['symbolBrush'] = colors[mask]
data = data[mask]
indices = self.filteredIndices[mask]
## Scatter y-values for a histogram-like appearance ## Scatter y-values for a histogram-like appearance
if xy[1] is None: if xy[1] is None:
@ -215,16 +243,44 @@ class ScatterPlotWidget(QtGui.QSplitter):
if smax != 0: if smax != 0:
scatter *= 0.2 / smax scatter *= 0.2 / smax
xy[ax][keymask] += scatter xy[ax][keymask] += scatter
if self.scatterPlot is not None: if self.scatterPlot is not None:
try: try:
self.scatterPlot.sigPointsClicked.disconnect(self.plotClicked) self.scatterPlot.sigPointsClicked.disconnect(self.plotClicked)
except: except:
pass pass
self.scatterPlot = self.plot.plot(xy[0], xy[1], data=data[mask], **style)
self.scatterPlot.sigPointsClicked.connect(self.plotClicked)
self._visibleXY = xy
self._visibleData = data
self._visibleIndices = indices
self._indexMap = None
self.scatterPlot = self.plot.plot(xy[0], xy[1], data=data, **style)
self.scatterPlot.sigPointsClicked.connect(self.plotClicked)
self.updateSelected()
def updateSelected(self):
if self._visibleXY is None:
return
# map from global index to visible index
indMap = self._getIndexMap()
inds = [indMap[i] for i in self.selectedIndices if i in indMap]
x,y = self._visibleXY[0][inds], self._visibleXY[1][inds]
if self.selectionScatter is not None:
self.plot.plotItem.removeItem(self.selectionScatter)
if len(x) == 0:
return
self.selectionScatter = self.plot.plot(x, y, pen=None, symbol='s', symbolSize=12, symbolBrush=None, symbolPen='y')
def _getIndexMap(self):
# mapping from original data index to visible point index
if self._indexMap is None:
self._indexMap = {j:i for i,j in enumerate(self._visibleIndices)}
return self._indexMap
def plotClicked(self, plot, points): def plotClicked(self, plot, points):
# Tag each point with its index into the original dataset
for pt in points:
pt.originalIndex = self._visibleIndices[pt.index()]
self.sigScatterPlotClicked.emit(self, points) self.sigScatterPlotClicked.emit(self, points)