Skip to content

Commit

Permalink
Add rollout scripts for MujocoUR5eInsert task.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmurooka committed Jan 4, 2025
1 parent f22c196 commit 261d7a0
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 6 deletions.
2 changes: 1 addition & 1 deletion robo_manip_baselines/act/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ RuntimeError: The size of tensor a (70) must match the size of tensor b (102) at
Run a trained policy:
```console
$ python ./bin/rollout/RolloutActMujocoUR5eCable.py \
--checkpoint ./log/<demo_name>>/policy_last.ckpt \
--checkpoint ./log/<demo_name>/policy_last.ckpt \
--skip 3 --world_idx 0
```

Expand Down
9 changes: 4 additions & 5 deletions robo_manip_baselines/act/bin/TrainAct.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def train_bc(self):
min_val_loss,
deepcopy(self.policy.state_dict()),
)
summary_string = f"[TrainAct] val loss: {epoch_val_loss:.3f}"
summary_string = "[TrainAct][val]"
for k, v in epoch_summary.items():
summary_string += f", {k}: {v.item():.3f}"
summary_string += f" {k}: {v.item():.3f}"
print(summary_string)

# training
Expand All @@ -175,10 +175,9 @@ def train_bc(self):
epoch_summary = compute_dict_mean(
train_history[(batch_idx + 1) * epoch : (batch_idx + 1) * (epoch + 1)]
)
epoch_train_loss = epoch_summary["loss"]
summary_string = f"[TrainAct] train loss: {epoch_train_loss:.3f}"
summary_string = "[TrainAct][train]"
for k, v in epoch_summary.items():
summary_string += f", {k}: {v.item():.3f}"
summary_string += f" {k}: {v.item():.3f}"
print(summary_string)

if epoch % 100 == 0:
Expand Down
11 changes: 11 additions & 0 deletions robo_manip_baselines/act/bin/rollout/RolloutActMujocoUR5eInsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from robo_manip_baselines.act import RolloutAct
from robo_manip_baselines.common.rollout import RolloutMujocoUR5eInsert


class RolloutActMujocoUR5eInsert(RolloutAct, RolloutMujocoUR5eInsert):
pass


if __name__ == "__main__":
rollout = RolloutActMujocoUR5eInsert()
rollout.run()
26 changes: 26 additions & 0 deletions robo_manip_baselines/common/rollout/RolloutMujocoUR5eInsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import pinocchio as pin
import gymnasium as gym
from robo_manip_baselines.common import MotionStatus
from .RolloutBase import RolloutBase


class RolloutMujocoUR5eInsert(RolloutBase):
def setup_env(self):
self.env = gym.make(
"robo_manip_baselines/MujocoUR5eInsertEnv-v0", render_mode="human"
)

def set_arm_command(self):
if self.data_manager.status in (MotionStatus.PRE_REACH, MotionStatus.REACH):
target_pos = self.env.unwrapped.get_body_pose("peg")[0:3]
if self.data_manager.status == MotionStatus.PRE_REACH:
target_pos[2] = 1.1 # [m]
elif self.data_manager.status == MotionStatus.REACH:
target_pos[2] = 1.03 # [m]
self.motion_manager.target_se3 = pin.SE3(
np.diag([-1.0, 1.0, -1.0]), target_pos
)
self.motion_manager.inverse_kinematics()
else:
super().set_arm_command()
1 change: 1 addition & 0 deletions robo_manip_baselines/common/rollout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .RolloutMujocoUR5eRing import RolloutMujocoUR5eRing
from .RolloutMujocoUR5eParticle import RolloutMujocoUR5eParticle
from .RolloutMujocoUR5eCloth import RolloutMujocoUR5eCloth
from .RolloutMujocoUR5eInsert import RolloutMujocoUR5eInsert

## Xarm7
from .RolloutMujocoXarm7Cable import RolloutMujocoXarm7Cable
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from robo_manip_baselines.diffusion_policy import RolloutDiffusionPolicy
from robo_manip_baselines.common.rollout import RolloutMujocoUR5eInsert


class RolloutDiffusionPolicyMujocoUR5eInsert(
RolloutDiffusionPolicy, RolloutMujocoUR5eInsert
):
pass


if __name__ == "__main__":
rollout = RolloutDiffusionPolicyMujocoUR5eInsert()
rollout.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from robo_manip_baselines.sarnn import RolloutSarnn
from robo_manip_baselines.common.rollout import RolloutMujocoUR5eInsert


class RolloutSarnnMujocoUR5eInsert(RolloutSarnn, RolloutMujocoUR5eInsert):
pass


if __name__ == "__main__":
rollout = RolloutSarnnMujocoUR5eInsert()
rollout.run()

0 comments on commit 261d7a0

Please sign in to comment.