Skip to content

Commit

Permalink
Merge pull request #16 from instadeepai/feat/replace-cpprb-with-flashbax
Browse files Browse the repository at this point in the history
feat: Big 'Everything' update
  • Loading branch information
callumtilbury authored Feb 23, 2024
2 parents 6de4f92 + 0eea06d commit 273c679
Show file tree
Hide file tree
Showing 57 changed files with 2,649 additions and 1,109 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ launch.json
.vscode
json_logs
vaults
vaults_unprocessed
development
SMAC_Maps
logs
__MACOSX
3.9
Expand Down
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ repos:
stages: [ commit-msg ]
additional_dependencies: [ '@commitlint/config-conventional' ]

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.0
hooks:
- id: insert-license
name: "License inserter"
files: .*py$
args:
- --license-filepath=docs/license_header.txt
- --comment-style=#
exclude: .npy$ # Do not add license to .npy files (the standard binary file format in NumPy)
# - repo: https://github.com/Lucas-C/pre-commit-hooks
# rev: v1.3.0
# hooks:
# - id: insert-license
# name: "License inserter"
# files: .*py$
# args:
# - --license-filepath=docs/license_header.txt
# - --comment-style=#
# exclude: .npy$ # Do not add license to .npy files (the standard binary file format in NumPy)
46 changes: 46 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04

# Ensure no installs try to launch interactive screen
ARG DEBIAN_FRONTEND=noninteractive

# Update packages and install python3.9 and other dependencies
RUN apt-get update -y && \
apt-get install -y software-properties-common && \
add-apt-repository -y ppa:deadsnakes/ppa && \
apt-get install -y python3.9 python3.9-dev python3-pip python3.9-venv python3-dev python3-opencv swig ffmpeg git unzip wget libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf && \
update-alternatives --install /usr/bin/python python /usr/bin/python3.9 10 && \
python -m venv og-marl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Setup virtual env and path
ENV VIRTUAL_ENV /og-marl
ENV PATH /og-marl/bin:$PATH

# Location of og-marl folder
ARG folder=/home/app/og-marl

# Set working directory
WORKDIR ${folder}

# Copy all code needed to install dependencies
COPY ./install_environments ./install_environments
COPY ./og_marl ./og_marl
COPY setup.py .

RUN echo "Installing requirements..."
RUN pip install --quiet --upgrade pip setuptools wheel && \
pip install -e . && \
pip install flashbax==0.1.0

ENV SC2PATH /home/app/StarCraftII
# RUN ./install_environments/smacv1.sh
RUN ./install_environments/smacv2.sh

# ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia
# ENV SUPPRESS_GR_PROMPT 1
# RUN ./install_environments/mamujoco.sh

# Copy all code
COPY ./examples ./examples
COPY ./baselines ./baselines
150 changes: 43 additions & 107 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
</p>

<h2 align="center">
<p>Off-the-Grid MARL: Offline Multi-Agent Reinforcement Learning made easy</p>
<p>Off-the-Grid MARL: Offline Multi-Agent Reinforcement Learning Datasets and Baselines</p>
</h2>
<p align="center">
<a href="https://www.python.org/doc/versions/">
Expand Down Expand Up @@ -36,41 +36,28 @@ OG-MARL forms part of the [InstaDeep](https://www.instadeep.com/) MARL [ecosyste
community. To join us in these efforts, reach out, raise issues or just
🌟 to stay up to date with the latest developments!

## Updates [06/12/2023] 📰
## Quickstart 🏎️
Clone this repository.

OG-MARL is a research tool that is under active development and therefore evolving quickly. We have several very exciting new features on the roadmap but sometimes when we introduce a new feature we may abruptly change how things work in OG-MARL.
But in the interest of moving quickly, we believe this is an acceptable trade-off and ask our users to kindly be aware of this.

The following is a list of the latest updates to OG-MARL:

✅ We have **removed several cumbersome dependencies** from OG-MARL, including `reverb` and `launchpad`. This means that its significantly easier to install and use OG-MARL.
`git clone https://github.com/instadeepai/og-marl.git`

✅ We added **functionality to pre-load the TF Record datasets into a [Cpprb](https://ymd_h.gitlab.io/cpprb/) replay buffer**. This speeds up the time to sample the replay buffer by several orders of magnitude.
Install `og-marl` and its dependencies. We tested `og-marl` with Python 3.9. Consider using a `conda` virtual environment.

✅ We have implemented our **first set of JAX-based systems in OG-MARL**. Our JAX systems use [Flashbax](https://github.com/instadeepai/flashbax) as the replay buffer backend. Flashbax buffers are completely jit-able, which means that our JAX systems have fully integrated and jitted training and data sampling.
`pip install -e .`

✅ We have **integrated [MARL-eval](https://github.com/instadeepai/marl-eval/tree/main)** into OG-MARL to standardise and simplify the reporting of experimental results.
`pip install flashbax==0.1.0`

## Need for Speed 🏎️
Download environment dependencies. We will use SMACv1 in this example.

We have made our TF2 systems compatible with jit compilation. This combined with our new `cpprb` replay buffers have made our systems significantly faster. Furthermore, our JAX systems with tightly integrated replay sampling and training using Flashbax are even faster.
`bash install_environments/smacv1.sh`

**Speed Comparison**: for each setup, we trained MAICQ on the 8m Good dataset for 10k training steps and evaluated every 1k training steps for 4 episodes using a batch size of 256.
Download a dataset.

<div class="collage">
<div class="row" align="center">
<img src="docs/assets/system_speed_comparison.png" alt="OG-MARL Speed Comparison" width="65%"/>
</div>
</div>
`python examples/download_vault.py --env=smac_v1 --scenario=3m`

**Performance Comparison**: In order to make sure performance between the TF2 system and the JAX system is the same, we trained both variants on each of the three datasets for 8m (Good, Medium and Poor). We then normalised the scores and aggregated the results using MARL-eval. The sample efficiency curves and the performance profiles are given below.
Run a baseline. In this example we will run MAICQ.

<div class="collage">
<div class="row" align="center">
<img src="docs/assets/sample_efficiency.png" alt="Sample Efficiency" width="45%"/>
<img src="docs/assets/performance_profile.png" alt="Performance Profile" width="35%"/>
</div>
</div>
`python baselines/main.py --env=smac_v1 --scenario=3m --dataset=Good --system=maicq`

## Datasets 🎥

Expand All @@ -96,65 +83,53 @@ We have generated datasets on a diverse set of popular MARL environments. A list

<br/>

## Dataset Backends 🔌
We are in the process of migrating our datasets from TF Records to Flashbax Vaults. Flashbax Vaults have the advantage of being significantly more flexible than the TF Record Datasets.

### Flashbax Vaults ⚡
| Environment | Scenario | Agents | Act | Obs | Reward | Types | Repo |
|-----|----|----|-----|-----|----|----|-----|
| 🔫SMAC v1 | 3m <br/> 8m <br/> 2s3z <br/> 5m_vs_6m <br/> 27m_vs_30m <br/> 3s5z_vs_3s6z <br/> 2c_vs_64zg| 3 <br/> 8 <br/> 5 <br/> 5 <br/> 27 <br/> 8 <br/> 2 | Discrete | Vector | Dense | Homog <br/> Homog <br/> Heterog <br/> Homog <br/> Homog <br/> Heterog <br/> Homog |[source](https://github.com/oxwhirl/smac) |
| 💣SMAC v2 | terran_5_vs_5 <br/> zerg_5_vs_5 <br/> terran_10_vs_10 | 5 <br/> 5 <br/> 10 | Discrete | Vector | Dense | Heterog | [source](https://github.com/oxwhirl/smacv2) |
| 🐻PettingZoo | Pursuit <br/> Co-op Pong <br/> PistonBall <br/> KAZ| 8 <br/> 2 <br/> 15 <br/> 2| Discrete <br/> Discrete <br/> Cont. <br/> Discrete | Pixels <br/> Pixels <br/> Pixels <br/> Vector | Dense | Homog <br/> Heterog <br/> Homog <br/> Heterog| [source](https://pettingzoo.farama.org/) |
| 🚅Flatland | 3 Trains <br/> 5 Trains | 3 <br/> 5 | Discrete | Vector | Sparse | Homog | [source](https://flatland.aicrowd.com/intro.html) |
| 🐜MAMuJoCo | 2-HalfCheetah <br/> 2-Ant <br/> 4-Ant | 2 <br/> 2 <br/> 4 | Cont. | Vector | Dense | Heterog <br/> Homog <br/> Homog | [source](https://github.com/schroederdewitt/multiagent_mujoco) |


### Legacy Datasets (still to be migrated to Vault) 👴
| Environment | Scenario | Agents | Act | Obs | Reward | Types | Repo |
|-----|----|----|-----|-----|----|----|-----|
| 🐻PettingZoo | Pursuit <br/> Co-op Pong <br/> PistonBall <br/> KAZ| 8 <br/> 2 <br/> 15 <br/> 2| Discrete <br/> Discrete <br/> Cont. <br/> Discrete | Pixels <br/> Pixels <br/> Pixels <br/> Vector | Dense | Homog <br/> Heterog <br/> Homog <br/> Heterog| [source](https://pettingzoo.farama.org/) |
| 🏙️CityLearn | 2022_all_phases | 17 | Cont. | Vector | Dense | Homog | [source](https://github.com/intelligent-environments-lab/CityLearn) |
| 🔌Voltage Control | case33_3min_final | 6 | Cont. | Vector | Dense | Homog | [source](https://github.com/Future-Power-Networks/MAPDN) |
| 🔴MPE | simple_adversary | 3 | Discrete. | Vector | Dense | Competitive | [source](https://pettingzoo.farama.org/environments/mpe/simple_adversary/) |

**Note:** The dataset on KAZ was generated by recording experience from human game players.

## Quickstart 🏁

### Installation 🛠️

Start by cloning this repository:
`git clone https://github.com/instadeepai/og-marl.git`

Navigate into the og-marl directory.

`cd og-marl`

To install og-marl and its dependencies run the following command in a virtual environment (e.g. conda).

`pip install -e .`

To run the JAX based systems include the extra requirements.

`pip install -e .[jax]`

### Environments ⛰️

Depending on the environment you want to use, you should install that environments dependencies. We provide convenient shell scripts for this.
## Dataset API

`bash install_environments/<environment_name>.sh`
We provide a simple demonstrative notebook of how to use OG-MARL's dataset API here:
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/og-marl/blob/main/examples/dataset_api_demo.ipynb)

You should replace `<environment_name>` with the name of the environment you want to install.

Installing several different environments dependencies in the same python virtual environment (or conda environment) may work in some cases but in others, they may have conflicting requirements. So we recommend maintaining a different virtual environment for each environment.

### Downloading Datasets ⏬

Next you need to download the dataset you want to use and add it to the correct file path. We provided a utility for easily downloading and extracting datasets. Below is an example of how to download the dataset for the "3m" map in SMACv1.

```python
from og_marl.offline_dataset import download_and_unzip_dataset

download_and_unzip_dataset("smac_v1", "3m")
```

After running the download function you should check that the datasets were extracted to the correct location, as below. Alternatively, go to the OG-MARL [website](<https://sites.google.com/view/og-marl>) and download the dataset manually. Once the zip file is downloaded, extract the datasets from it and add them to a directory called `datasets` on the same level as the `og-marl/` directory. The folder structure should look like this:
### Dataset and Vault Locations
For OG-MARL's systems, we require the following dataset storage structure:

```
examples/
|_> ...
og_marl/
|_> ...
vaults/
|_> smac_v1/
|_> 3m.vlt/
| |_> Good/
| |_> Medium/
| |_> Poor/
|_> ...
|_> smac_v2/
|_> terran_5_vs_5.vlt/
| |_> Good/
| |_> Medium/
| |_> Poor/
|_> ...
datasets/
|_> smac_v1/
|_> 3m/
Expand All @@ -171,47 +146,6 @@ datasets/
...
```

### Launching Experiments 🚀

We include scripts (`examples/tf2/main.py` and `examples/jax/main.py`) for easily launching experiments using the command below:

`python examples/<backend>/main.py --system=<system_name> --env=<env_name> --scenario=<scenario_name>`

Example options for each placeholder are given below:

* `<backend>` : {`jax`, `tf2`}
* `<system_name>` : {`maicq`, `qmix`, `qmix+cql`, `qmix+bcq`, `idrqn`, `iddpg`, ...}
* `<env_name>` : {`smac_v1`, `smac_v2`, `mamujoco`, ...}
* `<scenario_name>`: {`3m`, `8m`, `terran_5_vs_5`, `2halfcheetah`, ...}

**Note:** We have not implemented any checks to make sure the combination of `env`, `scenario` and `system` is valid. For example, certain algorithms can only be run on discrete action environments. We hope to implement more guard rails in the future. For now, please refer to the code and the paper for clarification. We are also still in the process of migrating all the experiments to this unified launcher. So if some configuration is not supported yet, please reach out in the issues and we will be happy to help.

### Code Snippet 🧙‍♂️

```python
from og_marl.offline_dataset import download_flashbax_dataset
from og_marl.environments.smacv1 import SMACv1
from og_marl.jax.systems.maicq import train_maicq_system
from og_marl.loggers import TerminalLogger

# Download the dataset
download_flashbax_dataset(
env_name="smac_v1",
scenario_name="8m",
base_dir="datasets/flashbax"
)
dataset_path = "datasets/flashbax/smac_v1/8m/Good"

# Instantiate environment for evaluation
env = SMACv1("8m")

# Setup a logger to write to terminal
logger = TerminalLogger()

# Train system
train_maicq_system(env, logger, dataset_path)
```

## See Also 🔎

**InstaDeep's MARL ecosystem in JAX.** In particular, we suggest users check out the following sister repositories:
Expand Down Expand Up @@ -242,6 +176,8 @@ If you use OG-MARL in your work, please cite the library using:
}
```

<img src="docs/assets/aamas2023.png" alt="AAMAS 2023" width="20%"/>

## Acknowledgements 🙏

The development of this library was supported with Cloud TPUs
Expand Down
68 changes: 68 additions & 0 deletions baselines/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import app, flags

from og_marl.environments.utils import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.offline_dataset import download_and_unzip_vault
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

FLAGS = flags.FLAGS
flags.DEFINE_string("env", "smac_v1", "Environment name.")
flags.DEFINE_string("scenario", "3m", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
flags.DEFINE_string("system", "dbc", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.")
flags.DEFINE_integer("batch_size", 64, "Number of training steps.")


def main(_):
config = {
"env": FLAGS.env,
"scenario": FLAGS.scenario,
"dataset": FLAGS.dataset,
"system": FLAGS.system,
"backend": "tf2",
}

env = get_environment(FLAGS.env, FLAGS.scenario)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=2)

download_and_unzip_vault(FLAGS.env, FLAGS.scenario)

is_vault_loaded = buffer.populate_from_vault(FLAGS.env, FLAGS.scenario, FLAGS.dataset)
if not is_vault_loaded:
print("Vault not found. Exiting.")
return

logger = WandbLogger(project="og-marl-baselines", config=config)

json_writer = JsonWriter(
"logs", f"{FLAGS.system}", f"{FLAGS.scenario}_{FLAGS.dataset}", FLAGS.env, FLAGS.seed, file_name=f"{FLAGS.scenario}_{FLAGS.dataset}_{FLAGS.seed}.json", save_to_wandb=True
)

system_kwargs = {"add_agent_id_to_obs": True}
system = get_system(FLAGS.system, env, logger, **system_kwargs)

system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer)


if __name__ == "__main__":
app.run(main)
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# API Reference (Coming Soon)
Binary file added docs/assets/aamas2023.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions docs/baselines.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Latest Baseline Results

<iframe src="https://wandb.ai/off-the-grid-marl-team/og-marl-baselines/reports/OG-MARL-Refactor-Results--Vmlldzo2ODk4NjYw" style="border:none;height:1024px;width:100%">
Loading

0 comments on commit 273c679

Please sign in to comment.