Skip to content

Commit

Permalink
Merge pull request #21 from instadeepai/feat/pixel-obs-vaults
Browse files Browse the repository at this point in the history
Add support for pixel-obs environments
  • Loading branch information
callumtilbury authored Mar 11, 2024
2 parents bfeb79e + 5492aca commit 112f8d7
Show file tree
Hide file tree
Showing 22 changed files with 583 additions and 433 deletions.
9 changes: 7 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04

# Ensure no installs try to launch interactive screen
ARG DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1

# Update packages and install python3.9 and other dependencies
RUN apt-get update -y && \
Expand Down Expand Up @@ -33,14 +34,18 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \
pip install -e . && \
pip install flashbax==0.1.0

ENV SC2PATH /home/app/StarCraftII
# ENV SC2PATH /home/app/StarCraftII
# RUN ./install_environments/smacv1.sh
RUN ./install_environments/smacv2.sh
# RUN ./install_environments/smacv2.sh

# ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia
# ENV SUPPRESS_GR_PROMPT 1
# RUN ./install_environments/mamujoco.sh

RUN ./install_environments/pettingzoo.sh

# RUN ./install_environments/flatland.sh

# Copy all code
COPY ./examples ./examples
COPY ./baselines ./baselines
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ We are in the process of migrating our datasets from TF Records to Flashbax Vaul
| 💣SMAC v2 | terran_5_vs_5 <br/> zerg_5_vs_5 <br/> terran_10_vs_10 | 5 <br/> 5 <br/> 10 | Discrete | Vector | Dense | Heterog | [source](https://github.com/oxwhirl/smacv2) |
| 🚅Flatland | 3 Trains <br/> 5 Trains | 3 <br/> 5 | Discrete | Vector | Sparse | Homog | [source](https://flatland.aicrowd.com/intro.html) |
| 🐜MAMuJoCo | 2-HalfCheetah <br/> 2-Ant <br/> 4-Ant | 2 <br/> 2 <br/> 4 | Cont. | Vector | Dense | Heterog <br/> Homog <br/> Homog | [source](https://github.com/schroederdewitt/multiagent_mujoco) |

| 🐻PettingZoo | Pursuit <br/> Co-op Pong | 8 <br/> 2 | Discrete <br/> Discrete | Pixels <br/> Pixels | Dense | Homog <br/> Heterog | [source](https://pettingzoo.farama.org/) |

### Legacy Datasets (still to be migrated to Vault) 👴
| Environment | Scenario | Agents | Act | Obs | Reward | Types | Repo |
|-----|----|----|-----|-----|----|----|-----|
| 🐻PettingZoo | Pursuit <br/> Co-op Pong <br/> PistonBall <br/> KAZ| 8 <br/> 2 <br/> 15 <br/> 2| Discrete <br/> Discrete <br/> Cont. <br/> Discrete | Pixels <br/> Pixels <br/> Pixels <br/> Vector | Dense | Homog <br/> Heterog <br/> Homog <br/> Heterog| [source](https://pettingzoo.farama.org/) |
| 🐻PettingZoo | PistonBall <br/> KAZ| 15 <br/> 2| Cont. <br/> Discrete | Pixels <br/> Vector | Dense | Homog <br/> Heterog| [source](https://pettingzoo.farama.org/) |
| 🏙️CityLearn | 2022_all_phases | 17 | Cont. | Vector | Dense | Homog | [source](https://github.com/intelligent-environments-lab/CityLearn) |
| 🔌Voltage Control | case33_3min_final | 6 | Cont. | Vector | Dense | Homog | [source](https://github.com/Future-Power-Networks/MAPDN) |
| 🔴MPE | simple_adversary | 3 | Discrete. | Vector | Dense | Competitive | [source](https://pettingzoo.farama.org/environments/mpe/simple_adversary/) |
Expand Down
14 changes: 9 additions & 5 deletions baselines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
# limitations under the License.
from absl import app, flags

from og_marl.environments.utils import get_environment
from og_marl.environments import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.offline_dataset import download_and_unzip_vault
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

FLAGS = flags.FLAGS
flags.DEFINE_string("env", "smac_v1", "Environment name.")
flags.DEFINE_string("scenario", "3m", "Environment scenario name.")
flags.DEFINE_string("env", "pettingzoo", "Environment name.")
flags.DEFINE_string("scenario", "pursuit", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
flags.DEFINE_string("system", "dbc", "System name.")
flags.DEFINE_string("system", "qmix", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.")
flags.DEFINE_integer("batch_size", 64, "Number of training steps.")
Expand All @@ -43,7 +44,7 @@ def main(_):

env = get_environment(FLAGS.env, FLAGS.scenario)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=2)
buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)

download_and_unzip_vault(FLAGS.env, FLAGS.scenario)

Expand All @@ -65,6 +66,9 @@ def main(_):
)

system_kwargs = {"add_agent_id_to_obs": True}
if FLAGS.scenario == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(FLAGS.system, env, logger, **system_kwargs)

system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer)
Expand Down
45 changes: 33 additions & 12 deletions examples/tf2/run_all_baselines.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
import os

# import module
import traceback

from og_marl.environments import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.loggers import TerminalLogger, JsonWriter
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

os.environ["SUPPRESS_GR_PROMPT"] = 1
# For MAMuJoCo
os.environ["SUPPRESS_GR_PROMPT"] = "1"

scenario_system_configs = {
"smac_v1": {
"3m": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq"],
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 3000,
"trainer_steps": 2000,
"evaluate_every": 1000,
},
},
"mamujoco": {
"2halfcheetah": {
"systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
"pettingzoo": {
"pursuit": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 3000,
"trainer_steps": 2000,
"evaluate_every": 1000,
},
},
# "mamujoco": {
# "2halfcheetah": {
# "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
# "datasets": ["Good"],
# "trainer_steps": 3000,
# "evaluate_every": 1000,
# },
# },
}

seeds = [42]
Expand All @@ -44,7 +57,7 @@
"system": env_name,
"seed": seed,
}
logger = WandbLogger(config, project="og-marl-baselines")
logger = TerminalLogger()
env = get_environment(env_name, scenario_name)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)
Expand All @@ -55,10 +68,18 @@
raise ValueError("Vault not found. Exiting.")

json_writer = JsonWriter(
"logs", system_name, f"{scenario_name}_{dataset_name}", env_name, seed
"test_all_baselines",
system_name,
f"{scenario_name}_{dataset_name}",
env_name,
seed,
)

system_kwargs = {"add_agent_id_to_obs": True}

if scenario_name == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(system_name, env, logger, **system_kwargs)

trainer_steps = scenario_system_configs[env_name][scenario_name][
Expand All @@ -75,7 +96,7 @@
)
except: # noqa: E722
logger.close()
print()
print("BROKEN")
print("BROKEN:", env_name, scenario_name, system_name)
traceback.print_exc()
print()
continue
2 changes: 1 addition & 1 deletion install_environments/requirements/pettingzoo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ autorom
gym
numpy
opencv-python
pettingzoo==1.22.0
pettingzoo==1.23.1
pygame
pymunk
scipy
Expand Down
1 change: 0 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ theme:

nav:
- Home: 'index.md'
- Datasets: 'datasets.md'
- Baseline Results: 'baselines.md'
- Updates: 'updates.md'
- API Reference: 'api.md'
Expand Down
10 changes: 9 additions & 1 deletion og_marl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from og_marl.environments.base import BaseEnvironment


def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
def get_environment(env_name: str, scenario: str) -> BaseEnvironment: # noqa: C901
if env_name == "smac_v1":
from og_marl.environments.smacv1 import SMACv1

Expand All @@ -31,6 +31,14 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
from og_marl.environments.old_mamujoco import MAMuJoCo

return MAMuJoCo(scenario)
elif scenario == "pursuit":
from og_marl.environments.pursuit import Pursuit

return Pursuit()
elif scenario == "coop_pong":
from og_marl.environments.coop_pong import CooperativePong

return CooperativePong()
elif env_name == "gymnasium_mamujoco":
from og_marl.environments.gymnasium_mamujoco import MAMuJoCo

Expand Down
114 changes: 114 additions & 0 deletions og_marl/environments/coop_pong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrapper for Cooperative Pettingzoo environments."""
from typing import Any, List, Dict

import numpy as np
from pettingzoo.butterfly import cooperative_pong_v5
import supersuit

from og_marl.environments.base import BaseEnvironment
from og_marl.environments.base import Observations, ResetReturn, StepReturn


class CooperativePong(BaseEnvironment):
"""Environment wrapper PettingZoo Cooperative Pong."""

def __init__(
self,
) -> None:
"""Constructor."""
self._environment = cooperative_pong_v5.parallel_env(render_mode="rgb_array")
# Wrap environment with supersuit pre-process wrappers
self._environment = supersuit.color_reduction_v0(self._environment, mode="R")
self._environment = supersuit.resize_v0(self._environment, x_size=145, y_size=84)
self._environment = supersuit.dtype_v0(self._environment, dtype="float32")
self._environment = supersuit.normalize_obs_v0(self._environment)

self._agents = self._environment.possible_agents
self._done = False
self.max_episode_length = 900

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
observations, _ = self._environment.reset() # type: ignore

# Convert observations
observations = self._convert_observations(observations)

# Global state
env_state = self._create_state_representation(observations, first=True)

# Infos
info = {"state": env_state, "legals": self._legals}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
# Step the environment
observations, rewards, terminals, truncations, _ = self._environment.step(actions)

# Convert observations
observations = self._convert_observations(observations)

# Global state
env_state = self._create_state_representation(observations)

# Extra infos
info = {"state": env_state, "legals": self._legals}

return observations, rewards, terminals, truncations, info

def _create_state_representation(self, observations: Observations, first: bool = False) -> Any:
if first:
self._state_history = np.zeros((84, 145, 4), "float32")

state = np.expand_dims(observations["paddle_0"][:, :], axis=-1)

# framestacking
self._state_history = np.concatenate((state, self._state_history[:, :, :3]), axis=-1)

return self._state_history

def _convert_observations(self, observations: List) -> Observations:
"""Make observations partial."""
processed_observations = {}
for agent in self._agents:
if agent == "paddle_0":
agent_obs = observations[agent][:, :110] # hide the other agent
else:
agent_obs = observations[agent][:, 35:] # hide the other agent

agent_obs = np.expand_dims(agent_obs, axis=-1)
processed_observations[agent] = agent_obs

return processed_observations

def __getattr__(self, name: str) -> Any:
"""Expose any other attributes of the underlying environment.
Args:
name (str): attribute.
Returns:
Any: return attribute from env or underlying env.
"""
if hasattr(self.__class__, name):
return self.__getattribute__(name)
else:
return getattr(self._environment, name)
2 changes: 1 addition & 1 deletion og_marl/environments/pettingzoo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self) -> None:
def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
observations = self._environment.reset() # type: ignore
observations, _ = self._environment.reset() # type: ignore

# Global state
env_state = self._create_state_representation(observations)
Expand Down
Loading

0 comments on commit 112f8d7

Please sign in to comment.