Skip to content

Commit

Permalink
Fix get_environment function to work with 3rd party datasets.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Aug 26, 2024
1 parent ad0428d commit 5c3e01a
Show file tree
Hide file tree
Showing 12 changed files with 17 additions and 17 deletions.
12 changes: 6 additions & 6 deletions og_marl/environments.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from og_marl.environment_wrappers.base import BaseEnvironment


def get_environment(env_name: str, scenario: str, seed: int = 42) -> BaseEnvironment: # noqa: C901
if env_name in ["smac_v1", "smac_v1_cfcql"]:
def get_environment(source: str, env_name: str, scenario: str, seed: int = 42) -> BaseEnvironment: # noqa: C901
if env_name=="smac_v1" and source in ["cfcql", "og_marl"]:
from og_marl.environment_wrappers.smacv1 import SMACv1

return SMACv1(scenario, seed=seed)
elif env_name == "smac_v1_omiga":
elif env_name == "smac_v1" and source=="omiga":
from og_marl.environment_wrappers.smacv1_omiga import SMACv1OMIGA

return SMACv1OMIGA(scenario, seed=seed)
elif env_name == "smac_v2":
from og_marl.environment_wrappers.smacv2 import SMACv2

return SMACv2(scenario, seed=seed)
elif env_name == "mpe_omar":
elif env_name == "mpe" and source=="omar":
from og_marl.environment_wrappers.mpe_omar import MPEOMAR

return MPEOMAR(scenario, seed=seed)
elif env_name == "mamujoco":
elif env_name == "mamujoco" and source == "og_marl":
from og_marl.environment_wrappers.mamujoco import MAMuJoCo

return MAMuJoCo(scenario, seed=seed)
elif env_name == "mamujoco_omiga":
elif env_name == "mamujoco" and source=="omiga":
from og_marl.environment_wrappers.mamujoco_omiga import MAMuJoCoOMIGA

return MAMuJoCoOMIGA(scenario, seed=seed)
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/continuous_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/discrete_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/iddpg_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/iddpg_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/iql_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/iql_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/maddpg_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/maicq.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _tf_train_step(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/omar.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/qmix_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def mixing(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/qmix_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def mixing(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

env = get_environment(cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])
env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
sequence_length=cfg["replay"]["sequence_length"],
Expand Down

0 comments on commit 5c3e01a

Please sign in to comment.