Skip to content

Commit

Permalink
update pose inverse for goalset
Browse files Browse the repository at this point in the history
  • Loading branch information
balakumar-s committed Feb 24, 2024
1 parent f25281e commit 286b382
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
8 changes: 2 additions & 6 deletions src/curobo/geom/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,6 @@ def forward(
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
):
b, _ = position.shape

if out_position is None:
out_position = torch.zeros_like(position)
Expand All @@ -951,7 +950,8 @@ def forward(
adj_position = torch.zeros_like(position)
if adj_quaternion is None:
adj_quaternion = torch.zeros_like(quaternion)

b, _ = position.view(-1, 3).shape
ctx.b = b
init_warp()
ctx.save_for_backward(
position,
Expand All @@ -961,7 +961,6 @@ def forward(
adj_position,
adj_quaternion,
)
ctx.b = b

wp.launch(
kernel=compute_pose_inverse,
Expand All @@ -976,9 +975,6 @@ def forward(
],
stream=wp.stream_from_torch(position.device),
)
# remove close to zero values:
# out_position[torch.abs(out_position)<1e-8] = 0.0
# out_quaternion[torch.abs(out_quaternion)<1e-8] = 0.0

return out_position, out_quaternion

Expand Down
4 changes: 3 additions & 1 deletion src/curobo/types/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def linear_distance(self, other_pose: Pose):

@profiler.record_function("pose/multiply")
def multiply(self, other_pose: Pose):
if self.shape == other_pose.shape or (self.shape[0] == 1 and other_pose.shape[0] > 1):
if self.shape == other_pose.shape or (
(self.shape[0] == 1 and other_pose.shape[0] > 1) and len(other_pose.shape) == 2
):
p3, q3 = pose_multiply(
self.position, self.quaternion, other_pose.position, other_pose.quaternion
)
Expand Down
50 changes: 34 additions & 16 deletions tests/motion_gen_constrained_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,29 @@ def motion_gen(request):
tensor_args = TensorDeviceType()
world_file = "collision_table.yml"
robot_file = "franka.yml"

motion_gen_config = MotionGenConfig.load_from_robot_config(
robot_file,
world_file,
tensor_args,
use_cuda_graph=True,
project_pose_to_goal_frame=request.param,
project_pose_to_goal_frame=request.param[0],
)
motion_gen_instance = MotionGen(motion_gen_config)
motion_gen_instance.warmup(enable_graph=False, warmup_js_trajopt=False)

motion_gen_instance.warmup(
enable_graph=False, warmup_js_trajopt=False, n_goalset=request.param[1]
)
return motion_gen_instance


@pytest.mark.parametrize(
"motion_gen",
[
(True),
(False),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down Expand Up @@ -77,8 +83,10 @@ def test_approach_grasp_pose(motion_gen):
@pytest.mark.parametrize(
"motion_gen",
[
(True),
(False),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down Expand Up @@ -112,8 +120,10 @@ def test_reach_only_position(motion_gen):
@pytest.mark.parametrize(
"motion_gen",
[
(True),
(False),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down Expand Up @@ -147,8 +157,10 @@ def test_reach_only_orientation(motion_gen):
@pytest.mark.parametrize(
"motion_gen",
[
(True),
(False),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down Expand Up @@ -186,8 +198,10 @@ def test_hold_orientation(motion_gen):
@pytest.mark.parametrize(
"motion_gen",
[
(True),
(False),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down Expand Up @@ -224,8 +238,10 @@ def test_hold_position(motion_gen):
@pytest.mark.parametrize(
"motion_gen",
[
(False),
(True),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down Expand Up @@ -276,8 +292,10 @@ def test_hold_partial_pose(motion_gen):
@pytest.mark.parametrize(
"motion_gen",
[
(False),
(True),
([True, -1]),
([False, -1]),
([True, 10]),
([False, 10]),
],
indirect=True,
)
Expand Down

0 comments on commit 286b382

Please sign in to comment.