Skip to content

Commit

Permalink
Merge pull request #36 from henryraubenheimer/main
Browse files Browse the repository at this point in the history
Fixed online system methods
  • Loading branch information
jcformanek authored Aug 22, 2024
2 parents 261543b + c08b661 commit 7228ba4
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 23 deletions.
2 changes: 1 addition & 1 deletion examples/tf2/online/idrqn_smacv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

env = SMACv1("3m")

logger = WandbLogger(entity="claude_formanek")
logger = WandbLogger()

system = IDRQNSystem(env, logger, eps_decay_timesteps=10_000)

Expand Down
2 changes: 1 addition & 1 deletion examples/tf2/online/idrqn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

env = SMAX("3m")

logger = WandbLogger(entity="claude_formanek")
logger = WandbLogger()

system = QMIXSystem(env, logger, eps_decay_timesteps=50_000)

Expand Down
2 changes: 1 addition & 1 deletion examples/tf2/online/qmix_pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

env = Pursuit()

logger = WandbLogger(entity="claude_formanek")
logger = WandbLogger()

system = QMIXSystem(env, logger, add_agent_id_to_obs=True, target_update_rate=0.00005)

Expand Down
2 changes: 1 addition & 1 deletion examples/tf2/online/qmix_smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

env = SMACv2("terran_5_vs_5")

logger = WandbLogger(entity="claude_formanek")
logger = WandbLogger()

system = QMIXSystem(
env, logger, add_agent_id_to_obs=True, learning_rate=0.0005, eps_decay_timesteps=100_000
Expand Down
18 changes: 9 additions & 9 deletions install_environments/smacv1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
# Install SC2 and add the custom maps
# Script adapted from https://github.com/oxwhirl/pymarl

# export SC2PATH=~/StarCraftII
export SC2PATH=~/StarCraftII

echo 'StarCraftII is not installed. Installing now ...';
wget --progress=dot:mega http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip
unzip -oP iagreetotheeula SC2.4.10.zip
mv StarCraftII $SC2PATH
rm -rf SC2.4.10.zip
#wget --progress=dot:mega http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip
#unzip -oP iagreetotheeula SC2.4.10.zip
#mv StarCraftII $SC2PATH
#rm -rf SC2.4.10.zip

echo 'Adding SMAC maps.'
MAP_DIR="$SC2PATH/Maps/"
echo 'MAP_DIR is set to '$MAP_DIR
mkdir -p $MAP_DIR

wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip
unzip SMAC_Maps.zip
mv SMAC_Maps $MAP_DIR
rm -rf SMAC_Maps.zip
#wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip
#unzip SMAC_Maps.zip
#mv SMAC_Maps $MAP_DIR
#rm -rf SMAC_Maps.zip

echo 'StarCraft II is installed.'

Expand Down
19 changes: 13 additions & 6 deletions og_marl/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,20 @@ def add(
truncations: Dict[str, np.ndarray],
infos: Dict[str, Any],
) -> None:
stacked_infos = {}
for key, value in infos.items():
if isinstance(value, dict):
stacked_infos[key] = np.stack(list(value.values()), axis=0)
else:
stacked_infos[key] = value

timestep = {
"observations": observations,
"actions": actions,
"rewards": rewards,
"terminals": terminals,
"truncations": truncations,
"infos": infos,
"observations": np.stack(list(observations.values()), axis=0),
"actions": np.stack(list(actions.values()), axis=0),
"rewards": np.stack(list(rewards.values()), axis=0),
"terminals": np.stack(list(terminals.values()), axis=0),
"truncations": np.stack(list(truncations.values()), axis=0),
"infos": stacked_infos,
}

if self._buffer_state is None:
Expand Down
4 changes: 2 additions & 2 deletions og_marl/tf2/systems/idrqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
observations = experience["observations"] # (B,T,N,O)
actions = experience["actions"] # (B,T,N)
rewards = experience["rewards"] # (B,T,N)
truncations = experience["truncations"] # (B,T,N)
terminals = experience["terminals"] # (B,T,N)
truncations = tf.cast(experience["truncations"], "float32") # (B,T,N)
terminals = tf.cast(experience["terminals"], "float32") # (B,T,N)
legal_actions = experience["infos"]["legals"] # (B,T,N,A)

# When to reset the RNN hidden state
Expand Down
4 changes: 2 additions & 2 deletions og_marl/tf2/systems/idrqn_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
observations = experience["observations"] # (B,T,N,O)
actions = experience["actions"] # (B,T,N)
rewards = experience["rewards"] # (B,T,N)
truncations = experience["truncations"] # (B,T,N)
terminals = experience["terminals"] # (B,T,N)
truncations = tf.cast(experience["truncations"], "float32") # (B,T,N)
terminals = tf.cast(experience["terminals"], "float32") # (B,T,N)
legal_actions = experience["infos"]["legals"] # (B,T,N,A)

# When to reset the RNN hidden state
Expand Down

0 comments on commit 7228ba4

Please sign in to comment.