From 013de4380871e8486b1f67630794e6728bd49666 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 18 Apr 2024 16:18:04 +0200 Subject: [PATCH] Use structural pattern matching --- src/rod/kinematics/tree_transforms.py | 90 ++++++------- src/rod/utils/frame_convention.py | 185 +++++++++++++------------- 2 files changed, 140 insertions(+), 135 deletions(-) diff --git a/src/rod/kinematics/tree_transforms.py b/src/rod/kinematics/tree_transforms.py index 2783d0d..b5bb337 100644 --- a/src/rod/kinematics/tree_transforms.py +++ b/src/rod/kinematics/tree_transforms.py @@ -31,51 +31,51 @@ def build( ) def transform(self, name: str) -> npt.NDArray: - if name == TreeFrame.WORLD: - return np.eye(4) - - if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: - relative_to = self.kinematic_tree.model.pose.relative_to - assert relative_to in {None, ""}, (relative_to, name) - return self.kinematic_tree.model.pose.transform() - - if name in self.kinematic_tree.joint_names(): - edge = self.kinematic_tree.joints_dict[name] - assert edge.name() == name - - # Get the pose of the frame in which the node's pose is expressed - assert edge._source.pose.relative_to not in {"", None} - x_H_E = edge._source.pose.transform() - W_H_x = self.transform(name=edge._source.pose.relative_to) - - # Compute the world-to-node transform - # TODO: this assumes all joint positions to be 0 - W_H_E = W_H_x @ x_H_E - - return W_H_E - - if ( - name in self.kinematic_tree.link_names() - or name in self.kinematic_tree.frame_names() - ): - element = ( - self.kinematic_tree.links_dict[name] - if name in self.kinematic_tree.link_names() - else self.kinematic_tree.frames_dict[name] - ) - assert element.name() == name - - # Get the pose of the frame in which the node's pose is expressed - assert element._source.pose.relative_to not in {"", None} - x_H_N = element._source.pose.transform() - W_H_x = self.transform(name=element._source.pose.relative_to) - - # Compute and cache the world-to-node transform - W_H_N = W_H_x @ x_H_N - - return W_H_N - - raise ValueError(name) + match name: + case TreeFrame.WORLD: + return np.eye(4) + + case name if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: + relative_to = self.kinematic_tree.model.pose.relative_to + assert relative_to in {None, ""}, (relative_to, name) + return self.kinematic_tree.model.pose.transform() + + case name if name in self.kinematic_tree.joint_names(): + edge = self.kinematic_tree.joints_dict[name] + assert edge.name() == name + + # Get the pose of the frame in which the node's pose is expressed + assert edge._source.pose.relative_to not in {"", None} + x_H_E = edge._source.pose.transform() + W_H_x = self.transform(name=edge._source.pose.relative_to) + + # Compute the world-to-node transform + # TODO: this assumes all joint positions to be 0 + W_H_E = W_H_x @ x_H_E + + return W_H_E + + case ( + name + ) if name in self.kinematic_tree.link_names() or name in self.kinematic_tree.frame_names(): + element = ( + self.kinematic_tree.links_dict[name] + if name in self.kinematic_tree.link_names() + else self.kinematic_tree.frames_dict[name] + ) + assert element.name() == name + + # Get the pose of the frame in which the node's pose is expressed + assert element._source.pose.relative_to not in {"", None} + x_H_N = element._source.pose.transform() + W_H_x = self.transform(name=element._source.pose.relative_to) + + # Compute and cache the world-to-node transform + W_H_N = W_H_x @ x_H_N + + return W_H_N + case _: + raise ValueError(name) def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: return np.linalg.inv(self.transform(name=relative_to)) @ self.transform( diff --git a/src/rod/utils/frame_convention.py b/src/rod/utils/frame_convention.py index 8f2eba8..85ac89b 100644 --- a/src/rod/utils/frame_convention.py +++ b/src/rod/utils/frame_convention.py @@ -24,97 +24,102 @@ def switch_frame_convention( # Define the default reference frames of the different elements # ============================================================= - if frame_convention is FrameConvention.World: - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "world" - reference_frame_frames = lambda f: "world" - reference_frame_joints = lambda j: "world" - reference_frame_visuals = lambda v: "world" - reference_frame_inertials = lambda i, parent_link: "world" - reference_frame_collisions = lambda c: "world" - reference_frame_link_canonical = "world" - - elif frame_convention is FrameConvention.Model: - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "__model__" - reference_frame_frames = lambda f: "__model__" - reference_frame_joints = lambda j: "__model__" - reference_frame_visuals = lambda v: "__model__" - reference_frame_inertials = lambda i, parent_link: "__model__" - reference_frame_collisions = lambda c: "__model__" - reference_frame_link_canonical = "__model__" - - elif frame_convention is FrameConvention.Sdf: - visual_name_to_parent_link = { - visual_name: parent_link - for d in [{v.name: link for v in link.visuals()} for link in model.links()] - for visual_name, parent_link in d.items() - } - - collision_name_to_parent_link = { - collision_name: parent_link - for d in [ - {c.name: link for c in link.collisions()} for link in model.links() - ] - for collision_name, parent_link in d.items() - } - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "__model__" - reference_frame_frames = lambda f: "__model__" - reference_frame_joints = lambda j: joint.child - reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name - reference_frame_inertials = lambda i, parent_link: parent_link.name - reference_frame_collisions = lambda c: collision_name_to_parent_link[ - c.name - ].name - reference_frame_link_canonical = "__model__" - - elif frame_convention is FrameConvention.Urdf: - visual_name_to_parent_link = { - visual_name: parent_link - for d in [{v.name: link for v in link.visuals()} for link in model.links()] - for visual_name, parent_link in d.items() - } - - collision_name_to_parent_link = { - collision_name: parent_link - for d in [ - {c.name: link for c in link.collisions()} for link in model.links() - ] - for collision_name, parent_link in d.items() - } - - link_name_to_parent_joint_names = defaultdict(list) - - for j in model.joints(): - if j.child != model.get_canonical_link(): - link_name_to_parent_joint_names[j.child].append(j.name) - else: - # The pose of the canonical link is used to define the origin of - # the URDF joint connecting the world to the robot - assert model.is_fixed_base() - link_name_to_parent_joint_names[j.child].append("world") - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0] - reference_frame_frames = lambda f: "__model__" - reference_frame_joints = lambda j: j.parent - reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name - reference_frame_inertials = lambda i, parent_link: parent_link.name - reference_frame_collisions = lambda c: collision_name_to_parent_link[ - c.name - ].name - - if model.is_fixed_base(): - canonical_link = {l.name: l for l in model.links()}[ - model.get_canonical_link() - ] - reference_frame_link_canonical = reference_frame_links(l=canonical_link) - else: + match frame_convention: + case FrameConvention.World: + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: "world" + reference_frame_frames = lambda f: "world" + reference_frame_joints = lambda j: "world" + reference_frame_visuals = lambda v: "world" + reference_frame_inertials = lambda i, parent_link: "world" + reference_frame_collisions = lambda c: "world" + reference_frame_link_canonical = "world" + + case FrameConvention.Model: + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: "__model__" + reference_frame_frames = lambda f: "__model__" + reference_frame_joints = lambda j: "__model__" + reference_frame_visuals = lambda v: "__model__" + reference_frame_inertials = lambda i, parent_link: "__model__" + reference_frame_collisions = lambda c: "__model__" reference_frame_link_canonical = "__model__" - else: - raise ValueError(frame_convention) + + case FrameConvention.Sdf: + visual_name_to_parent_link = { + visual_name: parent_link + for d in [ + {v.name: link for v in link.visuals()} for link in model.links() + ] + for visual_name, parent_link in d.items() + } + + collision_name_to_parent_link = { + collision_name: parent_link + for d in [ + {c.name: link for c in link.collisions()} for link in model.links() + ] + for collision_name, parent_link in d.items() + } + + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: "__model__" + reference_frame_frames = lambda f: "__model__" + reference_frame_joints = lambda j: joint.child + reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name + reference_frame_inertials = lambda i, parent_link: parent_link.name + reference_frame_collisions = lambda c: collision_name_to_parent_link[ + c.name + ].name + reference_frame_link_canonical = "__model__" + + case FrameConvention.Urdf: + visual_name_to_parent_link = { + visual_name: parent_link + for d in [ + {v.name: link for v in link.visuals()} for link in model.links() + ] + for visual_name, parent_link in d.items() + } + + collision_name_to_parent_link = { + collision_name: parent_link + for d in [ + {c.name: link for c in link.collisions()} for link in model.links() + ] + for collision_name, parent_link in d.items() + } + + link_name_to_parent_joint_names = defaultdict(list) + + for j in model.joints(): + if j.child != model.get_canonical_link(): + link_name_to_parent_joint_names[j.child].append(j.name) + else: + # The pose of the canonical link is used to define the origin of + # the URDF joint connecting the world to the robot + assert model.is_fixed_base() + link_name_to_parent_joint_names[j.child].append("world") + + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0] + reference_frame_frames = lambda f: "__model__" + reference_frame_joints = lambda j: j.parent + reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name + reference_frame_inertials = lambda i, parent_link: parent_link.name + reference_frame_collisions = lambda c: collision_name_to_parent_link[ + c.name + ].name + + if model.is_fixed_base(): + canonical_link = {l.name: l for l in model.links()}[ + model.get_canonical_link() + ] + reference_frame_link_canonical = reference_frame_links(l=canonical_link) + else: + reference_frame_link_canonical = "__model__" + case _: + raise ValueError(frame_convention) # ========================================= # Process the reference frames of the model