Skip to content

Commit

Permalink
Use structural pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Apr 18, 2024
1 parent f63965c commit 013de43
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 135 deletions.
90 changes: 45 additions & 45 deletions src/rod/kinematics/tree_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,51 +31,51 @@ def build(
)

def transform(self, name: str) -> npt.NDArray:
if name == TreeFrame.WORLD:
return np.eye(4)

if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}:
relative_to = self.kinematic_tree.model.pose.relative_to
assert relative_to in {None, ""}, (relative_to, name)
return self.kinematic_tree.model.pose.transform()

if name in self.kinematic_tree.joint_names():
edge = self.kinematic_tree.joints_dict[name]
assert edge.name() == name

# Get the pose of the frame in which the node's pose is expressed
assert edge._source.pose.relative_to not in {"", None}
x_H_E = edge._source.pose.transform()
W_H_x = self.transform(name=edge._source.pose.relative_to)

# Compute the world-to-node transform
# TODO: this assumes all joint positions to be 0
W_H_E = W_H_x @ x_H_E

return W_H_E

if (
name in self.kinematic_tree.link_names()
or name in self.kinematic_tree.frame_names()
):
element = (
self.kinematic_tree.links_dict[name]
if name in self.kinematic_tree.link_names()
else self.kinematic_tree.frames_dict[name]
)
assert element.name() == name

# Get the pose of the frame in which the node's pose is expressed
assert element._source.pose.relative_to not in {"", None}
x_H_N = element._source.pose.transform()
W_H_x = self.transform(name=element._source.pose.relative_to)

# Compute and cache the world-to-node transform
W_H_N = W_H_x @ x_H_N

return W_H_N

raise ValueError(name)
match name:
case TreeFrame.WORLD:
return np.eye(4)

case name if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}:
relative_to = self.kinematic_tree.model.pose.relative_to
assert relative_to in {None, ""}, (relative_to, name)
return self.kinematic_tree.model.pose.transform()

case name if name in self.kinematic_tree.joint_names():
edge = self.kinematic_tree.joints_dict[name]
assert edge.name() == name

# Get the pose of the frame in which the node's pose is expressed
assert edge._source.pose.relative_to not in {"", None}
x_H_E = edge._source.pose.transform()
W_H_x = self.transform(name=edge._source.pose.relative_to)

# Compute the world-to-node transform
# TODO: this assumes all joint positions to be 0
W_H_E = W_H_x @ x_H_E

return W_H_E

case (
name
) if name in self.kinematic_tree.link_names() or name in self.kinematic_tree.frame_names():
element = (
self.kinematic_tree.links_dict[name]
if name in self.kinematic_tree.link_names()
else self.kinematic_tree.frames_dict[name]
)
assert element.name() == name

# Get the pose of the frame in which the node's pose is expressed
assert element._source.pose.relative_to not in {"", None}
x_H_N = element._source.pose.transform()
W_H_x = self.transform(name=element._source.pose.relative_to)

# Compute and cache the world-to-node transform
W_H_N = W_H_x @ x_H_N

return W_H_N
case _:
raise ValueError(name)

def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
return np.linalg.inv(self.transform(name=relative_to)) @ self.transform(
Expand Down
185 changes: 95 additions & 90 deletions src/rod/utils/frame_convention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,97 +24,102 @@ def switch_frame_convention(
# Define the default reference frames of the different elements
# =============================================================

if frame_convention is FrameConvention.World:
reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: "world"
reference_frame_frames = lambda f: "world"
reference_frame_joints = lambda j: "world"
reference_frame_visuals = lambda v: "world"
reference_frame_inertials = lambda i, parent_link: "world"
reference_frame_collisions = lambda c: "world"
reference_frame_link_canonical = "world"

elif frame_convention is FrameConvention.Model:
reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: "__model__"
reference_frame_frames = lambda f: "__model__"
reference_frame_joints = lambda j: "__model__"
reference_frame_visuals = lambda v: "__model__"
reference_frame_inertials = lambda i, parent_link: "__model__"
reference_frame_collisions = lambda c: "__model__"
reference_frame_link_canonical = "__model__"

elif frame_convention is FrameConvention.Sdf:
visual_name_to_parent_link = {
visual_name: parent_link
for d in [{v.name: link for v in link.visuals()} for link in model.links()]
for visual_name, parent_link in d.items()
}

collision_name_to_parent_link = {
collision_name: parent_link
for d in [
{c.name: link for c in link.collisions()} for link in model.links()
]
for collision_name, parent_link in d.items()
}

reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: "__model__"
reference_frame_frames = lambda f: "__model__"
reference_frame_joints = lambda j: joint.child
reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name
reference_frame_inertials = lambda i, parent_link: parent_link.name
reference_frame_collisions = lambda c: collision_name_to_parent_link[
c.name
].name
reference_frame_link_canonical = "__model__"

elif frame_convention is FrameConvention.Urdf:
visual_name_to_parent_link = {
visual_name: parent_link
for d in [{v.name: link for v in link.visuals()} for link in model.links()]
for visual_name, parent_link in d.items()
}

collision_name_to_parent_link = {
collision_name: parent_link
for d in [
{c.name: link for c in link.collisions()} for link in model.links()
]
for collision_name, parent_link in d.items()
}

link_name_to_parent_joint_names = defaultdict(list)

for j in model.joints():
if j.child != model.get_canonical_link():
link_name_to_parent_joint_names[j.child].append(j.name)
else:
# The pose of the canonical link is used to define the origin of
# the URDF joint connecting the world to the robot
assert model.is_fixed_base()
link_name_to_parent_joint_names[j.child].append("world")

reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0]
reference_frame_frames = lambda f: "__model__"
reference_frame_joints = lambda j: j.parent
reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name
reference_frame_inertials = lambda i, parent_link: parent_link.name
reference_frame_collisions = lambda c: collision_name_to_parent_link[
c.name
].name

if model.is_fixed_base():
canonical_link = {l.name: l for l in model.links()}[
model.get_canonical_link()
]
reference_frame_link_canonical = reference_frame_links(l=canonical_link)
else:
match frame_convention:
case FrameConvention.World:
reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: "world"
reference_frame_frames = lambda f: "world"
reference_frame_joints = lambda j: "world"
reference_frame_visuals = lambda v: "world"
reference_frame_inertials = lambda i, parent_link: "world"
reference_frame_collisions = lambda c: "world"
reference_frame_link_canonical = "world"

case FrameConvention.Model:
reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: "__model__"
reference_frame_frames = lambda f: "__model__"
reference_frame_joints = lambda j: "__model__"
reference_frame_visuals = lambda v: "__model__"
reference_frame_inertials = lambda i, parent_link: "__model__"
reference_frame_collisions = lambda c: "__model__"
reference_frame_link_canonical = "__model__"
else:
raise ValueError(frame_convention)

case FrameConvention.Sdf:
visual_name_to_parent_link = {
visual_name: parent_link
for d in [
{v.name: link for v in link.visuals()} for link in model.links()
]
for visual_name, parent_link in d.items()
}

collision_name_to_parent_link = {
collision_name: parent_link
for d in [
{c.name: link for c in link.collisions()} for link in model.links()
]
for collision_name, parent_link in d.items()
}

reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: "__model__"
reference_frame_frames = lambda f: "__model__"
reference_frame_joints = lambda j: joint.child
reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name
reference_frame_inertials = lambda i, parent_link: parent_link.name
reference_frame_collisions = lambda c: collision_name_to_parent_link[
c.name
].name
reference_frame_link_canonical = "__model__"

case FrameConvention.Urdf:
visual_name_to_parent_link = {
visual_name: parent_link
for d in [
{v.name: link for v in link.visuals()} for link in model.links()
]
for visual_name, parent_link in d.items()
}

collision_name_to_parent_link = {
collision_name: parent_link
for d in [
{c.name: link for c in link.collisions()} for link in model.links()
]
for collision_name, parent_link in d.items()
}

link_name_to_parent_joint_names = defaultdict(list)

for j in model.joints():
if j.child != model.get_canonical_link():
link_name_to_parent_joint_names[j.child].append(j.name)
else:
# The pose of the canonical link is used to define the origin of
# the URDF joint connecting the world to the robot
assert model.is_fixed_base()
link_name_to_parent_joint_names[j.child].append("world")

reference_frame_model = lambda m: "world"
reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0]
reference_frame_frames = lambda f: "__model__"
reference_frame_joints = lambda j: j.parent
reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name
reference_frame_inertials = lambda i, parent_link: parent_link.name
reference_frame_collisions = lambda c: collision_name_to_parent_link[
c.name
].name

if model.is_fixed_base():
canonical_link = {l.name: l for l in model.links()}[
model.get_canonical_link()
]
reference_frame_link_canonical = reference_frame_links(l=canonical_link)
else:
reference_frame_link_canonical = "__model__"
case _:
raise ValueError(frame_convention)

# =========================================
# Process the reference frames of the model
Expand Down

0 comments on commit 013de43

Please sign in to comment.