Make it easier to track data in and out of scatterplotwidget

This commit is contained in:
Luke Campagnola 2018-03-15 11:59:45 -07:00
parent ee0877170d
commit 96a3d216e2
2 changed files with 71 additions and 55 deletions

View File

@ -701,16 +701,12 @@ class ScatterPlotItem(GraphicsObject):
GraphicsObject.setExportMode(self, *args, **kwds)
self.invalidate()
def mapPointsToDevice(self, pts):
# Map point locations to device
tr = self.deviceTransform()
if tr is None:
return None
#pts = np.empty((2,len(self.data['x'])))
#pts[0] = self.data['x']
#pts[1] = self.data['y']
pts = fn.transformCoordinates(tr, pts)
pts -= self.data['width']
pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault.
@ -731,7 +727,6 @@ class ScatterPlotItem(GraphicsObject):
(pts[1] - w < viewBounds.bottom())) ## remove out of view points
return mask
@debug.warnOnException ## raising an exception here causes crash
def paint(self, p, *args):
cmode = self.opts.get('compositionMode', None)
@ -758,8 +753,6 @@ class ScatterPlotItem(GraphicsObject):
# Cull points that are outside view
viewMask = self.getViewMask(pts)
#pts = pts[:,mask]
#data = self.data[mask]
if self.opts['useCache'] and self._exportOpts is False:
# Draw symbols from pre-rendered atlas
@ -804,9 +797,9 @@ class ScatterPlotItem(GraphicsObject):
self.picture.play(p)
def points(self):
for rec in self.data:
for i,rec in enumerate(self.data):
if rec['item'] is None:
rec['item'] = SpotItem(rec, self)
rec['item'] = SpotItem(rec, self, i)
return self.data['item']
def pointsAt(self, pos):
@ -854,18 +847,19 @@ class SpotItem(object):
by connecting to the ScatterPlotItem's click signals.
"""
def __init__(self, data, plot):
#GraphicsItem.__init__(self, register=False)
def __init__(self, data, plot, index):
self._data = data
self._plot = plot
#self.setParentItem(plot)
#self.setPos(QtCore.QPointF(data['x'], data['y']))
#self.updateItem()
self._index = index
def data(self):
"""Return the user data associated with this spot."""
return self._data['data']
def index(self):
"""Return the index of this point as given in the scatter plot data."""
return self._index
def size(self):
"""Return the size of this spot.
If the spot has no explicit size set, then return the ScatterPlotItem's default size instead."""
@ -949,37 +943,3 @@ class SpotItem(object):
self._data['sourceRect'] = None
self._plot.updateSpots(self._data.reshape(1))
self._plot.invalidate()
#class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem):
#def __init__(self, data, plot):
#QtGui.QGraphicsPixmapItem.__init__(self)
#self.setFlags(self.flags() | self.ItemIgnoresTransformations)
#SpotItem.__init__(self, data, plot)
#def setPixmap(self, pixmap):
#QtGui.QGraphicsPixmapItem.setPixmap(self, pixmap)
#self.setOffset(-pixmap.width()/2.+0.5, -pixmap.height()/2.)
#def updateItem(self):
#symbolOpts = (self._data['pen'], self._data['brush'], self._data['size'], self._data['symbol'])
### If all symbol options are default, use default pixmap
#if symbolOpts == (None, None, -1, ''):
#pixmap = self._plot.defaultSpotPixmap()
#else:
#pixmap = makeSymbolPixmap(size=self.size(), pen=self.pen(), brush=self.brush(), symbol=self.symbol())
#self.setPixmap(pixmap)
#class PathSpotItem(SpotItem, QtGui.QGraphicsPathItem):
#def __init__(self, data, plot):
#QtGui.QGraphicsPathItem.__init__(self)
#SpotItem.__init__(self, data, plot)
#def updateItem(self):
#QtGui.QGraphicsPathItem.setPath(self, Symbols[self.symbol()])
#QtGui.QGraphicsPathItem.setPen(self, self.pen())
#QtGui.QGraphicsPathItem.setBrush(self, self.brush())
#size = self.size()
#self.resetTransform()
#self.scale(size, size)

View File

@ -59,9 +59,16 @@ class ScatterPlotWidget(QtGui.QSplitter):
self.filterText.setParentItem(self.plot.plotItem)
self.data = None
self.indices = None
self.mouseOverField = None
self.scatterPlot = None
self.selectionScatter = None
self.selectedIndices = []
self.style = dict(pen=None, symbol='o')
self._visibleXY = None # currently plotted points
self._visibleData = None # currently plotted records
self._visibleIndices = None
self._indexMap = None
self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged)
self.filter.sigFilterChanged.connect(self.filterChanged)
@ -102,9 +109,26 @@ class ScatterPlotWidget(QtGui.QSplitter):
Argument must be a numpy record array.
"""
self.data = data
self.indices = np.arange(len(data))
self.filtered = None
self.filteredIndices = None
self.updatePlot()
def setSelectedIndices(self, inds):
"""Mark the specified indices as selected.
Must be a sequence of integers that index into the array given in setData().
"""
self.selectedIndices = inds
self.updateSelected()
def setSelectedPoints(self, points):
"""Mark the specified points as selected.
Must be a list of points as generated by the sigScatterPlotClicked signal.
"""
self.setSelectedIndices([pt.originalIndex for pt in points])
def fieldSelectionChanged(self):
sel = self.fieldList.selectedItems()
if len(sel) > 2:
@ -129,11 +153,13 @@ class ScatterPlotWidget(QtGui.QSplitter):
def updatePlot(self):
self.plot.clear()
if self.data is None:
if self.data is None or len(self.data) == 0:
return
if self.filtered is None:
self.filtered = self.filter.filterData(self.data)
mask = self.filter.generateMask(self.data)
self.filtered = self.data[mask]
self.filteredIndices = self.indices[mask]
data = self.filtered
if len(data) == 0:
return
@ -194,6 +220,8 @@ class ScatterPlotWidget(QtGui.QSplitter):
xy[0] = xy[0][mask]
style['symbolBrush'] = colors[mask]
data = data[mask]
indices = self.filteredIndices[mask]
## Scatter y-values for a histogram-like appearance
if xy[1] is None:
@ -216,15 +244,43 @@ class ScatterPlotWidget(QtGui.QSplitter):
scatter *= 0.2 / smax
xy[ax][keymask] += scatter
if self.scatterPlot is not None:
try:
self.scatterPlot.sigPointsClicked.disconnect(self.plotClicked)
except:
pass
self.scatterPlot = self.plot.plot(xy[0], xy[1], data=data[mask], **style)
self._visibleXY = xy
self._visibleData = data
self._visibleIndices = indices
self._indexMap = None
self.scatterPlot = self.plot.plot(xy[0], xy[1], data=data, **style)
self.scatterPlot.sigPointsClicked.connect(self.plotClicked)
self.updateSelected()
def updateSelected(self):
if self._visibleXY is None:
return
# map from global index to visible index
indMap = self._getIndexMap()
inds = [indMap[i] for i in self.selectedIndices if i in indMap]
x,y = self._visibleXY[0][inds], self._visibleXY[1][inds]
if self.selectionScatter is not None:
self.plot.plotItem.removeItem(self.selectionScatter)
if len(x) == 0:
return
self.selectionScatter = self.plot.plot(x, y, pen=None, symbol='s', symbolSize=12, symbolBrush=None, symbolPen='y')
def _getIndexMap(self):
# mapping from original data index to visible point index
if self._indexMap is None:
self._indexMap = {j:i for i,j in enumerate(self._visibleIndices)}
return self._indexMap
def plotClicked(self, plot, points):
# Tag each point with its index into the original dataset
for pt in points:
pt.originalIndex = self._visibleIndices[pt.index()]
self.sigScatterPlotClicked.emit(self, points)