diff --git a/pyqtgraph/functions.py b/pyqtgraph/functions.py index 62df1ce3..74d1f8a5 100644 --- a/pyqtgraph/functions.py +++ b/pyqtgraph/functions.py @@ -568,16 +568,25 @@ def transformCoordinates(tr, coords, transpose=False): def solve3DTransform(points1, points2): """ Find a 3D transformation matrix that maps points1 onto points2. - Points must be specified as a list of 4 Vectors. + Points must be specified as either lists of 4 Vectors or + (4, 3) arrays. """ import numpy.linalg - A = np.array([[points1[i].x(), points1[i].y(), points1[i].z(), 1] for i in range(4)]) - B = np.array([[points2[i].x(), points2[i].y(), points2[i].z(), 1] for i in range(4)]) + pts = [] + for inp in (points1, points2): + if isinstance(inp, np.ndarray): + A = np.empty((4,4), dtype=float) + A[:,:3] = inp[:,:3] + A[:,3] = 1.0 + else: + A = np.array([[inp[i].x(), inp[i].y(), inp[i].z(), 1] for i in range(4)]) + pts.append(A) ## solve 3 sets of linear equations to determine transformation matrix elements matrix = np.zeros((4,4)) for i in range(3): - matrix[i] = numpy.linalg.solve(A, B[:,i]) ## solve Ax = B; x is one row of the desired transformation matrix + ## solve Ax = B; x is one row of the desired transformation matrix + matrix[i] = numpy.linalg.solve(pts[0], pts[1][:,i]) return matrix diff --git a/pyqtgraph/tests/test_functions.py b/pyqtgraph/tests/test_functions.py new file mode 100644 index 00000000..da9beca2 --- /dev/null +++ b/pyqtgraph/tests/test_functions.py @@ -0,0 +1,21 @@ +import pyqtgraph as pg +import numpy as np +from numpy.testing import assert_array_almost_equal, assert_almost_equal + +np.random.seed(12345) + +def testSolve3D(): + p1 = np.array([[0,0,0,1], + [1,0,0,1], + [0,1,0,1], + [0,0,1,1]], dtype=float) + + # transform points through random matrix + tr = np.random.normal(size=(4, 4)) + tr[3] = (0,0,0,1) + p2 = np.dot(tr, p1.T).T[:,:3] + + # solve to see if we can recover the transformation matrix. + tr2 = pg.solve3DTransform(p1, p2) + + assert_array_almost_equal(tr[:3], tr2[:3])