diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py index de3765c1d..f80b1a7a2 100644 --- a/gym_pybullet_drones/examples/learn.py +++ b/gym_pybullet_drones/examples/learn.py @@ -14,7 +14,9 @@ reinforcement learning library `stable-baselines3`. """ +import os import time +from datetime import datetime import argparse import gymnasium as gym import numpy as np @@ -22,8 +24,6 @@ from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold -from stable_baselines3.common.policies import ActorCriticPolicy as a2cppoMlpPolicy -from stable_baselines3.common.policies import ActorCriticCnnPolicy as a2cppoCnnPolicy from stable_baselines3.common.evaluation import evaluate_policy from gym_pybullet_drones.utils.Logger import Logger @@ -37,111 +37,105 @@ DEFAULT_OUTPUT_FOLDER = 'results' DEFAULT_COLAB = False -DEFAULT_ALGO = 'ppo' DEFAULT_OBS = ObservationType('kin') DEFAULT_ACT = ActionType('rpm') +DEFAULT_AGENTS = 2 +DEFAULT_MA = True def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=DEFAULT_COLAB, record_video=DEFAULT_RECORD_VIDEO): - MULTI_AGENT = False + filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S")) + if not os.path.exists(filename): + os.makedirs(filename+'/') - sa_env_kwargs = dict(obs=DEFAULT_OBS, act=DEFAULT_ACT) - - if not MULTI_AGENT: - # train_env = gym.make('hover-aviary-v0') + if not DEFAULT_MA: train_env = make_vec_env(HoverAviary, - env_kwargs=sa_env_kwargs, - n_envs=2, + env_kwargs=dict(obs=DEFAULT_OBS, act=DEFAULT_ACT), + n_envs=1, seed=0 ) - eval_env = gym.make('hover-aviary-v0') + eval_env = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT) else: - train_env = gym.make('leaderfollower-aviary-v0') + train_env = make_vec_env(LeaderFollowerAviary, + env_kwargs=dict(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT), + n_envs=1, + seed=0 + ) + eval_env = LeaderFollowerAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT) + #### Check the environment's spaces ######################## print('[INFO] Action space:', train_env.action_space) print('[INFO] Observation space:', train_env.observation_space) #### Train the model ####################################### - onpolicy_kwargs = dict(activation_fn=torch.nn.ReLU, - net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])] - ) # or None - - # model = PPO('MlpPolicy', - # train_env, - # verbose=1 - # ) - model = PPO(a2cppoMlpPolicy, # or a2cppoCnnPolicy + model = PPO('MlpPolicy', train_env, - # policy_kwargs=onpolicy_kwargs, + # policy_kwargs=dict(activation_fn=torch.nn.ReLU, net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])]), # tensorboard_log=filename+'/tb/', - verbose=1 - ) - callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=1000, - verbose=1 - ) + verbose=1) + + callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=1000000, + verbose=1) eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1, - # best_model_save_path=filename+'/', - # log_path=filename+'/', - eval_freq=int(2000), + best_model_save_path=filename+'/', + log_path=filename+'/', + eval_freq=int(1000), deterministic=True, - render=False - ) + render=False) model.learn(total_timesteps=10000, #int(1e12), callback=eval_callback, - log_interval=100, - ) - # model.learn(total_timesteps=10000) # Typically not enough + log_interval=100) #### Save the model ######################################## - # model.save(filename+'/success_model.zip') - # print(filename) - - # #### Print training progression ############################ - # with np.load(filename+'/evaluations.npz') as data: - # for j in range(data['timesteps'].shape[0]): - # print(str(data['timesteps'][j])+","+str(data['results'][j][0])) - - - - - - - - - # if os.path.isfile(exp+'/success_model.zip'): - # path = exp+'/success_model.zip' - # elif os.path.isfile(exp+'/best_model.zip'): - # path = exp+'/best_model.zip' - # else: - # print("[ERROR]: no model under the specified path", exp) - # model = PPO.load(path) - - + model.save(filename+'/success_model.zip') + print(filename) + + #### Print training progression ############################ + with np.load(filename+'/evaluations.npz') as data: + for j in range(data['timesteps'].shape[0]): + print(str(data['timesteps'][j])+","+str(data['results'][j][0])) + + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + + if os.path.isfile(filename+'/success_model.zip'): + path = filename+'/success_model.zip' + elif os.path.isfile(filename+'/best_model.zip'): + path = filename+'/best_model.zip' + else: + print("[ERROR]: no model under the specified path", filename) + model = PPO.load(path) #### Show (and record a video of) the model's performance ## - if not MULTI_AGENT: + if not DEFAULT_MA: test_env = HoverAviary(gui=gui, - record=record_video - ) - test_env_nogui = HoverAviary() - logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ), - num_drones=1, - output_folder=output_folder, - colab=colab - ) + obs=DEFAULT_OBS, + act=DEFAULT_ACT, + record=record_video) + test_env_nogui = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT) else: test_env = LeaderFollowerAviary(gui=gui, - record=record_video - ) - test_env_nogui = LeaderFollowerAviary() - logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ), - num_drones=2, - output_folder=output_folder, - colab=colab - ) - + num_drones=DEFAULT_AGENTS, + obs=DEFAULT_OBS, + act=DEFAULT_ACT, + record=record_video) + test_env_nogui = LeaderFollowerAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT) + logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ), + num_drones=DEFAULT_AGENTS if DEFAULT_MA else 1, + output_folder=output_folder, + colab=colab + ) + mean_reward, std_reward = evaluate_policy(model, test_env_nogui, n_eval_episodes=10 @@ -158,7 +152,7 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D obs2 = obs.squeeze() act2 = action.squeeze() print("Obs:", obs, "\tAction", action, "\tReward:", reward, "\tTerminated:", terminated, "\tTruncated:", truncated) - if not MULTI_AGENT: + if not DEFAULT_MA: logger.log(drone=0, timestamp=i/test_env.CTRL_FREQ, state=np.hstack([obs2[0:3], @@ -169,7 +163,7 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D control=np.zeros(12) ) else: - for d in range(2): + for d in range(DEFAULT_AGENTS): logger.log(drone=d, timestamp=i/test_env.CTRL_FREQ, state=np.hstack([obs2[d][0:3],