Skip to content

Commit

Permalink
Added conv net embedding to all dicrete action envs.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Mar 8, 2024
1 parent ee79d6e commit 3414723
Show file tree
Hide file tree
Showing 12 changed files with 391 additions and 217 deletions.
58 changes: 41 additions & 17 deletions examples/tf2/run_all_baselines.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
import os

# import module
import traceback

from og_marl.environments.utils import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.loggers import TerminalLogger, JsonWriter
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

os.environ["SUPPRESS_GR_PROMPT"] = 1
# For MAMuJoCo
os.environ["SUPPRESS_GR_PROMPT"] = "1"

scenario_system_configs = {
"smac_v1": {
"3m": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq"],
"datasets": ["Good"],
"trainer_steps": 3000,
"evaluate_every": 1000,
},
},
"mamujoco": {
"2halfcheetah": {
"systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
# "smac_v1": {
# "3m": {
# "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq",
# "maicq", "dbc"],
# "datasets": ["Good"],
# "trainer_steps": 3000,
# "evaluate_every": 1000,
# },
# },
"pettingzoo": {
"pursuit": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 3000,
"evaluate_every": 1000,
},
},
# "mamujoco": {
# "2halfcheetah": {
# "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
# "datasets": ["Good"],
# "trainer_steps": 3000,
# "evaluate_every": 1000,
# },
# },
}

seeds = [42]
Expand All @@ -44,7 +58,7 @@
"system": env_name,
"seed": seed,
}
logger = WandbLogger(config, project="og-marl-baselines")
logger = TerminalLogger()
env = get_environment(env_name, scenario_name)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)
Expand All @@ -55,10 +69,20 @@
raise ValueError("Vault not found. Exiting.")

json_writer = JsonWriter(
"logs", system_name, f"{scenario_name}_{dataset_name}", env_name, seed
"test_all_baselines",
system_name,
f"{scenario_name}_{dataset_name}",
env_name,
seed,
)

system_kwargs = {"add_agent_id_to_obs": True}

if scenario_name == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()
if system_name in ["qmix", "qmix+cql", "qmix+bcq", "maicq"]:
system_kwargs["state_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(system_name, env, logger, **system_kwargs)

trainer_steps = scenario_system_configs[env_name][scenario_name][
Expand All @@ -75,7 +99,7 @@
)
except: # noqa: E722
logger.close()
print()
print("BROKEN")
print("BROKEN:", env_name, scenario_name, system_name)
traceback.print_exc()
print()
continue
10 changes: 5 additions & 5 deletions og_marl/environments/pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from typing import Dict

import numpy as np
from gymnasium.spaces import Box, Discrete
from pettingzoo.sisl import pursuit_v4
from supersuit import black_death_v3

from og_marl.environments.base import BaseEnvironment, Observations, ResetReturn, StepReturn
from og_marl.environments.pettingzoo_base import PettingZooBase


class Pursuit(BaseEnvironment):
Expand All @@ -28,7 +26,7 @@ class Pursuit(BaseEnvironment):

def __init__(self) -> None:
"""Constructor for Pursuit"""
self._environment = black_death_v3(pursuit_v4.parallel_env())
self._environment = pursuit_v4.parallel_env()
self.possible_agents = self._environment.possible_agents
self._num_actions = 5
self._obs_dim = (7, 7, 3)
Expand All @@ -38,7 +36,9 @@ def __init__(self) -> None:
agent: Box(-np.inf, np.inf, (*self._obs_dim,)) for agent in self.possible_agents
}

self._legals = {agent: np.ones((self._num_actions,), "int32") for agent in self.possible_agents}
self._legals = {
agent: np.ones((self._num_actions,), "int32") for agent in self.possible_agents
}

self.info_spec = {"state": np.zeros(8 * 2 + 30 * 2, "float32"), "legals": self._legals}

Expand Down
142 changes: 142 additions & 0 deletions og_marl/tf2/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Sequence

import tensorflow as tf
from tensorflow import Tensor
import sonnet as snt


class QMixer(snt.Module):

"""QMIX mixing network."""

def __init__(
self,
num_agents: int,
embed_dim: int = 32,
hypernet_embed: int = 64,
non_monotonic: bool = False,
):
"""Initialise QMIX mixing network
Args:
----
num_agents: Number of agents in the environment
state_dim: Dimensions of the global environment state
embed_dim: The dimension of the output of the first layer
of the mixer.
hypernet_embed: Number of units in the hyper network
"""
super().__init__()
self.num_agents = num_agents
self.embed_dim = embed_dim
self.hypernet_embed = hypernet_embed
self._non_monotonic = non_monotonic

self.hyper_w_1 = snt.Sequential(
[
snt.Linear(self.hypernet_embed),
tf.nn.relu,
snt.Linear(self.embed_dim * self.num_agents),
]
)

self.hyper_w_final = snt.Sequential(
[snt.Linear(self.hypernet_embed), tf.nn.relu, snt.Linear(self.embed_dim)]
)

# State dependent bias for hidden layer
self.hyper_b_1 = snt.Linear(self.embed_dim)

# V(s) instead of a bias for the last layers
self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)])

def __call__(self, agent_qs: Tensor, states: Tensor) -> Tensor:
"""Forward method."""
B = agent_qs.shape[0] # batch size
state_dim = states.shape[2:]

agent_qs = tf.reshape(agent_qs, (-1, 1, self.num_agents))

states = tf.reshape(states, (-1, *state_dim))

# First layer
w1 = self.hyper_w_1(states)
if not self._non_monotonic:
w1 = tf.abs(w1)
b1 = self.hyper_b_1(states)
w1 = tf.reshape(w1, (-1, self.num_agents, self.embed_dim))
b1 = tf.reshape(b1, (-1, 1, self.embed_dim))
hidden = tf.nn.elu(tf.matmul(agent_qs, w1) + b1)

# Second layer
w_final = self.hyper_w_final(states)
if not self._non_monotonic:
w_final = tf.abs(w_final)
w_final = tf.reshape(w_final, (-1, self.embed_dim, 1))

# State-dependent bias
v = tf.reshape(self.V(states), (-1, 1, 1))

# Compute final output
y = tf.matmul(hidden, w_final) + v

# Reshape and return
q_tot = tf.reshape(y, (B, -1, 1))

return q_tot

def k(self, states: Tensor) -> Tensor:
"""Method used by MAICQ."""
B, T = states.shape[:2]

w1 = tf.math.abs(self.hyper_w_1(states))
w_final = tf.math.abs(self.hyper_w_final(states))
w1 = tf.reshape(w1, shape=(-1, self.num_agents, self.embed_dim))
w_final = tf.reshape(w_final, shape=(-1, self.embed_dim, 1))
k = tf.matmul(w1, w_final)
k = tf.reshape(k, shape=(B, -1, self.num_agents))
k = k / (tf.reduce_sum(k, axis=2, keepdims=True) + 1e-10)
return k


@snt.allow_empty_variables
class IdentityNetwork(snt.Module):
def __init__(self) -> None:
super().__init__()
return

def __call__(self, x: Tensor) -> Tensor:
return x


class CNNEmbeddingNetwork(snt.Module):
def __init__(
self, output_channels: Sequence[int] = (8, 16), kernel_sizes: Sequence[int] = (3, 2)
) -> None:
super().__init__()
assert len(output_channels) == len(kernel_sizes)

layers = []
for layer_i in range(len(output_channels)):
layers.append(snt.Conv2D(output_channels[layer_i], kernel_sizes[layer_i]))
layers.append(tf.nn.relu)
layers.append(tf.keras.layers.Flatten())

self.conv_net = snt.Sequential(layers)

def __call__(self, x: Tensor) -> Tensor:
"""Embed a pixel-styled input into a vector using a conv net.
We assume the input has leading batch, time and agent dims. With trailing dims
being the width, height and channel dimensions of the input.
The output shape is then given as (B,T,N,Embed)
"""
leading_dims = x.shape[:-3] # B,T,N
trailing_dims = x.shape[-3:] # W,H,C

x = tf.reshape(x, shape=(-1, *trailing_dims))
embed = self.conv_net(x)
embed = tf.reshape(embed, shape=(*leading_dims, -1))
return embed
19 changes: 16 additions & 3 deletions og_marl/tf2/systems/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
switch_two_leading_dims,
unroll_rnn,
)
from og_marl.tf2.networks import IdentityNetwork


class DicreteActionBehaviourCloning(BaseMARLSystem):
Expand All @@ -33,6 +34,7 @@ def __init__(
discount: float = 0.99,
learning_rate: float = 1e-3,
add_agent_id_to_obs: bool = True,
observation_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment, logger, discount=discount, add_agent_id_to_obs=add_agent_id_to_obs
Expand All @@ -48,6 +50,9 @@ def __init__(
snt.Linear(self._environment._num_actions),
]
) # shared network for all agents
if observation_embedding_network is None:
observation_embedding_network = IdentityNetwork()
self._policy_embedding_network = observation_embedding_network

self._optimizer = snt.optimizers.RMSProp(learning_rate=learning_rate)

Expand Down Expand Up @@ -147,11 +152,16 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
resets = switch_two_leading_dims(resets)
actions = switch_two_leading_dims(actions)

# Merge batch_dim and agent_dim
observations = merge_batch_and_agent_dim_of_time_major_sequence(observations)
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)

with tf.GradientTape() as tape:
embeddings = self._policy_embedding_network(observations)
probs_out = unroll_rnn(
self._policy_network,
merge_batch_and_agent_dim_of_time_major_sequence(observations),
merge_batch_and_agent_dim_of_time_major_sequence(resets),
embeddings,
resets,
)
probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N)

Expand All @@ -163,7 +173,10 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
bc_loss = tf.reduce_mean(bc_loss)

# Apply gradients to policy
variables = (*self._policy_network.trainable_variables,) # Get trainable variables
variables = (
*self._policy_network.trainable_variables,
*self._policy_embedding_network.trainable_variables,
) # Get trainable variables

gradients = tape.gradient(bc_loss, variables) # Compute gradients.
self._optimizer.apply(gradients, variables)
Expand Down
Loading

0 comments on commit 3414723

Please sign in to comment.