Merge new fixes and features from acq4

This commit is contained in:
Luke Campagnola 2013-02-10 14:10:30 -05:00
parent c7574f9adc
commit 17409bc9a6
24 changed files with 1023 additions and 51 deletions

View File

@ -27,6 +27,7 @@ class ExportDialog(QtGui.QWidget):
self.ui.closeBtn.clicked.connect(self.close) self.ui.closeBtn.clicked.connect(self.close)
self.ui.exportBtn.clicked.connect(self.exportClicked) self.ui.exportBtn.clicked.connect(self.exportClicked)
self.ui.copyBtn.clicked.connect(self.copyClicked)
self.ui.itemTree.currentItemChanged.connect(self.exportItemChanged) self.ui.itemTree.currentItemChanged.connect(self.exportItemChanged)
self.ui.formatList.currentItemChanged.connect(self.exportFormatChanged) self.ui.formatList.currentItemChanged.connect(self.exportFormatChanged)
@ -116,11 +117,16 @@ class ExportDialog(QtGui.QWidget):
else: else:
self.ui.paramTree.setParameters(params) self.ui.paramTree.setParameters(params)
self.currentExporter = exp self.currentExporter = exp
self.ui.copyBtn.setEnabled(exp.allowCopy)
def exportClicked(self): def exportClicked(self):
self.selectBox.hide() self.selectBox.hide()
self.currentExporter.export() self.currentExporter.export()
def copyClicked(self):
self.selectBox.hide()
self.currentExporter.export(copy=True)
def close(self): def close(self):
self.selectBox.setVisible(False) self.selectBox.setVisible(False)
self.setVisible(False) self.setVisible(False)

View File

@ -79,6 +79,13 @@
</property> </property>
</widget> </widget>
</item> </item>
<item row="6" column="0">
<widget class="QPushButton" name="copyBtn">
<property name="text">
<string>Copy</string>
</property>
</widget>
</item>
</layout> </layout>
</widget> </widget>
<customwidgets> <customwidgets>

View File

@ -2,8 +2,8 @@
# Form implementation generated from reading ui file './GraphicsScene/exportDialogTemplate.ui' # Form implementation generated from reading ui file './GraphicsScene/exportDialogTemplate.ui'
# #
# Created: Sun Sep 9 14:41:31 2012 # Created: Wed Jan 30 21:02:28 2013
# by: PyQt4 UI code generator 4.9.1 # by: PyQt4 UI code generator 4.9.3
# #
# WARNING! All changes made in this file will be lost! # WARNING! All changes made in this file will be lost!
@ -49,6 +49,9 @@ class Ui_Form(object):
self.label_3 = QtGui.QLabel(Form) self.label_3 = QtGui.QLabel(Form)
self.label_3.setObjectName(_fromUtf8("label_3")) self.label_3.setObjectName(_fromUtf8("label_3"))
self.gridLayout.addWidget(self.label_3, 4, 0, 1, 3) self.gridLayout.addWidget(self.label_3, 4, 0, 1, 3)
self.copyBtn = QtGui.QPushButton(Form)
self.copyBtn.setObjectName(_fromUtf8("copyBtn"))
self.gridLayout.addWidget(self.copyBtn, 6, 0, 1, 1)
self.retranslateUi(Form) self.retranslateUi(Form)
QtCore.QMetaObject.connectSlotsByName(Form) QtCore.QMetaObject.connectSlotsByName(Form)
@ -60,5 +63,6 @@ class Ui_Form(object):
self.exportBtn.setText(QtGui.QApplication.translate("Form", "Export", None, QtGui.QApplication.UnicodeUTF8)) self.exportBtn.setText(QtGui.QApplication.translate("Form", "Export", None, QtGui.QApplication.UnicodeUTF8))
self.closeBtn.setText(QtGui.QApplication.translate("Form", "Close", None, QtGui.QApplication.UnicodeUTF8)) self.closeBtn.setText(QtGui.QApplication.translate("Form", "Close", None, QtGui.QApplication.UnicodeUTF8))
self.label_3.setText(QtGui.QApplication.translate("Form", "Export options", None, QtGui.QApplication.UnicodeUTF8)) self.label_3.setText(QtGui.QApplication.translate("Form", "Export options", None, QtGui.QApplication.UnicodeUTF8))
self.copyBtn.setText(QtGui.QApplication.translate("Form", "Copy", None, QtGui.QApplication.UnicodeUTF8))
from pyqtgraph.parametertree import ParameterTree from pyqtgraph.parametertree import ParameterTree

View File

@ -2,8 +2,8 @@
# Form implementation generated from reading ui file './GraphicsScene/exportDialogTemplate.ui' # Form implementation generated from reading ui file './GraphicsScene/exportDialogTemplate.ui'
# #
# Created: Sun Sep 9 14:41:31 2012 # Created: Wed Jan 30 21:02:28 2013
# by: pyside-uic 0.2.13 running on PySide 1.1.0 # by: pyside-uic 0.2.13 running on PySide 1.1.1
# #
# WARNING! All changes made in this file will be lost! # WARNING! All changes made in this file will be lost!
@ -44,6 +44,9 @@ class Ui_Form(object):
self.label_3 = QtGui.QLabel(Form) self.label_3 = QtGui.QLabel(Form)
self.label_3.setObjectName("label_3") self.label_3.setObjectName("label_3")
self.gridLayout.addWidget(self.label_3, 4, 0, 1, 3) self.gridLayout.addWidget(self.label_3, 4, 0, 1, 3)
self.copyBtn = QtGui.QPushButton(Form)
self.copyBtn.setObjectName("copyBtn")
self.gridLayout.addWidget(self.copyBtn, 6, 0, 1, 1)
self.retranslateUi(Form) self.retranslateUi(Form)
QtCore.QMetaObject.connectSlotsByName(Form) QtCore.QMetaObject.connectSlotsByName(Form)
@ -55,5 +58,6 @@ class Ui_Form(object):
self.exportBtn.setText(QtGui.QApplication.translate("Form", "Export", None, QtGui.QApplication.UnicodeUTF8)) self.exportBtn.setText(QtGui.QApplication.translate("Form", "Export", None, QtGui.QApplication.UnicodeUTF8))
self.closeBtn.setText(QtGui.QApplication.translate("Form", "Close", None, QtGui.QApplication.UnicodeUTF8)) self.closeBtn.setText(QtGui.QApplication.translate("Form", "Close", None, QtGui.QApplication.UnicodeUTF8))
self.label_3.setText(QtGui.QApplication.translate("Form", "Export options", None, QtGui.QApplication.UnicodeUTF8)) self.label_3.setText(QtGui.QApplication.translate("Form", "Export options", None, QtGui.QApplication.UnicodeUTF8))
self.copyBtn.setText(QtGui.QApplication.translate("Form", "Copy", None, QtGui.QApplication.UnicodeUTF8))
from pyqtgraph.parametertree import ParameterTree from pyqtgraph.parametertree import ParameterTree

55
PlotData.py Normal file
View File

@ -0,0 +1,55 @@
class PlotData(object):
"""
Class used for managing plot data
- allows data sharing between multiple graphics items (curve, scatter, graph..)
- each item may define the columns it needs
- column groupings ('pos' or x, y, z)
- efficiently appendable
- log, fft transformations
- color mode conversion (float/byte/qcolor)
- pen/brush conversion
- per-field cached masking
- allows multiple masking fields (different graphics need to mask on different criteria)
- removal of nan/inf values
- option for single value shared by entire column
- cached downsampling
"""
def __init__(self):
self.fields = {}
self.maxVals = {} ## cache for max/min
self.minVals = {}
def addFields(self, fields):
for f in fields:
if f not in self.fields:
self.fields[f] = None
def hasField(self, f):
return f in self.fields
def __getitem__(self, field):
return self.fields[field]
def __setitem__(self, field, val):
self.fields[field] = val
def max(self, field):
mx = self.maxVals.get(field, None)
if mx is None:
mx = np.max(self[field])
self.maxVals[field] = mx
return mx
def min(self, field):
mn = self.minVals.get(field, None)
if mn is None:
mn = np.min(self[field])
self.minVals[field] = mn
return mn

262
colormap.py Normal file
View File

@ -0,0 +1,262 @@
import numpy as np
import scipy.interpolate
from pyqtgraph.Qt import QtGui, QtCore
class ColorMap(object):
## color interpolation modes
RGB = 1
HSV_POS = 2
HSV_NEG = 3
## boundary modes
CLIP = 1
REPEAT = 2
MIRROR = 3
## return types
BYTE = 1
FLOAT = 2
QCOLOR = 3
enumMap = {
'rgb': RGB,
'hsv+': HSV_POS,
'hsv-': HSV_NEG,
'clip': CLIP,
'repeat': REPEAT,
'mirror': MIRROR,
'byte': BYTE,
'float': FLOAT,
'qcolor': QCOLOR,
}
def __init__(self, pos, color, mode=None):
"""
========= ==============================================================
Arguments
pos Array of positions where each color is defined
color Array of RGBA colors.
Integer data types are interpreted as 0-255; float data types
are interpreted as 0.0-1.0
mode Array of color modes (ColorMap.RGB, HSV_POS, or HSV_NEG)
indicating the color space that should be used when
interpolating between stops. Note that the last mode value is
ignored. By default, the mode is entirely RGB.
========= ==============================================================
"""
self.pos = pos
self.color = color
if mode is None:
mode = np.ones(len(pos))
self.mode = mode
self.stopsCache = {}
def map(self, data, mode='byte'):
"""
Data must be either a scalar position or an array (any shape) of positions.
"""
if isinstance(mode, basestring):
mode = self.enumMap[mode.lower()]
if mode == self.QCOLOR:
pos, color = self.getStops(self.BYTE)
else:
pos, color = self.getStops(mode)
data = np.clip(data, pos.min(), pos.max())
if not isinstance(data, np.ndarray):
interp = scipy.interpolate.griddata(pos, color, np.array([data]))[0]
else:
interp = scipy.interpolate.griddata(pos, color, data)
if mode == self.QCOLOR:
if not isinstance(data, np.ndarray):
return QtGui.QColor(*interp)
else:
return [QtGui.QColor(*x) for x in interp]
else:
return interp
def mapToQColor(self, data):
return self.map(data, mode=self.QCOLOR)
def mapToByte(self, data):
return self.map(data, mode=self.BYTE)
def mapToFloat(self, data):
return self.map(data, mode=self.FLOAT)
def getGradient(self, p1=None, p2=None):
"""Return a QLinearGradient object."""
if p1 == None:
p1 = QtCore.QPointF(0,0)
if p2 == None:
p2 = QtCore.QPointF(self.pos.max()-self.pos.min(),0)
g = QtGui.QLinearGradient(p1, p2)
pos, color = self.getStops(mode=self.BYTE)
color = [QtGui.QColor(*x) for x in color]
g.setStops(zip(pos, color))
#if self.colorMode == 'rgb':
#ticks = self.listTicks()
#g.setStops([(x, QtGui.QColor(t.color)) for t,x in ticks])
#elif self.colorMode == 'hsv': ## HSV mode is approximated for display by interpolating 10 points between each stop
#ticks = self.listTicks()
#stops = []
#stops.append((ticks[0][1], ticks[0][0].color))
#for i in range(1,len(ticks)):
#x1 = ticks[i-1][1]
#x2 = ticks[i][1]
#dx = (x2-x1) / 10.
#for j in range(1,10):
#x = x1 + dx*j
#stops.append((x, self.getColor(x)))
#stops.append((x2, self.getColor(x2)))
#g.setStops(stops)
return g
def getColors(self, mode=None):
"""Return list of all colors converted to the specified mode.
If mode is None, then no conversion is done."""
if isinstance(mode, basestring):
mode = self.enumMap[mode.lower()]
color = self.color
if mode in [self.BYTE, self.QCOLOR] and color.dtype.kind == 'f':
color = (color * 255).astype(np.ubyte)
elif mode == self.FLOAT and color.dtype.kind != 'f':
color = color.astype(float) / 255.
if mode == self.QCOLOR:
color = [QtGui.QColor(*x) for x in color]
return color
def getStops(self, mode):
## Get fully-expanded set of RGBA stops in either float or byte mode.
if mode not in self.stopsCache:
color = self.color
if mode == self.BYTE and color.dtype.kind == 'f':
color = (color * 255).astype(np.ubyte)
elif mode == self.FLOAT and color.dtype.kind != 'f':
color = color.astype(float) / 255.
## to support HSV mode, we need to do a little more work..
#stops = []
#for i in range(len(self.pos)):
#pos = self.pos[i]
#color = color[i]
#imode = self.mode[i]
#if imode == self.RGB:
#stops.append((x,color))
#else:
#ns =
self.stopsCache[mode] = (self.pos, color)
return self.stopsCache[mode]
#def getColor(self, x, toQColor=True):
#"""
#Return a color for a given value.
#============= ==================================================================
#**Arguments**
#x Value (position on gradient) of requested color.
#toQColor If true, returns a QColor object, else returns a (r,g,b,a) tuple.
#============= ==================================================================
#"""
#ticks = self.listTicks()
#if x <= ticks[0][1]:
#c = ticks[0][0].color
#if toQColor:
#return QtGui.QColor(c) # always copy colors before handing them out
#else:
#return (c.red(), c.green(), c.blue(), c.alpha())
#if x >= ticks[-1][1]:
#c = ticks[-1][0].color
#if toQColor:
#return QtGui.QColor(c) # always copy colors before handing them out
#else:
#return (c.red(), c.green(), c.blue(), c.alpha())
#x2 = ticks[0][1]
#for i in range(1,len(ticks)):
#x1 = x2
#x2 = ticks[i][1]
#if x1 <= x and x2 >= x:
#break
#dx = (x2-x1)
#if dx == 0:
#f = 0.
#else:
#f = (x-x1) / dx
#c1 = ticks[i-1][0].color
#c2 = ticks[i][0].color
#if self.colorMode == 'rgb':
#r = c1.red() * (1.-f) + c2.red() * f
#g = c1.green() * (1.-f) + c2.green() * f
#b = c1.blue() * (1.-f) + c2.blue() * f
#a = c1.alpha() * (1.-f) + c2.alpha() * f
#if toQColor:
#return QtGui.QColor(int(r), int(g), int(b), int(a))
#else:
#return (r,g,b,a)
#elif self.colorMode == 'hsv':
#h1,s1,v1,_ = c1.getHsv()
#h2,s2,v2,_ = c2.getHsv()
#h = h1 * (1.-f) + h2 * f
#s = s1 * (1.-f) + s2 * f
#v = v1 * (1.-f) + v2 * f
#c = QtGui.QColor()
#c.setHsv(h,s,v)
#if toQColor:
#return c
#else:
#return (c.red(), c.green(), c.blue(), c.alpha())
def getLookupTable(self, start=0.0, stop=1.0, nPts=512, alpha=None, mode='byte'):
"""
Return an RGB(A) lookup table (ndarray).
============= ============================================================================
**Arguments**
nPts The number of points in the returned lookup table.
alpha True, False, or None - Specifies whether or not alpha values are included
in the table. If alpha is None, it will be automatically determined.
============= ============================================================================
"""
if isinstance(mode, basestring):
mode = self.enumMap[mode.lower()]
if alpha is None:
alpha = self.usesAlpha()
x = np.linspace(start, stop, nPts)
table = self.map(x, mode)
if not alpha:
return table[:,:3]
else:
return table
def usesAlpha(self):
"""Return True if any stops have an alpha < 255"""
max = 1.0 if self.color.dtype.kind == 'f' else 255
return np.any(self.color[:,3] != max)
def isMapTrivial(self):
"""Return True if the gradient has exactly two stops in it: black at 0.0 and white at 1.0"""
if len(self.pos) != 2:
return False
if self.pos[0] != 0.0 or self.pos[1] != 1.0:
return False
if self.color.dtype.kind == 'f':
return np.all(self.color == np.array([[0.,0.,0.,1.], [1.,1.,1.,1.]]))
else:
return np.all(self.color == np.array([[0,0,0,255], [255,255,255,255]]))

View File

@ -917,3 +917,21 @@ def qObjectReport(verbose=False):
for t in typs: for t in typs:
print(count[t], "\t", t) print(count[t], "\t", t)
class PrintDetector(object):
def __init__(self):
self.stdout = sys.stdout
sys.stdout = self
def remove(self):
sys.stdout = self.stdout
def __del__(self):
self.remove()
def write(self, x):
self.stdout.write(x)
traceback.print_stack()
def flush(self):
self.stdout.flush()

View File

@ -9,7 +9,8 @@ class Exporter(object):
""" """
Abstract class used for exporting graphics to file / printer / whatever. Abstract class used for exporting graphics to file / printer / whatever.
""" """
allowCopy = False # subclasses set this to True if they can use the copy buffer
def __init__(self, item): def __init__(self, item):
""" """
Initialize with the item to be exported. Initialize with the item to be exported.
@ -25,10 +26,11 @@ class Exporter(object):
"""Return the parameters used to configure this exporter.""" """Return the parameters used to configure this exporter."""
raise Exception("Abstract method must be overridden in subclass.") raise Exception("Abstract method must be overridden in subclass.")
def export(self, fileName=None, toBytes=False): def export(self, fileName=None, toBytes=False, copy=False):
""" """
If *fileName* is None, pop-up a file dialog. If *fileName* is None, pop-up a file dialog.
If *toString* is True, return a bytes object rather than writing to file. If *toBytes* is True, return a bytes object rather than writing to file.
If *copy* is True, export to the copy buffer rather than writing to file.
""" """
raise Exception("Abstract method must be overridden in subclass.") raise Exception("Abstract method must be overridden in subclass.")

View File

@ -8,6 +8,8 @@ __all__ = ['ImageExporter']
class ImageExporter(Exporter): class ImageExporter(Exporter):
Name = "Image File (PNG, TIF, JPG, ...)" Name = "Image File (PNG, TIF, JPG, ...)"
allowCopy = True
def __init__(self, item): def __init__(self, item):
Exporter.__init__(self, item) Exporter.__init__(self, item)
tr = self.getTargetRect() tr = self.getTargetRect()
@ -38,8 +40,8 @@ class ImageExporter(Exporter):
def parameters(self): def parameters(self):
return self.params return self.params
def export(self, fileName=None): def export(self, fileName=None, toBytes=False, copy=False):
if fileName is None: if fileName is None and not toBytes and not copy:
filter = ["*."+str(f) for f in QtGui.QImageWriter.supportedImageFormats()] filter = ["*."+str(f) for f in QtGui.QImageWriter.supportedImageFormats()]
preferred = ['*.png', '*.tif', '*.jpg'] preferred = ['*.png', '*.tif', '*.jpg']
for p in preferred[::-1]: for p in preferred[::-1]:
@ -78,6 +80,12 @@ class ImageExporter(Exporter):
finally: finally:
self.setExportMode(False) self.setExportMode(False)
painter.end() painter.end()
self.png.save(fileName)
if copy:
QtGui.QApplication.clipboard().setImage(self.png)
elif toBytes:
return self.png
else:
self.png.save(fileName)

View File

@ -11,6 +11,8 @@ __all__ = ['SVGExporter']
class SVGExporter(Exporter): class SVGExporter(Exporter):
Name = "Scalable Vector Graphics (SVG)" Name = "Scalable Vector Graphics (SVG)"
allowCopy=True
def __init__(self, item): def __init__(self, item):
Exporter.__init__(self, item) Exporter.__init__(self, item)
#tr = self.getTargetRect() #tr = self.getTargetRect()
@ -37,8 +39,8 @@ class SVGExporter(Exporter):
def parameters(self): def parameters(self):
return self.params return self.params
def export(self, fileName=None, toBytes=False): def export(self, fileName=None, toBytes=False, copy=False):
if toBytes is False and fileName is None: if toBytes is False and copy is False and fileName is None:
self.fileSaveDialog(filter="Scalable Vector Graphics (*.svg)") self.fileSaveDialog(filter="Scalable Vector Graphics (*.svg)")
return return
#self.svg = QtSvg.QSvgGenerator() #self.svg = QtSvg.QSvgGenerator()
@ -83,11 +85,16 @@ class SVGExporter(Exporter):
xml = generateSvg(self.item) xml = generateSvg(self.item)
if toBytes: if toBytes:
return bytes(xml) return xml.encode('UTF-8')
elif copy:
md = QtCore.QMimeData()
md.setData('image/svg+xml', QtCore.QByteArray(xml.encode('UTF-8')))
QtGui.QApplication.clipboard().setMimeData(md)
else: else:
with open(fileName, 'w') as fh: with open(fileName, 'w') as fh:
fh.write(xml.encode('UTF-8')) fh.write(xml.encode('UTF-8'))
xmlHeader = """\ xmlHeader = """\
<?xml version="1.0" encoding="UTF-8" standalone="no"?> <?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.2" baseProfile="tiny"> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.2" baseProfile="tiny">
@ -148,7 +155,7 @@ def _generateItemSvg(item, nodes=None, root=None):
## ##
## Both 2 and 3 can be addressed by drawing all items in world coordinates. ## Both 2 and 3 can be addressed by drawing all items in world coordinates.
prof = pg.debug.Profiler('generateItemSvg %s' % str(item), disabled=True)
if nodes is None: ## nodes maps all node IDs to their XML element. if nodes is None: ## nodes maps all node IDs to their XML element.
## this allows us to ensure all elements receive unique names. ## this allows us to ensure all elements receive unique names.
@ -170,8 +177,12 @@ def _generateItemSvg(item, nodes=None, root=None):
tr = QtGui.QTransform() tr = QtGui.QTransform()
if isinstance(item, QtGui.QGraphicsScene): if isinstance(item, QtGui.QGraphicsScene):
xmlStr = "<g>\n</g>\n" xmlStr = "<g>\n</g>\n"
childs = [i for i in item.items() if i.parentItem() is None]
doc = xml.parseString(xmlStr) doc = xml.parseString(xmlStr)
childs = [i for i in item.items() if i.parentItem() is None]
elif item.__class__.paint == QtGui.QGraphicsItem.paint:
xmlStr = "<g>\n</g>\n"
doc = xml.parseString(xmlStr)
childs = item.childItems()
else: else:
childs = item.childItems() childs = item.childItems()
tr = itemTransform(item, item.scene()) tr = itemTransform(item, item.scene())
@ -223,11 +234,12 @@ def _generateItemSvg(item, nodes=None, root=None):
print(doc.toxml()) print(doc.toxml())
raise raise
prof.mark('render')
## Get rid of group transformation matrices by applying ## Get rid of group transformation matrices by applying
## transformation to inner coordinates ## transformation to inner coordinates
correctCoordinates(g1, item) correctCoordinates(g1, item)
prof.mark('correct')
## make sure g1 has the transformation matrix ## make sure g1 has the transformation matrix
#m = (tr.m11(), tr.m12(), tr.m21(), tr.m22(), tr.m31(), tr.m32()) #m = (tr.m11(), tr.m12(), tr.m21(), tr.m22(), tr.m31(), tr.m32())
#g1.setAttribute('transform', "matrix(%f,%f,%f,%f,%f,%f)" % m) #g1.setAttribute('transform', "matrix(%f,%f,%f,%f,%f,%f)" % m)
@ -277,6 +289,8 @@ def _generateItemSvg(item, nodes=None, root=None):
childGroup = g1.ownerDocument.createElement('g') childGroup = g1.ownerDocument.createElement('g')
childGroup.setAttribute('clip-path', 'url(#%s)' % clip) childGroup.setAttribute('clip-path', 'url(#%s)' % clip)
g1.appendChild(childGroup) g1.appendChild(childGroup)
prof.mark('clipping')
## Add all child items as sub-elements. ## Add all child items as sub-elements.
childs.sort(key=lambda c: c.zValue()) childs.sort(key=lambda c: c.zValue())
for ch in childs: for ch in childs:
@ -284,7 +298,8 @@ def _generateItemSvg(item, nodes=None, root=None):
if cg is None: if cg is None:
continue continue
childGroup.appendChild(cg) ### this isn't quite right--some items draw below their parent (good enough for now) childGroup.appendChild(cg) ### this isn't quite right--some items draw below their parent (good enough for now)
prof.mark('children')
prof.finish()
return g1 return g1
def correctCoordinates(node, item): def correctCoordinates(node, item):

View File

@ -683,7 +683,7 @@ class AxisItem(GraphicsWidget):
if tickPositions[i][j] is None: if tickPositions[i][j] is None:
strings[j] = None strings[j] = None
textRects.extend([p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, s) for s in strings if s is not None]) textRects.extend([p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, str(s)) for s in strings if s is not None])
if i > 0: ## always draw top level if i > 0: ## always draw top level
## measure all text, make sure there's enough room ## measure all text, make sure there's enough room
if axis == 0: if axis == 0:
@ -699,8 +699,9 @@ class AxisItem(GraphicsWidget):
#strings = self.tickStrings(values, self.scale, spacing) #strings = self.tickStrings(values, self.scale, spacing)
for j in range(len(strings)): for j in range(len(strings)):
vstr = strings[j] vstr = strings[j]
if vstr is None:## this tick was ignored because it is out of bounds if vstr is None: ## this tick was ignored because it is out of bounds
continue continue
vstr = str(vstr)
x = tickPositions[i][j] x = tickPositions[i][j]
textRect = p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, vstr) textRect = p.boundingRect(QtCore.QRectF(0, 0, 100, 100), QtCore.Qt.AlignCenter, vstr)
height = textRect.height() height = textRect.height()

View File

@ -5,6 +5,8 @@ from .GraphicsObject import GraphicsObject
from .GraphicsWidget import GraphicsWidget from .GraphicsWidget import GraphicsWidget
import weakref import weakref
from pyqtgraph.pgcollections import OrderedDict from pyqtgraph.pgcollections import OrderedDict
from pyqtgraph.colormap import ColorMap
import numpy as np import numpy as np
__all__ = ['TickSliderItem', 'GradientEditorItem'] __all__ = ['TickSliderItem', 'GradientEditorItem']
@ -22,6 +24,9 @@ Gradients = OrderedDict([
]) ])
class TickSliderItem(GraphicsWidget): class TickSliderItem(GraphicsWidget):
## public class ## public class
"""**Bases:** :class:`GraphicsWidget <pyqtgraph.GraphicsWidget>` """**Bases:** :class:`GraphicsWidget <pyqtgraph.GraphicsWidget>`
@ -490,6 +495,18 @@ class GradientEditorItem(TickSliderItem):
self.colorMode = cm self.colorMode = cm
self.updateGradient() self.updateGradient()
def colorMap(self):
"""Return a ColorMap object representing the current state of the editor."""
if self.colorMode == 'hsv':
raise NotImplementedError('hsv colormaps not yet supported')
pos = []
color = []
for t,x in self.listTicks():
pos.append(x)
c = t.color
color.append([c.red(), c.green(), c.blue(), c.alpha()])
return ColorMap(np.array(pos), np.array(color, dtype=np.ubyte))
def updateGradient(self): def updateGradient(self):
#private #private
self.gradient = self.getGradient() self.gradient = self.getGradient()
@ -611,7 +628,7 @@ class GradientEditorItem(TickSliderItem):
b = c1.blue() * (1.-f) + c2.blue() * f b = c1.blue() * (1.-f) + c2.blue() * f
a = c1.alpha() * (1.-f) + c2.alpha() * f a = c1.alpha() * (1.-f) + c2.alpha() * f
if toQColor: if toQColor:
return QtGui.QColor(r, g, b,a) return QtGui.QColor(int(r), int(g), int(b), int(a))
else: else:
return (r,g,b,a) return (r,g,b,a)
elif self.colorMode == 'hsv': elif self.colorMode == 'hsv':
@ -751,6 +768,18 @@ class GradientEditorItem(TickSliderItem):
self.addTick(t[0], c, finish=False) self.addTick(t[0], c, finish=False)
self.updateGradient() self.updateGradient()
self.sigGradientChangeFinished.emit(self) self.sigGradientChangeFinished.emit(self)
def setColorMap(self, cm):
self.setColorMode('rgb')
for t in list(self.ticks.keys()):
self.removeTick(t, finish=False)
colors = cm.getColors(mode='qcolor')
for i in range(len(cm.pos)):
x = cm.pos[i]
c = colors[i]
self.addTick(x, c, finish=False)
self.updateGradient()
self.sigGradientChangeFinished.emit(self)
class Tick(GraphicsObject): class Tick(GraphicsObject):

View File

@ -20,7 +20,7 @@ class FiniteCache(OrderedDict):
del self[list(self.keys())[0]] del self[list(self.keys())[0]]
def __getitem__(self, item): def __getitem__(self, item):
val = dict.__getitem__(self, item) val = OrderedDict.__getitem__(self, item)
del self[item] del self[item]
self[item] = val ## promote this key self[item] = val ## promote this key
return val return val
@ -194,6 +194,10 @@ class GraphicsItem(object):
dt = self.deviceTransform() dt = self.deviceTransform()
if dt is None: if dt is None:
return None, None return None, None
## Ignore translation. If the translation is much larger than the scale
## (such as when looking at unix timestamps), we can get floating-point errors.
dt.setMatrix(dt.m11(), dt.m12(), 0, dt.m21(), dt.m22(), 0, 0, 0, 1)
## check local cache ## check local cache
if direction is None and dt == self._pixelVectorCache[0]: if direction is None and dt == self._pixelVectorCache[0]:
@ -213,15 +217,32 @@ class GraphicsItem(object):
raise Exception("Cannot compute pixel length for 0-length vector.") raise Exception("Cannot compute pixel length for 0-length vector.")
## attempt to re-scale direction vector to fit within the precision of the coordinate system ## attempt to re-scale direction vector to fit within the precision of the coordinate system
if direction.x() == 0: ## Here's the problem: we need to map the vector 'direction' from the item to the device, via transform 'dt'.
r = abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22())) ## In some extreme cases, this mapping can fail unless the length of 'direction' is cleverly chosen.
#r = 1.0/(abs(dt.m12()) + abs(dt.m22())) ## Example:
elif direction.y() == 0: ## dt = [ 1, 0, 2
r = abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21())) ## 0, 2, 1e20
#r = 1.0/(abs(dt.m11()) + abs(dt.m21())) ## 0, 0, 1 ]
else: ## Then we map the origin (0,0) and direction (0,1) and get:
r = ((abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22()))) * (abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21()))))**0.5 ## o' = 2,1e20
directionr = direction * r ## d' = 2,1e20 <-- should be 1e20+2, but this can't be represented with a 32-bit float
##
## |o' - d'| == 0 <-- this is the problem.
## Perhaps the easiest solution is to exclude the transformation column from dt. Does this cause any other problems?
#if direction.x() == 0:
#r = abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22()))
##r = 1.0/(abs(dt.m12()) + abs(dt.m22()))
#elif direction.y() == 0:
#r = abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21()))
##r = 1.0/(abs(dt.m11()) + abs(dt.m21()))
#else:
#r = ((abs(dt.m32())/(abs(dt.m12()) + abs(dt.m22()))) * (abs(dt.m31())/(abs(dt.m11()) + abs(dt.m21()))))**0.5
#if r == 0:
#r = 1. ## shouldn't need to do this; probably means the math above is wrong?
#directionr = direction * r
directionr = direction
## map direction vector onto device ## map direction vector onto device
#viewDir = Point(dt.map(directionr) - dt.map(Point(0,0))) #viewDir = Point(dt.map(directionr) - dt.map(Point(0,0)))

View File

@ -104,6 +104,7 @@ class PlotDataItem(GraphicsObject):
self.yData = None self.yData = None
self.xDisp = None self.xDisp = None
self.yDisp = None self.yDisp = None
self.dataMask = None
#self.curves = [] #self.curves = []
#self.scatters = [] #self.scatters = []
self.curve = PlotCurveItem() self.curve = PlotCurveItem()
@ -393,6 +394,7 @@ class PlotDataItem(GraphicsObject):
scatterArgs[v] = self.opts[k] scatterArgs[v] = self.opts[k]
x,y = self.getData() x,y = self.getData()
scatterArgs['mask'] = self.dataMask
if curveArgs['pen'] is not None or (curveArgs['brush'] is not None and curveArgs['fillLevel'] is not None): if curveArgs['pen'] is not None or (curveArgs['brush'] is not None and curveArgs['fillLevel'] is not None):
self.curve.setData(x=x, y=y, **curveArgs) self.curve.setData(x=x, y=y, **curveArgs)
@ -413,11 +415,15 @@ class PlotDataItem(GraphicsObject):
if self.xDisp is None: if self.xDisp is None:
nanMask = np.isnan(self.xData) | np.isnan(self.yData) | np.isinf(self.xData) | np.isinf(self.yData) nanMask = np.isnan(self.xData) | np.isnan(self.yData) | np.isinf(self.xData) | np.isinf(self.yData)
if any(nanMask): if any(nanMask):
x = self.xData[~nanMask] self.dataMask = ~nanMask
y = self.yData[~nanMask] x = self.xData[self.dataMask]
y = self.yData[self.dataMask]
else: else:
self.dataMask = None
x = self.xData x = self.xData
y = self.yData y = self.yData
ds = self.opts['downsample'] ds = self.opts['downsample']
if ds > 1: if ds > 1:
x = x[::ds] x = x[::ds]
@ -435,8 +441,11 @@ class PlotDataItem(GraphicsObject):
if any(self.opts['logMode']): ## re-check for NANs after log if any(self.opts['logMode']): ## re-check for NANs after log
nanMask = np.isinf(x) | np.isinf(y) | np.isnan(x) | np.isnan(y) nanMask = np.isinf(x) | np.isinf(y) | np.isnan(x) | np.isnan(y)
if any(nanMask): if any(nanMask):
x = x[~nanMask] self.dataMask = ~nanMask
y = y[~nanMask] x = x[self.dataMask]
y = y[self.dataMask]
else:
self.dataMask = None
self.xDisp = x self.xDisp = x
self.yDisp = y self.yDisp = y
#print self.yDisp.shape, self.yDisp.min(), self.yDisp.max() #print self.yDisp.shape, self.yDisp.min(), self.yDisp.max()

View File

@ -36,6 +36,7 @@ from .. LabelItem import LabelItem
from .. LegendItem import LegendItem from .. LegendItem import LegendItem
from .. GraphicsWidget import GraphicsWidget from .. GraphicsWidget import GraphicsWidget
from .. ButtonItem import ButtonItem from .. ButtonItem import ButtonItem
from .. InfiniteLine import InfiniteLine
from pyqtgraph.WidgetGroup import WidgetGroup from pyqtgraph.WidgetGroup import WidgetGroup
__all__ = ['PlotItem'] __all__ = ['PlotItem']
@ -548,10 +549,35 @@ class PlotItem(GraphicsWidget):
print("PlotItem.addDataItem is deprecated. Use addItem instead.") print("PlotItem.addDataItem is deprecated. Use addItem instead.")
self.addItem(item, *args) self.addItem(item, *args)
def listDataItems(self):
"""Return a list of all data items (PlotDataItem, PlotCurveItem, ScatterPlotItem, etc)
contained in this PlotItem."""
return self.dataItems[:]
def addCurve(self, c, params=None): def addCurve(self, c, params=None):
print("PlotItem.addCurve is deprecated. Use addItem instead.") print("PlotItem.addCurve is deprecated. Use addItem instead.")
self.addItem(c, params) self.addItem(c, params)
def addLine(self, x=None, y=None, z=None, **kwds):
"""
Create an InfiniteLine and add to the plot.
If *x* is specified,
the line will be vertical. If *y* is specified, the line will be
horizontal. All extra keyword arguments are passed to
:func:`InfiniteLine.__init__() <pyqtgraph.InfiniteLine.__init__>`.
Returns the item created.
"""
angle = 0 if x is None else 90
pos = x if x is not None else y
line = InfiniteLine(pos, angle, **kwds)
self.addItem(line)
if z is not None:
line.setZValue(z)
return line
def removeItem(self, item): def removeItem(self, item):
""" """
Remove an item from the internal ViewBox. Remove an item from the internal ViewBox.

View File

@ -384,7 +384,7 @@ class ScatterPlotItem(GraphicsObject):
for k in ['pen', 'brush', 'symbol', 'size']: for k in ['pen', 'brush', 'symbol', 'size']:
if k in kargs: if k in kargs:
setMethod = getattr(self, 'set' + k[0].upper() + k[1:]) setMethod = getattr(self, 'set' + k[0].upper() + k[1:])
setMethod(kargs[k], update=False, dataSet=newData) setMethod(kargs[k], update=False, dataSet=newData, mask=kargs.get('mask', None))
if 'data' in kargs: if 'data' in kargs:
self.setPointData(kargs['data'], dataSet=newData) self.setPointData(kargs['data'], dataSet=newData)
@ -425,6 +425,8 @@ class ScatterPlotItem(GraphicsObject):
if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)): if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)):
pens = args[0] pens = args[0]
if kargs['mask'] is not None:
pens = pens[kargs['mask']]
if len(pens) != len(dataSet): if len(pens) != len(dataSet):
raise Exception("Number of pens does not match number of points (%d != %d)" % (len(pens), len(dataSet))) raise Exception("Number of pens does not match number of points (%d != %d)" % (len(pens), len(dataSet)))
dataSet['pen'] = pens dataSet['pen'] = pens
@ -445,6 +447,8 @@ class ScatterPlotItem(GraphicsObject):
if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)): if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)):
brushes = args[0] brushes = args[0]
if kargs['mask'] is not None:
brushes = brushes[kargs['mask']]
if len(brushes) != len(dataSet): if len(brushes) != len(dataSet):
raise Exception("Number of brushes does not match number of points (%d != %d)" % (len(brushes), len(dataSet))) raise Exception("Number of brushes does not match number of points (%d != %d)" % (len(brushes), len(dataSet)))
#for i in xrange(len(brushes)): #for i in xrange(len(brushes)):
@ -458,7 +462,7 @@ class ScatterPlotItem(GraphicsObject):
if update: if update:
self.updateSpots(dataSet) self.updateSpots(dataSet)
def setSymbol(self, symbol, update=True, dataSet=None): def setSymbol(self, symbol, update=True, dataSet=None, mask=None):
"""Set the symbol(s) used to draw each spot. """Set the symbol(s) used to draw each spot.
If a list or array is provided, then the symbol for each spot will be set separately. If a list or array is provided, then the symbol for each spot will be set separately.
Otherwise, the argument will be used as the default symbol for Otherwise, the argument will be used as the default symbol for
@ -468,6 +472,8 @@ class ScatterPlotItem(GraphicsObject):
if isinstance(symbol, np.ndarray) or isinstance(symbol, list): if isinstance(symbol, np.ndarray) or isinstance(symbol, list):
symbols = symbol symbols = symbol
if kargs['mask'] is not None:
symbols = symbols[kargs['mask']]
if len(symbols) != len(dataSet): if len(symbols) != len(dataSet):
raise Exception("Number of symbols does not match number of points (%d != %d)" % (len(symbols), len(dataSet))) raise Exception("Number of symbols does not match number of points (%d != %d)" % (len(symbols), len(dataSet)))
dataSet['symbol'] = symbols dataSet['symbol'] = symbols
@ -479,7 +485,7 @@ class ScatterPlotItem(GraphicsObject):
if update: if update:
self.updateSpots(dataSet) self.updateSpots(dataSet)
def setSize(self, size, update=True, dataSet=None): def setSize(self, size, update=True, dataSet=None, mask=None):
"""Set the size(s) used to draw each spot. """Set the size(s) used to draw each spot.
If a list or array is provided, then the size for each spot will be set separately. If a list or array is provided, then the size for each spot will be set separately.
Otherwise, the argument will be used as the default size for Otherwise, the argument will be used as the default size for
@ -489,6 +495,8 @@ class ScatterPlotItem(GraphicsObject):
if isinstance(size, np.ndarray) or isinstance(size, list): if isinstance(size, np.ndarray) or isinstance(size, list):
sizes = size sizes = size
if kargs['mask'] is not None:
sizes = sizes[kargs['mask']]
if len(sizes) != len(dataSet): if len(sizes) != len(dataSet):
raise Exception("Number of sizes does not match number of points (%d != %d)" % (len(sizes), len(dataSet))) raise Exception("Number of sizes does not match number of points (%d != %d)" % (len(sizes), len(dataSet)))
dataSet['size'] = sizes dataSet['size'] = sizes
@ -505,6 +513,8 @@ class ScatterPlotItem(GraphicsObject):
dataSet = self.data dataSet = self.data
if isinstance(data, np.ndarray) or isinstance(data, list): if isinstance(data, np.ndarray) or isinstance(data, list):
if kargs['mask'] is not None:
data = data[kargs['mask']]
if len(data) != len(dataSet): if len(data) != len(dataSet):
raise Exception("Length of meta data does not match number of points (%d != %d)" % (len(data), len(dataSet))) raise Exception("Length of meta data does not match number of points (%d != %d)" % (len(data), len(dataSet)))
@ -881,7 +891,7 @@ class SpotItem(object):
def updateItem(self): def updateItem(self):
self._data['fragCoords'] = None self._data['fragCoords'] = None
self._plot.updateSpots([self._data]) self._plot.updateSpots(self._data.reshape(1))
self._plot.invalidate() self._plot.invalidate()
#class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem): #class PixmapSpotItem(SpotItem, QtGui.QGraphicsPixmapItem):

View File

@ -576,9 +576,12 @@ class ViewBox(GraphicsWidget):
w2 = (targetRect[ax][1]-targetRect[ax][0]) / 2. w2 = (targetRect[ax][1]-targetRect[ax][0]) / 2.
childRange[ax] = [x-w2, x+w2] childRange[ax] = [x-w2, x+w2]
else: else:
wp = (xr[1] - xr[0]) * 0.02 l = self.width() if ax==0 else self.height()
childRange[ax][0] -= wp if l > 0:
childRange[ax][1] += wp padding = np.clip(1./(l**0.5), 0.02, 0.1)
wp = (xr[1] - xr[0]) * padding
childRange[ax][0] -= wp
childRange[ax][1] += wp
targetRect[ax] = childRange[ax] targetRect[ax] = childRange[ax]
args['xRange' if ax == 0 else 'yRange'] = targetRect[ax] args['xRange' if ax == 0 else 'yRange'] = targetRect[ax]
if len(args) == 0: if len(args) == 0:

View File

@ -525,8 +525,9 @@ class Parameter(QtCore.QObject):
self.removeChild(ch) self.removeChild(ch)
def children(self): def children(self):
"""Return a list of this parameter's children.""" """Return a list of this parameter's children.
## warning -- this overrides QObject.children Warning: this overrides QObject.children
"""
return self.childs[:] return self.childs[:]
def hasChildren(self): def hasChildren(self):
@ -608,13 +609,13 @@ class Parameter(QtCore.QObject):
def __getattr__(self, attr): def __getattr__(self, attr):
## Leaving this undocumented because I might like to remove it in the future.. ## Leaving this undocumented because I might like to remove it in the future..
#print type(self), attr #print type(self), attr
import traceback
traceback.print_stack()
print "Warning: Use of Parameter.subParam is deprecated. Use Parameter.param(name) instead."
if 'names' not in self.__dict__: if 'names' not in self.__dict__:
raise AttributeError(attr) raise AttributeError(attr)
if attr in self.names: if attr in self.names:
import traceback
traceback.print_stack()
print "Warning: Use of Parameter.subParam is deprecated. Use Parameter.param(name) instead."
return self.param(attr) return self.param(attr)
else: else:
raise AttributeError(attr) raise AttributeError(attr)

View File

@ -4,6 +4,7 @@ from .Parameter import Parameter, registerParameterType
from .ParameterItem import ParameterItem from .ParameterItem import ParameterItem
from pyqtgraph.widgets.SpinBox import SpinBox from pyqtgraph.widgets.SpinBox import SpinBox
from pyqtgraph.widgets.ColorButton import ColorButton from pyqtgraph.widgets.ColorButton import ColorButton
#from pyqtgraph.widgets.GradientWidget import GradientWidget ## creates import loop
import pyqtgraph as pg import pyqtgraph as pg
import pyqtgraph.pixmaps as pixmaps import pyqtgraph.pixmaps as pixmaps
import os import os
@ -61,7 +62,11 @@ class WidgetParameterItem(ParameterItem):
w.sigChanging.connect(self.widgetValueChanging) w.sigChanging.connect(self.widgetValueChanging)
## update value shown in widget. ## update value shown in widget.
self.valueChanged(self, opts['value'], force=True) if opts.get('value', None) is not None:
self.valueChanged(self, opts['value'], force=True)
else:
## no starting value was given; use whatever the widget has
self.widgetValueChanged()
def makeWidget(self): def makeWidget(self):
@ -125,6 +130,14 @@ class WidgetParameterItem(ParameterItem):
w.setValue = w.setColor w.setValue = w.setColor
self.hideWidget = False self.hideWidget = False
w.setFlat(True) w.setFlat(True)
elif t == 'colormap':
from pyqtgraph.widgets.GradientWidget import GradientWidget ## need this here to avoid import loop
w = GradientWidget(orientation='bottom')
w.sigChanged = w.sigGradientChangeFinished
w.sigChanging = w.sigGradientChanged
w.value = w.colorMap
w.setValue = w.setColorMap
self.hideWidget = False
else: else:
raise Exception("Unknown type '%s'" % asUnicode(t)) raise Exception("Unknown type '%s'" % asUnicode(t))
return w return w
@ -294,6 +307,7 @@ registerParameterType('float', SimpleParameter, override=True)
registerParameterType('bool', SimpleParameter, override=True) registerParameterType('bool', SimpleParameter, override=True)
registerParameterType('str', SimpleParameter, override=True) registerParameterType('str', SimpleParameter, override=True)
registerParameterType('color', SimpleParameter, override=True) registerParameterType('color', SimpleParameter, override=True)
registerParameterType('colormap', SimpleParameter, override=True)

View File

@ -13,11 +13,11 @@ for path, sd, files in os.walk('.'):
ui = os.path.join(path, f) ui = os.path.join(path, f)
py = os.path.join(path, base + '_pyqt.py') py = os.path.join(path, base + '_pyqt.py')
if os.stat(ui).st_mtime > os.stat(py).st_mtime: if not os.path.exists(py) or os.stat(ui).st_mtime > os.stat(py).st_mtime:
os.system('%s %s > %s' % (pyqtuic, ui, py)) os.system('%s %s > %s' % (pyqtuic, ui, py))
print(py) print(py)
py = os.path.join(path, base + '_pyside.py') py = os.path.join(path, base + '_pyside.py')
if os.stat(ui).st_mtime > os.stat(py).st_mtime: if not os.path.exists(py) or os.stat(ui).st_mtime > os.stat(py).st_mtime:
os.system('%s %s > %s' % (pysideuic, ui, py)) os.system('%s %s > %s' % (pysideuic, ui, py))
print(py) print(py)

View File

@ -77,8 +77,14 @@ class ColorButton(QtGui.QPushButton):
def restoreState(self, state): def restoreState(self, state):
self.setColor(state) self.setColor(state)
def color(self): def color(self, mode='qcolor'):
return functions.mkColor(self._color) color = functions.mkColor(self._color)
if mode == 'qcolor':
return color
elif mode == 'byte':
return (color.red(), color.green(), color.blue(), color.alpha())
elif mode == 'float':
return (color.red()/255., color.green()/255., color.blue()/255., color.alpha()/255.)
def widgetGroupInterface(self): def widgetGroupInterface(self):
return (self.sigColorChanged, ColorButton.saveState, ColorButton.restoreState) return (self.sigColorChanged, ColorButton.saveState, ColorButton.restoreState)

173
widgets/ColorMapWidget.py Normal file
View File

@ -0,0 +1,173 @@
from pyqtgraph.Qt import QtGui, QtCore
import pyqtgraph.parametertree as ptree
import numpy as np
from pyqtgraph.pgcollections import OrderedDict
import pyqtgraph.functions as fn
__all__ = ['ColorMapWidget']
class ColorMapWidget(ptree.ParameterTree):
"""
This class provides a widget allowing the user to customize color mapping
for multi-column data.
"""
sigColorMapChanged = QtCore.Signal(object)
def __init__(self):
ptree.ParameterTree.__init__(self, showHeader=False)
self.params = ColorMapParameter()
self.setParameters(self.params)
self.params.sigTreeStateChanged.connect(self.mapChanged)
## wrap a couple methods
self.setFields = self.params.setFields
self.map = self.params.map
def mapChanged(self):
self.sigColorMapChanged.emit(self)
class ColorMapParameter(ptree.types.GroupParameter):
sigColorMapChanged = QtCore.Signal(object)
def __init__(self):
self.fields = {}
ptree.types.GroupParameter.__init__(self, name='Color Map', addText='Add Mapping..', addList=[])
self.sigTreeStateChanged.connect(self.mapChanged)
def mapChanged(self):
self.sigColorMapChanged.emit(self)
def addNew(self, name):
mode = self.fields[name].get('mode', 'range')
if mode == 'range':
self.addChild(RangeColorMapItem(name, self.fields[name]))
elif mode == 'enum':
self.addChild(EnumColorMapItem(name, self.fields[name]))
def fieldNames(self):
return self.fields.keys()
def setFields(self, fields):
self.fields = OrderedDict(fields)
#self.fields = fields
#self.fields.sort()
names = self.fieldNames()
self.setAddList(names)
def map(self, data, mode='byte'):
colors = np.zeros((len(data),4))
for item in self.children():
if not item['Enabled']:
continue
chans = item.param('Channels..')
mask = np.empty((len(data), 4), dtype=bool)
for i,f in enumerate(['Red', 'Green', 'Blue', 'Alpha']):
mask[:,i] = chans[f]
colors2 = item.map(data)
op = item['Operation']
if op == 'Add':
colors[mask] = colors[mask] + colors2[mask]
elif op == 'Multiply':
colors[mask] *= colors2[mask]
elif op == 'Overlay':
a = colors2[:,3:4]
c3 = colors * (1-a) + colors2 * a
c3[:,3:4] = colors[:,3:4] + (1-colors[:,3:4]) * a
colors = c3
elif op == 'Set':
colors[mask] = colors2[mask]
colors = np.clip(colors, 0, 1)
if mode == 'byte':
colors = (colors * 255).astype(np.ubyte)
return colors
class RangeColorMapItem(ptree.types.SimpleParameter):
def __init__(self, name, opts):
self.fieldName = name
units = opts.get('units', '')
ptree.types.SimpleParameter.__init__(self,
name=name, autoIncrementName=True, type='colormap', removable=True, renamable=True,
children=[
#dict(name="Field", type='list', value=name, values=fields),
dict(name='Min', type='float', value=0.0, suffix=units, siPrefix=True),
dict(name='Max', type='float', value=1.0, suffix=units, siPrefix=True),
dict(name='Operation', type='list', value='Overlay', values=['Overlay', 'Add', 'Multiply', 'Set']),
dict(name='Channels..', type='group', expanded=False, children=[
dict(name='Red', type='bool', value=True),
dict(name='Green', type='bool', value=True),
dict(name='Blue', type='bool', value=True),
dict(name='Alpha', type='bool', value=True),
]),
dict(name='Enabled', type='bool', value=True),
dict(name='NaN', type='color'),
])
def map(self, data):
data = data[self.fieldName]
scaled = np.clip((data-self['Min']) / (self['Max']-self['Min']), 0, 1)
cmap = self.value()
colors = cmap.map(scaled, mode='float')
mask = np.isnan(data) | np.isinf(data)
nanColor = self['NaN']
nanColor = (nanColor.red()/255., nanColor.green()/255., nanColor.blue()/255., nanColor.alpha()/255.)
colors[mask] = nanColor
return colors
class EnumColorMapItem(ptree.types.GroupParameter):
def __init__(self, name, opts):
self.fieldName = name
vals = opts.get('values', [])
childs = [{'name': v, 'type': 'color'} for v in vals]
ptree.types.GroupParameter.__init__(self,
name=name, autoIncrementName=True, removable=True, renamable=True,
children=[
dict(name='Values', type='group', children=childs),
dict(name='Operation', type='list', value='Overlay', values=['Overlay', 'Add', 'Multiply', 'Set']),
dict(name='Channels..', type='group', expanded=False, children=[
dict(name='Red', type='bool', value=True),
dict(name='Green', type='bool', value=True),
dict(name='Blue', type='bool', value=True),
dict(name='Alpha', type='bool', value=True),
]),
dict(name='Enabled', type='bool', value=True),
dict(name='Default', type='color'),
])
def map(self, data):
data = data[self.fieldName]
colors = np.empty((len(data), 4))
default = np.array(fn.colorTuple(self['Default'])) / 255.
colors[:] = default
for v in self.param('Values'):
n = v.name()
mask = data == n
c = np.array(fn.colorTuple(v.value())) / 255.
colors[mask] = c
#scaled = np.clip((data-self['Min']) / (self['Max']-self['Min']), 0, 1)
#cmap = self.value()
#colors = cmap.map(scaled, mode='float')
#mask = np.isnan(data) | np.isinf(data)
#nanColor = self['NaN']
#nanColor = (nanColor.red()/255., nanColor.green()/255., nanColor.blue()/255., nanColor.alpha()/255.)
#colors[mask] = nanColor
return colors

115
widgets/DataFilterWidget.py Normal file
View File

@ -0,0 +1,115 @@
from pyqtgraph.Qt import QtGui, QtCore
import pyqtgraph.parametertree as ptree
import numpy as np
from pyqtgraph.pgcollections import OrderedDict
__all__ = ['DataFilterWidget']
class DataFilterWidget(ptree.ParameterTree):
"""
This class allows the user to filter multi-column data sets by specifying
multiple criteria
"""
sigFilterChanged = QtCore.Signal(object)
def __init__(self):
ptree.ParameterTree.__init__(self, showHeader=False)
self.params = DataFilterParameter()
self.setParameters(self.params)
self.params.sigTreeStateChanged.connect(self.filterChanged)
self.setFields = self.params.setFields
self.filterData = self.params.filterData
def filterChanged(self):
self.sigFilterChanged.emit(self)
def parameters(self):
return self.params
class DataFilterParameter(ptree.types.GroupParameter):
sigFilterChanged = QtCore.Signal(object)
def __init__(self):
self.fields = {}
ptree.types.GroupParameter.__init__(self, name='Data Filter', addText='Add filter..', addList=[])
self.sigTreeStateChanged.connect(self.filterChanged)
def filterChanged(self):
self.sigFilterChanged.emit(self)
def addNew(self, name):
mode = self.fields[name].get('mode', 'range')
if mode == 'range':
self.addChild(RangeFilterItem(name, self.fields[name]))
elif mode == 'enum':
self.addChild(EnumFilterItem(name, self.fields[name]))
def fieldNames(self):
return self.fields.keys()
def setFields(self, fields):
self.fields = OrderedDict(fields)
names = self.fieldNames()
self.setAddList(names)
def filterData(self, data):
if len(data) == 0:
return data
return data[self.generateMask(data)]
def generateMask(self, data):
mask = np.ones(len(data), dtype=bool)
if len(data) == 0:
return mask
for fp in self:
if fp.value() is False:
continue
mask &= fp.generateMask(data)
#key, mn, mx = fp.fieldName, fp['Min'], fp['Max']
#vals = data[key]
#mask &= (vals >= mn)
#mask &= (vals < mx) ## Use inclusive minimum and non-inclusive maximum. This makes it easier to create non-overlapping selections
return mask
class RangeFilterItem(ptree.types.SimpleParameter):
def __init__(self, name, opts):
self.fieldName = name
units = opts.get('units', '')
ptree.types.SimpleParameter.__init__(self,
name=name, autoIncrementName=True, type='bool', value=True, removable=True, renamable=True,
children=[
#dict(name="Field", type='list', value=name, values=fields),
dict(name='Min', type='float', value=0.0, suffix=units, siPrefix=True),
dict(name='Max', type='float', value=1.0, suffix=units, siPrefix=True),
])
def generateMask(self, data):
vals = data[self.fieldName]
return (vals >= mn) & (vals < mx) ## Use inclusive minimum and non-inclusive maximum. This makes it easier to create non-overlapping selections
class EnumFilterItem(ptree.types.SimpleParameter):
def __init__(self, name, opts):
self.fieldName = name
vals = opts.get('values', [])
childs = [{'name': v, 'type': 'bool', 'value': True} for v in vals]
ptree.types.SimpleParameter.__init__(self,
name=name, autoIncrementName=True, type='bool', value=True, removable=True, renamable=True,
children=childs)
def generateMask(self, data):
vals = data[self.fieldName]
mask = np.ones(len(data), dtype=bool)
for c in self:
if c.value() is True:
continue
key = c.name()
mask &= vals != key
return mask

View File

@ -0,0 +1,183 @@
from pyqtgraph.Qt import QtGui, QtCore
from .PlotWidget import PlotWidget
from .DataFilterWidget import DataFilterParameter
from .ColorMapWidget import ColorMapParameter
import pyqtgraph.parametertree as ptree
import pyqtgraph.functions as fn
import numpy as np
from pyqtgraph.pgcollections import OrderedDict
__all__ = ['ScatterPlotWidget']
class ScatterPlotWidget(QtGui.QSplitter):
"""
Given a record array, display a scatter plot of a specific set of data.
This widget includes controls for selecting the columns to plot,
filtering data, and determining symbol color and shape. This widget allows
the user to explore relationships between columns in a record array.
The widget consists of four components:
1) A list of column names from which the user may select 1 or 2 columns
to plot. If one column is selected, the data for that column will be
plotted in a histogram-like manner by using :func:`pseudoScatter()
<pyqtgraph.pseudoScatter>`. If two columns are selected, then the
scatter plot will be generated with x determined by the first column
that was selected and y by the second.
2) A DataFilter that allows the user to select a subset of the data by
specifying multiple selection criteria.
3) A ColorMap that allows the user to determine how points are colored by
specifying multiple criteria.
4) A PlotWidget for displaying the data.
"""
def __init__(self, parent=None):
QtGui.QSplitter.__init__(self, QtCore.Qt.Horizontal)
self.ctrlPanel = QtGui.QSplitter(QtCore.Qt.Vertical)
self.addWidget(self.ctrlPanel)
self.fieldList = QtGui.QListWidget()
self.fieldList.setSelectionMode(self.fieldList.ExtendedSelection)
self.ptree = ptree.ParameterTree(showHeader=False)
self.filter = DataFilterParameter()
self.colorMap = ColorMapParameter()
self.params = ptree.Parameter.create(name='params', type='group', children=[self.filter, self.colorMap])
self.ptree.setParameters(self.params, showTop=False)
self.plot = PlotWidget()
self.ctrlPanel.addWidget(self.fieldList)
self.ctrlPanel.addWidget(self.ptree)
self.addWidget(self.plot)
self.data = None
self.style = dict(pen=None, symbol='o')
self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged)
self.filter.sigFilterChanged.connect(self.filterChanged)
self.colorMap.sigColorMapChanged.connect(self.updatePlot)
def setFields(self, fields):
"""
Set the list of field names/units to be processed.
Format is: [(name, units), ...]
"""
self.fields = OrderedDict(fields)
self.fieldList.clear()
for f,opts in fields:
item = QtGui.QListWidgetItem(f)
item.opts = opts
item = self.fieldList.addItem(item)
self.filter.setFields(fields)
self.colorMap.setFields(fields)
def setData(self, data):
"""
Set the data to be processed and displayed.
Argument must be a numpy record array.
"""
self.data = data
self.filtered = None
self.updatePlot()
def fieldSelectionChanged(self):
sel = self.fieldList.selectedItems()
if len(sel) > 2:
self.fieldList.blockSignals(True)
try:
for item in sel[1:-1]:
item.setSelected(False)
finally:
self.fieldList.blockSignals(False)
self.updatePlot()
def filterChanged(self, f):
self.filtered = None
self.updatePlot()
def updatePlot(self):
self.plot.clear()
if self.data is None:
return
if self.filtered is None:
self.filtered = self.filter.filterData(self.data)
data = self.filtered
if len(data) == 0:
return
colors = np.array([fn.mkBrush(*x) for x in self.colorMap.map(data)])
style = self.style.copy()
## Look up selected columns and units
sel = list([str(item.text()) for item in self.fieldList.selectedItems()])
units = list([item.opts.get('units', '') for item in self.fieldList.selectedItems()])
if len(sel) == 0:
self.plot.setTitle('')
return
if len(sel) == 1:
self.plot.setLabels(left=('N', ''), bottom=(sel[0], units[0]), title='')
if len(data) == 0:
return
x = data[sel[0]]
#if x.dtype.kind == 'f':
#mask = ~np.isnan(x)
#else:
#mask = np.ones(len(x), dtype=bool)
#x = x[mask]
#style['symbolBrush'] = colors[mask]
y = None
elif len(sel) == 2:
self.plot.setLabels(left=(sel[1],units[1]), bottom=(sel[0],units[0]))
if len(data) == 0:
return
xydata = []
for ax in [0,1]:
d = data[sel[ax]]
## scatter catecorical values just a bit so they show up better in the scatter plot.
#if sel[ax] in ['MorphologyBSMean', 'MorphologyTDMean', 'FIType']:
#d += np.random.normal(size=len(cells), scale=0.1)
xydata.append(d)
x,y = xydata
#mask = np.ones(len(x), dtype=bool)
#if x.dtype.kind == 'f':
#mask |= ~np.isnan(x)
#if y.dtype.kind == 'f':
#mask |= ~np.isnan(y)
#x = x[mask]
#y = y[mask]
#style['symbolBrush'] = colors[mask]
## convert enum-type fields to float, set axis labels
xy = [x,y]
for i in [0,1]:
axis = self.plot.getAxis(['bottom', 'left'][i])
if xy[i] is not None and xy[i].dtype.kind in ('S', 'O'):
vals = self.fields[sel[i]].get('values', list(set(xy[i])))
xy[i] = np.array([vals.index(x) if x in vals else None for x in xy[i]], dtype=float)
axis.setTicks([list(enumerate(vals))])
else:
axis.setTicks(None) # reset to automatic ticking
x,y = xy
## mask out any nan values
mask = np.ones(len(x), dtype=bool)
if x.dtype.kind == 'f':
mask &= ~np.isnan(x)
if y is not None and y.dtype.kind == 'f':
mask &= ~np.isnan(y)
x = x[mask]
style['symbolBrush'] = colors[mask]
## Scatter y-values for a histogram-like appearance
if y is None:
y = fn.pseudoScatter(x)
else:
y = y[mask]
self.plot.plot(x, y, **style)