479 lines
16 KiB
Python
479 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
|
from ..Node import Node
|
|
from ...Qt import QtGui, QtCore, QtWidgets
|
|
import numpy as np
|
|
import sys
|
|
from .common import *
|
|
from ...widgets.TreeWidget import TreeWidget
|
|
from ...graphicsItems.LinearRegionItem import LinearRegionItem
|
|
|
|
from . import functions
|
|
|
|
class ColumnSelectNode(Node):
|
|
"""Select named columns from a record array or MetaArray."""
|
|
nodeName = "ColumnSelect"
|
|
def __init__(self, name):
|
|
Node.__init__(self, name, terminals={'In': {'io': 'in'}})
|
|
self.columns = set()
|
|
self.columnList = QtGui.QListWidget()
|
|
self.axis = 0
|
|
self.columnList.itemChanged.connect(self.itemChanged)
|
|
|
|
def process(self, In, display=True):
|
|
if display:
|
|
self.updateList(In)
|
|
|
|
out = {}
|
|
if hasattr(In, 'implements') and In.implements('MetaArray'):
|
|
for c in self.columns:
|
|
out[c] = In[self.axis:c]
|
|
elif isinstance(In, np.ndarray) and In.dtype.fields is not None:
|
|
for c in self.columns:
|
|
out[c] = In[c]
|
|
else:
|
|
self.In.setValueAcceptable(False)
|
|
raise Exception("Input must be MetaArray or ndarray with named fields")
|
|
|
|
return out
|
|
|
|
def ctrlWidget(self):
|
|
return self.columnList
|
|
|
|
def updateList(self, data):
|
|
if hasattr(data, 'implements') and data.implements('MetaArray'):
|
|
cols = data.listColumns()
|
|
for ax in cols: ## find first axis with columns
|
|
if len(cols[ax]) > 0:
|
|
self.axis = ax
|
|
cols = set(cols[ax])
|
|
break
|
|
else:
|
|
cols = list(data.dtype.fields.keys())
|
|
|
|
rem = set()
|
|
for c in self.columns:
|
|
if c not in cols:
|
|
self.removeTerminal(c)
|
|
rem.add(c)
|
|
self.columns -= rem
|
|
|
|
self.columnList.blockSignals(True)
|
|
self.columnList.clear()
|
|
for c in cols:
|
|
item = QtGui.QListWidgetItem(c)
|
|
item.setFlags(QtCore.Qt.ItemFlag.ItemIsEnabled|QtCore.Qt.ItemFlag.ItemIsUserCheckable)
|
|
if c in self.columns:
|
|
item.setCheckState(QtCore.Qt.CheckState.Checked)
|
|
else:
|
|
item.setCheckState(QtCore.Qt.CheckState.Unchecked)
|
|
self.columnList.addItem(item)
|
|
self.columnList.blockSignals(False)
|
|
|
|
|
|
def itemChanged(self, item):
|
|
col = str(item.text())
|
|
if item.checkState() == QtCore.Qt.CheckState.Checked:
|
|
if col not in self.columns:
|
|
self.columns.add(col)
|
|
self.addOutput(col)
|
|
else:
|
|
if col in self.columns:
|
|
self.columns.remove(col)
|
|
self.removeTerminal(col)
|
|
self.update()
|
|
|
|
def saveState(self):
|
|
state = Node.saveState(self)
|
|
state['columns'] = list(self.columns)
|
|
return state
|
|
|
|
def restoreState(self, state):
|
|
Node.restoreState(self, state)
|
|
self.columns = set(state.get('columns', []))
|
|
for c in self.columns:
|
|
self.addOutput(c)
|
|
|
|
|
|
|
|
class RegionSelectNode(CtrlNode):
|
|
"""Returns a slice from a 1-D array. Connect the 'widget' output to a plot to display a region-selection widget."""
|
|
nodeName = "RegionSelect"
|
|
uiTemplate = [
|
|
('start', 'spin', {'value': 0, 'step': 0.1}),
|
|
('stop', 'spin', {'value': 0.1, 'step': 0.1}),
|
|
('display', 'check', {'value': True}),
|
|
('movable', 'check', {'value': True}),
|
|
]
|
|
|
|
def __init__(self, name):
|
|
self.items = {}
|
|
CtrlNode.__init__(self, name, terminals={
|
|
'data': {'io': 'in'},
|
|
'selected': {'io': 'out'},
|
|
'region': {'io': 'out'},
|
|
'widget': {'io': 'out', 'multi': True}
|
|
})
|
|
self.ctrls['display'].toggled.connect(self.displayToggled)
|
|
self.ctrls['movable'].toggled.connect(self.movableToggled)
|
|
|
|
def displayToggled(self, b):
|
|
for item in self.items.values():
|
|
item.setVisible(b)
|
|
|
|
def movableToggled(self, b):
|
|
for item in self.items.values():
|
|
item.setMovable(b)
|
|
|
|
|
|
def process(self, data=None, display=True):
|
|
#print "process.."
|
|
s = self.stateGroup.state()
|
|
region = [s['start'], s['stop']]
|
|
|
|
if display:
|
|
conn = self['widget'].connections()
|
|
for c in conn:
|
|
plot = c.node().getPlot()
|
|
if plot is None:
|
|
continue
|
|
if c in self.items:
|
|
item = self.items[c]
|
|
item.setRegion(region)
|
|
#print " set rgn:", c, region
|
|
#item.setXVals(events)
|
|
else:
|
|
item = LinearRegionItem(values=region)
|
|
self.items[c] = item
|
|
#item.connect(item, QtCore.SIGNAL('regionChanged'), self.rgnChanged)
|
|
item.sigRegionChanged.connect(self.rgnChanged)
|
|
item.setVisible(s['display'])
|
|
item.setMovable(s['movable'])
|
|
#print " new rgn:", c, region
|
|
#self.items[c].setYRange([0., 0.2], relative=True)
|
|
|
|
if self['selected'].isConnected():
|
|
if data is None:
|
|
sliced = None
|
|
elif (hasattr(data, 'implements') and data.implements('MetaArray')):
|
|
sliced = data[0:s['start']:s['stop']]
|
|
else:
|
|
mask = (data['time'] >= s['start']) * (data['time'] < s['stop'])
|
|
sliced = data[mask]
|
|
else:
|
|
sliced = None
|
|
return {'selected': sliced, 'widget': self.items, 'region': region}
|
|
|
|
|
|
def rgnChanged(self, item):
|
|
region = item.getRegion()
|
|
self.stateGroup.setState({'start': region[0], 'stop': region[1]})
|
|
self.update()
|
|
|
|
|
|
class TextEdit(QtWidgets.QTextEdit):
|
|
def __init__(self, on_update):
|
|
super().__init__()
|
|
self.on_update = on_update
|
|
self.lastText = None
|
|
|
|
def focusOutEvent(self, ev):
|
|
text = self.toPlainText()
|
|
if text != self.lastText:
|
|
self.lastText = text
|
|
self.on_update()
|
|
super().focusOutEvent(ev)
|
|
|
|
|
|
class EvalNode(Node):
|
|
"""Return the output of a string evaluated/executed by the python interpreter.
|
|
The string may be either an expression or a python script, and inputs are accessed as the name of the terminal.
|
|
For expressions, a single value may be evaluated for a single output, or a dict for multiple outputs.
|
|
For a script, the text will be executed as the body of a function."""
|
|
nodeName = 'PythonEval'
|
|
|
|
def __init__(self, name):
|
|
Node.__init__(self, name,
|
|
terminals = {
|
|
'input': {'io': 'in', 'renamable': True, 'multiable': True},
|
|
'output': {'io': 'out', 'renamable': True, 'multiable': True},
|
|
},
|
|
allowAddInput=True, allowAddOutput=True)
|
|
|
|
self.ui = QtGui.QWidget()
|
|
self.layout = QtGui.QGridLayout()
|
|
self.text = TextEdit(self.update)
|
|
self.text.setTabStopWidth(30)
|
|
self.text.setPlainText("# Access inputs as args['input_name']\nreturn {'output': None} ## one key per output terminal")
|
|
self.layout.addWidget(self.text, 1, 0, 1, 2)
|
|
self.ui.setLayout(self.layout)
|
|
|
|
def ctrlWidget(self):
|
|
return self.ui
|
|
|
|
def setCode(self, code):
|
|
# unindent code; this allows nicer inline code specification when
|
|
# calling this method.
|
|
ind = []
|
|
lines = code.split('\n')
|
|
for line in lines:
|
|
stripped = line.lstrip()
|
|
if len(stripped) > 0:
|
|
ind.append(len(line) - len(stripped))
|
|
if len(ind) > 0:
|
|
ind = min(ind)
|
|
code = '\n'.join([line[ind:] for line in lines])
|
|
|
|
self.text.clear()
|
|
self.text.insertPlainText(code)
|
|
|
|
def code(self):
|
|
return self.text.toPlainText()
|
|
|
|
def process(self, display=True, **args):
|
|
l = locals()
|
|
l.update(args)
|
|
## try eval first, then exec
|
|
try:
|
|
text = self.text.toPlainText().replace('\n', ' ')
|
|
output = eval(text, globals(), l)
|
|
except SyntaxError:
|
|
fn = "def fn(**args):\n"
|
|
run = "\noutput=fn(**args)\n"
|
|
text = fn + "\n".join([" "+l for l in self.text.toPlainText().split('\n')]) + run
|
|
ldict = locals()
|
|
exec(text, globals(), ldict)
|
|
output = ldict['output']
|
|
except:
|
|
print(f"Error processing node: {self.name()}")
|
|
raise
|
|
return output
|
|
|
|
def saveState(self):
|
|
state = Node.saveState(self)
|
|
state['text'] = self.text.toPlainText()
|
|
#state['terminals'] = self.saveTerminals()
|
|
return state
|
|
|
|
def restoreState(self, state):
|
|
Node.restoreState(self, state)
|
|
self.setCode(state['text'])
|
|
self.restoreTerminals(state['terminals'])
|
|
self.update()
|
|
|
|
|
|
class ColumnJoinNode(Node):
|
|
"""Concatenates record arrays and/or adds new columns"""
|
|
nodeName = 'ColumnJoin'
|
|
|
|
def __init__(self, name):
|
|
Node.__init__(self, name, terminals = {
|
|
'output': {'io': 'out'},
|
|
})
|
|
|
|
#self.items = []
|
|
|
|
self.ui = QtGui.QWidget()
|
|
self.layout = QtGui.QGridLayout()
|
|
self.ui.setLayout(self.layout)
|
|
|
|
self.tree = TreeWidget()
|
|
self.addInBtn = QtGui.QPushButton('+ Input')
|
|
self.remInBtn = QtGui.QPushButton('- Input')
|
|
|
|
self.layout.addWidget(self.tree, 0, 0, 1, 2)
|
|
self.layout.addWidget(self.addInBtn, 1, 0)
|
|
self.layout.addWidget(self.remInBtn, 1, 1)
|
|
|
|
self.addInBtn.clicked.connect(self.addInput)
|
|
self.remInBtn.clicked.connect(self.remInput)
|
|
self.tree.sigItemMoved.connect(self.update)
|
|
|
|
def ctrlWidget(self):
|
|
return self.ui
|
|
|
|
def addInput(self):
|
|
#print "ColumnJoinNode.addInput called."
|
|
term = Node.addInput(self, 'input', renamable=True, removable=True, multiable=True)
|
|
#print "Node.addInput returned. term:", term
|
|
item = QtGui.QTreeWidgetItem([term.name()])
|
|
item.term = term
|
|
term.joinItem = item
|
|
#self.items.append((term, item))
|
|
self.tree.addTopLevelItem(item)
|
|
|
|
def remInput(self):
|
|
sel = self.tree.currentItem()
|
|
term = sel.term
|
|
term.joinItem = None
|
|
sel.term = None
|
|
self.tree.removeTopLevelItem(sel)
|
|
self.removeTerminal(term)
|
|
self.update()
|
|
|
|
def process(self, display=True, **args):
|
|
order = self.order()
|
|
vals = []
|
|
for name in order:
|
|
if name not in args:
|
|
continue
|
|
val = args[name]
|
|
if isinstance(val, np.ndarray) and len(val.dtype) > 0:
|
|
vals.append(val)
|
|
else:
|
|
vals.append((name, None, val))
|
|
return {'output': functions.concatenateColumns(vals)}
|
|
|
|
def order(self):
|
|
return [str(self.tree.topLevelItem(i).text(0)) for i in range(self.tree.topLevelItemCount())]
|
|
|
|
def saveState(self):
|
|
state = Node.saveState(self)
|
|
state['order'] = self.order()
|
|
return state
|
|
|
|
def restoreState(self, state):
|
|
Node.restoreState(self, state)
|
|
inputs = self.inputs()
|
|
|
|
## Node.restoreState should have created all of the terminals we need
|
|
## However: to maintain support for some older flowchart files, we need
|
|
## to manually add any terminals that were not taken care of.
|
|
for name in [n for n in state['order'] if n not in inputs]:
|
|
Node.addInput(self, name, renamable=True, removable=True, multiable=True)
|
|
inputs = self.inputs()
|
|
|
|
order = [name for name in state['order'] if name in inputs]
|
|
for name in inputs:
|
|
if name not in order:
|
|
order.append(name)
|
|
|
|
self.tree.clear()
|
|
for name in order:
|
|
term = self[name]
|
|
item = QtGui.QTreeWidgetItem([name])
|
|
item.term = term
|
|
term.joinItem = item
|
|
#self.items.append((term, item))
|
|
self.tree.addTopLevelItem(item)
|
|
|
|
def terminalRenamed(self, term, oldName):
|
|
Node.terminalRenamed(self, term, oldName)
|
|
item = term.joinItem
|
|
item.setText(0, term.name())
|
|
self.update()
|
|
|
|
|
|
class Mean(CtrlNode):
|
|
"""Calculate the mean of an array across an axis.
|
|
"""
|
|
nodeName = 'Mean'
|
|
uiTemplate = [
|
|
('axis', 'intSpin', {'value': 0, 'min': -1, 'max': 1000000}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
ax = None if s['axis'] == -1 else s['axis']
|
|
return data.mean(axis=ax)
|
|
|
|
|
|
class Max(CtrlNode):
|
|
"""Calculate the maximum of an array across an axis.
|
|
"""
|
|
nodeName = 'Max'
|
|
uiTemplate = [
|
|
('axis', 'intSpin', {'value': 0, 'min': -1, 'max': 1000000}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
ax = None if s['axis'] == -1 else s['axis']
|
|
return data.max(axis=ax)
|
|
|
|
|
|
class Min(CtrlNode):
|
|
"""Calculate the minimum of an array across an axis.
|
|
"""
|
|
nodeName = 'Min'
|
|
uiTemplate = [
|
|
('axis', 'intSpin', {'value': 0, 'min': -1, 'max': 1000000}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
ax = None if s['axis'] == -1 else s['axis']
|
|
return data.min(axis=ax)
|
|
|
|
|
|
class Stdev(CtrlNode):
|
|
"""Calculate the standard deviation of an array across an axis.
|
|
"""
|
|
nodeName = 'Stdev'
|
|
uiTemplate = [
|
|
('axis', 'intSpin', {'value': -0, 'min': -1, 'max': 1000000}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
ax = None if s['axis'] == -1 else s['axis']
|
|
return data.std(axis=ax)
|
|
|
|
|
|
class Index(CtrlNode):
|
|
"""Select an index from an array axis.
|
|
"""
|
|
nodeName = 'Index'
|
|
uiTemplate = [
|
|
('axis', 'intSpin', {'value': 0, 'min': 0, 'max': 1000000}),
|
|
('index', 'intSpin', {'value': 0, 'min': 0, 'max': 1000000}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
ax = s['axis']
|
|
ind = s['index']
|
|
if ax == 0:
|
|
# allow support for non-ndarray sequence types
|
|
return data[ind]
|
|
else:
|
|
return data.take(ind, axis=ax)
|
|
|
|
|
|
class Slice(CtrlNode):
|
|
"""Select a slice from an array axis.
|
|
"""
|
|
nodeName = 'Slice'
|
|
uiTemplate = [
|
|
('axis', 'intSpin', {'value': 0, 'min': 0, 'max': 1e6}),
|
|
('start', 'intSpin', {'value': 0, 'min': -1e6, 'max': 1e6}),
|
|
('stop', 'intSpin', {'value': -1, 'min': -1e6, 'max': 1e6}),
|
|
('step', 'intSpin', {'value': 1, 'min': -1e6, 'max': 1e6}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
ax = s['axis']
|
|
start = s['start']
|
|
stop = s['stop']
|
|
step = s['step']
|
|
if ax == 0:
|
|
# allow support for non-ndarray sequence types
|
|
return data[start:stop:step]
|
|
else:
|
|
sl = [slice(None) for i in range(data.ndim)]
|
|
sl[ax] = slice(start, stop, step)
|
|
return data[sl]
|
|
|
|
|
|
class AsType(CtrlNode):
|
|
"""Convert an array to a different dtype.
|
|
"""
|
|
nodeName = 'AsType'
|
|
uiTemplate = [
|
|
('dtype', 'combo', {'values': ['float', 'int', 'float32', 'float64', 'float128', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], 'index': 0}),
|
|
]
|
|
|
|
def processData(self, data):
|
|
s = self.stateGroup.state()
|
|
return data.astype(s['dtype'])
|