diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index bc2ceb05..08d3be26 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -43,7 +43,7 @@ tr2rt, Ab2M, ) -from spatialmath.base.quaternions import r2q, q2r, qeye, qslerp +from spatialmath.base.quaternions import r2q, q2r, qeye, qslerp, qunit from spatialmath.base.graphics import plotvol3, axes_logic from spatialmath.base.animate import Animate import spatialmath.base.symbolic as sym @@ -1675,7 +1675,7 @@ def trinterp(start, end, s, shortest=True): q1 = r2q(end) qr = qslerp(q0, q1, s, shortest=shortest) - return q2r(qr) + return q2r(qunit(qr)) elif ismatrix(end, (4, 4)): # SE(3) case @@ -1697,7 +1697,7 @@ def trinterp(start, end, s, shortest=True): qr = qslerp(q0, q1, s, shortest=shortest) pr = p0 * (1 - s) + s * p1 - return rt2tr(q2r(qr), pr) + return rt2tr(q2r(qunit(qr)), pr) else: return ValueError("Argument must be SO(3) or SE(3)") diff --git a/tests/test_pose3d.py b/tests/test_pose3d.py index 70b33ce0..fc9daf93 100755 --- a/tests/test_pose3d.py +++ b/tests/test_pose3d.py @@ -1389,6 +1389,46 @@ def test_rtvec(self): nt.assert_equal(rvec, [0, 1, 0]) nt.assert_equal(tvec, [2, 3, 4]) + def test_interp(self): + # This data is taken from https://github.com/bdaiinstitute/spatialmath-python/issues/165 + se3_1 = SE3() + se3_1.t = np.array( + [0.5705748101710814, 0.29623210833184527, 0.10764106509086407] + ) + se3_1.R = np.array( + [ + [0.2852875203191073, 0.9581330588259315, -0.024332536551692617], + [0.9582072394229962, -0.28568756930438033, -0.014882844564011068], + [-0.021211248608609852, -0.019069722856395098, -0.9995931315303468], + ] + ) + assert SE3.isvalid(se3_1.A) + + se3_2 = SE3() + se3_2.t = np.array( + [0.5150284150005691, 0.25796537207802533, 0.1558725490743694] + ) + se3_2.R = np.array( + [ + [0.42058255728234184, 0.9064420651629983, -0.038380919906699236], + [0.9070822373513454, -0.4209501599465646, -0.0016665901233428627], + [-0.01766712176680449, -0.0341137119645545, -0.9992617912561634], + ] + ) + assert SE3.isvalid(se3_2.A) + + path_se3 = se3_1.interp(end=se3_2, s=15, shortest=False) + + angle = None + for i in range(len(path_se3) - 1): + assert SE3.isvalid(path_se3[i].A) + + if angle is None: + angle = path_se3[i].angdist(path_se3[i + 1]) + else: + test_angle = path_se3[i].angdist(path_se3[i + 1]) + assert abs(test_angle - angle) < 1e-6 + # ---------------------------------------------------------------------------------------# if __name__ == "__main__":