Skip to content

Commit

Permalink
single and multi agent learn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Nov 20, 2023
1 parent f91f044 commit 7221e95
Showing 1 changed file with 73 additions and 79 deletions.
152 changes: 73 additions & 79 deletions gym_pybullet_drones/examples/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
reinforcement learning library `stable-baselines3`.
"""
import os
import time
from datetime import datetime
import argparse
import gymnasium as gym
import numpy as np
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.policies import ActorCriticPolicy as a2cppoMlpPolicy
from stable_baselines3.common.policies import ActorCriticCnnPolicy as a2cppoCnnPolicy
from stable_baselines3.common.evaluation import evaluate_policy

from gym_pybullet_drones.utils.Logger import Logger
Expand All @@ -37,111 +37,105 @@
DEFAULT_OUTPUT_FOLDER = 'results'
DEFAULT_COLAB = False

DEFAULT_ALGO = 'ppo'
DEFAULT_OBS = ObservationType('kin')
DEFAULT_ACT = ActionType('rpm')
DEFAULT_AGENTS = 2
DEFAULT_MA = True

def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=DEFAULT_COLAB, record_video=DEFAULT_RECORD_VIDEO):

MULTI_AGENT = False
filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S"))
if not os.path.exists(filename):
os.makedirs(filename+'/')

sa_env_kwargs = dict(obs=DEFAULT_OBS, act=DEFAULT_ACT)

if not MULTI_AGENT:
# train_env = gym.make('hover-aviary-v0')
if not DEFAULT_MA:
train_env = make_vec_env(HoverAviary,
env_kwargs=sa_env_kwargs,
n_envs=2,
env_kwargs=dict(obs=DEFAULT_OBS, act=DEFAULT_ACT),
n_envs=1,
seed=0
)
eval_env = gym.make('hover-aviary-v0')
eval_env = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)
else:
train_env = gym.make('leaderfollower-aviary-v0')
train_env = make_vec_env(LeaderFollowerAviary,
env_kwargs=dict(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT),
n_envs=1,
seed=0
)
eval_env = LeaderFollowerAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)

#### Check the environment's spaces ########################
print('[INFO] Action space:', train_env.action_space)
print('[INFO] Observation space:', train_env.observation_space)

#### Train the model #######################################
onpolicy_kwargs = dict(activation_fn=torch.nn.ReLU,
net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])]
) # or None

# model = PPO('MlpPolicy',
# train_env,
# verbose=1
# )
model = PPO(a2cppoMlpPolicy, # or a2cppoCnnPolicy
model = PPO('MlpPolicy',
train_env,
# policy_kwargs=onpolicy_kwargs,
# policy_kwargs=dict(activation_fn=torch.nn.ReLU, net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])]),
# tensorboard_log=filename+'/tb/',
verbose=1
)
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=1000,
verbose=1
)
verbose=1)

callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=1000000,
verbose=1)
eval_callback = EvalCallback(eval_env,
callback_on_new_best=callback_on_best,
verbose=1,
# best_model_save_path=filename+'/',
# log_path=filename+'/',
eval_freq=int(2000),
best_model_save_path=filename+'/',
log_path=filename+'/',
eval_freq=int(1000),
deterministic=True,
render=False
)
render=False)
model.learn(total_timesteps=10000, #int(1e12),
callback=eval_callback,
log_interval=100,
)
# model.learn(total_timesteps=10000) # Typically not enough
log_interval=100)

#### Save the model ########################################
# model.save(filename+'/success_model.zip')
# print(filename)

# #### Print training progression ############################
# with np.load(filename+'/evaluations.npz') as data:
# for j in range(data['timesteps'].shape[0]):
# print(str(data['timesteps'][j])+","+str(data['results'][j][0]))








# if os.path.isfile(exp+'/success_model.zip'):
# path = exp+'/success_model.zip'
# elif os.path.isfile(exp+'/best_model.zip'):
# path = exp+'/best_model.zip'
# else:
# print("[ERROR]: no model under the specified path", exp)
# model = PPO.load(path)


model.save(filename+'/success_model.zip')
print(filename)

#### Print training progression ############################
with np.load(filename+'/evaluations.npz') as data:
for j in range(data['timesteps'].shape[0]):
print(str(data['timesteps'][j])+","+str(data['results'][j][0]))

############################################################
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################
############################################################

if os.path.isfile(filename+'/success_model.zip'):
path = filename+'/success_model.zip'
elif os.path.isfile(filename+'/best_model.zip'):
path = filename+'/best_model.zip'
else:
print("[ERROR]: no model under the specified path", filename)
model = PPO.load(path)

#### Show (and record a video of) the model's performance ##
if not MULTI_AGENT:
if not DEFAULT_MA:
test_env = HoverAviary(gui=gui,
record=record_video
)
test_env_nogui = HoverAviary()
logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ),
num_drones=1,
output_folder=output_folder,
colab=colab
)
obs=DEFAULT_OBS,
act=DEFAULT_ACT,
record=record_video)
test_env_nogui = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)
else:
test_env = LeaderFollowerAviary(gui=gui,
record=record_video
)
test_env_nogui = LeaderFollowerAviary()
logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ),
num_drones=2,
output_folder=output_folder,
colab=colab
)

num_drones=DEFAULT_AGENTS,
obs=DEFAULT_OBS,
act=DEFAULT_ACT,
record=record_video)
test_env_nogui = LeaderFollowerAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)
logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ),
num_drones=DEFAULT_AGENTS if DEFAULT_MA else 1,
output_folder=output_folder,
colab=colab
)

mean_reward, std_reward = evaluate_policy(model,
test_env_nogui,
n_eval_episodes=10
Expand All @@ -158,7 +152,7 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D
obs2 = obs.squeeze()
act2 = action.squeeze()
print("Obs:", obs, "\tAction", action, "\tReward:", reward, "\tTerminated:", terminated, "\tTruncated:", truncated)
if not MULTI_AGENT:
if not DEFAULT_MA:
logger.log(drone=0,
timestamp=i/test_env.CTRL_FREQ,
state=np.hstack([obs2[0:3],
Expand All @@ -169,7 +163,7 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=D
control=np.zeros(12)
)
else:
for d in range(2):
for d in range(DEFAULT_AGENTS):
logger.log(drone=d,
timestamp=i/test_env.CTRL_FREQ,
state=np.hstack([obs2[d][0:3],
Expand Down

0 comments on commit 7221e95

Please sign in to comment.