Skip to content

Commit

Permalink
1D rpm kin PPO example tested
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Nov 20, 2023
1 parent ec44b64 commit b5d3549
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 23 deletions.
4 changes: 2 additions & 2 deletions gym_pybullet_drones/envs/BaseRLAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,14 @@ def _preprocessAction(self,
rpm[k,:] = np.repeat(self.HOVER_RPM * (1+0.05*target), 4)
elif self.ACT_TYPE == ActionType.ONE_D_PID:
state = self._getDroneStateVector(k)
rpm, _, _ = self.ctrl[k].computeControl(control_timestep=self.CTRL_TIMESTEP,
res, _, _ = self.ctrl[k].computeControl(control_timestep=self.CTRL_TIMESTEP,
cur_pos=state[0:3],
cur_quat=state[3:7],
cur_vel=state[10:13],
cur_ang_vel=state[13:16],
target_pos=state[0:3]+0.1*np.array([0,0,target[0]])
)
rpm[k,:] = rpm
rpm[k,:] = res
else:
print("[ERROR] in BaseRLAviary._preprocessAction()")
exit()
Expand Down
6 changes: 3 additions & 3 deletions gym_pybullet_drones/envs/HoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self,
initial_rpys=None,
physics: Physics=Physics.PYB,
pyb_freq: int = 240,
ctrl_freq: int = 240,
ctrl_freq: int = 30,
gui=False,
record=False,
obs: ObservationType=ObservationType.KIN,
Expand Down Expand Up @@ -74,7 +74,7 @@ def _computeReward(self):
"""
state = self._getDroneStateVector(0)
ret = max(0, 500 - np.linalg.norm(self.TARGET_POS-state[0:3])**2)
ret = max(0, 10 - np.linalg.norm(self.TARGET_POS-state[0:3])**4)
return ret

################################################################################
Expand All @@ -89,7 +89,7 @@ def _computeTerminated(self):
"""
state = self._getDroneStateVector(0)
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .001:
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .0001:
return True
else:
return False
Expand Down
6 changes: 3 additions & 3 deletions gym_pybullet_drones/envs/LeaderFollowerAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self,
initial_rpys=None,
physics: Physics=Physics.PYB,
pyb_freq: int = 240,
ctrl_freq: int = 240,
ctrl_freq: int = 30,
gui=False,
record=False,
obs: ObservationType=ObservationType.KIN,
Expand Down Expand Up @@ -81,9 +81,9 @@ def _computeReward(self):
"""
states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
ret = max(0, 500 - np.linalg.norm(self.TARGET_POS-states[0, 0:3])**2)
ret = max(0, 100 - np.linalg.norm(self.TARGET_POS-states[0, 0:3])**2)
for i in range(1, self.NUM_DRONES):
ret += max(0, 100 - np.linalg.norm(states[i-1, 3]-states[i, 3])**2)
ret += max(0, 10 - np.linalg.norm(states[i-1, 3]-states[i, 3])**2)
return ret

################################################################################
Expand Down
25 changes: 10 additions & 15 deletions gym_pybullet_drones/examples/learn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Script demonstrating the use of `gym_pybullet_drones`'s Gymnasium interface.
Class HoverAviary is used as a learning env for the PPO algorithm.
Classes HoverAviary and LeaderFollowerAviary are used as learning envs for the PPO algorithm.
Example
-------
Expand Down Expand Up @@ -38,7 +38,7 @@
DEFAULT_COLAB = False

DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'
DEFAULT_ACT = ActionType('vel') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' / TO BE FIXED: 'one_d_pid'
DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_AGENTS = 3
DEFAULT_MA = False

Expand Down Expand Up @@ -74,17 +74,17 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D
# tensorboard_log=filename+'/tb/',
verbose=1)

callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=1000000,
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=np.inf,
verbose=1)
eval_callback = EvalCallback(eval_env,
callback_on_new_best=callback_on_best,
verbose=1,
best_model_save_path=filename+'/',
log_path=filename+'/',
eval_freq=int(1000),
eval_freq=int(2000),
deterministic=True,
render=False)
model.learn(total_timesteps=10000, #int(1e12),
model.learn(total_timesteps=int(1e6),
callback=eval_callback,
log_interval=100)

Expand All @@ -102,11 +102,6 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################

if os.path.isfile(filename+'/success_model.zip'):
path = filename+'/success_model.zip'
Expand Down Expand Up @@ -141,7 +136,7 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D
n_eval_episodes=10
)
print("\n\n\nMean reward ", mean_reward, " +- ", std_reward, "\n\n")

obs, info = test_env.reset(seed=42, options={})
start = time.time()
for i in range(3*test_env.CTRL_FREQ):
Expand Down Expand Up @@ -186,11 +181,11 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D

if __name__ == '__main__':
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script using HoverAviary')
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script')
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
parser.add_argument('--colab', default=DEFAULT_COLAB, type=bool, help='Whether example is being run by a notebook (default: "False")', metavar='')
parser.add_argument('--colab', default=DEFAULT_COLAB, type=bool, help='Whether example is being run by a notebook (default: "False")', metavar='')
ARGS = parser.parse_args()

run(**vars(ARGS))

0 comments on commit b5d3549

Please sign in to comment.