Skip to content

Commit

Permalink
#165 use qunit in trinterp (#166)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Yeatman <myeatman@theaiinstitute.com>
  • Loading branch information
tweng-bdai and myeatman-bdai authored Feb 19, 2025
1 parent 4c68fa9 commit 550d6fb
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
6 changes: 3 additions & 3 deletions spatialmath/base/transforms3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
tr2rt,
Ab2M,
)
from spatialmath.base.quaternions import r2q, q2r, qeye, qslerp
from spatialmath.base.quaternions import r2q, q2r, qeye, qslerp, qunit
from spatialmath.base.graphics import plotvol3, axes_logic
from spatialmath.base.animate import Animate
import spatialmath.base.symbolic as sym
Expand Down Expand Up @@ -1675,7 +1675,7 @@ def trinterp(start, end, s, shortest=True):
q1 = r2q(end)
qr = qslerp(q0, q1, s, shortest=shortest)

return q2r(qr)
return q2r(qunit(qr))

elif ismatrix(end, (4, 4)):
# SE(3) case
Expand All @@ -1697,7 +1697,7 @@ def trinterp(start, end, s, shortest=True):
qr = qslerp(q0, q1, s, shortest=shortest)
pr = p0 * (1 - s) + s * p1

return rt2tr(q2r(qr), pr)
return rt2tr(q2r(qunit(qr)), pr)
else:
return ValueError("Argument must be SO(3) or SE(3)")

Expand Down
40 changes: 40 additions & 0 deletions tests/test_pose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,46 @@ def test_rtvec(self):
nt.assert_equal(rvec, [0, 1, 0])
nt.assert_equal(tvec, [2, 3, 4])

def test_interp(self):
# This data is taken from https://github.com/bdaiinstitute/spatialmath-python/issues/165
se3_1 = SE3()
se3_1.t = np.array(
[0.5705748101710814, 0.29623210833184527, 0.10764106509086407]
)
se3_1.R = np.array(
[
[0.2852875203191073, 0.9581330588259315, -0.024332536551692617],
[0.9582072394229962, -0.28568756930438033, -0.014882844564011068],
[-0.021211248608609852, -0.019069722856395098, -0.9995931315303468],
]
)
assert SE3.isvalid(se3_1.A)

se3_2 = SE3()
se3_2.t = np.array(
[0.5150284150005691, 0.25796537207802533, 0.1558725490743694]
)
se3_2.R = np.array(
[
[0.42058255728234184, 0.9064420651629983, -0.038380919906699236],
[0.9070822373513454, -0.4209501599465646, -0.0016665901233428627],
[-0.01766712176680449, -0.0341137119645545, -0.9992617912561634],
]
)
assert SE3.isvalid(se3_2.A)

path_se3 = se3_1.interp(end=se3_2, s=15, shortest=False)

angle = None
for i in range(len(path_se3) - 1):
assert SE3.isvalid(path_se3[i].A)

if angle is None:
angle = path_se3[i].angdist(path_se3[i + 1])
else:
test_angle = path_se3[i].angdist(path_se3[i + 1])
assert abs(test_angle - angle) < 1e-6


# ---------------------------------------------------------------------------------------#
if __name__ == "__main__":
Expand Down

0 comments on commit 550d6fb

Please sign in to comment.