Skip to content

Commit

Permalink
Merge pull request #46 from OSUrobotics/behavior_cloning_state_dim
Browse files Browse the repository at this point in the history
Behavior cloning state dim
  • Loading branch information
jimzers authored Aug 13, 2021
2 parents 9d54e01 + b08ce60 commit 5333116
Show file tree
Hide file tree
Showing 19 changed files with 1,295 additions and 905 deletions.
41 changes: 39 additions & 2 deletions gym-kinova-gripper/DDPGfD.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def forward(self, state, action):

class DDPGfD(object):
def __init__(self, state_dim=82, action_dim=3, max_action=3, n=5, discount=0.995, tau=0.0005, batch_size=64, expert_sampling_proportion=0.7):
print('================================ INITTING DDPGfD with state dim of: ', state_dim,
'===============================')
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4)
Expand Down Expand Up @@ -92,7 +94,7 @@ def select_action(self, state):
return self.actor(state).cpu().data.numpy().flatten()


def train(self, episode_step, expert_replay_buffer, replay_buffer=None, prob=0.7):
def train(self, episode_step, expert_replay_buffer, replay_buffer=None, prob=0.7, mod_state_idx=np.arange(82)):
""" Update policy based on full trajectory of one episode """
self.total_it += 1

Expand Down Expand Up @@ -129,6 +131,21 @@ def train(self, episode_step, expert_replay_buffer, replay_buffer=None, prob=0.7
print("IN OG TRAIN: lift_reward_count: ", lift_reward_count)
"""


# # TODO: STATE DIM CODE, FOR NON-BATCHING TRAINING. NEEDS TO BE TESTED BEFORE USE. COMMENTED OUT UNTIL TESTED.
# print('=============== Start printing - BATCH training =======================')
# print('=============== Before modification =======================')
# print('state dimensions: ', state.shape)
# print('next state dimensions: ', next_state.shape)
# # modify state dimensions
# state = state[:, mod_state_idx]
# next_state = next_state[:, mod_state_idx]
#
# print('=============== After modification =======================')
# print('state dimensions: ', state.shape)
# print('next state dimensions: ', next_state.shape)
# print('=============== End printing - BATCH training =======================')

# Target Q network
#print("Target Q")
target_Q = self.critic_target(next_state, self.actor_target(next_state))
Expand Down Expand Up @@ -235,7 +252,7 @@ def train(self, episode_step, expert_replay_buffer, replay_buffer=None, prob=0.7
return actor_loss.item(), critic_loss.item(), critic_L1loss.item(), critic_LNloss.item()


def train_batch(self, max_episode_num, episode_num, update_count, expert_replay_buffer, replay_buffer):
def train_batch(self, max_episode_num, episode_num, update_count, expert_replay_buffer, replay_buffer, mod_state_idx=np.arange(82)):
""" Update policy networks based on batch_size of episodes using n-step returns """
self.total_it += 1
agent_batch_size = 0
Expand Down Expand Up @@ -283,9 +300,29 @@ def train_batch(self, max_episode_num, episode_num, update_count, expert_replay_
reward = reward.unsqueeze(0)
not_done = not_done.unsqueeze(0)

expert_state = expert_state[:, :, mod_state_idx]
# expert_next_state = expert_next_state[:, :, mod_state_idx]

reward = reward.unsqueeze(-1)
not_done = not_done.unsqueeze(-1)

# STATE DIMENSION MODIFICATION
# print('=============== Start printing - BATCH training =======================')
# print('=============== Before modification =======================')
# print('state dimensions: ', state.shape)
# print('next state dimensions: ', next_state.shape)
# print('sanity check - mod_state_idx length: ', len(mod_state_idx))

# modify state dimensions
state = state[:, :, mod_state_idx]
next_state = next_state[:, :, mod_state_idx]


# print('=============== After modification =======================')
# print('state dimensions: ', state.shape)
# print('next state dimensions: ', next_state.shape)
# print('=============== End printing - BATCH training =======================')

### FOR TESTING:
#assert_batch_size = self.batch_size * num_trajectories
num_timesteps_sampled = len(reward)
Expand Down

Large diffs are not rendered by default.

1,472 changes: 882 additions & 590 deletions gym-kinova-gripper/gym_kinova_gripper/envs/kinova_gripper_env.py

Large diffs are not rendered by default.

172 changes: 149 additions & 23 deletions gym-kinova-gripper/main_DDPGfD.py

Large diffs are not rendered by default.

Binary file added wiki_figures/DDPGfD_diagram.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/all_objects.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/all_possible_objects.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/grasp_trial.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/hov.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/input_variations.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/orientations.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/policy_training.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/sample_update.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/shapes_with_titles.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/sizes_of_the_object.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/touch_vel_PID_Variable_Speed.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/training_pipeline.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added wiki_figures/velocity_pid_Variable_Speed.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions wiki_figures/wiki_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GitHub wiki images

0 comments on commit 5333116

Please sign in to comment.