Allow binary operator nodes to select output type
This commit is contained in:
parent
19fc846b90
commit
237b848837
@ -1,5 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from ..Node import Node
|
from ..Node import Node
|
||||||
|
from .common import CtrlNode
|
||||||
|
|
||||||
|
|
||||||
class UniOpNode(Node):
|
class UniOpNode(Node):
|
||||||
"""Generic node for performing any operation like Out = In.fn()"""
|
"""Generic node for performing any operation like Out = In.fn()"""
|
||||||
@ -13,11 +15,22 @@ class UniOpNode(Node):
|
|||||||
def process(self, **args):
|
def process(self, **args):
|
||||||
return {'Out': getattr(args['In'], self.fn)()}
|
return {'Out': getattr(args['In'], self.fn)()}
|
||||||
|
|
||||||
class BinOpNode(Node):
|
class BinOpNode(CtrlNode):
|
||||||
"""Generic node for performing any operation like A.fn(B)"""
|
"""Generic node for performing any operation like A.fn(B)"""
|
||||||
|
|
||||||
|
_dtypes = [
|
||||||
|
'float64', 'float32', 'float16',
|
||||||
|
'int64', 'int32', 'int16', 'int8',
|
||||||
|
'uint64', 'uint32', 'uint16', 'uint8'
|
||||||
|
]
|
||||||
|
|
||||||
|
uiTemplate = [
|
||||||
|
('outputType', 'combo', {'values': ['no change', 'input A', 'input B'] + _dtypes , 'index': 0})
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, name, fn):
|
def __init__(self, name, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
Node.__init__(self, name, terminals={
|
CtrlNode.__init__(self, name, terminals={
|
||||||
'A': {'io': 'in'},
|
'A': {'io': 'in'},
|
||||||
'B': {'io': 'in'},
|
'B': {'io': 'in'},
|
||||||
'Out': {'io': 'out', 'bypass': 'A'}
|
'Out': {'io': 'out', 'bypass': 'A'}
|
||||||
@ -36,6 +49,18 @@ class BinOpNode(Node):
|
|||||||
out = fn(args['B'])
|
out = fn(args['B'])
|
||||||
if out is NotImplemented:
|
if out is NotImplemented:
|
||||||
raise Exception("Operation %s not implemented between %s and %s" % (fn, str(type(args['A'])), str(type(args['B']))))
|
raise Exception("Operation %s not implemented between %s and %s" % (fn, str(type(args['A'])), str(type(args['B']))))
|
||||||
|
|
||||||
|
# Coerce dtype if requested
|
||||||
|
typ = self.stateGroup.state()['outputType']
|
||||||
|
if typ == 'no change':
|
||||||
|
pass
|
||||||
|
elif typ == 'input A':
|
||||||
|
out = out.astype(args['A'].dtype)
|
||||||
|
elif typ == 'input B':
|
||||||
|
out = out.astype(args['B'].dtype)
|
||||||
|
else:
|
||||||
|
out = out.astype(typ)
|
||||||
|
|
||||||
#print " ", fn, out
|
#print " ", fn, out
|
||||||
return {'Out': out}
|
return {'Out': out}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user