diff --git a/pyqtgraph/Transform3D.py b/pyqtgraph/Transform3D.py index 43b12de3..56283351 100644 --- a/pyqtgraph/Transform3D.py +++ b/pyqtgraph/Transform3D.py @@ -1,13 +1,19 @@ # -*- coding: utf-8 -*- from .Qt import QtCore, QtGui from . import functions as fn +from .Vector import Vector import numpy as np + class Transform3D(QtGui.QMatrix4x4): """ Extension of QMatrix4x4 with some helpful methods added. """ def __init__(self, *args): + if len(args) == 1 and isinstance(args[0], (list, tuple, np.ndarray)): + args = [x for y in args[0] for x in y] + if len(args) != 16: + raise TypeError("Single argument to Transform3D must have 16 elements.") QtGui.QMatrix4x4.__init__(self, *args) def matrix(self, nd=3): @@ -25,8 +31,15 @@ class Transform3D(QtGui.QMatrix4x4): """ Extends QMatrix4x4.map() to allow mapping (3, ...) arrays of coordinates """ - if isinstance(obj, np.ndarray) and obj.ndim >= 2 and obj.shape[0] in (2,3): - return fn.transformCoordinates(self, obj) + if isinstance(obj, np.ndarray) and obj.shape[0] in (2,3): + if obj.ndim >= 2: + return fn.transformCoordinates(self, obj) + elif obj.ndim == 1: + v = QtGui.QMatrix4x4.map(self, Vector(obj)) + return np.array([v.x(), v.y(), v.z()])[:obj.shape[0]] + elif isinstance(obj, (list, tuple)): + v = QtGui.QMatrix4x4.map(self, Vector(obj)) + return type(obj)([v.x(), v.y(), v.z()])[:len(obj)] else: return QtGui.QMatrix4x4.map(self, obj)