Skip to content

Commit

Permalink
[BugFix] Add MultiCategorical support in PettingZoo action masks (#…
Browse files Browse the repository at this point in the history
…2485)

Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
matteobettini and vmoens authored Oct 14, 2024
1 parent 77de5ee commit be15cab
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dependencies:
- pyyaml
- autorom[accept-rom-license]
- pettingzoo[all]==1.24.3
- gymnasium<1.0.0
7 changes: 2 additions & 5 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]):
n=2,
shape=group_action_spec["action"].shape
if not self.categorical_actions
else (
*group_action_spec["action"].shape,
group_action_spec["action"].space.n,
),
else group_action_spec["action"].to_one_hot_spec().shape,
dtype=torch.bool,
device=self.device,
)
Expand Down Expand Up @@ -494,7 +491,7 @@ def _init_env(self):
n=2,
shape=group_action_spec.shape
if not self.categorical_actions
else (*group_action_spec.shape, group_action_spec.space.n),
else group_action_spec.to_one_hot_spec().shape,
dtype=torch.bool,
device=self.device,
)
Expand Down

0 comments on commit be15cab

Please sign in to comment.