diff --git a/examples/tf2/online/idrqn_smacv1.py b/examples/tf2/online/idrqn_smacv1.py index 0a94c53c..e45bd942 100644 --- a/examples/tf2/online/idrqn_smacv1.py +++ b/examples/tf2/online/idrqn_smacv1.py @@ -19,7 +19,7 @@ env = SMACv1("3m") -logger = WandbLogger(entity="claude_formanek") +logger = WandbLogger() system = IDRQNSystem(env, logger, eps_decay_timesteps=10_000) diff --git a/examples/tf2/online/idrqn_smax.py b/examples/tf2/online/idrqn_smax.py index 2b2c9a47..501dd229 100644 --- a/examples/tf2/online/idrqn_smax.py +++ b/examples/tf2/online/idrqn_smax.py @@ -19,7 +19,7 @@ env = SMAX("3m") -logger = WandbLogger(entity="claude_formanek") +logger = WandbLogger() system = QMIXSystem(env, logger, eps_decay_timesteps=50_000) diff --git a/examples/tf2/online/qmix_pursuit.py b/examples/tf2/online/qmix_pursuit.py index b0a5b73b..e726969b 100644 --- a/examples/tf2/online/qmix_pursuit.py +++ b/examples/tf2/online/qmix_pursuit.py @@ -5,7 +5,7 @@ env = Pursuit() -logger = WandbLogger(entity="claude_formanek") +logger = WandbLogger() system = QMIXSystem(env, logger, add_agent_id_to_obs=True, target_update_rate=0.00005) diff --git a/examples/tf2/online/qmix_smacv2.py b/examples/tf2/online/qmix_smacv2.py index d549c509..e7e7b0ca 100644 --- a/examples/tf2/online/qmix_smacv2.py +++ b/examples/tf2/online/qmix_smacv2.py @@ -5,7 +5,7 @@ env = SMACv2("terran_5_vs_5") -logger = WandbLogger(entity="claude_formanek") +logger = WandbLogger() system = QMIXSystem( env, logger, add_agent_id_to_obs=True, learning_rate=0.0005, eps_decay_timesteps=100_000 diff --git a/install_environments/smacv1.sh b/install_environments/smacv1.sh index 60fa131c..57805e16 100755 --- a/install_environments/smacv1.sh +++ b/install_environments/smacv1.sh @@ -2,23 +2,23 @@ # Install SC2 and add the custom maps # Script adapted from https://github.com/oxwhirl/pymarl -# export SC2PATH=~/StarCraftII +export SC2PATH=~/StarCraftII echo 'StarCraftII is not installed. Installing now ...'; -wget --progress=dot:mega http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip -unzip -oP iagreetotheeula SC2.4.10.zip -mv StarCraftII $SC2PATH -rm -rf SC2.4.10.zip +#wget --progress=dot:mega http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip +#unzip -oP iagreetotheeula SC2.4.10.zip +#mv StarCraftII $SC2PATH +#rm -rf SC2.4.10.zip echo 'Adding SMAC maps.' MAP_DIR="$SC2PATH/Maps/" echo 'MAP_DIR is set to '$MAP_DIR mkdir -p $MAP_DIR -wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip -unzip SMAC_Maps.zip -mv SMAC_Maps $MAP_DIR -rm -rf SMAC_Maps.zip +#wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip +#unzip SMAC_Maps.zip +#mv SMAC_Maps $MAP_DIR +#rm -rf SMAC_Maps.zip echo 'StarCraft II is installed.' diff --git a/og_marl/replay_buffers.py b/og_marl/replay_buffers.py index dcd64325..5ffcd6ba 100644 --- a/og_marl/replay_buffers.py +++ b/og_marl/replay_buffers.py @@ -64,13 +64,20 @@ def add( truncations: Dict[str, np.ndarray], infos: Dict[str, Any], ) -> None: + stacked_infos = {} + for key, value in infos.items(): + if isinstance(value, dict): + stacked_infos[key] = np.stack(list(value.values()), axis=0) + else: + stacked_infos[key] = value + timestep = { - "observations": observations, - "actions": actions, - "rewards": rewards, - "terminals": terminals, - "truncations": truncations, - "infos": infos, + "observations": np.stack(list(observations.values()), axis=0), + "actions": np.stack(list(actions.values()), axis=0), + "rewards": np.stack(list(rewards.values()), axis=0), + "terminals": np.stack(list(terminals.values()), axis=0), + "truncations": np.stack(list(truncations.values()), axis=0), + "infos": stacked_infos, } if self._buffer_state is None: diff --git a/og_marl/tf2/systems/idrqn.py b/og_marl/tf2/systems/idrqn.py index 19c9c91f..bb95f785 100644 --- a/og_marl/tf2/systems/idrqn.py +++ b/og_marl/tf2/systems/idrqn.py @@ -182,8 +182,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic observations = experience["observations"] # (B,T,N,O) actions = experience["actions"] # (B,T,N) rewards = experience["rewards"] # (B,T,N) - truncations = experience["truncations"] # (B,T,N) - terminals = experience["terminals"] # (B,T,N) + truncations = tf.cast(experience["truncations"], "float32") # (B,T,N) + terminals = tf.cast(experience["terminals"], "float32") # (B,T,N) legal_actions = experience["infos"]["legals"] # (B,T,N,A) # When to reset the RNN hidden state diff --git a/og_marl/tf2/systems/idrqn_cql.py b/og_marl/tf2/systems/idrqn_cql.py index 1c5d1dfb..dfab94fa 100644 --- a/og_marl/tf2/systems/idrqn_cql.py +++ b/og_marl/tf2/systems/idrqn_cql.py @@ -71,8 +71,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st observations = experience["observations"] # (B,T,N,O) actions = experience["actions"] # (B,T,N) rewards = experience["rewards"] # (B,T,N) - truncations = experience["truncations"] # (B,T,N) - terminals = experience["terminals"] # (B,T,N) + truncations = tf.cast(experience["truncations"], "float32") # (B,T,N) + terminals = tf.cast(experience["terminals"], "float32") # (B,T,N) legal_actions = experience["infos"]["legals"] # (B,T,N,A) # When to reset the RNN hidden state