From 17409bc9a6b48f9ca8a1764f4a4e4b2013c09d69 Mon Sep 17 00:00:00 2001 From: Luke Campagnola <> Date: Sun, 10 Feb 2013 14:10:30 -0500 Subject: [PATCH] Merge new fixes and features from acq4 --- GraphicsScene/exportDialog.py | 6 + GraphicsScene/exportDialogTemplate.ui | 7 + GraphicsScene/exportDialogTemplate_pyqt.py | 8 +- GraphicsScene/exportDialogTemplate_pyside.py | 8 +- PlotData.py | 55 ++++ colormap.py | 262 +++++++++++++++++++ debug.py | 18 ++ exporters/Exporter.py | 8 +- exporters/ImageExporter.py | 14 +- exporters/SVGExporter.py | 29 +- graphicsItems/AxisItem.py | 5 +- graphicsItems/GradientEditorItem.py | 31 ++- graphicsItems/GraphicsItem.py | 41 ++- graphicsItems/PlotDataItem.py | 17 +- graphicsItems/PlotItem/PlotItem.py | 26 ++ graphicsItems/ScatterPlotItem.py | 18 +- graphicsItems/ViewBox/ViewBox.py | 9 +- parametertree/Parameter.py | 11 +- parametertree/parameterTypes.py | 16 +- rebuildUi.py | 4 +- widgets/ColorButton.py | 10 +- widgets/ColorMapWidget.py | 173 ++++++++++++ widgets/DataFilterWidget.py | 115 ++++++++ widgets/ScatterPlotWidget.py | 183 +++++++++++++ 24 files changed, 1023 insertions(+), 51 deletions(-) create mode 100644 PlotData.py create mode 100644 colormap.py create mode 100644 widgets/ColorMapWidget.py create mode 100644 widgets/DataFilterWidget.py create mode 100644 widgets/ScatterPlotWidget.py diff --git a/GraphicsScene/exportDialog.py b/GraphicsScene/exportDialog.py index dafcd501..73a8c83f 100644 --- a/GraphicsScene/exportDialog.py +++ b/GraphicsScene/exportDialog.py @@ -27,6 +27,7 @@ class ExportDialog(QtGui.QWidget): self.ui.closeBtn.clicked.connect(self.close) self.ui.exportBtn.clicked.connect(self.exportClicked) + self.ui.copyBtn.clicked.connect(self.copyClicked) self.ui.itemTree.currentItemChanged.connect(self.exportItemChanged) self.ui.formatList.currentItemChanged.connect(self.exportFormatChanged) @@ -116,11 +117,16 @@ class ExportDialog(QtGui.QWidget): else: self.ui.paramTree.setParameters(params) self.currentExporter = exp + self.ui.copyBtn.setEnabled(exp.allowCopy) def exportClicked(self): self.selectBox.hide() self.currentExporter.export() + def copyClicked(self): + self.selectBox.hide() + self.currentExporter.export(copy=True) + def close(self): self.selectBox.setVisible(False) self.setVisible(False) diff --git a/GraphicsScene/exportDialogTemplate.ui b/GraphicsScene/exportDialogTemplate.ui index c81c8831..c91fbc3f 100644 --- a/GraphicsScene/exportDialogTemplate.ui +++ b/GraphicsScene/exportDialogTemplate.ui @@ -79,6 +79,13 @@ + + + + Copy + + + diff --git a/GraphicsScene/exportDialogTemplate_pyqt.py b/GraphicsScene/exportDialogTemplate_pyqt.py index 20609b51..c3056d1c 100644 --- a/GraphicsScene/exportDialogTemplate_pyqt.py +++ b/GraphicsScene/exportDialogTemplate_pyqt.py @@ -2,8 +2,8 @@ # Form implementation generated from reading ui file './GraphicsScene/exportDialogTemplate.ui' # -# Created: Sun Sep 9 14:41:31 2012 -# by: PyQt4 UI code generator 4.9.1 +# Created: Wed Jan 30 21:02:28 2013 +# by: PyQt4 UI code generator 4.9.3 # # WARNING! All changes made in this file will be lost! @@ -49,6 +49,9 @@ class Ui_Form(object): self.label_3 = QtGui.QLabel(Form) self.label_3.setObjectName(_fromUtf8("label_3")) self.gridLayout.addWidget(self.label_3, 4, 0, 1, 3) + self.copyBtn = QtGui.QPushButton(Form) + self.copyBtn.setObjectName(_fromUtf8("copyBtn")) + self.gridLayout.addWidget(self.copyBtn, 6, 0, 1, 1) self.retranslateUi(Form) QtCore.QMetaObject.connectSlotsByName(Form) @@ -60,5 +63,6 @@ class Ui_Form(object): self.exportBtn.setText(QtGui.QApplication.translate("Form", "Export", None, QtGui.QApplication.UnicodeUTF8)) self.closeBtn.setText(QtGui.QApplication.translate("Form", "Close", None, QtGui.QApplication.UnicodeUTF8)) self.label_3.setText(QtGui.QApplication.translate("Form", "Export options", None, QtGui.QApplication.UnicodeUTF8)) + self.copyBtn.setText(QtGui.QApplication.translate("Form", "Copy", None, QtGui.QApplication.UnicodeUTF8)) from pyqtgraph.parametertree import ParameterTree diff --git a/GraphicsScene/exportDialogTemplate_pyside.py b/GraphicsScene/exportDialogTemplate_pyside.py index 4ffc0b9a..cf27f60a 100644 --- a/GraphicsScene/exportDialogTemplate_pyside.py +++ b/GraphicsScene/exportDialogTemplate_pyside.py @@ -2,8 +2,8 @@ # Form implementation generated from reading ui file './GraphicsScene/exportDialogTemplate.ui' # -# Created: Sun Sep 9 14:41:31 2012 -# by: pyside-uic 0.2.13 running on PySide 1.1.0 +# Created: Wed Jan 30 21:02:28 2013 +# by: pyside-uic 0.2.13 running on PySide 1.1.1 # # WARNING! All changes made in this file will be lost! @@ -44,6 +44,9 @@ class Ui_Form(object): self.label_3 = QtGui.QLabel(Form) self.label_3.setObjectName("label_3") self.gridLayout.addWidget(self.label_3, 4, 0, 1, 3) + self.copyBtn = QtGui.QPushButton(Form) + self.copyBtn.setObjectName("copyBtn") + self.gridLayout.addWidget(self.copyBtn, 6, 0, 1, 1) self.retranslateUi(Form) QtCore.QMetaObject.connectSlotsByName(Form) @@ -55,5 +58,6 @@ class Ui_Form(object): self.exportBtn.setText(QtGui.QApplication.translate("Form", "Export", None, QtGui.QApplication.UnicodeUTF8)) self.closeBtn.setText(QtGui.QApplication.translate("Form", "Close", None, QtGui.QApplication.UnicodeUTF8)) self.label_3.setText(QtGui.QApplication.translate("Form", "Export options", None, QtGui.QApplication.UnicodeUTF8)) + self.copyBtn.setText(QtGui.QApplication.translate("Form", "Copy", None, QtGui.QApplication.UnicodeUTF8)) from pyqtgraph.parametertree import ParameterTree diff --git a/PlotData.py b/PlotData.py new file mode 100644 index 00000000..18531c14 --- /dev/null +++ b/PlotData.py @@ -0,0 +1,55 @@ + + +class PlotData(object): + """ + Class used for managing plot data + - allows data sharing between multiple graphics items (curve, scatter, graph..) + - each item may define the columns it needs + - column groupings ('pos' or x, y, z) + - efficiently appendable + - log, fft transformations + - color mode conversion (float/byte/qcolor) + - pen/brush conversion + - per-field cached masking + - allows multiple masking fields (different graphics need to mask on different criteria) + - removal of nan/inf values + - option for single value shared by entire column + - cached downsampling + """ + def __init__(self): + self.fields = {} + + self.maxVals = {} ## cache for max/min + self.minVals = {} + + def addFields(self, fields): + for f in fields: + if f not in self.fields: + self.fields[f] = None + + def hasField(self, f): + return f in self.fields + + def __getitem__(self, field): + return self.fields[field] + + def __setitem__(self, field, val): + self.fields[field] = val + + def max(self, field): + mx = self.maxVals.get(field, None) + if mx is None: + mx = np.max(self[field]) + self.maxVals[field] = mx + return mx + + def min(self, field): + mn = self.minVals.get(field, None) + if mn is None: + mn = np.min(self[field]) + self.minVals[field] = mn + return mn + + + + \ No newline at end of file diff --git a/colormap.py b/colormap.py new file mode 100644 index 00000000..c7e683fb --- /dev/null +++ b/colormap.py @@ -0,0 +1,262 @@ +import numpy as np +import scipy.interpolate +from pyqtgraph.Qt import QtGui, QtCore + +class ColorMap(object): + + ## color interpolation modes + RGB = 1 + HSV_POS = 2 + HSV_NEG = 3 + + ## boundary modes + CLIP = 1 + REPEAT = 2 + MIRROR = 3 + + ## return types + BYTE = 1 + FLOAT = 2 + QCOLOR = 3 + + enumMap = { + 'rgb': RGB, + 'hsv+': HSV_POS, + 'hsv-': HSV_NEG, + 'clip': CLIP, + 'repeat': REPEAT, + 'mirror': MIRROR, + 'byte': BYTE, + 'float': FLOAT, + 'qcolor': QCOLOR, + } + + def __init__(self, pos, color, mode=None): + """ + ========= ============================================================== + Arguments + pos Array of positions where each color is defined + color Array of RGBA colors. + Integer data types are interpreted as 0-255; float data types + are interpreted as 0.0-1.0 + mode Array of color modes (ColorMap.RGB, HSV_POS, or HSV_NEG) + indicating the color space that should be used when + interpolating between stops. Note that the last mode value is + ignored. By default, the mode is entirely RGB. + ========= ============================================================== + """ + self.pos = pos + self.color = color + if mode is None: + mode = np.ones(len(pos)) + self.mode = mode + self.stopsCache = {} + + def map(self, data, mode='byte'): + """ + Data must be either a scalar position or an array (any shape) of positions. + """ + if isinstance(mode, basestring): + mode = self.enumMap[mode.lower()] + + if mode == self.QCOLOR: + pos, color = self.getStops(self.BYTE) + else: + pos, color = self.getStops(mode) + + data = np.clip(data, pos.min(), pos.max()) + + if not isinstance(data, np.ndarray): + interp = scipy.interpolate.griddata(pos, color, np.array([data]))[0] + else: + interp = scipy.interpolate.griddata(pos, color, data) + + if mode == self.QCOLOR: + if not isinstance(data, np.ndarray): + return QtGui.QColor(*interp) + else: + return [QtGui.QColor(*x) for x in interp] + else: + return interp + + def mapToQColor(self, data): + return self.map(data, mode=self.QCOLOR) + + def mapToByte(self, data): + return self.map(data, mode=self.BYTE) + + def mapToFloat(self, data): + return self.map(data, mode=self.FLOAT) + + def getGradient(self, p1=None, p2=None): + """Return a QLinearGradient object.""" + if p1 == None: + p1 = QtCore.QPointF(0,0) + if p2 == None: + p2 = QtCore.QPointF(self.pos.max()-self.pos.min(),0) + g = QtGui.QLinearGradient(p1, p2) + + pos, color = self.getStops(mode=self.BYTE) + color = [QtGui.QColor(*x) for x in color] + g.setStops(zip(pos, color)) + + #if self.colorMode == 'rgb': + #ticks = self.listTicks() + #g.setStops([(x, QtGui.QColor(t.color)) for t,x in ticks]) + #elif self.colorMode == 'hsv': ## HSV mode is approximated for display by interpolating 10 points between each stop + #ticks = self.listTicks() + #stops = [] + #stops.append((ticks[0][1], ticks[0][0].color)) + #for i in range(1,len(ticks)): + #x1 = ticks[i-1][1] + #x2 = ticks[i][1] + #dx = (x2-x1) / 10. + #for j in range(1,10): + #x = x1 + dx*j + #stops.append((x, self.getColor(x))) + #stops.append((x2, self.getColor(x2))) + #g.setStops(stops) + return g + + def getColors(self, mode=None): + """Return list of all colors converted to the specified mode. + If mode is None, then no conversion is done.""" + if isinstance(mode, basestring): + mode = self.enumMap[mode.lower()] + + color = self.color + if mode in [self.BYTE, self.QCOLOR] and color.dtype.kind == 'f': + color = (color * 255).astype(np.ubyte) + elif mode == self.FLOAT and color.dtype.kind != 'f': + color = color.astype(float) / 255. + + if mode == self.QCOLOR: + color = [QtGui.QColor(*x) for x in color] + + return color + + def getStops(self, mode): + ## Get fully-expanded set of RGBA stops in either float or byte mode. + if mode not in self.stopsCache: + color = self.color + if mode == self.BYTE and color.dtype.kind == 'f': + color = (color * 255).astype(np.ubyte) + elif mode == self.FLOAT and color.dtype.kind != 'f': + color = color.astype(float) / 255. + + ## to support HSV mode, we need to do a little more work.. + #stops = [] + #for i in range(len(self.pos)): + #pos = self.pos[i] + #color = color[i] + + #imode = self.mode[i] + #if imode == self.RGB: + #stops.append((x,color)) + #else: + #ns = + self.stopsCache[mode] = (self.pos, color) + return self.stopsCache[mode] + + #def getColor(self, x, toQColor=True): + #""" + #Return a color for a given value. + + #============= ================================================================== + #**Arguments** + #x Value (position on gradient) of requested color. + #toQColor If true, returns a QColor object, else returns a (r,g,b,a) tuple. + #============= ================================================================== + #""" + #ticks = self.listTicks() + #if x <= ticks[0][1]: + #c = ticks[0][0].color + #if toQColor: + #return QtGui.QColor(c) # always copy colors before handing them out + #else: + #return (c.red(), c.green(), c.blue(), c.alpha()) + #if x >= ticks[-1][1]: + #c = ticks[-1][0].color + #if toQColor: + #return QtGui.QColor(c) # always copy colors before handing them out + #else: + #return (c.red(), c.green(), c.blue(), c.alpha()) + + #x2 = ticks[0][1] + #for i in range(1,len(ticks)): + #x1 = x2 + #x2 = ticks[i][1] + #if x1 <= x and x2 >= x: + #break + + #dx = (x2-x1) + #if dx == 0: + #f = 0. + #else: + #f = (x-x1) / dx + #c1 = ticks[i-1][0].color + #c2 = ticks[i][0].color + #if self.colorMode == 'rgb': + #r = c1.red() * (1.-f) + c2.red() * f + #g = c1.green() * (1.-f) + c2.green() * f + #b = c1.blue() * (1.-f) + c2.blue() * f + #a = c1.alpha() * (1.-f) + c2.alpha() * f + #if toQColor: + #return QtGui.QColor(int(r), int(g), int(b), int(a)) + #else: + #return (r,g,b,a) + #elif self.colorMode == 'hsv': + #h1,s1,v1,_ = c1.getHsv() + #h2,s2,v2,_ = c2.getHsv() + #h = h1 * (1.-f) + h2 * f + #s = s1 * (1.-f) + s2 * f + #v = v1 * (1.-f) + v2 * f + #c = QtGui.QColor() + #c.setHsv(h,s,v) + #if toQColor: + #return c + #else: + #return (c.red(), c.green(), c.blue(), c.alpha()) + + def getLookupTable(self, start=0.0, stop=1.0, nPts=512, alpha=None, mode='byte'): + """ + Return an RGB(A) lookup table (ndarray). + + ============= ============================================================================ + **Arguments** + nPts The number of points in the returned lookup table. + alpha True, False, or None - Specifies whether or not alpha values are included + in the table. If alpha is None, it will be automatically determined. + ============= ============================================================================ + """ + if isinstance(mode, basestring): + mode = self.enumMap[mode.lower()] + + if alpha is None: + alpha = self.usesAlpha() + + x = np.linspace(start, stop, nPts) + table = self.map(x, mode) + + if not alpha: + return table[:,:3] + else: + return table + + def usesAlpha(self): + """Return True if any stops have an alpha < 255""" + max = 1.0 if self.color.dtype.kind == 'f' else 255 + return np.any(self.color[:,3] != max) + + def isMapTrivial(self): + """Return True if the gradient has exactly two stops in it: black at 0.0 and white at 1.0""" + if len(self.pos) != 2: + return False + if self.pos[0] != 0.0 or self.pos[1] != 1.0: + return False + if self.color.dtype.kind == 'f': + return np.all(self.color == np.array([[0.,0.,0.,1.], [1.,1.,1.,1.]])) + else: + return np.all(self.color == np.array([[0,0,0,255], [255,255,255,255]])) + + diff --git a/debug.py b/debug.py index 7fa169a4..ae2b21ac 100644 --- a/debug.py +++ b/debug.py @@ -917,3 +917,21 @@ def qObjectReport(verbose=False): for t in typs: print(count[t], "\t", t) + +class PrintDetector(object): + def __init__(self): + self.stdout = sys.stdout + sys.stdout = self + + def remove(self): + sys.stdout = self.stdout + + def __del__(self): + self.remove() + + def write(self, x): + self.stdout.write(x) + traceback.print_stack() + + def flush(self): + self.stdout.flush() \ No newline at end of file diff --git a/exporters/Exporter.py b/exporters/Exporter.py index b1a663bc..81930670 100644 --- a/exporters/Exporter.py +++ b/exporters/Exporter.py @@ -9,7 +9,8 @@ class Exporter(object): """ Abstract class used for exporting graphics to file / printer / whatever. """ - + allowCopy = False # subclasses set this to True if they can use the copy buffer + def __init__(self, item): """ Initialize with the item to be exported. @@ -25,10 +26,11 @@ class Exporter(object): """Return the parameters used to configure this exporter.""" raise Exception("Abstract method must be overridden in subclass.") - def export(self, fileName=None, toBytes=False): + def export(self, fileName=None, toBytes=False, copy=False): """ If *fileName* is None, pop-up a file dialog. - If *toString* is True, return a bytes object rather than writing to file. + If *toBytes* is True, return a bytes object rather than writing to file. + If *copy* is True, export to the copy buffer rather than writing to file. """ raise Exception("Abstract method must be overridden in subclass.") diff --git a/exporters/ImageExporter.py b/exporters/ImageExporter.py index cb6cf396..bdb8b9be 100644 --- a/exporters/ImageExporter.py +++ b/exporters/ImageExporter.py @@ -8,6 +8,8 @@ __all__ = ['ImageExporter'] class ImageExporter(Exporter): Name = "Image File (PNG, TIF, JPG, ...)" + allowCopy = True + def __init__(self, item): Exporter.__init__(self, item) tr = self.getTargetRect() @@ -38,8 +40,8 @@ class ImageExporter(Exporter): def parameters(self): return self.params - def export(self, fileName=None): - if fileName is None: + def export(self, fileName=None, toBytes=False, copy=False): + if fileName is None and not toBytes and not copy: filter = ["*."+str(f) for f in QtGui.QImageWriter.supportedImageFormats()] preferred = ['*.png', '*.tif', '*.jpg'] for p in preferred[::-1]: @@ -78,6 +80,12 @@ class ImageExporter(Exporter): finally: self.setExportMode(False) painter.end() - self.png.save(fileName) + + if copy: + QtGui.QApplication.clipboard().setImage(self.png) + elif toBytes: + return self.png + else: + self.png.save(fileName) \ No newline at end of file diff --git a/exporters/SVGExporter.py b/exporters/SVGExporter.py index 587282e0..b284db89 100644 --- a/exporters/SVGExporter.py +++ b/exporters/SVGExporter.py @@ -11,6 +11,8 @@ __all__ = ['SVGExporter'] class SVGExporter(Exporter): Name = "Scalable Vector Graphics (SVG)" + allowCopy=True + def __init__(self, item): Exporter.__init__(self, item) #tr = self.getTargetRect() @@ -37,8 +39,8 @@ class SVGExporter(Exporter): def parameters(self): return self.params - def export(self, fileName=None, toBytes=False): - if toBytes is False and fileName is None: + def export(self, fileName=None, toBytes=False, copy=False): + if toBytes is False and copy is False and fileName is None: self.fileSaveDialog(filter="Scalable Vector Graphics (*.svg)") return #self.svg = QtSvg.QSvgGenerator() @@ -83,11 +85,16 @@ class SVGExporter(Exporter): xml = generateSvg(self.item) if toBytes: - return bytes(xml) + return xml.encode('UTF-8') + elif copy: + md = QtCore.QMimeData() + md.setData('image/svg+xml', QtCore.QByteArray(xml.encode('UTF-8'))) + QtGui.QApplication.clipboard().setMimeData(md) else: with open(fileName, 'w') as fh: fh.write(xml.encode('UTF-8')) + xmlHeader = """\ @@ -148,7 +155,7 @@ def _generateItemSvg(item, nodes=None, root=None): ## ## Both 2 and 3 can be addressed by drawing all items in world coordinates. - + prof = pg.debug.Profiler('generateItemSvg %s' % str(item), disabled=True) if nodes is None: ## nodes maps all node IDs to their XML element. ## this allows us to ensure all elements receive unique names. @@ -170,8 +177,12 @@ def _generateItemSvg(item, nodes=None, root=None): tr = QtGui.QTransform() if isinstance(item, QtGui.QGraphicsScene): xmlStr = "\n\n" - childs = [i for i in item.items() if i.parentItem() is None] doc = xml.parseString(xmlStr) + childs = [i for i in item.items() if i.parentItem() is None] + elif item.__class__.paint == QtGui.QGraphicsItem.paint: + xmlStr = "\n\n" + doc = xml.parseString(xmlStr) + childs = item.childItems() else: childs = item.childItems() tr = itemTransform(item, item.scene()) @@ -223,11 +234,12 @@ def _generateItemSvg(item, nodes=None, root=None): print(doc.toxml()) raise + prof.mark('render') ## Get rid of group transformation matrices by applying ## transformation to inner coordinates correctCoordinates(g1, item) - + prof.mark('correct') ## make sure g1 has the transformation matrix #m = (tr.m11(), tr.m12(), tr.m21(), tr.m22(), tr.m31(), tr.m32()) #g1.setAttribute('transform', "matrix(%f,%f,%f,%f,%f,%f)" % m) @@ -277,6 +289,8 @@ def _generateItemSvg(item, nodes=None, root=None): childGroup = g1.ownerDocument.createElement('g') childGroup.setAttribute('clip-path', 'url(#%s)' % clip) g1.appendChild(childGroup) + prof.mark('clipping') + ## Add all child items as sub-elements. childs.sort(key=lambda c: c.zValue()) for ch in childs: @@ -284,7 +298,8 @@ def _generateItemSvg(item, nodes=None, root=None): if cg is None: continue childGroup.appendChild(cg) ### this isn't quite right--some items draw below their parent (good enough for now) - + prof.mark('children') + prof.finish() return g1 def correctCoordinates(node, item): diff --git a/graphicsItems/AxisItem.py b/graphicsItems/AxisItem.py index d5b09915..9ef64763 100644 --- a/graphicsItems/AxisItem.py +++ b/graphicsItems/AxisItem.py @@ -683,7 +683,7 @@ class AxisItem(GraphicsWidget): if tickPositions[i][j] is None: strings[j] = None - textRects.extend([p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, s) for s in strings if s is not None]) + textRects.extend([p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, str(s)) for s in strings if s is not None]) if i > 0: ## always draw top level ## measure all text, make sure there's enough room if axis == 0: @@ -699,8 +699,9 @@ class AxisItem(GraphicsWidget): #strings = self.tickStrings(values, self.scale, spacing) for j in range(len(strings)): vstr = strings[j] - if vstr is None:## this tick was ignored because it is out of bounds + if vstr is None: ## this tick was ignored because it is out of bounds continue + vstr = str(vstr) x = tickPositions[i][j] textRect = p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, vstr) height = textRect.height() diff --git a/graphicsItems/GradientEditorItem.py b/graphicsItems/GradientEditorItem.py index 3c078ede..5439c731 100644 --- a/graphicsItems/GradientEditorItem.py +++ b/graphicsItems/GradientEditorItem.py @@ -5,6 +5,8 @@ from .GraphicsObject import GraphicsObject from .GraphicsWidget import GraphicsWidget import weakref from pyqtgraph.pgcollections import OrderedDict +from pyqtgraph.colormap import ColorMap + import numpy as np __all__ = ['TickSliderItem', 'GradientEditorItem'] @@ -22,6 +24,9 @@ Gradients = OrderedDict([ ]) + + + class TickSliderItem(GraphicsWidget): ## public class """**Bases:** :class:`GraphicsWidget ` @@ -490,6 +495,18 @@ class GradientEditorItem(TickSliderItem): self.colorMode = cm self.updateGradient() + def colorMap(self): + """Return a ColorMap object representing the current state of the editor.""" + if self.colorMode == 'hsv': + raise NotImplementedError('hsv colormaps not yet supported') + pos = [] + color = [] + for t,x in self.listTicks(): + pos.append(x) + c = t.color + color.append([c.red(), c.green(), c.blue(), c.alpha()]) + return ColorMap(np.array(pos), np.array(color, dtype=np.ubyte)) + def updateGradient(self): #private self.gradient = self.getGradient() @@ -611,7 +628,7 @@ class GradientEditorItem(TickSliderItem): b = c1.blue() * (1.-f) + c2.blue() * f a = c1.alpha() * (1.-f) + c2.alpha() * f if toQColor: - return QtGui.QColor(r, g, b,a) + return QtGui.QColor(int(r), int(g), int(b), int(a)) else: return (r,g,b,a) elif self.colorMode == 'hsv': @@ -751,6 +768,18 @@ class GradientEditorItem(TickSliderItem): self.addTick(t[0], c, finish=False) self.updateGradient() self.sigGradientChangeFinished.emit(self) + + def setColorMap(self, cm): + self.setColorMode('rgb') + for t in list(self.ticks.keys()): + self.removeTick(t, finish=False) + colors = cm.getColors(mode='qcolor') + for i in range(len(cm.pos)): + x = cm.pos[i] + c = colors[i] + self.addTick(x, c, finish=False) + self.updateGradient() + self.sigGradientChangeFinished.emit(self) class Tick(GraphicsObject): diff --git a/graphicsItems/GraphicsItem.py b/graphicsItems/GraphicsItem.py index 75e72177..1795e79e 100644 --- a/graphicsItems/GraphicsItem.py +++ b/graphicsItems/GraphicsItem.py @@ -20,7 +20,7 @@ class FiniteCache(OrderedDict): del self[list(self.keys())[0]] def __getitem__(self, item): - val = dict.__getitem__(self, item) + val = OrderedDict.__getitem__(self, item) del self[item] self[item] = val ## promote this key return val @@ -194,6 +194,10 @@ class GraphicsItem(object): dt = self.deviceTransform() if dt is None: return None, None + + ## Ignore translation. If the translation is much larger than the scale + ## (such as when looking at unix timestamps), we can get floating-point errors. + dt.setMatrix(dt.m11(), dt.m12(), 0, dt.m21(), dt.m22(), 0, 0, 0, 1) ## check local cache if direction is None and dt == self._pixelVectorCache[0]: @@ -213,15 +217,32 @@ class GraphicsItem(object): raise Exception("Cannot compute pixel length for 0-length vector.") ## attempt to re-scale direction vector to fit within the precision of the coordinate system - if direction.x() == 0: - r = abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22())) - #r = 1.0/(abs(dt.m12()) + abs(dt.m22())) - elif direction.y() == 0: - r = abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21())) - #r = 1.0/(abs(dt.m11()) + abs(dt.m21())) - else: - r = ((abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22()))) * (abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21()))))**0.5 - directionr = direction * r + ## Here's the problem: we need to map the vector 'direction' from the item to the device, via transform 'dt'. + ## In some extreme cases, this mapping can fail unless the length of 'direction' is cleverly chosen. + ## Example: + ## dt = [ 1, 0, 2 + ## 0, 2, 1e20 + ## 0, 0, 1 ] + ## Then we map the origin (0,0) and direction (0,1) and get: + ## o' = 2,1e20 + ## d' = 2,1e20 <-- should be 1e20+2, but this can't be represented with a 32-bit float + ## + ## |o' - d'| == 0 <-- this is the problem. + + ## Perhaps the easiest solution is to exclude the transformation column from dt. Does this cause any other problems? + + #if direction.x() == 0: + #r = abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22())) + ##r = 1.0/(abs(dt.m12()) + abs(dt.m22())) + #elif direction.y() == 0: + #r = abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21())) + ##r = 1.0/(abs(dt.m11()) + abs(dt.m21())) + #else: + #r = ((abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22()))) * (abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21()))))**0.5 + #if r == 0: + #r = 1. ## shouldn't need to do this; probably means the math above is wrong? + #directionr = direction * r + directionr = direction ## map direction vector onto device #viewDir = Point(dt.map(directionr) - dt.map(Point(0,0))) diff --git a/graphicsItems/PlotDataItem.py b/graphicsItems/PlotDataItem.py index 8e6162f2..83afbbfe 100644 --- a/graphicsItems/PlotDataItem.py +++ b/graphicsItems/PlotDataItem.py @@ -104,6 +104,7 @@ class PlotDataItem(GraphicsObject): self.yData = None self.xDisp = None self.yDisp = None + self.dataMask = None #self.curves = [] #self.scatters = [] self.curve = PlotCurveItem() @@ -393,6 +394,7 @@ class PlotDataItem(GraphicsObject): scatterArgs[v] = self.opts[k] x,y = self.getData() + scatterArgs['mask'] = self.dataMask if curveArgs['pen'] is not None or (curveArgs['brush'] is not None and curveArgs['fillLevel'] is not None): self.curve.setData(x=x, y=y, **curveArgs) @@ -413,11 +415,15 @@ class PlotDataItem(GraphicsObject): if self.xDisp is None: nanMask = np.isnan(self.xData) | np.isnan(self.yData) | np.isinf(self.xData) | np.isinf(self.yData) if any(nanMask): - x = self.xData[~nanMask] - y = self.yData[~nanMask] + self.dataMask = ~nanMask + x = self.xData[self.dataMask] + y = self.yData[self.dataMask] else: + self.dataMask = None x = self.xData y = self.yData + + ds = self.opts['downsample'] if ds > 1: x = x[::ds] @@ -435,8 +441,11 @@ class PlotDataItem(GraphicsObject): if any(self.opts['logMode']): ## re-check for NANs after log nanMask = np.isinf(x) | np.isinf(y) | np.isnan(x) | np.isnan(y) if any(nanMask): - x = x[~nanMask] - y = y[~nanMask] + self.dataMask = ~nanMask + x = x[self.dataMask] + y = y[self.dataMask] + else: + self.dataMask = None self.xDisp = x self.yDisp = y #print self.yDisp.shape, self.yDisp.min(), self.yDisp.max() diff --git a/graphicsItems/PlotItem/PlotItem.py b/graphicsItems/PlotItem/PlotItem.py index 63b4bf03..3100087a 100644 --- a/graphicsItems/PlotItem/PlotItem.py +++ b/graphicsItems/PlotItem/PlotItem.py @@ -36,6 +36,7 @@ from .. LabelItem import LabelItem from .. LegendItem import LegendItem from .. GraphicsWidget import GraphicsWidget from .. ButtonItem import ButtonItem +from .. InfiniteLine import InfiniteLine from pyqtgraph.WidgetGroup import WidgetGroup __all__ = ['PlotItem'] @@ -548,10 +549,35 @@ class PlotItem(GraphicsWidget): print("PlotItem.addDataItem is deprecated. Use addItem instead.") self.addItem(item, *args) + def listDataItems(self): + """Return a list of all data items (PlotDataItem, PlotCurveItem, ScatterPlotItem, etc) + contained in this PlotItem.""" + return self.dataItems[:] + def addCurve(self, c, params=None): print("PlotItem.addCurve is deprecated. Use addItem instead.") self.addItem(c, params) + def addLine(self, x=None, y=None, z=None, **kwds): + """ + Create an InfiniteLine and add to the plot. + + If *x* is specified, + the line will be vertical. If *y* is specified, the line will be + horizontal. All extra keyword arguments are passed to + :func:`InfiniteLine.__init__() `. + Returns the item created. + """ + angle = 0 if x is None else 90 + pos = x if x is not None else y + line = InfiniteLine(pos, angle, **kwds) + self.addItem(line) + if z is not None: + line.setZValue(z) + return line + + + def removeItem(self, item): """ Remove an item from the internal ViewBox. diff --git a/graphicsItems/ScatterPlotItem.py b/graphicsItems/ScatterPlotItem.py index 45adcf4d..5af82a00 100644 --- a/graphicsItems/ScatterPlotItem.py +++ b/graphicsItems/ScatterPlotItem.py @@ -384,7 +384,7 @@ class ScatterPlotItem(GraphicsObject): for k in ['pen', 'brush', 'symbol', 'size']: if k in kargs: setMethod = getattr(self, 'set' + k[0].upper() + k[1:]) - setMethod(kargs[k], update=False, dataSet=newData) + setMethod(kargs[k], update=False, dataSet=newData, mask=kargs.get('mask', None)) if 'data' in kargs: self.setPointData(kargs['data'], dataSet=newData) @@ -425,6 +425,8 @@ class ScatterPlotItem(GraphicsObject): if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)): pens = args[0] + if kargs['mask'] is not None: + pens = pens[kargs['mask']] if len(pens) != len(dataSet): raise Exception("Number of pens does not match number of points (%d != %d)" % (len(pens), len(dataSet))) dataSet['pen'] = pens @@ -445,6 +447,8 @@ class ScatterPlotItem(GraphicsObject): if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)): brushes = args[0] + if kargs['mask'] is not None: + brushes = brushes[kargs['mask']] if len(brushes) != len(dataSet): raise Exception("Number of brushes does not match number of points (%d != %d)" % (len(brushes), len(dataSet))) #for i in xrange(len(brushes)): @@ -458,7 +462,7 @@ class ScatterPlotItem(GraphicsObject): if update: self.updateSpots(dataSet) - def setSymbol(self, symbol, update=True, dataSet=None): + def setSymbol(self, symbol, update=True, dataSet=None, mask=None): """Set the symbol(s) used to draw each spot. If a list or array is provided, then the symbol for each spot will be set separately. Otherwise, the argument will be used as the default symbol for @@ -468,6 +472,8 @@ class ScatterPlotItem(GraphicsObject): if isinstance(symbol, np.ndarray) or isinstance(symbol, list): symbols = symbol + if kargs['mask'] is not None: + symbols = symbols[kargs['mask']] if len(symbols) != len(dataSet): raise Exception("Number of symbols does not match number of points (%d != %d)" % (len(symbols), len(dataSet))) dataSet['symbol'] = symbols @@ -479,7 +485,7 @@ class ScatterPlotItem(GraphicsObject): if update: self.updateSpots(dataSet) - def setSize(self, size, update=True, dataSet=None): + def setSize(self, size, update=True, dataSet=None, mask=None): """Set the size(s) used to draw each spot. If a list or array is provided, then the size for each spot will be set separately. Otherwise, the argument will be used as the default size for @@ -489,6 +495,8 @@ class ScatterPlotItem(GraphicsObject): if isinstance(size, np.ndarray) or isinstance(size, list): sizes = size + if kargs['mask'] is not None: + sizes = sizes[kargs['mask']] if len(sizes) != len(dataSet): raise Exception("Number of sizes does not match number of points (%d != %d)" % (len(sizes), len(dataSet))) dataSet['size'] = sizes @@ -505,6 +513,8 @@ class ScatterPlotItem(GraphicsObject): dataSet = self.data if isinstance(data, np.ndarray) or isinstance(data, list): + if kargs['mask'] is not None: + data = data[kargs['mask']] if len(data) != len(dataSet): raise Exception("Length of meta data does not match number of points (%d != %d)" % (len(data), len(dataSet))) @@ -881,7 +891,7 @@ class SpotItem(object): def updateItem(self): self._data['fragCoords'] = None - self._plot.updateSpots([self._data]) + self._plot.updateSpots(self._data.reshape(1)) self._plot.invalidate() #class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem): diff --git a/graphicsItems/ViewBox/ViewBox.py b/graphicsItems/ViewBox/ViewBox.py index 8b4ba2af..37f21182 100644 --- a/graphicsItems/ViewBox/ViewBox.py +++ b/graphicsItems/ViewBox/ViewBox.py @@ -576,9 +576,12 @@ class ViewBox(GraphicsWidget): w2 = (targetRect[ax][1]-targetRect[ax][0]) / 2. childRange[ax] = [x-w2, x+w2] else: - wp = (xr[1] - xr[0]) * 0.02 - childRange[ax][0] -= wp - childRange[ax][1] += wp + l = self.width() if ax==0 else self.height() + if l > 0: + padding = np.clip(1./(l**0.5), 0.02, 0.1) + wp = (xr[1] - xr[0]) * padding + childRange[ax][0] -= wp + childRange[ax][1] += wp targetRect[ax] = childRange[ax] args['xRange' if ax == 0 else 'yRange'] = targetRect[ax] if len(args) == 0: diff --git a/parametertree/Parameter.py b/parametertree/Parameter.py index f7da0dbe..c8e19f16 100644 --- a/parametertree/Parameter.py +++ b/parametertree/Parameter.py @@ -525,8 +525,9 @@ class Parameter(QtCore.QObject): self.removeChild(ch) def children(self): - """Return a list of this parameter's children.""" - ## warning -- this overrides QObject.children + """Return a list of this parameter's children. + Warning: this overrides QObject.children + """ return self.childs[:] def hasChildren(self): @@ -608,13 +609,13 @@ class Parameter(QtCore.QObject): def __getattr__(self, attr): ## Leaving this undocumented because I might like to remove it in the future.. #print type(self), attr - import traceback - traceback.print_stack() - print "Warning: Use of Parameter.subParam is deprecated. Use Parameter.param(name) instead." if 'names' not in self.__dict__: raise AttributeError(attr) if attr in self.names: + import traceback + traceback.print_stack() + print "Warning: Use of Parameter.subParam is deprecated. Use Parameter.param(name) instead." return self.param(attr) else: raise AttributeError(attr) diff --git a/parametertree/parameterTypes.py b/parametertree/parameterTypes.py index 3aab5a6d..84db9f06 100644 --- a/parametertree/parameterTypes.py +++ b/parametertree/parameterTypes.py @@ -4,6 +4,7 @@ from .Parameter import Parameter, registerParameterType from .ParameterItem import ParameterItem from pyqtgraph.widgets.SpinBox import SpinBox from pyqtgraph.widgets.ColorButton import ColorButton +#from pyqtgraph.widgets.GradientWidget import GradientWidget ## creates import loop import pyqtgraph as pg import pyqtgraph.pixmaps as pixmaps import os @@ -61,7 +62,11 @@ class WidgetParameterItem(ParameterItem): w.sigChanging.connect(self.widgetValueChanging) ## update value shown in widget. - self.valueChanged(self, opts['value'], force=True) + if opts.get('value', None) is not None: + self.valueChanged(self, opts['value'], force=True) + else: + ## no starting value was given; use whatever the widget has + self.widgetValueChanged() def makeWidget(self): @@ -125,6 +130,14 @@ class WidgetParameterItem(ParameterItem): w.setValue = w.setColor self.hideWidget = False w.setFlat(True) + elif t == 'colormap': + from pyqtgraph.widgets.GradientWidget import GradientWidget ## need this here to avoid import loop + w = GradientWidget(orientation='bottom') + w.sigChanged = w.sigGradientChangeFinished + w.sigChanging = w.sigGradientChanged + w.value = w.colorMap + w.setValue = w.setColorMap + self.hideWidget = False else: raise Exception("Unknown type '%s'" % asUnicode(t)) return w @@ -294,6 +307,7 @@ registerParameterType('float', SimpleParameter, override=True) registerParameterType('bool', SimpleParameter, override=True) registerParameterType('str', SimpleParameter, override=True) registerParameterType('color', SimpleParameter, override=True) +registerParameterType('colormap', SimpleParameter, override=True) diff --git a/rebuildUi.py b/rebuildUi.py index 92d5991a..1e4cbf9c 100644 --- a/rebuildUi.py +++ b/rebuildUi.py @@ -13,11 +13,11 @@ for path, sd, files in os.walk('.'): ui = os.path.join(path, f) py = os.path.join(path, base + '_pyqt.py') - if os.stat(ui).st_mtime > os.stat(py).st_mtime: + if not os.path.exists(py) or os.stat(ui).st_mtime > os.stat(py).st_mtime: os.system('%s %s > %s' % (pyqtuic, ui, py)) print(py) py = os.path.join(path, base + '_pyside.py') - if os.stat(ui).st_mtime > os.stat(py).st_mtime: + if not os.path.exists(py) or os.stat(ui).st_mtime > os.stat(py).st_mtime: os.system('%s %s > %s' % (pysideuic, ui, py)) print(py) diff --git a/widgets/ColorButton.py b/widgets/ColorButton.py index fafe2ae7..ee91801a 100644 --- a/widgets/ColorButton.py +++ b/widgets/ColorButton.py @@ -77,8 +77,14 @@ class ColorButton(QtGui.QPushButton): def restoreState(self, state): self.setColor(state) - def color(self): - return functions.mkColor(self._color) + def color(self, mode='qcolor'): + color = functions.mkColor(self._color) + if mode == 'qcolor': + return color + elif mode == 'byte': + return (color.red(), color.green(), color.blue(), color.alpha()) + elif mode == 'float': + return (color.red()/255., color.green()/255., color.blue()/255., color.alpha()/255.) def widgetGroupInterface(self): return (self.sigColorChanged, ColorButton.saveState, ColorButton.restoreState) diff --git a/widgets/ColorMapWidget.py b/widgets/ColorMapWidget.py new file mode 100644 index 00000000..69a5e10a --- /dev/null +++ b/widgets/ColorMapWidget.py @@ -0,0 +1,173 @@ +from pyqtgraph.Qt import QtGui, QtCore +import pyqtgraph.parametertree as ptree +import numpy as np +from pyqtgraph.pgcollections import OrderedDict +import pyqtgraph.functions as fn + +__all__ = ['ColorMapWidget'] + +class ColorMapWidget(ptree.ParameterTree): + """ + This class provides a widget allowing the user to customize color mapping + for multi-column data. + """ + + sigColorMapChanged = QtCore.Signal(object) + + def __init__(self): + ptree.ParameterTree.__init__(self, showHeader=False) + + self.params = ColorMapParameter() + self.setParameters(self.params) + self.params.sigTreeStateChanged.connect(self.mapChanged) + + ## wrap a couple methods + self.setFields = self.params.setFields + self.map = self.params.map + + def mapChanged(self): + self.sigColorMapChanged.emit(self) + + +class ColorMapParameter(ptree.types.GroupParameter): + sigColorMapChanged = QtCore.Signal(object) + + def __init__(self): + self.fields = {} + ptree.types.GroupParameter.__init__(self, name='Color Map', addText='Add Mapping..', addList=[]) + self.sigTreeStateChanged.connect(self.mapChanged) + + def mapChanged(self): + self.sigColorMapChanged.emit(self) + + def addNew(self, name): + mode = self.fields[name].get('mode', 'range') + if mode == 'range': + self.addChild(RangeColorMapItem(name, self.fields[name])) + elif mode == 'enum': + self.addChild(EnumColorMapItem(name, self.fields[name])) + + def fieldNames(self): + return self.fields.keys() + + def setFields(self, fields): + self.fields = OrderedDict(fields) + #self.fields = fields + #self.fields.sort() + names = self.fieldNames() + self.setAddList(names) + + def map(self, data, mode='byte'): + colors = np.zeros((len(data),4)) + for item in self.children(): + if not item['Enabled']: + continue + chans = item.param('Channels..') + mask = np.empty((len(data), 4), dtype=bool) + for i,f in enumerate(['Red', 'Green', 'Blue', 'Alpha']): + mask[:,i] = chans[f] + + colors2 = item.map(data) + + op = item['Operation'] + if op == 'Add': + colors[mask] = colors[mask] + colors2[mask] + elif op == 'Multiply': + colors[mask] *= colors2[mask] + elif op == 'Overlay': + a = colors2[:,3:4] + c3 = colors * (1-a) + colors2 * a + c3[:,3:4] = colors[:,3:4] + (1-colors[:,3:4]) * a + colors = c3 + elif op == 'Set': + colors[mask] = colors2[mask] + + + colors = np.clip(colors, 0, 1) + if mode == 'byte': + colors = (colors * 255).astype(np.ubyte) + + return colors + + +class RangeColorMapItem(ptree.types.SimpleParameter): + def __init__(self, name, opts): + self.fieldName = name + units = opts.get('units', '') + ptree.types.SimpleParameter.__init__(self, + name=name, autoIncrementName=True, type='colormap', removable=True, renamable=True, + children=[ + #dict(name="Field", type='list', value=name, values=fields), + dict(name='Min', type='float', value=0.0, suffix=units, siPrefix=True), + dict(name='Max', type='float', value=1.0, suffix=units, siPrefix=True), + dict(name='Operation', type='list', value='Overlay', values=['Overlay', 'Add', 'Multiply', 'Set']), + dict(name='Channels..', type='group', expanded=False, children=[ + dict(name='Red', type='bool', value=True), + dict(name='Green', type='bool', value=True), + dict(name='Blue', type='bool', value=True), + dict(name='Alpha', type='bool', value=True), + ]), + dict(name='Enabled', type='bool', value=True), + dict(name='NaN', type='color'), + ]) + + def map(self, data): + data = data[self.fieldName] + + + + scaled = np.clip((data-self['Min']) / (self['Max']-self['Min']), 0, 1) + cmap = self.value() + colors = cmap.map(scaled, mode='float') + + mask = np.isnan(data) | np.isinf(data) + nanColor = self['NaN'] + nanColor = (nanColor.red()/255., nanColor.green()/255., nanColor.blue()/255., nanColor.alpha()/255.) + colors[mask] = nanColor + + return colors + + +class EnumColorMapItem(ptree.types.GroupParameter): + def __init__(self, name, opts): + self.fieldName = name + vals = opts.get('values', []) + childs = [{'name': v, 'type': 'color'} for v in vals] + ptree.types.GroupParameter.__init__(self, + name=name, autoIncrementName=True, removable=True, renamable=True, + children=[ + dict(name='Values', type='group', children=childs), + dict(name='Operation', type='list', value='Overlay', values=['Overlay', 'Add', 'Multiply', 'Set']), + dict(name='Channels..', type='group', expanded=False, children=[ + dict(name='Red', type='bool', value=True), + dict(name='Green', type='bool', value=True), + dict(name='Blue', type='bool', value=True), + dict(name='Alpha', type='bool', value=True), + ]), + dict(name='Enabled', type='bool', value=True), + dict(name='Default', type='color'), + ]) + + def map(self, data): + data = data[self.fieldName] + colors = np.empty((len(data), 4)) + default = np.array(fn.colorTuple(self['Default'])) / 255. + colors[:] = default + + for v in self.param('Values'): + n = v.name() + mask = data == n + c = np.array(fn.colorTuple(v.value())) / 255. + colors[mask] = c + #scaled = np.clip((data-self['Min']) / (self['Max']-self['Min']), 0, 1) + #cmap = self.value() + #colors = cmap.map(scaled, mode='float') + + #mask = np.isnan(data) | np.isinf(data) + #nanColor = self['NaN'] + #nanColor = (nanColor.red()/255., nanColor.green()/255., nanColor.blue()/255., nanColor.alpha()/255.) + #colors[mask] = nanColor + + return colors + + diff --git a/widgets/DataFilterWidget.py b/widgets/DataFilterWidget.py new file mode 100644 index 00000000..a2e1a7b8 --- /dev/null +++ b/widgets/DataFilterWidget.py @@ -0,0 +1,115 @@ +from pyqtgraph.Qt import QtGui, QtCore +import pyqtgraph.parametertree as ptree +import numpy as np +from pyqtgraph.pgcollections import OrderedDict + +__all__ = ['DataFilterWidget'] + +class DataFilterWidget(ptree.ParameterTree): + """ + This class allows the user to filter multi-column data sets by specifying + multiple criteria + """ + + sigFilterChanged = QtCore.Signal(object) + + def __init__(self): + ptree.ParameterTree.__init__(self, showHeader=False) + self.params = DataFilterParameter() + + self.setParameters(self.params) + self.params.sigTreeStateChanged.connect(self.filterChanged) + + self.setFields = self.params.setFields + self.filterData = self.params.filterData + + def filterChanged(self): + self.sigFilterChanged.emit(self) + + def parameters(self): + return self.params + + +class DataFilterParameter(ptree.types.GroupParameter): + + sigFilterChanged = QtCore.Signal(object) + + def __init__(self): + self.fields = {} + ptree.types.GroupParameter.__init__(self, name='Data Filter', addText='Add filter..', addList=[]) + self.sigTreeStateChanged.connect(self.filterChanged) + + def filterChanged(self): + self.sigFilterChanged.emit(self) + + def addNew(self, name): + mode = self.fields[name].get('mode', 'range') + if mode == 'range': + self.addChild(RangeFilterItem(name, self.fields[name])) + elif mode == 'enum': + self.addChild(EnumFilterItem(name, self.fields[name])) + + + def fieldNames(self): + return self.fields.keys() + + def setFields(self, fields): + self.fields = OrderedDict(fields) + names = self.fieldNames() + self.setAddList(names) + + def filterData(self, data): + if len(data) == 0: + return data + return data[self.generateMask(data)] + + def generateMask(self, data): + mask = np.ones(len(data), dtype=bool) + if len(data) == 0: + return mask + for fp in self: + if fp.value() is False: + continue + mask &= fp.generateMask(data) + #key, mn, mx = fp.fieldName, fp['Min'], fp['Max'] + + #vals = data[key] + #mask &= (vals >= mn) + #mask &= (vals < mx) ## Use inclusive minimum and non-inclusive maximum. This makes it easier to create non-overlapping selections + return mask + +class RangeFilterItem(ptree.types.SimpleParameter): + def __init__(self, name, opts): + self.fieldName = name + units = opts.get('units', '') + ptree.types.SimpleParameter.__init__(self, + name=name, autoIncrementName=True, type='bool', value=True, removable=True, renamable=True, + children=[ + #dict(name="Field", type='list', value=name, values=fields), + dict(name='Min', type='float', value=0.0, suffix=units, siPrefix=True), + dict(name='Max', type='float', value=1.0, suffix=units, siPrefix=True), + ]) + + def generateMask(self, data): + vals = data[self.fieldName] + return (vals >= mn) & (vals < mx) ## Use inclusive minimum and non-inclusive maximum. This makes it easier to create non-overlapping selections + + +class EnumFilterItem(ptree.types.SimpleParameter): + def __init__(self, name, opts): + self.fieldName = name + vals = opts.get('values', []) + childs = [{'name': v, 'type': 'bool', 'value': True} for v in vals] + ptree.types.SimpleParameter.__init__(self, + name=name, autoIncrementName=True, type='bool', value=True, removable=True, renamable=True, + children=childs) + + def generateMask(self, data): + vals = data[self.fieldName] + mask = np.ones(len(data), dtype=bool) + for c in self: + if c.value() is True: + continue + key = c.name() + mask &= vals != key + return mask diff --git a/widgets/ScatterPlotWidget.py b/widgets/ScatterPlotWidget.py new file mode 100644 index 00000000..85f5489a --- /dev/null +++ b/widgets/ScatterPlotWidget.py @@ -0,0 +1,183 @@ +from pyqtgraph.Qt import QtGui, QtCore +from .PlotWidget import PlotWidget +from .DataFilterWidget import DataFilterParameter +from .ColorMapWidget import ColorMapParameter +import pyqtgraph.parametertree as ptree +import pyqtgraph.functions as fn +import numpy as np +from pyqtgraph.pgcollections import OrderedDict + +__all__ = ['ScatterPlotWidget'] + +class ScatterPlotWidget(QtGui.QSplitter): + """ + Given a record array, display a scatter plot of a specific set of data. + This widget includes controls for selecting the columns to plot, + filtering data, and determining symbol color and shape. This widget allows + the user to explore relationships between columns in a record array. + + The widget consists of four components: + + 1) A list of column names from which the user may select 1 or 2 columns + to plot. If one column is selected, the data for that column will be + plotted in a histogram-like manner by using :func:`pseudoScatter() + `. If two columns are selected, then the + scatter plot will be generated with x determined by the first column + that was selected and y by the second. + 2) A DataFilter that allows the user to select a subset of the data by + specifying multiple selection criteria. + 3) A ColorMap that allows the user to determine how points are colored by + specifying multiple criteria. + 4) A PlotWidget for displaying the data. + """ + def __init__(self, parent=None): + QtGui.QSplitter.__init__(self, QtCore.Qt.Horizontal) + self.ctrlPanel = QtGui.QSplitter(QtCore.Qt.Vertical) + self.addWidget(self.ctrlPanel) + self.fieldList = QtGui.QListWidget() + self.fieldList.setSelectionMode(self.fieldList.ExtendedSelection) + self.ptree = ptree.ParameterTree(showHeader=False) + self.filter = DataFilterParameter() + self.colorMap = ColorMapParameter() + self.params = ptree.Parameter.create(name='params', type='group', children=[self.filter, self.colorMap]) + self.ptree.setParameters(self.params, showTop=False) + + self.plot = PlotWidget() + self.ctrlPanel.addWidget(self.fieldList) + self.ctrlPanel.addWidget(self.ptree) + self.addWidget(self.plot) + + self.data = None + self.style = dict(pen=None, symbol='o') + + self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged) + self.filter.sigFilterChanged.connect(self.filterChanged) + self.colorMap.sigColorMapChanged.connect(self.updatePlot) + + def setFields(self, fields): + """ + Set the list of field names/units to be processed. + Format is: [(name, units), ...] + """ + self.fields = OrderedDict(fields) + self.fieldList.clear() + for f,opts in fields: + item = QtGui.QListWidgetItem(f) + item.opts = opts + item = self.fieldList.addItem(item) + self.filter.setFields(fields) + self.colorMap.setFields(fields) + + def setData(self, data): + """ + Set the data to be processed and displayed. + Argument must be a numpy record array. + """ + self.data = data + self.filtered = None + self.updatePlot() + + def fieldSelectionChanged(self): + sel = self.fieldList.selectedItems() + if len(sel) > 2: + self.fieldList.blockSignals(True) + try: + for item in sel[1:-1]: + item.setSelected(False) + finally: + self.fieldList.blockSignals(False) + + self.updatePlot() + + def filterChanged(self, f): + self.filtered = None + self.updatePlot() + + def updatePlot(self): + self.plot.clear() + if self.data is None: + return + + if self.filtered is None: + self.filtered = self.filter.filterData(self.data) + data = self.filtered + if len(data) == 0: + return + + colors = np.array([fn.mkBrush(*x) for x in self.colorMap.map(data)]) + + style = self.style.copy() + + ## Look up selected columns and units + sel = list([str(item.text()) for item in self.fieldList.selectedItems()]) + units = list([item.opts.get('units', '') for item in self.fieldList.selectedItems()]) + if len(sel) == 0: + self.plot.setTitle('') + return + + + if len(sel) == 1: + self.plot.setLabels(left=('N', ''), bottom=(sel[0], units[0]), title='') + if len(data) == 0: + return + x = data[sel[0]] + #if x.dtype.kind == 'f': + #mask = ~np.isnan(x) + #else: + #mask = np.ones(len(x), dtype=bool) + #x = x[mask] + #style['symbolBrush'] = colors[mask] + y = None + elif len(sel) == 2: + self.plot.setLabels(left=(sel[1],units[1]), bottom=(sel[0],units[0])) + if len(data) == 0: + return + + xydata = [] + for ax in [0,1]: + d = data[sel[ax]] + ## scatter catecorical values just a bit so they show up better in the scatter plot. + #if sel[ax] in ['MorphologyBSMean', 'MorphologyTDMean', 'FIType']: + #d += np.random.normal(size=len(cells), scale=0.1) + xydata.append(d) + x,y = xydata + #mask = np.ones(len(x), dtype=bool) + #if x.dtype.kind == 'f': + #mask |= ~np.isnan(x) + #if y.dtype.kind == 'f': + #mask |= ~np.isnan(y) + #x = x[mask] + #y = y[mask] + #style['symbolBrush'] = colors[mask] + + ## convert enum-type fields to float, set axis labels + xy = [x,y] + for i in [0,1]: + axis = self.plot.getAxis(['bottom', 'left'][i]) + if xy[i] is not None and xy[i].dtype.kind in ('S', 'O'): + vals = self.fields[sel[i]].get('values', list(set(xy[i]))) + xy[i] = np.array([vals.index(x) if x in vals else None for x in xy[i]], dtype=float) + axis.setTicks([list(enumerate(vals))]) + else: + axis.setTicks(None) # reset to automatic ticking + x,y = xy + + ## mask out any nan values + mask = np.ones(len(x), dtype=bool) + if x.dtype.kind == 'f': + mask &= ~np.isnan(x) + if y is not None and y.dtype.kind == 'f': + mask &= ~np.isnan(y) + x = x[mask] + style['symbolBrush'] = colors[mask] + + ## Scatter y-values for a histogram-like appearance + if y is None: + y = fn.pseudoScatter(x) + else: + y = y[mask] + + + self.plot.plot(x, y, **style) + +