-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from instadeepai/feat/replace-cpprb-with-flashbax
feat: Big 'Everything' update
- Loading branch information
Showing
57 changed files
with
2,649 additions
and
1,109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,9 @@ launch.json | |
.vscode | ||
json_logs | ||
vaults | ||
vaults_unprocessed | ||
development | ||
SMAC_Maps | ||
logs | ||
__MACOSX | ||
3.9 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 | ||
|
||
# Ensure no installs try to launch interactive screen | ||
ARG DEBIAN_FRONTEND=noninteractive | ||
|
||
# Update packages and install python3.9 and other dependencies | ||
RUN apt-get update -y && \ | ||
apt-get install -y software-properties-common && \ | ||
add-apt-repository -y ppa:deadsnakes/ppa && \ | ||
apt-get install -y python3.9 python3.9-dev python3-pip python3.9-venv python3-dev python3-opencv swig ffmpeg git unzip wget libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf && \ | ||
update-alternatives --install /usr/bin/python python /usr/bin/python3.9 10 && \ | ||
python -m venv og-marl && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# Setup virtual env and path | ||
ENV VIRTUAL_ENV /og-marl | ||
ENV PATH /og-marl/bin:$PATH | ||
|
||
# Location of og-marl folder | ||
ARG folder=/home/app/og-marl | ||
|
||
# Set working directory | ||
WORKDIR ${folder} | ||
|
||
# Copy all code needed to install dependencies | ||
COPY ./install_environments ./install_environments | ||
COPY ./og_marl ./og_marl | ||
COPY setup.py . | ||
|
||
RUN echo "Installing requirements..." | ||
RUN pip install --quiet --upgrade pip setuptools wheel && \ | ||
pip install -e . && \ | ||
pip install flashbax==0.1.0 | ||
|
||
ENV SC2PATH /home/app/StarCraftII | ||
# RUN ./install_environments/smacv1.sh | ||
RUN ./install_environments/smacv2.sh | ||
|
||
# ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia | ||
# ENV SUPPRESS_GR_PROMPT 1 | ||
# RUN ./install_environments/mamujoco.sh | ||
|
||
# Copy all code | ||
COPY ./examples ./examples | ||
COPY ./baselines ./baselines |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2023 InstaDeep Ltd. All rights reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from absl import app, flags | ||
|
||
from og_marl.environments.utils import get_environment | ||
from og_marl.loggers import JsonWriter, WandbLogger | ||
from og_marl.replay_buffers import FlashbaxReplayBuffer | ||
from og_marl.offline_dataset import download_and_unzip_vault | ||
from og_marl.tf2.systems import get_system | ||
from og_marl.tf2.utils import set_growing_gpu_memory | ||
|
||
set_growing_gpu_memory() | ||
|
||
FLAGS = flags.FLAGS | ||
flags.DEFINE_string("env", "smac_v1", "Environment name.") | ||
flags.DEFINE_string("scenario", "3m", "Environment scenario name.") | ||
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ") | ||
flags.DEFINE_string("system", "dbc", "System name.") | ||
flags.DEFINE_integer("seed", 42, "Seed.") | ||
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.") | ||
flags.DEFINE_integer("batch_size", 64, "Number of training steps.") | ||
|
||
|
||
def main(_): | ||
config = { | ||
"env": FLAGS.env, | ||
"scenario": FLAGS.scenario, | ||
"dataset": FLAGS.dataset, | ||
"system": FLAGS.system, | ||
"backend": "tf2", | ||
} | ||
|
||
env = get_environment(FLAGS.env, FLAGS.scenario) | ||
|
||
buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=2) | ||
|
||
download_and_unzip_vault(FLAGS.env, FLAGS.scenario) | ||
|
||
is_vault_loaded = buffer.populate_from_vault(FLAGS.env, FLAGS.scenario, FLAGS.dataset) | ||
if not is_vault_loaded: | ||
print("Vault not found. Exiting.") | ||
return | ||
|
||
logger = WandbLogger(project="og-marl-baselines", config=config) | ||
|
||
json_writer = JsonWriter( | ||
"logs", f"{FLAGS.system}", f"{FLAGS.scenario}_{FLAGS.dataset}", FLAGS.env, FLAGS.seed, file_name=f"{FLAGS.scenario}_{FLAGS.dataset}_{FLAGS.seed}.json", save_to_wandb=True | ||
) | ||
|
||
system_kwargs = {"add_agent_id_to_obs": True} | ||
system = get_system(FLAGS.system, env, logger, **system_kwargs) | ||
|
||
system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# API Reference (Coming Soon) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Latest Baseline Results | ||
|
||
<iframe src="https://wandb.ai/off-the-grid-marl-team/og-marl-baselines/reports/OG-MARL-Refactor-Results--Vmlldzo2ODk4NjYw" style="border:none;height:1024px;width:100%"> |
Oops, something went wrong.