diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6de6318f..cc2aa18d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,4 +1,16 @@
-### [v0.0.1] - DATE
+### [v0.0.4]
+
+#### Added
+ - Reset to specific state within environment base class
+ - JaxNav environment
+
+
+### [v0.0.3]
+
+#### Added
+ - Hanabi bug fixes
+
+### [v0.0.1]
##### Added
- base set of environments and algorithms
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index df19be2b..eb69e8a9 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -35,4 +35,5 @@ All contributions will fall under the project's original license.
## Roadmap
Some improvements we would like to see implemented:
-- [ ] improved RNN implementations. In the current implementation, the hidden size is dependent on "NUM_STEPS", it should be made independent. Speed could also be improved with an S5 architecture.
+- [x] improved RNN implementations. In the current implementation, the hidden size is dependent on "NUM_STEPS", it should be made independent.
+- [ ] S5 RNN architecture.
diff --git a/Dockerfile b/Dockerfile
index 7ca0eee8..b3299296 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -10,14 +10,15 @@ USER ${MYUSER}
WORKDIR /home/${MYUSER}/
COPY --chown=${MYUSER} --chmod=765 . .
-#jaxmarl from source if needed, all the requirements
USER root
-RUN pip install -e .
# install tmux
RUN apt-get update && \
apt-get install -y tmux
+#jaxmarl from source if needed, all the requirements
+RUN pip install -e .
+
USER ${MYUSER}
#disabling preallocation
diff --git a/README.md b/README.md
index b6569bae..64766d4d 100644
--- a/README.md
+++ b/README.md
@@ -28,7 +28,7 @@
@@ -50,6 +50,7 @@ For more details, take a look at our [blog post](https://blog.foersterlab.com/ja
| 🎆 Hanabi | [Paper](https://arxiv.org/abs/1902.00506) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/hanabi) | Fully-cooperative partially-observable multiplayer card game |
| 👾 SMAX | Novel | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/smax) | Simplified cooperative StarCraft micro-management environment |
| 🧮 STORM: Spatial-Temporal Representations of Matrix Games | [Paper](https://openreview.net/forum?id=54F8woU8vhq) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/storm) | Matrix games represented as grid world scenarios
+| 🧠JaxNav | Paper coming | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/jaxnav) | 2D geometric navigation for differential drive robots
| 🪙 Coin Game | [Paper](https://arxiv.org/abs/1802.09640) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/coin_game) | Two-player grid world environment which emulates social dilemmas
| 💡 Switch Riddle | [Paper](https://proceedings.neurips.cc/paper_files/paper/2016/hash/c7635bfd99248a2cdef8249ef7bfbef4-Abstract.html) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/switch_riddle) | Simple cooperative communication game included for debugging
diff --git a/docs/imgs/jaxnav-ma.gif b/docs/imgs/jaxnav-ma.gif
new file mode 100644
index 00000000..f4996d5d
Binary files /dev/null and b/docs/imgs/jaxnav-ma.gif differ
diff --git a/jaxmarl/__init__.py b/jaxmarl/__init__.py
index 58e75ae8..fdf4a908 100644
--- a/jaxmarl/__init__.py
+++ b/jaxmarl/__init__.py
@@ -1,4 +1,4 @@
from .registration import make, registered_envs
__all__ = ["make", "registered_envs"]
-__version__ = "0.0.3"
+__version__ = "0.0.4"
diff --git a/jaxmarl/environments/__init__.py b/jaxmarl/environments/__init__.py
index 0415dd8b..8c9b2cf4 100644
--- a/jaxmarl/environments/__init__.py
+++ b/jaxmarl/environments/__init__.py
@@ -21,4 +21,5 @@
from .hanabi import Hanabi
from .storm import InTheGrid, InTheGrid_2p
from .coin_game import CoinGame
+from .jaxnav import JaxNav
diff --git a/jaxmarl/environments/jaxnav/README.md b/jaxmarl/environments/jaxnav/README.md
new file mode 100644
index 00000000..1f10b28f
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/README.md
@@ -0,0 +1,5 @@
+# 🧠JaxNav
+
+2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles.
+
+MORE TO COME.
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/__init__.py b/jaxmarl/environments/jaxnav/__init__.py
new file mode 100644
index 00000000..09f805df
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/__init__.py
@@ -0,0 +1,2 @@
+from .jaxnav_env import JaxNav
+from .jaxnav_singletons import make_jaxnav_singleton, make_jaxnav_singleton_collection, JaxNavSingleton
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/jaxnav_env.py b/jaxmarl/environments/jaxnav/jaxnav_env.py
new file mode 100644
index 00000000..91bc2734
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/jaxnav_env.py
@@ -0,0 +1,774 @@
+"""
+Rob sim that follows the JaxMARL interface
+"""
+
+import jax
+import jax.numpy as jnp
+from jax import random, jit, vmap
+import numpy as np
+from functools import partial
+import chex
+from flax import struct
+from typing import Tuple, Dict
+#from gymnax.environments import spaces
+import os, pathlib
+import matplotlib.pyplot as plt
+import matplotlib.axes._axes as axes
+
+from jaxmarl.environments import MultiAgentEnv
+from jaxmarl.environments.spaces import Box, Discrete
+
+from .maps import make_map, Map
+from .jaxnav_utils import pol2cart, wrap, unitvec, cart2pol
+import jaxmarl.environments.jaxnav.jaxnav_graph_utils as _graph_utils
+
+
+NUM_REWARD_COMPONENTS = 2
+REWARD_COMPONENT_SPARSE = 0
+REWARD_COMPONENT_DENSE = 1
+@struct.dataclass
+class Reward:
+ sparse: jnp.ndarray
+ dense: jnp.ndarray
+
+def listify_reward(reward: Reward, do_batchify=False): # returns shape of (*batch, num_agents, 2)
+ ans = jnp.stack(
+ [reward.sparse, reward.dense],
+ axis=-1
+ )
+ # batchify stacks the agents first and then does reshape, which is why we need the swapaxes.
+ if do_batchify:
+ ans = jnp.swapaxes(ans, 0, 1).reshape(-1, *ans.shape[2:]) # shape of (batch * num_agents, 2)
+ return ans
+
+@struct.dataclass
+class State:
+ pos: chex.Array # [n, [x, y, theta]]
+ theta: chex.Array # [n, theta]
+ vel: chex.Array # [n, [speed, omega]]
+ done: chex.Array # [n, bool] whether an agent has terminated
+ term: chex.Array # [n, bool] whether an agent acted in this step
+ goal_reached: chex.Array # [n, bool] whether an agent has reached goal
+ move_term: chex.Array # [n, bool] whether an agent has crashed
+ step: int # step count
+ ep_done: bool # whether epsiode has terminated
+ goal: chex.Array # [n, x, y]
+ map_data: chex.Array # occupancy grid for environment map
+ rew_lambda: float # linear interpolation between individual and team rewards
+
+@struct.dataclass
+class EnvInstance:
+ agent_pos: chex.Array
+ agent_theta: chex.Array
+ goal_pos: chex.Array
+ map_data: chex.Array
+ rew_lambda: chex.Array
+
+### ---- Discrete action constants ----
+DISCRETE_ACTS = jnp.array([
+ jnp.array([0.0, 0.5]),
+ jnp.array([0.0, 0.25]),
+ jnp.array([0.0, 0.0]),
+ jnp.array([0.0, -0.25]),
+ jnp.array([0.0, -0.5]),
+ jnp.array([0.5, 0.5]),
+ jnp.array([0.5, 0.25]),
+ jnp.array([0.5, 0.0]),
+ jnp.array([0.5, -0.25]),
+ jnp.array([0.5, -0.5]),
+ jnp.array([1.0, 0.5]),
+ jnp.array([1.0, 0.25]),
+ jnp.array([1.0, 0.0]),
+ jnp.array([1.0, -0.25]),
+ jnp.array([1.0, -0.5]),
+], dtype=jnp.float32)
+
+
+@partial(jax.vmap, in_axes=[0])
+def discrete_act_map(action: int) -> jnp.ndarray:
+ print('action', action, action.shape)
+ return DISCRETE_ACTS[action]
+
+## ---- Environment defaults ----
+AGENT_BASE = "agent"
+MAP_PARAMS = {
+ "map_size": (7, 7),
+ "fill": 0.3,
+}
+
+## ---- Environment ----
+class JaxNav(MultiAgentEnv):
+
+ def __init__(self,
+ num_agents: int, # Number of agents
+ act_type="Continuous", # Action type, either Continuous or Discrete
+ normalise_obs=True,
+ rad=0.3, # Agent radius
+ evaporating=False, # Whether agents evaporate (dissapeare) when they reach the goal
+ map_id="Grid-Rand-Poly", # Map type
+ map_params=MAP_PARAMS, # Map parameters
+ lidar_num_beams=200,
+ lidar_range_resolution=0.05,
+ lidar_max_range=6.0,
+ lidar_min_range=0.0,
+ lidar_angle_factor=1.0,
+ min_v=0.0,
+ max_v=1.0,
+ max_v_acc=1.0,
+ max_w=1.0,
+ max_w_acc=1.0,
+ max_steps=500,
+ dt=0.1,
+ fixed_lambda=True,
+ rew_lambda=1.0, # linear interpolation between individual and team rewards
+ lambda_range=[0.0, 1.0],
+ goal_radius=0.3,
+ goal_rew=4.0,
+ weight_g=0.25,
+ lim_w=0.7,
+ weight_w=-0.0,
+ dt_rew=-0.01,
+ coll_rew=-5.0,
+ lidar_thresh=0.1,
+ lidar_rew=-0.1,
+ do_sep_reward=False,
+ share_only_sparse=False,
+ info_by_agent=False,
+ ):
+ super().__init__(num_agents)
+
+ assert rad < 1, "current code assumes radius of less than 1"
+ self.rad = rad
+ self.agents = ["agent_{}".format(i) for i in range(num_agents)]
+ self.agent_range = jnp.arange(0, num_agents)
+ self.evaporating = evaporating
+
+ self._map_obj = make_map(map_id, self.num_agents, self.rad, **map_params)
+ self._act_type = act_type
+ if self._act_type == "Discrete":
+ assert min_v == 0.0, "min_v must be 0.0 for Discrete actions"
+
+ # Lidar parameters
+ self.normalise_obs = normalise_obs
+ self.lidar_num_beams = lidar_num_beams
+ self.lidar_max_range = lidar_max_range
+ self.lidar_min_range = lidar_min_range
+ assert self.lidar_min_range == 0.0, "lidar_min_range must be 0.0 FOR NOW"
+ self.lidar_range_resolution = lidar_range_resolution
+ self.lidar_angle_factor = lidar_angle_factor
+ self.lidar_max_angle = jnp.pi * self.lidar_angle_factor
+ self.lidar_angles = jnp.linspace(-jnp.pi * self.lidar_angle_factor, jnp.pi * self.lidar_angle_factor, self.lidar_num_beams)
+ #self.lidar_ranges = jnp.arange(self.lidar_min_range, self.lidar_max_range, self.lidar_range_resolution)
+ num_lidar_samples = int((self.lidar_max_range - self.lidar_min_range) / self.lidar_range_resolution)
+ self.lidar_ranges = jnp.linspace(self.lidar_min_range, self.lidar_max_range, num_lidar_samples)
+
+ assert min_v < max_v, "min_v must be less than max_v"
+ if min_v != 0.0: print(f"WARNING: min_v is not 0.0, it is {min_v}")
+ self.min_v = min_v # min linear velocity (m/s)
+ self.max_v = max_v # max linear velocity (m/s)
+ self.max_v_acc = max_v_acc # max linear acceleration (m/s^2)
+ self.max_w = max_w # max angular velocity (rad/s)
+ self.max_w_acc = max_w_acc # max angular acceleration (rad/s^2)
+ self.max_steps = max_steps # max environment steps within an episode
+ self.dt = dt # seconds per step (s)
+
+ # Rewards
+ # if share_only_sparse:
+ # do_sep_reward, f"If share_only_sparse is True, do_sep_reward must be True for it to work, it is current: {do_sep_reward}"
+ self.do_sep_reward = do_sep_reward
+ self.share_only_sparse = share_only_sparse
+ self.fixed_lambda = fixed_lambda
+ self.rew_lambda = rew_lambda # linear interpolation between individual and team rewards
+ self.lambda_range = lambda_range
+ if self.fixed_lambda: assert self.rew_lambda is not None, "If fixed_lambda is True, rew_lambda must be set"
+ self.goal_radius = goal_radius # goal radius (m)
+ self.goal_rew = goal_rew
+ self.weight_g = weight_g
+ self.lim_w = lim_w
+ self.weight_w = weight_w
+ self.dt_rew = dt_rew
+ self.coll_rew = coll_rew
+ self.lidar_thresh = lidar_thresh
+ self.lidar_rew = lidar_rew
+
+ self.info_by_agent = info_by_agent
+ self.eval_solved_rate = self.get_eval_solved_rate_fn()
+
+ self.action_spaces = {a: self.agent_action_space() for a in self.agents}
+ self.observation_spaces = {a: self.agent_observation_space() for a in self.agents}
+
+ @property
+ def map_obj(self) -> Map:
+ """ Return map object """
+ return self._map_obj
+
+
+ @partial(jax.jit, static_argnums=[0])
+ def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
+ """ Reset environment. Returns initial agent observations, states and the enviornment state """
+
+ state = self.sample_test_case(key)
+ obs = self._get_obs(state)
+ return {a: obs[i] for i, a in enumerate(self.agents)}, state
+
+ @partial(jax.jit, static_argnums=[0])
+ def step_env(
+ self,
+ key: chex.PRNGKey,
+ agent_states: State,
+ actions: Dict[str, chex.Array]
+ ):
+ actions = jnp.array([actions[a] for a in self.agents]) # Batchify
+
+ # 1) Update agent states
+ if self._act_type == "Discrete": actions = discrete_act_map(actions).reshape(actions.shape[0], 2)
+ old_pos = agent_states.pos
+ update_state_valid = agent_states.done
+ if not self.evaporating:
+ update_state_valid = update_state_valid | agent_states.move_term
+ new_pos, new_theta, new_vel = self.update_state(agent_states.pos, agent_states.theta, agent_states.vel, actions, update_state_valid)
+ step = agent_states.step+1
+
+ # 2) Check collisions, goal and time
+ old_goal_reached = agent_states.goal_reached
+ old_move_term = agent_states.move_term
+ map_collisions = self._check_map_collisions(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
+ agent_collisions = self._check_agent_collisions(jnp.arange(agent_states.pos.shape[0]), new_pos, agent_states.done)*(1- agent_states.done).astype(bool)
+ collisions = map_collisions | agent_collisions
+ goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool)
+ time_up = jnp.full((self.num_agents,), (step >= self.max_steps))
+
+ # 3) Compute rewards and done values
+ old_done = agent_states.done
+ if self.evaporating:
+ dones = collisions | goal_reached | time_up | agent_states.done # OR operation over agent status
+ ep_done = jnp.all(dones)
+ else:
+ goal_reached = goal_reached | old_goal_reached
+ collisions = collisions | old_move_term
+ ep_done = jnp.all(goal_reached | collisions | time_up)
+ dones = jnp.full((self.num_agents,), ep_done)
+
+ # 4) Update JAX state
+ agent_states = agent_states.replace(
+ pos=new_pos,
+ theta=new_theta,
+ vel=new_vel,
+ move_term=collisions,
+ goal_reached=goal_reached,
+ done=dones,
+ term=old_done,
+ step=step,
+ ep_done=ep_done,
+ )
+
+ dones = {a: agent_states.done[i] for i, a in enumerate(self.agents)}
+ dones["__all__"] = ep_done
+
+ # 5) Compute observations
+ obs_batch = self._get_obs(agent_states)
+
+ # 6) Reward
+ rew_individual, individual_rew_sep = self.compute_reward(
+ obs_batch,
+ agent_states.pos,
+ old_pos,
+ actions,
+ agent_states.goal,
+ collisions,
+ goal_reached,
+ old_done,
+ old_goal_reached,
+ old_move_term
+ )
+ avg_rew = rew_individual.mean()
+ if self.share_only_sparse:
+ shared_rew = individual_rew_sep.sparse.mean()
+ else:
+ shared_rew = avg_rew
+
+ if self.do_sep_reward:
+ rew_batch = self.rew_lambda * rew_individual + (1 - self.rew_lambda) * shared_rew
+ else:
+ rew_batch = self.rew_lambda * rew_individual + (1 - self.rew_lambda) * shared_rew
+
+ rew = {a: rew_batch[i] for i, a in enumerate(self.agents)}
+
+ obs = {a: obs_batch[i] for i, a in enumerate(self.agents)}
+
+ if self.evaporating:
+ num_c = jnp.sum(collisions | goal_reached | time_up)
+ time_o = time_up & ~old_done
+ else:
+ num_c = jax.lax.select(ep_done, self.num_agents, 0)
+ time_o = time_up & ~(collisions | goal_reached)
+
+ goal_r = goal_reached * (1 - old_goal_reached)
+ agent_c = agent_collisions * (1 - old_move_term)
+ map_c = map_collisions * (1 - old_move_term)
+ rew_info = avg_rew
+ if not self.info_by_agent:
+ goal_r = jnp.sum(goal_r)
+ agent_c = jnp.sum(agent_c)
+ map_c = jnp.sum(map_c)
+ time_o = jnp.sum(time_o)
+ term = {a: old_done[i] for i, a in enumerate(self.agents)}
+ else:
+ num_c = jnp.full((self.num_agents,), ep_done, dtype=jnp.int32)
+ rew_info = rew_batch
+ term = old_done
+
+ info = {
+ # outcomes
+ "NumC": num_c,
+ "GoalR": goal_r,
+ "AgentC": agent_c,
+ "MapC": map_c,
+ "TimeO": time_o,
+ # reward
+ "Return": rew_info,
+ # whether action was valid
+ "terminated": term,
+ }
+ if self.do_sep_reward:
+ raise NotImplementedError("Separate reward not implemented")
+ return obs, agent_states, individual_rew_sep, dones, info # NOTE no sharing ..?
+ else:
+ return obs, agent_states, rew, dones, info
+
+ def _lidar_sense(self, idx: int, state: State) -> chex.Array:
+ """ Return observation for an agent given the current world state """
+
+ pos = state.pos[idx]
+ theta = state.theta[idx]
+
+ point_fn = jax.vmap(self._map_obj.check_point_map_collision, in_axes=(0, None))
+
+ angles = self.lidar_angles + theta
+
+ angles_ranges_mesh = jnp.meshgrid(angles, self.lidar_ranges) # value mesh
+ angles_ranges = jnp.dstack(angles_ranges_mesh) # reformat array [num_points_per_beam, num_beams, 2]
+ beam_coords_x = (angles_ranges[:,:,1]*jnp.cos(angles_ranges[:,:,0])).T + pos[0]
+ beam_coords_y = (angles_ranges[:,:,1]*jnp.sin(angles_ranges[:,:,0])).T + pos[1]
+ beam_coords = jnp.dstack((beam_coords_x, beam_coords_y)) # [num_beams, num_points_per_beam, 2]
+
+ agent_c = self._lidar_agent_check(self.agent_range, state.pos, state.theta, beam_coords, idx)
+ rc_range = jnp.where(agent_c==-1, jnp.inf, agent_c)
+ rc_m = jnp.min(rc_range, axis=0)
+ rc = jnp.where(rc_m==jnp.inf, -1, rc_m).astype(int)
+
+ lidar_hits = point_fn(beam_coords.reshape(-1, 2), state.map_data).reshape(beam_coords.shape[0], beam_coords.shape[1], -1)
+
+ idxs = jnp.arange(0, beam_coords.shape[0])
+ lidar_hits = lidar_hits.at[idxs, rc].set(1)
+ fh_idx = jnp.argmax(lidar_hits>0, axis=1)
+ return self.lidar_ranges[fh_idx]
+
+ @partial(jax.vmap, in_axes=[None, 0, 0, 0, None, None])
+ def _lidar_agent_check(self, other_idx, other_pos, other_theta, beam_coords, host_idx):
+ """ Compute lidar collisions with other robots, vectorised across other agent indicies
+ Returns:
+ chex.Array: index of lidar particle for which the hit occured, direction to goal in global frame,
+ distance to goal
+ """
+
+ i = jax.vmap(
+ self._map_obj.check_agent_beam_intersect,
+ in_axes=(0, None, None, None)
+ )(beam_coords, other_pos, other_theta, self.lidar_range_resolution)
+ return jax.lax.select(other_idx==host_idx, jnp.full(i.shape, -1), i)
+
+
+ def normalise_lidar(self, ranges):
+ return ranges/self.lidar_max_range - 0.5
+
+ def unnormalise_lidar(self, ranges):
+ return (ranges + 0.5) * self.lidar_max_range
+
+ def get_avail_actions(self, state: State):
+
+ return {a: jnp.array([1.0, 1.0]) for a in self.agents}
+
+ def sample_test_case(self, key: chex.PRNGKey) -> State:
+
+ key_tc, key_lambda = jax.random.split(key)
+ map_data, test_case = self._map_obj.sample_test_case(key_tc)
+
+ states = State(
+ pos=test_case[:, 0, :2],
+ theta=test_case[:, 0, 2],
+ vel=jnp.zeros((self.num_agents, 2)),
+ done=jnp.full((self.num_agents,), False),
+ term=jnp.full((self.num_agents,), False), # TODO don't think this is needed
+ goal_reached=jnp.full((self.num_agents,), False),
+ move_term=jnp.full((self.num_agents,), False),
+ step=0,
+ ep_done=False,
+ goal=test_case[:, 1, :2],
+ map_data=map_data,
+ rew_lambda=self.sample_lambda(key_lambda),
+ )
+
+ return states
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_lambda(self, key):
+ if self.fixed_lambda:
+ rew_lambda = self.rew_lambda
+ else:
+ rew_lambda = jax.random.uniform(key, (1,), minval=self.lambda_range[0], maxval=self.lambda_range[1])
+ return rew_lambda
+
+ @partial(vmap, in_axes=(None, 0, 0, None))
+ def _check_map_collisions(self, pos: chex.Array, theta: chex.Array, map_data: chex.Array) -> bool:
+ return self._map_obj.check_agent_map_collision(pos, theta, map_data)
+
+ @partial(vmap, in_axes=(None, 0, 0))
+ def _check_goal_reached(self, pos: chex.Array, goal_pos: chex.Array) -> bool:
+ return jnp.sqrt(jnp.sum((pos - goal_pos)**2)) <= self.goal_radius
+
+ @partial(vmap, in_axes=(None, 0, None, None))
+ def _check_agent_collisions(self, agent_idx: int, agent_positions: chex.Array, dones: chex.Array) -> bool:
+ # TODO this function is a little clunky FIX
+ z = jnp.zeros(agent_positions.shape)
+ z = z.at[agent_idx,:].set(jnp.ones(2)*self.rad*2.1)
+ x = agent_positions + z
+ return jnp.any(jnp.sqrt(jnp.sum((x - agent_positions[agent_idx,:])**2, axis=1)) <= self.rad*2)
+
+ @partial(jax.jit, static_argnums=[0])
+ def get_obs(self, state: State) -> chex.Array:
+ obs_batch = self._get_obs(state)
+ return {a: obs_batch[i] for i, a in enumerate(self.agents)}
+
+ @partial(jax.jit, static_argnums=[0])
+ def _get_obs(self, state: State) -> chex.Array:
+ """ Return observation for an agent given the current world state
+
+ obs: [lidar (num lidar beams), speeds (2), goal (2), lambda (1)]
+ """
+
+ @partial(jax.vmap, in_axes=[0, None])
+ def _observation(idx: int, state: State) -> jnp.ndarray:
+ """Return observation for agent i."""
+
+ lidar = self._lidar_sense(idx, state).squeeze()
+
+ vel_obs = state.vel[idx]
+ goal_dir = state.goal[idx] - state.pos[idx]
+ goal_obs = cart2pol(*goal_dir)
+ goal_dist = jnp.clip(goal_obs[0], 0, self.lidar_max_range)
+ goal_orient = wrap(goal_obs[1]-state.theta[idx])
+
+ if self.normalise_obs:
+ lidar = self.normalise_lidar(lidar)
+ vel_obs = vel_obs / jnp.array([self.max_v, self.max_w]) - jnp.array([0.5, 0.0])
+ goal_dist = goal_dist/self.lidar_max_range - 0.5
+ goal_orient = goal_orient/jnp.pi
+ rew_lambda = state.rew_lambda - 0.5
+ vel_goal = jnp.concatenate([vel_obs, goal_dist[None], goal_orient[None], jnp.array([rew_lambda]).reshape(1)])
+ return jnp.concatenate((lidar, vel_goal))
+
+ return _observation(self.agent_range, state)
+
+ @partial(jax.jit, static_argnums=[0])
+ def get_world_state(self, state: State) -> chex.Array:
+ walls = state.map_data.at[1:-1, 1:-1].get().flatten()
+ pos = (state.pos / jnp.array([self._map_obj.width, self._map_obj.height]) - 0.5).flatten()
+ theta = (state.theta / jnp.pi - 0.5).flatten()
+ goal = (state.goal / jnp.array([self._map_obj.width, self._map_obj.height]) - 0.5).flatten()
+ vel = (state.vel / jnp.array([self.max_v, self.max_w]) - 0.5).flatten()
+ step = jnp.array(state.step / self.max_steps - 0.5)[None]
+ concat = (jnp.concatenate([walls, pos, theta, goal, vel, step])[None]).repeat(self.num_agents, axis=0)
+ agent_idx = jnp.eye(self.num_agents)
+
+ obs = self._get_obs(state)
+
+ return jnp.concatenate([agent_idx, concat, obs], axis=1)
+
+ @partial(vmap, in_axes=(None, 0, 0, 0, 0, 0))
+ def update_state(self, pos: chex.Array, theta: float, speed: chex.Array, action: chex.Array, done: chex.Array) -> chex.Array:
+ """ Update agent's state, if `done` the current position and velocity are returned"""
+ if self.evaporating:
+ out_done = (jnp.array([0.0,0.0]), theta, jnp.array([0.0,0.0])) # "Evaporating" agents
+ else:
+ out_done = (pos, theta, jnp.array([0.0, 0.0]))
+
+ # check if action within limits
+ v_acc = jnp.clip((action[0] - speed[0])/self.dt, -self.max_v_acc, self.max_v_acc)
+ w_acc = jnp.clip((action[1] - speed[1])/self.dt, -self.max_w_acc, self.max_w_acc)
+
+ v = jnp.clip(speed[0] + v_acc*self.dt, self.min_v, self.max_v)
+ w = jnp.clip(speed[1] + w_acc*self.dt, -self.max_w, self.max_w)
+
+ dx = v * jnp.cos(theta) * self.dt
+ dy = v * jnp.sin(theta) * self.dt
+ pos = pos + jnp.array([dx, dy])
+ theta = wrap(theta + w*self.dt)
+
+ out = (pos, theta, jnp.array([v, w], dtype=jnp.float32))
+ return jax.tree_map(lambda x, y: jax.lax.select(done, x, y), out_done, out)
+
+ @partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
+ def compute_reward(
+ self,
+ obs,
+ new_pos,
+ old_pos,
+ act,
+ goal,
+ collision,
+ goal_reached,
+ done,
+ old_goal_reached,
+ old_move_term,
+ ):
+ rga = self.weight_g * (jnp.linalg.norm(old_pos - goal) - jnp.linalg.norm(new_pos - goal))
+ rg = jnp.where(goal_reached, self.goal_rew, rga) * (1 - old_goal_reached)# goal reward
+ rc = collision * self.coll_rew * (1 - old_move_term) # collision reward
+ rw = jax.lax.select(jnp.abs(act[1]) > self.lim_w, self.weight_w * jnp.abs(act[1]), 0.0) # angular velocity magnitue penalty
+ rt = self.dt_rew * (1 - (old_goal_reached | old_move_term)) # time penalty
+ rl = jnp.any(self.unnormalise_lidar(obs[:self.lidar_num_beams]) <= (self.lidar_thresh + self.rad)) * self.lidar_rew # lidar proximity reward
+
+ ret = Reward((jnp.where(goal_reached, rg, 0.0) + rc + rt)*(1 - done),
+ (jnp.where(goal_reached, 0.0, rg) + rw + rl)*(1 - done))
+
+ # {
+ # # Sparse reward is goal reward if it was reached & collision reward.
+ # 'sparse': jnp.where(goal_reached, rg, 0.0) + rc + rt,
+ # #
+ # 'dense': jnp.where(goal_reached, 0.0, rg) + rw + rl,
+ # }
+
+ return (rg + rc + rw + rt + rl)*(1 - done), ret
+
+ def set_state(
+ self,
+ state: State
+ ) -> Tuple[Dict[str, chex.ArrayTree], State]:
+ """
+ Implemented for basic envs.
+ """
+ obs = self._get_obs(state)
+ return {a: obs[i] for i, a in enumerate(self.agents)}, state
+
+ def set_env_instance(
+ self,
+ encoding: EnvInstance
+ ) -> Tuple[Dict[str, chex.ArrayTree], State]:
+ """
+ Instance is encoded as a PyTree containing the following fields:
+ agent_pos, agent_theta, goal_pos, map_data
+ """
+ state = State(
+ pos=encoding.agent_pos,
+ theta=encoding.agent_theta,
+ vel=jnp.zeros((self.num_agents, 2)),
+ done=jnp.full((self.num_agents,), False),
+ term=jnp.full((self.num_agents,), False),
+ goal_reached=jnp.full((self.num_agents,), False),
+ move_term=jnp.full((self.num_agents,), False),
+ step=0,
+ ep_done=False,
+ goal=encoding.goal_pos,
+ map_data=encoding.map_data,
+ rew_lambda=encoding.rew_lambda
+ )
+ obs = self._get_obs(state)
+ return {a: obs[i] for i, a in enumerate(self.agents)}, state
+
+ @partial(jax.jit, static_argnums=(0))
+ def reset_to_level(self, level: Tuple[chex.Array, chex.Array]) -> Tuple[chex.Array, State]:
+ print(' ** WARNING ** reset_to_level in JaxNav is deprecated, use set_state instead')
+ map_data, test_case = level
+
+ state = State(
+ pos=test_case[:, 0, :2],
+ theta=test_case[:, 0, 2],
+ vel=jnp.zeros((self.num_agents, 2)),
+ done=jnp.full((self.num_agents,), False),
+ term=jnp.full((self.num_agents,), False),
+ step=0,
+ ep_done=False,
+ goal=test_case[:, 1, :2],
+ map_data=map_data,
+ rew_lambda=self.rew_lambda,
+ )
+ obs = self._get_obs(state)
+ return {a: obs[i] for i, a in enumerate(self.agents)}, state
+
+ @partial(jax.jit, static_argnums=(0,))
+ def step_plr(
+ self,
+ key: chex.PRNGKey,
+ state: State,
+ actions: chex.Array,
+ level: Tuple,
+ ):
+ """ Resets to PLR level rather than a random one."""
+ print(' ** WARNING ** step_plr in JaxNav is deprecated ')
+ obs_st, state_st, rewards, dones, infos = self.step_env(
+ key, state, actions
+ )
+ obs_re, state_re = self.reset_to_level(level) # todo maybe should be set state depending on PLR code
+ state = jax.tree_map(
+ lambda x, y: jax.lax.select(state_st.ep_done, x, y), state_re, state_st
+ )
+ obs = jax.tree_map(
+ lambda x, y: jax.lax.select(state_st.ep_done, x, y), obs_re, obs_st
+ )
+ #obs = jax.lax.select(state_st.ep_done, obs_re, obs_st)
+ return obs, state, rewards, dones, infos
+
+ @partial(jax.jit, static_argnums=[0])
+ def unnormalise_obs(self, obs_batch: chex.Array) -> chex.Array:
+ lidar = self.unnormalise_lidar(obs_batch[:, :self.lidar_num_beams])
+ vel_obs = (obs_batch[:, self.lidar_num_beams:self.lidar_num_beams+2] + jnp.array([0.5, 0.0])) * jnp.array([self.max_v, self.max_w])
+ goal_dist = (obs_batch[:, self.lidar_num_beams+2:self.lidar_num_beams+3] + 0.5) * self.lidar_max_range
+ goal_orient = obs_batch[:, self.lidar_num_beams+3:self.lidar_num_beams+4] * jnp.pi
+ rew_lambda = obs_batch[:, -1] + 0.5
+ vel_goal = jnp.concatenate([vel_obs, goal_dist, goal_orient, rew_lambda[:, None]], axis=1)
+ o = jnp.concatenate([lidar, vel_goal], axis=1)
+ return o
+
+ def get_monitored_metrics(self):
+ return ["NumC", "GoalR", "AgentC", "MapC", "TimeO", "Return"]
+
+ def get_eval_solved_rate_fn(self):
+ def _fn(ep_stats):
+ return ep_stats["GoalR"] / ep_stats["NumC"]
+
+ return _fn
+
+ def agent_action_space(self):
+ if self._act_type == "Discrete":
+ return Discrete(15)
+ low = jnp.array(
+ [self.min_v, -jnp.pi/6], dtype=jnp.float32 # NOTE hard coded heading angle
+ )
+ high = jnp.array(
+ [self.max_v, jnp.pi/6], dtype=jnp.float32
+ )
+ return Box(low, high, (2,), jnp.float32)
+
+
+ def agent_observation_space(self):
+ return Box(-jnp.inf, jnp.inf, (self.lidar_num_beams+5,)) # NOTE hardcoded
+
+ # def action_space(self, agent=None): # NOTE assuming homogenous observation spaces, NOTE I think jnp.empty is fine
+ # aa = self.agent_action_space()
+ # return jnp.empty((self.num_agents, *aa.shape))
+
+ # def observation_space(self, agent=None):
+ # oo = self.agent_observation_space()
+ # return jnp.empty((self.num_agents, *oo.shape))
+
+ @partial(jax.jit, static_argnums=[0])
+ def generate_scenario(self, key):
+ """ Sample map grid and agent start/goal poses """
+ return self._map_obj.sample_scenario(key)
+
+ def get_env_metrics(self, state: State) -> dict:
+ """ NOTE only valid for grid map type"""
+ # n_walls = state.map_data.sum() - state.map_data.shape[0]*2 - state.map_data.shape[1]*2 + 4
+ inside = state.map_data.astype(jnp.bool_)[1:-1, 1:-1]
+ n_walls = jnp.sum(inside)
+ passability = jax.vmap(
+ self.map_obj.passable_check,
+ in_axes=(0, 0, None)
+ )(
+ state.pos,
+ state.goal,
+ state.map_data,
+ )
+
+ # shortest_path_lengths = jax.vmap( # BUG in the minimax code somewhere
+ # _graph_util.shortest_path_len,
+ # in_axes=(None, 0, 0),
+ # )(
+ # inside.astype(jnp.bool_),
+ # jnp.floor(state.pos-1).astype(jnp.int32),
+ # jnp.floor(state.goal-1).astype(jnp.int32),
+ # )
+
+ return dict(
+ n_walls=n_walls,
+ # shortest_path_length_mean=jnp.mean(shortest_path_lengths),
+ # shortest_path_lengths_stderr=jnp.std(shortest_path_lengths)/jnp.sqrt(self.num_agents),
+ passable=jnp.mean(passability),
+ )
+
+ ### === VISULISATION === ###
+
+ def plot_lidar(self, ax: axes.Axes, obs: Dict, state: State, num_to_plot: int=10):
+
+ @partial(jax.vmap, in_axes=(0, 0, 0, None))
+ def lidar_scatter(
+ pos: chex.Array,
+ theta: float,
+ lidar_ranges: chex.Array,
+ idx,
+ ):
+ """Return lidar ranges as points ready to be plotted with `ax.scatter()`
+
+ Args:
+ state (EnvSingleAState): agent state
+ params (EnvParams): environment parameters
+ ranges (chex.Array): reported lidar ranges
+
+ Returns:
+ Tuple[List, List]: lists of x and y coordinates respectively of lidar ranges for plotting
+ """
+ ranges = self.unnormalise_lidar(lidar_ranges) # (lidar_ranges+0.5)*self.lidar_max_range # correct normalisation
+ x = [ranges[i]*jnp.cos(self.lidar_angles[idx[i]]+theta) + pos[0] for i in range(ranges.shape[0])]
+ y = [ranges[i]*jnp.sin(self.lidar_angles[idx[i]]+theta) + pos[1] for i in range(ranges.shape[0])]
+ return jnp.array([x, y])
+
+ if self.lidar_num_beams>10:
+ if num_to_plot > self.lidar_num_beams:
+ num_to_plot = self.lidar_num_beams
+ print('Warning: num_to_plot > lidar_num_beams, setting num_to_plot to lidar_num_beams')
+ idx = jnp.round(jnp.linspace(0, self.lidar_num_beams-1, num_to_plot)).astype(int)
+ else:
+ idx = range(self.lidar_num_beams)
+
+ obs_batch = jnp.stack([obs[a] for a in self.agents])
+
+ lidar_scat = lidar_scatter(state.pos, state.theta, obs_batch[:, idx], idx)
+ lidar_scat = jnp.swapaxes(lidar_scat, 1, 2).reshape((-1, 2))
+ lidar_scat = self._map_obj.scale_coords(lidar_scat)
+ ax.scatter(lidar_scat[:, 0], lidar_scat[:, 1], c='b', s=2)
+
+ # Plotting by SMAX style
+ def init_render(self,
+ ax: axes.Axes,
+ state: State,
+ obs: Dict=None,
+ lidar=True, # plot lidar?
+ agent=True, # plot agents?
+ goal=True, # plot goals?
+ rew_lambda=False, # plot lambda?
+ ticks_off=False, # turn off axis ticks?
+ num_to_plot=10, # number of lidar beams to plot
+ colour_agents_by_idx=False,
+ ):
+ """ Render environment. """
+
+ ax.set_aspect('equal', 'box')
+
+ self.map_obj.plot_map(ax, state.map_data)
+ if agent:
+ self.map_obj.plot_agents(ax, state.pos, state.theta, state.goal, state.done, plot_line_to_goal=goal, colour_agents_by_idx=colour_agents_by_idx)
+ if lidar:
+ assert obs is not None, "lidar is True but no obs provided, TODO make it not obs dependent"
+ self.plot_lidar(ax, obs, state, num_to_plot=num_to_plot)
+
+ if rew_lambda:
+ ax.text(0.5, 0.5, f"lambda: {state.rew_lambda}", fontsize=12, ha='center', va='center', c='white')
+ if ticks_off:
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+ canvas = ax.figure.canvas
+ canvas.draw()
+
+
diff --git a/jaxmarl/environments/jaxnav/jaxnav_graph_utils.py b/jaxmarl/environments/jaxnav/jaxnav_graph_utils.py
new file mode 100644
index 00000000..7bb25000
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/jaxnav_graph_utils.py
@@ -0,0 +1,272 @@
+"""
+Copyright (c) Meta Platforms, Inc. and affiliates.
+All rights reserved.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+
+@partial(jax.jit, static_argnums=(1,))
+def apsp(A, n=None):
+ """
+ Compute APSP for adjacency matrix A
+ using Seidel's algorithm.
+ """
+ if n is None:
+ n = A.shape[0]
+ assert(n == A.shape[0]), 'n must equal dim of A.'
+
+ n_steps = int(np.ceil(np.log(n)/np.log(2)))
+ A_cache = jnp.zeros((n_steps, n, n), dtype=jnp.uint32)
+ steps_to_reduce = jnp.array(1, dtype=jnp.int32)
+
+ def _scan_fwd_step(carry, step):
+ i = step
+ A, A_cache, steps_to_reduce = carry
+ A_cache = A_cache.at[i].set(A)
+
+ Z = A@A
+ B = jnp.logical_or(
+ A == 1,
+ Z > 0
+ ).astype(jnp.uint32) \
+ .at[jnp.diag_indices(n)].set(0)
+ A = B
+
+ complete = B.sum() - jnp.diagonal(B).sum() == n*(n-1)
+ steps_to_reduce += ~complete
+
+ return (A, A_cache, steps_to_reduce), None
+
+ (B, A_cache, steps_to_reduce), _ = jax.lax.scan(
+ _scan_fwd_step,
+ (A, A_cache, 1),
+ jnp.arange(n_steps),
+ length=n_steps
+ )
+
+ D = 2*B - A_cache[steps_to_reduce-1]
+
+ def _scan_bkwd_step(carry, step):
+ i = step
+ (T, A_cache,steps_to_reduce) = carry
+
+ A = A_cache[steps_to_reduce - i - 1]
+ X = T@A
+
+ thresh = T*(jnp.tile(A.sum(0, keepdims=True), (n, 1)))
+ D = 2*T*(X >= thresh) + (2*T - 1)*(X < thresh)
+ T = D*(i < steps_to_reduce) + T*(i >= steps_to_reduce)
+
+ return (T, A_cache, steps_to_reduce), None
+
+ (D, _, _), _ = jax.lax.scan(
+ _scan_bkwd_step,
+ (D, A_cache, steps_to_reduce),
+ jnp.arange(1, n_steps),
+ length=n_steps-1
+ )
+
+ return D
+
+
+@jax.jit
+def grid_to_graph(grid):
+ """
+ Transform a binary grid (True == wall) into a
+ graph.
+ """
+ h, w = grid.shape
+ nodes = grid.flatten()
+ print('nodes', nodes)
+ n = len(nodes)
+ A = jnp.zeros((n,n), dtype=jnp.uint32)
+
+ all_idx = jnp.arange(n)
+ # jax.debug.print('dumneigh idx {x}', x=~nodes)
+ dum_neighbor_idx = jnp.argmax(~nodes)
+ dum_neighbor_mask = jnp.zeros(n, dtype=jnp.bool_)
+ dum_neighbor_mask = \
+ dum_neighbor_mask.at[dum_neighbor_idx].set(True)
+
+ def _get_neigbors(idx):
+ # Return length n boolean mask of neighbors
+ # We then vmap this function over all n
+ r = idx + 1
+ l = idx - 1
+ t = idx - w
+ b = idx + w
+
+ l_border = jnp.logical_or(
+ idx % w == 0,
+ nodes[l]
+ )
+ r_border = jnp.logical_or(
+ r % w == 0,
+ nodes[r]
+ )
+ t_border = jnp.logical_or(
+ idx // w == 0,
+ nodes[t],
+ )
+ b_border = jnp.logical_or(
+ idx // w == h - 1,
+ nodes[b]
+ )
+
+ l_ignore = jnp.logical_or(
+ l_border,
+ nodes[idx]
+ )
+ r_ignore = jnp.logical_or(
+ r_border,
+ nodes[idx]
+ )
+ t_ignore = jnp.logical_or(
+ t_border,
+ nodes[idx]
+ )
+ b_ignore = jnp.logical_or(
+ b_border,
+ nodes[idx]
+ )
+
+ left = l*(1-l_ignore) + idx*(l_ignore)
+ right = r*(1-r_ignore) + idx*(r_ignore)
+ top = t*(1-t_ignore) + idx*(t_ignore)
+ bottom = b*(1-b_ignore) + idx*(b_ignore)
+
+ neighbor_mask = jnp.zeros(n, dtype=jnp.bool_)
+ # jax.debug.print('idx {x}, neigh {n}', x=idx, n=jnp.array([left, right, top, bottom]))
+ neighbor_mask = neighbor_mask.at[jnp.array([left, right, top, bottom])].set(True)
+
+ neighbor_mask = (1-nodes[idx])*neighbor_mask + nodes[idx]*dum_neighbor_mask
+
+ neighbor_mask = neighbor_mask.at[idx].set(False)
+ # jax.debug.print('idx {x} mask {m}', x=idx, m=neighbor_mask)
+
+ return neighbor_mask
+
+ A = jax.vmap(_get_neigbors)(all_idx).astype(dtype=jnp.uint32)
+ A = jnp.maximum(A, A.transpose())
+
+ return A
+
+
+NEIGHBOR_OFFSETS = jnp.array([
+ [1,0], # right
+ [0,1], # bottom
+ [-1,0], # left
+ [0,-1], # top
+ [0,0] # self
+], dtype=jnp.int32)
+
+
+@jax.jit
+def component_mask_with_pos(grid, pos_a):
+ """
+ Return a mask set to True in all cells that are
+ a part of the connected component containing pos_a.
+ pos_a in format [x, y].
+ """
+ h,w = grid.shape
+ visited_mask = grid
+
+ pos = pos_a
+ visited_mask = visited_mask.at[
+ pos[1],pos[0]
+ ].set(True)
+ vstack = jnp.zeros((h*w, 2), dtype=jnp.uint32)
+ vstack = vstack.at[:2].set(pos)
+ vstack_size = 2
+
+ def _scan_dfs(carry, step):
+ (visited_mask, vstack, vstack_size) = carry
+
+ pos = vstack[vstack_size-1]
+
+ neighbors = \
+ jnp.minimum(
+ jnp.maximum(
+ pos + NEIGHBOR_OFFSETS, 0
+ ), jnp.array([[h,w]])
+ ).astype(jnp.uint32)
+
+ neighbors_mask = visited_mask.at[
+ neighbors[:,1],neighbors[:,0]
+ ].get()
+ n_neighbor_visited = neighbors_mask[:4].sum()
+ all_visited = n_neighbor_visited == 4
+ all_visited_post = n_neighbor_visited >= 3
+ neighbors_mask = neighbors_mask.at[-1].set(~all_visited)
+
+ next_neighbor_idx = jnp.argmax(~neighbors_mask)
+ next_neighbor = neighbors[next_neighbor_idx]
+
+ visited_mask = visited_mask.at[
+ next_neighbor[1],next_neighbor[0]
+ ].set(True)
+
+ vstack_size -= all_visited_post
+
+ vstack = vstack.at[vstack_size].set(next_neighbor)
+ vstack_size += ~all_visited
+
+ pos = next_neighbor
+
+ return (visited_mask, vstack, vstack_size), None
+
+ max_n_steps = 2*h*w
+ (visited_mask, vstack, vstack_size), _ = jax.lax.scan(
+ _scan_dfs,
+ (visited_mask, vstack, vstack_size),
+ jnp.arange(max_n_steps),
+ length=max_n_steps
+ )
+
+ visited_mask = visited_mask ^ grid
+ return visited_mask
+
+
+@jax.jit
+def shortest_path_len(grid, pos_a, pos_b):
+ # false should equal free space
+ # jax.debug.print('pos_a {x} pos_b {y} grid {g}', x=pos_a, y=pos_b, g=grid)
+ grid = ~component_mask_with_pos(grid, pos_a)
+ # jax.debug.print('component pos_a {x} pos_b {y} grid {g}', x=pos_a, y=pos_b, g=grid)
+ A = grid_to_graph(grid)
+ D = apsp(A, n=A.shape[0])
+ # jax.debug.print('A {x} {x2}', x=A, x2=A.shape)
+ # jax.debug.print('D {x} {x2}', x=D, x2=D.shape)
+ # jax.debug.print('pos_b shape {x}', x=pos_b.shape)
+ if len(pos_b.shape) == 2: # batch eval
+ return jax.vmap(_shortest_path_len, in_axes=(None, None, 0, None))(
+ grid, pos_a, pos_b, D
+ )
+ else:
+ return _shortest_path_len(grid, pos_a, pos_b, D)
+
+
+@jax.jit
+def _shortest_path_len(grid, pos_a, pos_b, D):
+ h,w = grid.shape
+
+ a_idx = pos_a[1]*w + pos_a[0]
+ b_idx = pos_b[1]*w + pos_b[0]
+ d = D[a_idx][b_idx]
+
+ mhttn_d = jnp.sum(jnp.abs(jnp.maximum(pos_a,pos_b)- jnp.minimum(pos_a,pos_b)))
+
+ impossible = jnp.logical_and(
+ d == 1,
+ mhttn_d > 1
+ )
+ return jax.lax.select(jnp.all(pos_a == pos_b), 1, (d*(1-impossible)).astype(jnp.int32))
+ # return d*(1-impossible)
diff --git a/jaxmarl/environments/jaxnav/jaxnav_singletons.py b/jaxmarl/environments/jaxnav/jaxnav_singletons.py
new file mode 100644
index 00000000..b07ae287
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/jaxnav_singletons.py
@@ -0,0 +1,729 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import NamedTuple, List, Tuple
+import matplotlib.pyplot as plt
+
+from .jaxnav_env import JaxNav, State
+from .maps import make_map
+
+class TestCase(NamedTuple):
+ map_data: list
+ start_pose: tuple
+ goal_pose: tuple
+
+class JaxNavSingleton(JaxNav):
+ def __init__(self,
+ num_agents: int, # Number of agents
+ test_case=None,
+ fixed_lambda=False,
+ rew_lambda=0.0,
+ map_id="Grid-Rand",
+ **env_kwargs
+ ):
+ assert len(test_case.start_pose) == num_agents, f"len start_pose: {len(test_case.start_pose)} != num_agents: {num_agents}"
+ assert len(test_case.goal_pose) == num_agents
+ assert map_id.startswith("Grid"), f"map_id: {map_id} does not start with Grid"
+
+ super().__init__(num_agents,
+ map_id=map_id,
+ **env_kwargs)
+
+ if fixed_lambda is True:
+ self.rew_lambda = rew_lambda
+ else:
+ self.rew_lambda = 0.0
+
+ if test_case is None:
+ raise NotImplementedError
+ else:
+ if map_id == "Grid-Rand-Poly-Single":
+ map_id = "Grid-Rand-Poly"
+ self.map_data = jnp.array(
+ [[int(x) for x in row.split()] for row in test_case.map_data],
+ dtype=jnp.int32
+ )
+ height, width = self.map_data.shape
+ self.goal_pose = jnp.array(test_case.goal_pose, dtype=jnp.float32)
+ self.start_pose = jnp.array(test_case.start_pose, dtype=jnp.float32)
+
+ map_kwargs = {
+ "map_size": (width, height),
+ "fill": 0.5,
+ }
+ self._map_obj = make_map(map_id, num_agents, self.rad, **map_kwargs)
+
+ def reset(
+ self,
+ key: chex.PRNGKey=None,
+ ):
+
+ state = State(
+ pos=self.start_pose[:, :2],
+ theta=self.start_pose[:, 2],
+ vel=jnp.zeros((self.num_agents, 2)),
+ done=jnp.full((self.num_agents,), False),
+ term=jnp.full((self.num_agents,), False),
+ goal_reached=jnp.full((self.num_agents,), False),
+ move_term=jnp.full((self.num_agents,), False),
+ goal=self.goal_pose[:, :2],
+ ep_done=False,
+ step=0,
+ map_data=self.map_data,
+ rew_lambda=self.rew_lambda,
+ )
+ obs_batch = self._get_obs(state)
+ obs = {a: obs_batch[i] for i, a in enumerate(self.agents)}
+ return obs, state
+
+ def get_monitored_metrics(self):
+ return ["NumC", "GoalR", "AgentC", "MapC", "TimeO", "Return"]
+
+ def viz_testcase(self, save=True, show=False, plot_lidar=True):
+
+ obs, state = self.reset()
+ fig, ax = plt.subplots(figsize=(5,5))
+
+ ax.set_aspect('equal', 'box')
+ self._map_obj.plot_map(ax, state.map_data)
+ ax.scatter(state.goal[:, 0], state.goal[:, 1], marker='+')
+ self._map_obj.plot_agents(ax, state.pos, state.theta, state.goal, state.done)
+ if plot_lidar:
+ self.plot_lidar(ax, obs, state, 100)
+
+ # plot a line from start to goal for each agent
+ for i in range(self.num_agents):
+ ax.plot(jnp.concatenate([state.pos[i, 0][None], state.goal[i, 0][None]]),
+ jnp.concatenate([state.pos[i, 1][None], state.goal[i, 1][None]]),
+ color='gray', alpha=0.2)
+
+ if save:
+ plt.savefig(f'{self.name}.png')
+ if show:
+ plt.show()
+
+ @property
+ def name(self) -> str:
+ return self.__class__.__name__
+
+
+## SINGLE AGENT
+# blank map
+class BlankTest(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=1,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1",
+ ],
+ start_pose=[(1.5, 1.5, 0.0)],
+ goal_pose=[(5.5, 5.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class MiddleTest(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=1,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 1",
+ "1 0 1 1 1 0 1",
+ "1 0 1 1 1 0 1",
+ "1 0 1 1 1 0 1",
+ "1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1",
+ ],
+ start_pose=[(1.5, 1.5, 0.0)],
+ goal_pose=[(5.5, 5.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+## MULTI-AGENT
+class BlankCross2(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1",
+ ],
+ start_pose=[(1.5, 1.5, 0.78),
+ (5.5, 5.5, 3.92)],
+ goal_pose=[(5.5, 5.5, 0.0),
+ (1.5, 1.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class BlankCross4(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=4,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1",
+ ],
+ start_pose=[(1.5, 1.5, 0.78),
+ (5.5, 5.5, 3.92),
+ (1.5, 5.5, -0.78),
+ (5.5, 1.5, -3.92)],
+ goal_pose=[(5.5, 5.5, 0.0),
+ (1.5, 1.5, 0.0),
+ (5.5, 1.5, 0.0),
+ (1.5, 5.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class CircleCross(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=10,
+ circle_rad=6,
+ **env_kwargs
+ ):
+
+ width, height = circle_rad*3, circle_rad*3
+ centre_x = width/2
+ centre_y = height/2
+ top = "1 " * int(width)
+ row = "1 " + "0 " * int(width-2) + "1"
+ rows = [row for _ in range(int(height)-2)]
+ map_data = [top] + rows + [top]
+
+ start_pose = []
+ goal_pose = []
+ for i in range(num_agents):
+ theta = 2*jnp.pi * i / num_agents
+ to_center_theta = jnp.pi + theta
+ start_pose.append((circle_rad*jnp.cos(theta)+centre_y, circle_rad*jnp.sin(theta)+centre_x, to_center_theta))
+ goal_theta = theta + jnp.pi
+ goal_pose.append((circle_rad*jnp.cos(goal_theta)+centre_y, circle_rad*jnp.sin(goal_theta)+centre_x, goal_theta))
+
+ test_case = TestCase(
+ map_data=map_data,
+ start_pose=start_pose,
+ goal_pose=goal_pose,
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class BlankCrossUneven2(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.0, 2.5, 0.78),
+ (5.5, 5.5, 3.92)],
+ goal_pose=[(5.5, 5.5, 0.0),
+ (1.5, 1.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class SingleNav1(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=1,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1",
+ "1 1 0 0 0 0 0 0 0 0 1",
+ "1 0 0 1 1 0 1 1 1 0 1",
+ "1 1 0 0 0 0 0 0 0 0 1",
+ "1 1 0 1 1 0 1 1 1 0 1",
+ "1 0 0 0 0 0 0 1 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(9.5, 1.5, 0.78)],
+ goal_pose=[(1.5, 5.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class SingleNav2(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=1,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 1 1 1 1 1 0 0 1",
+ "1 0 0 0 1 0 0 1 1 0 1",
+ "1 0 1 0 1 0 0 0 0 0 1",
+ "1 0 1 0 0 0 0 1 0 0 1",
+ "1 0 1 0 1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(8.5, 1.5, 3.14)],
+ goal_pose=[(1.5, 5.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class SingleNav3(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=1,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 1 1 1 1 1 0 0 1",
+ "1 0 0 0 1 0 0 1 1 0 1",
+ "1 0 1 0 1 0 0 0 0 0 1",
+ "1 0 1 0 0 0 0 1 0 0 1",
+ "1 0 1 0 1 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(9.1, 1.5, 3.14)],
+ goal_pose=[(1.5, 5.5, 0.0)],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class LongCorridor2(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1",
+ "1 1 1 1 1 1 1 0 0 1",
+ "1 0 0 1 1 1 1 0 0 1",
+ "1 0 0 0 0 0 0 0 0 1",
+ "1 0 0 1 1 1 1 0 0 1",
+ "1 1 1 1 1 1 1 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(1.5, 3.5, 0.0),
+ (8.0, 3.5, 3.14),],
+ goal_pose=[(6.0, 3.5, 0.0),
+ (1.5, 3.5, 0.0),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class Corridor4(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=4,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 0 0 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.0, 2.5, 0.0),
+ (2.0, 4.5, 0.0),
+ (8.0, 2.5, 3.14),
+ (8.0, 4.5, 3.14),],
+ goal_pose=[(8.0, 2.5, 3.14),
+ (8.0, 4.5, 3.14),
+ (2.0, 2.5, 0.0),
+ (2.0, 4.5, 0.0),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class Corridor8(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=8,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 0 0 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.0, 2.0, 0.0),
+ (2.0, 3.5, 0.0),
+ (2.0, 5.5, 0.0),
+ (2.0, 7.0, 0.0),
+ (8.0, 2.0, 3.14),
+ (8.0, 3.5, 3.14),
+ (8.0, 5.5, 3.14),
+ (8.0, 7.0, 3.14),],
+ goal_pose=[(8.0, 2.0, 3.14),
+ (8.0, 3.5, 3.14),
+ (8.0, 5.5, 3.14),
+ (8.0, 7.0, 3.14),
+ (2.0, 2.0, 0.0),
+ (2.0, 3.5, 0.0),
+ (2.0, 5.5, 0.0),
+ (2.0, 7.0, 0.0),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class Layby4(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=4,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 1 1 1 1 0 0 1",
+ "1 0 0 1 0 1 1 0 0 1",
+ "1 0 0 0 0 0 0 0 0 1",
+ "1 0 0 1 1 0 1 0 0 1",
+ "1 0 0 1 1 1 1 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.0, 2.5, 0.0),
+ (2.0, 4.5, 0.0),
+ (8.0, 2.5, 3.14),
+ (8.0, 4.5, 3.14),],
+ goal_pose=[(8.0, 2.5, 3.14),
+ (8.0, 4.5, 3.14),
+ (2.0, 2.5, 0.0),
+ (2.0, 4.5, 0.0),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class Corner2(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 1",
+ "1 1 1 1 0 0 1",
+ "1 1 1 1 1 0 1",
+ "1 1 1 1 1 0 1",
+ "1 1 1 1 1 0 1",
+ "1 1 1 1 1 1 1",
+ ],
+ start_pose=[(1.5, 1.5, 0.0),
+ (5.5, 5.5, -1.57),],
+ goal_pose=[(5.5, 5.5, 3.14),
+ (1.5, 1.5, 3.14),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class Chicane2(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.5, 2.5, 0.0),
+ (11.5, 2.5, 3.14),],
+ goal_pose=[(9.5, 3.5, 3.14),
+ (2.5, 2.5, 3.14),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class NarrowChicane2a(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 1 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 1 0 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.5, 2.5, 0.0),
+ (11.5, 2.5, 3.14),],
+ goal_pose=[(9.5, 3.5, 3.14),
+ (2.5, 2.5, 3.14),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class NarrowChicane2b(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=2,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 1 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 1 0 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.5, 1.5, 0.0),
+ (11.5, 2.5, 3.14),],
+ goal_pose=[(7.5, 1.5, 3.14),
+ (2.5, 2.5, 3.14),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+class Chicane4(JaxNavSingleton):
+ def __init__(
+ self,
+ num_agents=4,
+ **env_kwargs
+ ):
+ test_case = TestCase(
+ map_data = [
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ "1 0 0 0 0 0 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 1 1 0 0 0 1",
+ "1 0 0 0 1 1 0 0 0 0 0 0 0 1",
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1",
+ ],
+ start_pose=[(2.5, 1.75, 0.0),
+ (2.5, 3.25, 0.0),
+ (11.5, 1.75, 3.14),
+ (11.5, 3.25, 3.14),],
+ goal_pose=[(9.5, 3.5, 3.14),
+ (11.5, 1.75, 3.14),
+ (2.5, 3.25, 3.14),
+ (4.5, 1.5, 3.14),],
+ )
+
+ super().__init__(num_agents,
+ test_case=test_case,
+ **env_kwargs)
+
+# REGISTRATION
+def make_jaxnav_singleton(env_id: str, **env_kwargs) -> JaxNavSingleton:
+ if env_id not in registered_singletons:
+ raise ValueError(f"Singleton env_id: {env_id} not registered!")
+ if env_id == "BlankTest":
+ return BlankTest(**env_kwargs)
+ if env_id == "MiddleTest":
+ return MiddleTest(**env_kwargs)
+
+ if env_id == "BlankCross2":
+ return BlankCross2(**env_kwargs)
+ if env_id == "BlankCross4":
+ return BlankCross4(**env_kwargs)
+ if env_id == "BlankCrossUneven2":
+ return BlankCrossUneven2(**env_kwargs)
+ if env_id == "CircleCross":
+ return CircleCross(**env_kwargs)
+ if env_id == "Corridor4":
+ return Corridor4(**env_kwargs)
+ if env_id == "Corridor8":
+ return Corridor8(**env_kwargs)
+ if env_id == "LongCorridor2":
+ return LongCorridor2(**env_kwargs)
+ if env_id == "Layby4":
+ return Layby4(**env_kwargs)
+ if env_id == "Corner2":
+ return Corner2(**env_kwargs)
+ if env_id == "Chicane2":
+ return Chicane2(**env_kwargs)
+ if env_id == "Chicane4":
+ return Chicane4(**env_kwargs)
+ if env_id == "NarrowChicane2a":
+ return NarrowChicane2a(**env_kwargs)
+ if env_id == "NarrowChicane2b":
+ return NarrowChicane2b(**env_kwargs)
+
+ if env_id == "SingleNav1":
+ return SingleNav1(**env_kwargs)
+ if env_id == "SingleNav2":
+ return SingleNav2(**env_kwargs)
+ if env_id == "SingleNav3":
+ return SingleNav3(**env_kwargs)
+
+ raise ValueError(f"Map: {env_id} not registered correctly!")
+
+registered_singletons = [
+ "BlankTest",
+ "MiddleTest",
+ "BlankCross2",
+ "BlankCross4",
+ "BlankCrossUneven2",
+ "CircleCross",
+ "SingleNav1",
+ "SingleNav2",
+ "SingleNav3",
+ "Corridor4",
+ "LongCorridor2",
+ "Layby4",
+ "Corner2",
+ "Corridor8",
+ "Chicane2",
+ "Chicane4",
+ "NarrowChicane2a",
+ "NarrowChicane2b",
+]
+
+def make_jaxnav_singleton_collection(collection_id: str, **env_kwargs) -> Tuple[List[JaxNavSingleton], List[str]]:
+
+ env_ids = registered_singleton_collections[collection_id]
+ envs = []
+ for env_id in env_ids:
+ envs.append(make_jaxnav_singleton(env_id, **env_kwargs))
+
+ return envs, env_ids
+
+registered_singleton_collections = {
+ "test": [
+ "BlankTest",
+ ],
+ "multi": [
+ "CircleCross",
+ "BlankCross4",
+ "BlankCrossUneven2",
+ "Corridor4",
+ "LongCorridor2",
+ "Layby4",
+ "Corner2",
+ "SingleNav1",
+ "SingleNav2",
+ "SingleNav3",
+ ],
+ "single": [
+ "BlankTest",
+ "MiddleTest",
+ "SingleNav1",
+ "SingleNav2",
+ "SingleNav3"
+ ],
+ "hard": [
+ "SingleNav2",
+ "Layby4",
+ "Corner2",
+ "Corridor4",
+ "Corridor8",
+ "CircleCross",
+ "NarrowChicane2a",
+ "NarrowChicane2b",
+ "Chicane4",
+ ],
+ "new": [
+ "NarrowChicane2a",
+ "NarrowChicane2b",
+ "Chicane4",
+ ],
+ "corridor": [
+ "BlankCross4",
+ "LongCorridor2",
+ "Corridor4",
+ "Corner2",
+ ],
+ "just-long-corridor": [
+ "LongCorridor2",
+ ],
+ "just-single2": [
+ "SingleNav2",
+ ],
+}
+
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py b/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py
new file mode 100644
index 00000000..0c4fb6c1
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py
@@ -0,0 +1,130 @@
+import jax
+import jax.numpy as jnp
+import chex
+# from .level import Level
+# from .env import DIR_TO_VEC
+from .jaxnav_env import State
+from .maps import GridMapPolygonAgents
+from enum import IntEnum
+import numpy as np
+from typing import Callable
+from functools import partial
+
+
+
+"""
+Mutation strategy:
+1. flip wall: compute free cells, choose one at random, flip it to a wall
+- should probably switch to a MapData class rather than just an array
+ as it would be good to store free cells. Ignore for now
+
+"""
+
+
+def make_level_mutator(max_num_edits: int, map: GridMapPolygonAgents):
+
+ class Mutations(IntEnum):
+ # Turn left, turn right, move forward
+ NO_OP = 0
+ FLIP_WALL = 1
+ MOVE_GOAL = 2
+
+
+ def move_goal_flip_walls(rng, state: State, n: int = 1):
+ def _mutate(carry, step):
+ state = carry
+ rng, mutation = step
+
+ def _apply(rng, state):
+ rng, arng, brng = jax.random.split(rng, 3)
+
+ is_flip_wall = jnp.equal(mutation, Mutations.FLIP_WALL.value)
+ mutated_state = flip_wall(arng, map, state)
+ next_state = jax.tree_map(lambda x,y: jax.lax.select(is_flip_wall, x, y), mutated_state, state)
+
+ is_move_goal = jnp.equal(mutation, Mutations.MOVE_GOAL.value)
+ mutated_state = move_goal(brng, map, state)
+ next_state = jax.tree_map(lambda x,y: jax.lax.select(is_move_goal, x, y), mutated_state, next_state)
+
+ return next_state
+
+ return jax.lax.cond(mutation != -1, _apply, lambda *_: state, rng, state), None
+
+
+ rng, nrng, *mrngs = jax.random.split(rng, max_num_edits+2)
+ mutations = jax.random.choice(nrng, np.arange(len(Mutations)), (max_num_edits,))
+ mutations = jnp.where(jnp.arange(max_num_edits) < n, mutations, -1) # mask out extra mutations
+
+ new_level, _ = jax.lax.scan(_mutate, state, (jnp.array(mrngs), mutations))
+
+ return new_level
+
+ return move_goal_flip_walls
+
+def flip_wall(rng, map: GridMapPolygonAgents, state: State):
+ wall_map = state.map_data
+ h,w = wall_map.shape
+
+ # goal_map_mask = jnp.any(jax.vmap(
+ # map.get_circle_map_occupancy_mask,
+ # in_axes=(0, None, None)
+ # )(state.goal, wall_map, 0.3), axis=0)
+
+ goal_map_mask = wall_map
+
+ start_map_mask = jnp.any(jax.vmap(
+ map.get_agent_map_occupancy_mask, in_axes=(0,0,None)
+ )(state.pos, state.theta, wall_map), axis=0)
+
+
+ goal = jnp.floor(state.goal).astype(jnp.int32)
+ goal_map_mask = goal_map_mask.at[goal[:, 1], goal[:, 0]].set(1)
+
+ map_mask = start_map_mask | map._gen_base_grid() | goal_map_mask
+
+
+ flip_idx = jax.random.choice(rng, np.arange(h*w), p=jnp.logical_not(map_mask.flatten()))
+
+ flip_y = flip_idx // w
+ flip_x = flip_idx % w
+
+ flip_val = 1-wall_map.at[flip_y, flip_x].get()
+ next_wall_map = wall_map.at[flip_y, flip_x].set(flip_val)
+ print('next_wall_map', next_wall_map)
+ return state.replace(map_data=next_wall_map)
+
+
+def move_goal(rng, map:GridMapPolygonAgents, state:State):
+ wall_map = state.map_data
+ h,w = wall_map.shape
+
+ rng, _rng = jax.random.split(rng)
+ agent_idx = jax.random.choice(_rng, np.arange(state.pos.shape[0]))
+
+ # goal_map_masks = jax.vmap(
+ # map.get_circle_map_occupancy_mask,
+ # in_axes=(0, None, None)
+ # )(state.goal, wall_map, 1.0)
+ goal_map_masks = jnp.repeat(wall_map[None], state.pos.shape[0], axis=0)
+ # goal_map_masks = goal_map_masks.at[agent_idx].set(jnp.zeros(wall_map.shape, dtype=jnp.int32))
+ goal_map_mask = jnp.any(goal_map_masks, axis=0)
+ current_goal = jnp.floor(state.goal[agent_idx]).astype(jnp.int32)
+ goal_map_mask = goal_map_mask.at[current_goal[1], current_goal[0]].set(0)
+
+ start_map_mask = jnp.any(jax.vmap(
+ map.get_agent_map_occupancy_mask, in_axes=(0,0,None)
+ )(state.pos, state.theta, wall_map), axis=0)
+
+ map_mask = start_map_mask | map._gen_base_grid() | goal_map_mask
+
+ next_idx = jax.random.choice(rng, np.arange(h*w), p=jnp.logical_not(map_mask.flatten()))
+ next_goal_y = next_idx // w
+ next_goal_x = next_idx % w
+
+ goals = state.goal.at[agent_idx].set(jnp.array([next_goal_x, next_goal_y]) + 0.5) # Make the goal in the center of the cell.
+
+ return state.replace(goal=goals)
+
+
+
+
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/jaxnav_utils.py b/jaxmarl/environments/jaxnav/jaxnav_utils.py
new file mode 100644
index 00000000..563bcba8
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/jaxnav_utils.py
@@ -0,0 +1,193 @@
+'''
+Utility functions for simulators
+'''
+
+import jax
+import jax.numpy as jnp
+import chex
+import os, pathlib
+import numpy as np
+from functools import partial
+
+
+# map names cannot include an '_'
+MAP_PATHS = {
+ "blank": "blank_map.npy",
+ "blank-small" : "blank-small_map.npy",
+ "blank-15": "blank-15.npy",
+ "central-square" : "central-square_map.npy",
+ "central-square-easy" : "central-square_map.npy", # sample cases have a lower treshold
+ "cross-20": "cross-20_map.npy",
+ "circle-20": "circle-20_map.npy",
+ "circle-sym-20": "circle-20_map.npy",
+ "corridor-10": "corridor-10_map.npy",
+ "corridor-15": "corridor-15_map.npy",
+ "barn-test" : "barn/barn-test_map.npy",
+ "barn-20": "barn/barn-20_map.npy",
+ "barn-25": "barn/barn-25_map.npy",
+ "barn-30": "barn/barn-30_map.npy",
+ "1-wide-c": "corridor/1-wide-c_map.npy",
+ "2-wide-c": "corridor/2-wide-c_map.npy",
+ "1-wide-b": "corridor/1-wide-b_map.npy",
+ "5-wide-chicane": "corridor/5-wide-chicane_map.npy",
+}
+
+GRID_HALF_HEIGHT = 0.5
+
+### --- MATHS UTILS ---
+def pol2cart(rho: float, phi: float) -> chex.Array:
+ ''' Convert polar coordinates into cartesian '''
+ x = rho * jnp.cos(phi)
+ y = rho * jnp.sin(phi)
+ return jnp.array([x,y])
+
+def cart2pol(x, y) -> chex.Array:
+ rho = jnp.sqrt(x**2 + y**2)
+ phi = jnp.arctan2(y, x)
+ return jnp.array([rho, phi])
+
+def unitvec(theta) -> chex.Array:
+ return jnp.array([jnp.cos(theta), jnp.sin(theta)])
+
+def wrap(angle):
+ ''' Ensure angle lies in the range [-pi, pi] '''
+ large = lambda x: x - 2*jnp.pi
+ small = lambda x: x + 2*jnp.pi
+ noChange = lambda x: x
+ wrapped_angle = jax.lax.cond(angle >= jnp.pi,
+ large, noChange, angle)
+ wrapped_angle = jax.lax.cond(angle < -jnp.pi,
+ small, noChange, wrapped_angle)
+
+ return wrapped_angle
+
+def euclid_dist(x, y):
+ return jnp.norm(x-y)
+
+def rot_mat(theta):
+ """ 2x2 rotation matrix for 2D about the origin """
+ return jnp.array([[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]]).squeeze()
+
+
+### --- ENV UTILS
+@jax.jit
+def map_collision(pos: chex.Array, map_grid: chex.Array, radius: float) -> bool:
+ """ For a circle agent, ASSUMES radius<1 and grids of size 1x1 """
+ # Calculates which grid square robot overlaps in
+ min_x, min_y = jnp.floor(jnp.maximum(jnp.zeros(2), pos-radius)).astype(int)
+ max_x, max_y = jnp.floor(jnp.minimum(jnp.array(map_grid.shape), pos+radius)).astype(int)
+
+ map_c_list = jnp.array([
+ [min_x, min_y],
+ [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y],
+ ]) + GRID_HALF_HEIGHT
+
+ grid_check = check_grid(map_c_list, pos, radius)
+
+ map_occ = jnp.array([
+ map_grid[min_y, min_x],
+ map_grid[min_y, max_x],
+ map_grid[max_y, min_x],
+ map_grid[max_y, max_x],
+ ]).astype(int)
+ return jnp.any((map_occ+grid_check)>1)
+
+@partial(jax.vmap, in_axes=[0, None, None])
+def check_grid(c, pos, radius):
+ p = jnp.clip(pos - c, -GRID_HALF_HEIGHT, GRID_HALF_HEIGHT)
+ p = p + c
+ return jnp.linalg.norm(p - pos) <= radius
+
+'''#@partial(jax.jit)
+def check_square_plot(pos, c, radius):
+ from matplotlib import pyplot as plt
+ from matplotlib.patches import Circle
+ fig, ax = plt.subplots()
+ hh = 0.5
+
+ p = jnp.clip(pos - c, -hh, hh)
+
+ print('p', p)
+ p = p + c
+ print('p', p)
+ d = p - pos
+ print('d', jnp.linalg.norm(d))
+
+
+ square = plt.Rectangle((c[0] - hh, c[1] - hh), hh*2, hh*2, linewidth=1, edgecolor='r', facecolor='none')
+ ax.add_patch(square)
+
+ circle = Circle(pos, radius)
+ ax.add_patch(circle)
+
+ plt.xlim([0, 5])
+ plt.ylim([0, 5])
+ plt.show()
+'''
+
+@partial(jax.jit)
+def map_collision_square(pos: chex.Array, map_grid: chex.Array, radius: float) -> bool:
+ """ This is for a square agent that cannot rotate """
+ # calculate which grid cells robot overlaps in, assumes radius<1
+ min_x, min_y = jnp.floor(jnp.maximum(jnp.zeros(2), pos-radius)).astype(int)
+ max_x, max_y = jnp.floor(jnp.minimum(jnp.array(map_grid.shape), pos+radius)).astype(int)
+ rg = jnp.zeros(map_grid.shape, dtype=int)
+
+ rg = rg.at[min_y, min_x].set(1)
+ rg = rg.at[min_y, max_x].set(1)
+ rg = rg.at[max_y, min_x].set(1)
+ rg = rg.at[max_y, max_x].set(1)
+
+ return jnp.any((rg+map_grid)>1)
+
+
+### --- LOADING FILES ---
+def load_map(name: str) -> chex.Array:
+ ''' Load map using jnp
+ Possible map names
+ - blank: 30x30 blank map
+ '''
+ print('load map: ', name)
+ return load_map_array(MAP_PATHS[name])
+
+
+def load_max_cases(map_name: str, num_agents: int):
+ prefix = f"{map_name}_{num_agents}_agents_"
+ parent_dir_path = pathlib.Path(__file__).parent.resolve()
+ dir_path = os.path.join(parent_dir_path, pathlib.Path(f"sample_cases/{map_name}/"))
+ prefixed = [filename for filename in os.listdir(dir_path) if filename.startswith(prefix)]
+
+ num_cases = max([p.split("_")[3] for p in prefixed])
+
+ return load_cases(map_name, num_agents, num_cases)
+
+
+def load_cases(map_name: str, num_agents: int, num_cases: int):
+
+ filename = f"sample_cases/{map_name}/{map_name}_{num_agents}_agents_{num_cases}_cases.npy"
+ parent_dir_path = pathlib.Path(__file__).parent.resolve()
+ filepath = os.path.join(parent_dir_path, filename)
+ return jnp.load(filepath)
+
+
+def load_map_array(filename: str) -> chex.Array:
+ parent_dir_path = pathlib.Path(__file__).parent.resolve()
+ return jnp.load(os.path.join(parent_dir_path, pathlib.Path("maps/" + filename)))
+
+
+
+
+if __name__ == "__main__":
+
+ pos = jnp.array([0.9, 0.9])
+ c = jnp.array([1.5, 1.5])
+ rad = 0.3
+
+ map_grid = jnp.zeros((5, 5))
+ map_grid = map_grid.at[1:,1:].set(1)
+
+ map_collision(pos, map_grid, rad)
+
+
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/jaxnav_viz.py b/jaxmarl/environments/jaxnav/jaxnav_viz.py
new file mode 100644
index 00000000..7cd74801
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/jaxnav_viz.py
@@ -0,0 +1,108 @@
+""" Built off gymnax vizualizer.py"""
+
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+from typing import Optional, List
+
+from .jaxnav_env import JaxNav
+import jax.numpy as jnp
+
+class JaxNavVisualizer(object):
+ def __init__(self,
+ env: JaxNav,
+ obs_seq: List,
+ state_seq: List,
+ reward_seq: List=None,
+ done_frames=None,
+ title_text: str=None,
+ plot_lidar=True,
+ plot_path=True,
+ plot_agent=True,
+ plot_reward=True,
+ plot_line_to_goal=True,):
+ self.env = env
+
+ self.interval = 15
+ self.obs_seq = obs_seq
+ self.state_seq = state_seq
+ self.reward_seq = reward_seq
+ self.done_frames = done_frames
+ self.reward = 0.0
+ self.plot_lidar = plot_lidar
+ self.plot_agent = plot_agent
+ self.plot_path = plot_path
+ self.plot_line_to_goal = plot_line_to_goal
+ self.title_text = title_text
+ if (plot_reward) and (reward_seq is not None):
+ self.plot_reward=True
+ else:
+ self.plot_reward=False
+
+ self.fig, self.ax = plt.subplots(1, 1, figsize=(5, 5))
+
+ if self.plot_path:
+ self.path_seq = jnp.empty((len(self.state_seq), env.num_agents, 2))
+ for i in range(len(self.state_seq)):
+ self.path_seq = self.path_seq.at[i].set(self.state_seq[i].pos)
+
+
+ def animate(
+ self,
+ save_fname: Optional[str] = None,
+ view: bool = False,
+ ):
+ """Anim for 2D fct - x (#steps, #pop, 2) & fitness (#steps, #pop)"""
+ ani = animation.FuncAnimation(
+ self.fig,
+ self.update,
+ frames=len(self.state_seq),
+ init_func=self.init,
+ blit=False,
+ interval=self.interval,
+ )
+ # Save the animation to a gif
+ if save_fname is not None:
+ ani.save(save_fname)
+ # Simply view it 3 times
+ if view:
+ plt.show(block=False)
+ plt.pause(3)
+ plt.close()
+
+ def init(self):
+ self.env.init_render(self.ax, self.state_seq[0], self.obs_seq[0], lidar=self.plot_lidar, agent=self.plot_agent, goal=False)
+
+ def update(self, frame):
+ self.ax.cla()
+ if self.plot_path:
+ for a in range(self.env.num_agents):
+ plot_frame = frame
+ if self.done_frames[a] < frame:
+ plot_frame = self.done_frames[a]
+ self.env.map_obj.plot_agent_path(self.ax, self.path_seq[:plot_frame, a, 0], self.path_seq[:plot_frame, a, 1])
+ # self.ax.plot(self.path_seq[:frame, 0], self.path_seq[:frame, 1], color='b', linewidth=2.0, zorder=1)
+ self.env.init_render(self.ax, self.state_seq[frame], self.obs_seq[frame], lidar=self.plot_lidar, agent=self.plot_agent)
+ txt_to_plot = []
+ txt_to_plot.append(f"Time: {frame*self.env.dt:.2f} s")
+ # self.ax.text(0.05, 0.95, f"Time: {frame*self.env.dt:.2f} s", transform=self.ax.transAxes, fontsize=12, verticalalignment='top', c='w')
+ if self.plot_reward:
+ self.reward += self.reward_seq[frame]
+ txt_to_plot.append(f"R: {self.reward:.2f}")
+ if self.title_text is not None:
+ title_text = self.title_text + ' ' + ' '.join(txt_to_plot)
+ else:
+ title_text = ' '.join(txt_to_plot)
+ self.ax.set_title(title_text)
+ # self.ax.text(0.05, 0.9, f"R: {self.reward:.2f}", transform=self.ax.transAxes, fontsize=12, verticalalignment='top', c='w')
+ # if len(txt_to_plot) > 0:
+ # self.ax.text(0.05, 0.95, ' '.join(txt_to_plot), transform=self.ax.transAxes, fontsize=12, verticalalignment='top', c='w')
+
+ # if self.plot_line_to_goal:
+ # for a in range(self.env.num_agents):
+ # plot_frame = frame
+ # if self.done_frames[a] < frame:
+ # plot_frame = self.done_frames[a]
+ # x = jnp.concatenate([self.state_seq[plot_frame].pos[a, 0][None], self.state_seq[plot_frame].goal[a, 0][None]])
+ # y = jnp.concatenate([self.state_seq[plot_frame].pos[a, 1][None], self.state_seq[plot_frame].goal[a, 1][None]]) self.ax.plot(,
+ # jnp.concatenate([self.state_seq[plot_frame].pos[a, 1][None], self.state_seq[plot_frame].goal[a, 1][None]]),
+ # color='gray', alpha=0.4)
diff --git a/jaxmarl/environments/jaxnav/maps/__init__.py b/jaxmarl/environments/jaxnav/maps/__init__.py
new file mode 100644
index 00000000..c7e5aabf
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/maps/__init__.py
@@ -0,0 +1,3 @@
+from .map import Map
+from .grid_map import GridMapPolygonAgents
+from .map_registration import make_map
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/maps/grid_map.py b/jaxmarl/environments/jaxnav/maps/grid_map.py
new file mode 100644
index 00000000..39bb7a7e
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/maps/grid_map.py
@@ -0,0 +1,1006 @@
+import jax
+import jax.numpy as jnp
+from functools import partial
+import os
+import pickle
+from typing import Tuple, List
+from .map import Map
+
+import numpy as np
+import chex
+from enum import IntEnum
+import matplotlib.axes._axes as axes
+
+import jaxmarl.environments.jaxnav.jaxnav_graph_utils as _graph_utils
+
+def rotation_matrix(theta: float) -> jnp.ndarray:
+ """ Rotate about the z axis. Assume theta in radians """
+ return jnp.array([
+ [jnp.cos(theta), -jnp.sin(theta)],
+ [jnp.sin(theta), jnp.cos(theta)]
+ ])
+
+class SampleTestCaseTypes(IntEnum):
+ RANDOM = 0
+ GRID = 1
+
+class GridMapCircleAgents(Map):
+
+ def __init__(self,
+ num_agents: int,
+ rad,
+ map_size: Tuple[int, int],
+ fill: float=0.4,
+ cell_size: float=1.0,
+ sample_test_case_type='random',
+ **map_kwargs):
+ super().__init__(num_agents, rad, map_size, **map_kwargs)
+ assert self.rad < 1 # collision code only works for radius <1
+
+ self.width = map_size[0]
+ self.height = map_size[1]
+ self.length = self.width*self.height
+ self.pos_offset = jnp.full((2,), 0.5)
+
+ self.cell_size = cell_size
+ self.scaled_rad = self.scale_coords(rad)
+ self.circle_check_window = jnp.ceil(self.scaled_rad).astype(jnp.int32)
+ idxs = jnp.arange(-self.circle_check_window-1, self.circle_check_window+1)
+ self.cc_x_idx, self.cc_y_idx = jnp.meshgrid(idxs, idxs)
+ self.cell_half_height = self.cell_size / 2
+
+ # determine max number of blocks
+ if sample_test_case_type == 'random':
+ self.sample_test_case_type = SampleTestCaseTypes.RANDOM
+ elif sample_test_case_type == 'grid':
+ self.sample_test_case_type = SampleTestCaseTypes.GRID
+ else:
+ raise ValueError(f"Invalid sample_test_case_type: {sample_test_case_type}")
+ self.fill = fill
+ self.free_grids = (self.width-2)*(self.height-2)
+ self.n_clutter = jnp.floor(self.free_grids*self.fill).astype(int)
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_test_case(self, rng: chex.PRNGKey):
+
+ return jax.lax.switch(
+ self.sample_test_case_type,
+ [
+ super().sample_test_case,
+ self.grid_sample_test_case,
+ ],
+ rng
+ )
+
+ def grid_sample_test_case(self, key):
+ """ NOTE this won't throw an error if it's not possible, will just loop forever"""
+ assert self.cell_size == 1.0
+
+ key, _key = jax.random.split(key)
+ map_data = self.sample_map(_key)
+ inside_grid = map_data.at[1:-1, 1:-1].get()
+ iwidth = self.width - 2
+
+ def _sample_pair(key, start_masks, goal_masks):
+
+ flat_occ = start_masks.flatten()
+ key, _key = jax.random.split(key)
+ start_idx = jax.random.choice(_key, len(flat_occ), (1,), p=(1-flat_occ))[0]
+ start = jnp.array([start_idx % iwidth, start_idx // iwidth]) # [x, y]
+ actual_idx = (start + 1).astype(jnp.int32)
+ # connected_region = _graph_utils.component_mask_with_pos(inside_grid, start_idx) # BUG not working on inside grid
+ if self.valid_path_check:
+ connected_region = _graph_utils.component_mask_with_pos(map_data, actual_idx).at[1:-1, 1:-1].get()
+ else:
+ connected_region = 1-inside_grid
+ masked_start = connected_region.at[start[1], start[0]].set(0)
+ goal_possibilities = masked_start & (1 - goal_masks)
+ valid = jnp.any(goal_possibilities) # only valid if possible goal locations
+
+ goal_idx = jax.random.choice(key, len(flat_occ), (1,), p=goal_possibilities.flatten())[0]
+ goal = jnp.array([goal_idx % iwidth, goal_idx // iwidth]) # [x, y]
+
+ return valid, start, goal
+
+ def scan_fn(carry, rng):
+ i, pos, start_mask, goal_mask = carry
+ def _cond_fn(val):
+ return val[0]
+
+ def _body_fn(val):
+ valid, rng, pos = val
+
+ rng, _rng_pair = jax.random.split(rng)
+ valid, start, goal = _sample_pair(_rng_pair, start_mask, goal_mask)
+
+ positions = jnp.concatenate([start[None], goal[None]], axis=0)
+ pos = pos.at[i].set(positions)
+
+ return jnp.bitwise_not(valid), rng, pos
+
+ (_, rng, pos) = jax.lax.while_loop(
+ _cond_fn,
+ _body_fn,
+ (True, rng, pos)
+ )
+
+ start = pos.at[i, 0].get().astype(jnp.int32)
+ goal = pos.at[i, 1].get().astype(jnp.int32)
+ start_mask = start_mask.at[start[1], start[0]].set(1)
+ goal_mask = goal_mask.at[goal[1], goal[0]].set(1)
+ return (i+1, pos, start_mask, goal_mask), None
+
+
+ fill_max = jnp.max(jnp.array(self.map_size)) + self.rad*2
+ pos = jnp.full((self.num_agents, 2, 2), fill_max) # [num_agents, [start_pose, goal_pose]]
+
+ key, key_scan = jax.random.split(key)
+ key_scan = jax.random.split(key_scan, self.num_agents)
+ (_, pos, _, _), _ = jax.lax.scan(scan_fn, (0, pos, inside_grid, inside_grid), key_scan)
+ theta = jax.random.uniform(key, (self.num_agents, 2, 1), minval=-jnp.pi, maxval=jnp.pi)
+ cases = jnp.concatenate([pos + 1.5, theta], axis=2)
+
+ return map_data, cases
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_map(self, key):
+ """ Sample map grid from uniform distribution """
+
+ key_fill, key_shuff = jax.random.split(key)
+
+ base_map = self._gen_base_grid()
+
+ free_idx = jnp.arange(0, self.free_grids)
+ num_fill = jax.random.randint(key_fill, (1,), 0, self.n_clutter)[0]
+
+ map_within = jnp.where(free_idx1)
+ if rad is None: rad = self.rad
+
+ return jax.lax.switch(
+ int(self.cell_size == 1.0),
+ [
+ _variable_grid_size_check,
+ _grid_size_of_1_check,
+ ],
+ *(pos, rad, map_grid)
+ )
+
+
+ def get_circle_map_occupancy_mask(self, pos, map_grid, rad=None):
+ if rad is None: rad=self.rad
+
+ wall_map = jnp.zeros(map_grid.shape, dtype=jnp.int32)
+
+ theta = jnp.linspace(0, 2*jnp.pi, 100)
+ pos_to_check = jnp.array([jnp.cos(theta), jnp.sin(theta)]).T * rad + pos
+ print('pos to check shape', pos_to_check.shape)
+ idxs = jnp.floor(pos_to_check).astype(int)
+ wall_map = wall_map.at[idxs[:, 1], idxs[:, 0]].set(1)
+
+ x_mesh, y_mesh = jnp.meshgrid(jnp.arange(0, self.width), jnp.arange(0, self.height))
+ mesh = jnp.dstack([x_mesh, y_mesh]).reshape((-1,2))
+ cc = jnp.linalg.norm(mesh - pos, axis=1) < rad
+ inside_mask = cc.reshape((self.height, self.width)).astype(int)
+ print('inside mask', inside_mask)
+ return wall_map | inside_mask
+
+
+ def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, rad=None):
+ """ Check for intersection between a lidar beam and an agent. """
+ if rad is None: rad = self.rad
+ d = beam[-1] - beam[0]
+ f = beam[0] - pos
+
+ a = jnp.dot(d, d)
+ b = 2*jnp.dot(f, d)
+ c = jnp.dot(f, f) - self.rad**2
+
+ descrim = b**2 - 4*a*c
+
+ t1 = (-b - jnp.sqrt(descrim))/(2*a)
+ # t2 = (-b + jnp.sqrt(descrim))/(2*a)
+
+ miss = (descrim < 0) | (t1 < 0) | (t1 > 1) #| (host_idx==other_idx) # | (t2 < 0) | (t2 > 1)
+
+ intersect = beam[0] + t1*d
+ idx = jnp.floor(jnp.linalg.norm(intersect - beam[0])/range_resolution).astype(int)
+ return jax.lax.select(miss, -1, idx)
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_point_map_collision(self, pos, map_grid):
+ """ For a point """
+ pos = jnp.floor(self.scale_coords(pos)).astype(int)
+ return map_grid.at[pos[1], pos[0]].get() == 1
+
+ @partial(jax.jit, static_argnums=[0])
+ def _gen_base_grid(self):
+ """ Generate base grid map with walls around border """
+
+ map = jnp.zeros((self.height, self.width), dtype=int)
+ map = map.at[0,:].set(1)
+ map = map.at[-1,:].set(1)
+ map = map.at[:, 0].set(1)
+ map = map.at[:, -1].set(1)
+
+ return map
+
+ @partial(jax.vmap, in_axes=[None, 0, None, None])
+ def _check_grid(self, c, pos, radius):
+ p = jnp.clip(pos - c, -self.cell_half_height, self.cell_half_height)
+ p = p + c
+ return jnp.linalg.norm(p - pos) <= radius
+
+ @partial(jax.jit, static_argnums=[0, 4])
+ def check_line_collision(self, pos1, pos2, map_data, max_length=5.0):
+ """ uses same method as lidar (ray tracing) """
+ resolution = 0.05
+ line_length = jnp.linalg.norm(pos2-pos1)
+ angle = jnp.arctan2(pos2[1]-pos1[1], pos2[0]-pos1[0])
+
+ points = jnp.arange(0, max_length, resolution)
+ points = points[:, None] * jnp.array([jnp.cos(angle), jnp.sin(angle)]) + pos1
+ coords = jnp.floor(points).astype(int)
+ lidar_hits = map_data[coords[:, 1], coords[:, 0]]
+
+ num_points = jnp.floor(line_length/resolution).astype(int)
+ idx_range = jnp.arange(points.shape[0])
+ lidar_mask = jnp.where(idx_range len(colors)) and colour_agents_by_idx:
+ print('Too many agents to colour by index')
+ colour_agents_by_idx = False
+
+ colours = ['black' if done else 'red' for done in done]
+ if colour_agents_by_idx:
+ colours = ['black' if done else colors[i] for i, done in enumerate(done)]
+
+ for i in range(done.shape[0]):
+ circle = Circle(pos[i], rad, facecolor=colours[i])
+ ax.add_patch(circle)
+
+ x = pos[i][0] + rad * np.cos(theta[i])
+ y = pos[i][1] + rad * np.sin(theta[i])
+ ax.plot([pos[i][0], x], [pos[i][1], y], color='black')
+
+ if plot_line_to_goal:
+ ax.plot([pos[i][0], goal[i][0]], [pos[i][1], goal[i][1]], color='black', alpha=0.5)
+
+ def plot_agent_path(self,
+ ax: axes.Axes,
+ x_seq: jnp.ndarray,
+ y_seq: jnp.ndarray,):
+ """ Plot agent path """
+ x = self.scale_coords(x_seq)
+ y = self.scale_coords(y_seq)
+ ax.plot(x, y, c='b', linewidth=2.0, zorder=1)
+
+default_coords = jnp.array([
+ [-0.25, -0.25],
+ [-0.25, 0.25],
+ [0.25, 0.25],
+ [0.25, -0.25],
+])
+jackal_coords = jnp.array([
+ [-0.254, -0.215],
+ [-0.254, 0.215],
+ [0.254, 0.215],
+ [0.254, -0.215],
+])
+middle_line = jnp.array([
+ [0.0, 0.0],
+ [0.254, 0.0],
+])
+
+class GridMapPolygonAgents(GridMapCircleAgents):
+
+ def __init__(self,
+ num_agents: int,
+ rad,
+ map_size,
+ agent_coords=default_coords,
+ middle_line=middle_line,
+ **map_kwargs):
+ super().__init__(num_agents, rad, map_size, **map_kwargs)
+
+ self.agent_coords=agent_coords
+ self.middle_line=middle_line
+
+ # define window around agent to check for map collisions
+ min_x = jnp.min(agent_coords[:, 0])
+ max_x = jnp.max(agent_coords[:, 0])
+ min_y = jnp.min(agent_coords[:, 1])
+ max_y = jnp.max(agent_coords[:, 1])
+ side_length = jnp.maximum(max_x - min_x, max_y - min_y)
+
+ cell_ratio = side_length / self.cell_size
+ # print('cell ratio', cell_ratio)
+
+ self.agent_window = jnp.ceil(cell_ratio*2).astype(int) # NOTE times 2 should be enough
+ self.idx_offsets = jnp.arange(-self.agent_window, self.agent_window+1)
+ # print('side length', side_length)
+ # print('agent window', self.agent_window)
+ # print('idx offsets', self.idx_offsets)
+
+
+ # 2D with one set of coords for all agents
+ assert (len(self.agent_coords.shape) == 2)
+ # or \
+ # ((self.agent_coords.shape[0] == self.num_agents) and \
+ # (len(self.agent_coords.shape) == 3))
+
+ @partial(jax.jit, static_argnums=[0])
+ def transform_coords(self, pos, theta, coords):
+ r = rotation_matrix(theta)
+ return jnp.matmul(r, coords.T).T + pos
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_agent_map_collision(self, pos, theta, map_grid, agent_coords=None):
+ """ Check for collision between agent and map. For polygon agents.
+ For now assuming all agents have the same shape and that side lengths
+ are less than the grid size. """
+
+ if agent_coords is None: agent_coords = self.agent_coords
+
+ idx_to_check = jnp.floor(pos / self.cell_size).squeeze() # [x, y]
+ idx_0 = (idx_to_check[0] + self.idx_offsets).astype(int)
+ idx_1 = (idx_to_check[1] + self.idx_offsets).astype(int)
+
+ idx_pairs = jax.vmap(
+ lambda x, y: jax.vmap(lambda a, b: jnp.array([a, b]), in_axes=(None, 0))(x, y),
+ in_axes=(0, None)
+ )(idx_1, idx_0).reshape((-1, 2))
+
+ # need to scale to take account of grid size
+ transformed_coords = self.transform_coords(pos, theta.squeeze(), agent_coords)
+ scaled_coords = transformed_coords / self.cell_size
+ scaled_coords_appended = jnp.concatenate([scaled_coords, scaled_coords[0, :][None]], axis=0)
+
+ cut = jnp.any(
+ jax.vmap(self._checkGrid,
+ in_axes=(None, None, 0, None))(scaled_coords_appended[:-1], scaled_coords_appended[1:], idx_pairs, map_grid)
+ )
+
+ inside = jnp.any(
+ jax.vmap(self._checkInsideGrid,
+ in_axes=(None, 0, None))(scaled_coords, idx_pairs, map_grid)
+ )
+ return cut | inside
+
+ @partial(jax.jit, static_argnums=[0])
+ def get_agent_map_occupancy_mask(self, pos, theta, map_grid, agent_coords=None):
+
+ if agent_coords is None: agent_coords = self.agent_coords
+
+ map_mask = jnp.ones(map_grid.shape, dtype=jnp.int32)
+
+ idx_to_check = jnp.floor(pos / self.cell_size).squeeze()
+ idx_0 = (idx_to_check[0] + self.idx_offsets).astype(int)
+ idx_1 = (idx_to_check[1] + self.idx_offsets).astype(int)
+
+ idx_pairs = jax.vmap(
+ lambda x, y: jax.vmap(lambda a, b: jnp.array([a, b]), in_axes=(None, 0))(x, y),
+ in_axes=(0, None)
+ )(idx_1, idx_0).reshape((-1, 2))
+
+ # need to scale to take account of grid size
+ transformed_coords = self.transform_coords(pos, theta.squeeze(), agent_coords)
+ scaled_coords = transformed_coords / self.cell_size
+ scaled_coords_appended = jnp.concatenate([scaled_coords, scaled_coords[0, :][None]], axis=0)
+
+ cut = jax.vmap(self._checkGrid,
+ in_axes=(None, None, 0, None))(scaled_coords_appended[:-1], scaled_coords_appended[1:], idx_pairs, map_mask)
+ inside = jax.vmap(self._checkInsideGrid,
+ in_axes=(None, 0, None))(scaled_coords, idx_pairs, map_mask)
+
+ collisions = cut | inside
+ valid_idx = (idx_pairs[:, 0] >= 0) & (idx_pairs[:, 0] < self.height) & (idx_pairs[:, 1] >= 0) & (idx_pairs[:, 1] < self.width) & collisions
+ idx_pairs = jnp.where(jnp.repeat(valid_idx[:, None], 1, 1), idx_pairs, jnp.zeros(idx_pairs.shape)).astype(int)
+ map_mask = jnp.zeros(map_grid.shape, dtype=jnp.int32)
+
+ return map_mask.at[idx_pairs[:, 0], idx_pairs[:, 1]].set(1)
+
+
+ def _checkGrid(self, x1y1, x2y2, grid_idx, map_grid):
+
+ def _checkLineLine(x1, y1, x2, y2, x3, y3, x4, y4):
+ """ Check collision between line (x1, y1) -- (x2, y2) and line (x3, y3) -- (x4, y4) """
+ # TODO understand this
+ uA = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / ((y4-y3)*(x2-x1) - (x4-x3)*(y2-y1))
+ uB = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / ((y4-y3)*(x2-x1) - (x4-x3)*(y2-y1))
+ c = (uA >= 0) & (uA <= 1) & (uB >= 0) & (uB <= 1)
+ return c.astype(jnp.bool_)
+
+ def _checkLineRect(x1, y1, x2, y2, rx, ry):
+ """ Check collision between line (x1, y1) -- (x2, y2) and rectangle with bottom left corner at (rx, ry)
+ and width and height of 1."""
+ vmap_checkLineLine = jax.vmap(_checkLineLine, in_axes=(None, None, None, None, 0, 0, 0, 0))
+ x3 = jnp.array([0, 1, 0, 0]) + rx
+ y3 = jnp.array([0, 0, 0, 1]) + ry
+ x4 = jnp.array([0, 1, 1, 1]) + rx
+ y4 = jnp.array([1, 1, 0, 1]) + ry
+ checks = vmap_checkLineLine(x1, y1, x2, y2, x3, y3, x4, y4)
+ return jnp.any(checks)
+
+ @partial(jax.vmap, in_axes=(0, 0, None))
+ def _checkSide(x1y1, x2y2, grid_idx):
+ x1, y1 = x1y1
+ x2, y2 = x2y2
+ ry, rx = grid_idx[0], grid_idx[1]
+ filled = map_grid[ry, rx]
+ c = _checkLineRect(x1, y1, x2, y2, rx, ry)
+ return c & filled
+
+ valid_idx = (grid_idx[0] >= 0) & (grid_idx[0] < self.height) & (grid_idx[1] >= 0) & (grid_idx[1] < self.width)
+ return jnp.any(_checkSide(x1y1, x2y2, grid_idx)) & valid_idx
+
+ def _checkInsideGrid(self, sides, grid_idx, map_grid):
+
+ def _checkPolyWithinRect(sides, rx, ry):
+ """ Check if polygon is within rectangle with bottom left corner at (rx, ry) and width and height of 1."""
+
+ def _checkPointRect(x, y, rx, ry):
+ """ Check if point (x, y) is within rectangle with bottom left corner at (rx, ry) and width and height of 1."""
+ return (x >= rx) & (x <= rx+1) & (y >= ry) & (y <= ry+1)
+
+ vmap_checkPointRect = jax.vmap(_checkPointRect, in_axes=(0, 0, None, None))
+ checks = vmap_checkPointRect(sides[:, 0], sides[:, 1], rx, ry)
+ return jnp.all(checks)
+
+ valid_idx = (grid_idx[0] >= 0) & (grid_idx[0] < self.height) & (grid_idx[1] >= 0) & (grid_idx[1] < self.width)
+ inside = _checkPolyWithinRect(sides, grid_idx[1], grid_idx[0])
+ return inside & map_grid[grid_idx[0], grid_idx[1]] & valid_idx
+
+ def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, agent_coords=None):
+ """ Check for intersection between a lidar beam and an agent. """
+ if agent_coords is None: agent_coords = self.agent_coords
+
+ @partial(jax.vmap, in_axes=(None, None, 0, 0))
+ def _checkSide(beam_start, beam_end, side_start, side_end):
+ """ Check collision between line (x1, y1) -- (x2, y2) and line (x3, y3) -- (x4, y4) """
+ # TODO understand this
+ x1, y1 = beam_start
+ x2, y2 = beam_end
+ x3, y3 = side_start
+ x4, y4 = side_end
+
+ uA = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / ((y4-y3)*(x2-x1) - (x4-x3)*(y2-y1))
+ uB = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / ((y4-y3)*(x2-x1) - (x4-x3)*(y2-y1))
+ c = (uA >= 0) & (uA <= 1) & (uB >= 0) & (uB <= 1)
+
+ ix = x1 + (uA * (x2-x1))
+ iy = y1 + (uA * (y2-y1))
+ intersect = jnp.array([ix, iy])
+ idx = jnp.floor(jnp.linalg.norm(intersect - beam[0])/range_resolution)
+
+ return jax.lax.select(c, idx, jnp.inf)
+
+ tc = self.transform_coords(pos, theta, agent_coords)
+ tc = jnp.concatenate([tc, tc[0, :][None]], axis=0)
+
+ idxs = _checkSide(beam[0], beam[-1], tc[:-1], tc[1:])
+ idx = jnp.min(idxs)
+ return jax.lax.select(idx==jnp.inf, -1.0, idx).astype(int)
+
+ def plot_agents(
+ self,
+ ax: axes.Axes,
+ pos: jnp.ndarray,
+ theta: jnp.ndarray,
+ goal: jnp.ndarray,
+ done: jnp.ndarray,
+ plot_line_to_goal=True,
+ agent_coords=None,
+ middle_line=None,
+ colour_agents_by_idx=False,
+ ):
+ """ Plot agents """
+ from matplotlib.patches import Polygon
+ if agent_coords is None: agent_coords = self.agent_coords
+ if middle_line is None: middle_line = self.middle_line
+ colors = ['red', 'blue', 'green', 'purple', 'orange', 'pink', 'yellow', 'brown', 'grey', 'cyan']
+ if (done.shape[0] > len(colors)) and colour_agents_by_idx:
+ print('Too many agents to colour by index')
+ colour_agents_by_idx = False
+
+ colours = ['black' if done else 'red' for done in done]
+ if colour_agents_by_idx:
+ colours = ['black' if done else colors[i] for i, done in enumerate(done)]
+
+ for i in range(done.shape[0]):
+ transformed_coords = self.transform_coords(pos[i], theta[i], agent_coords) / self.cell_size
+
+ poly = Polygon(transformed_coords, facecolor=colours[i])
+ ax.add_patch(poly)
+
+ # middle line
+ transformed_middle_line = self.transform_coords(pos[i], theta[i], self.middle_line) / self.cell_size
+ ax.plot(transformed_middle_line[:, 0], transformed_middle_line[:, 1], color='black')
+
+ if plot_line_to_goal:
+ pos_scaled = self.scale_coords(pos[i])
+ goal_scaled = self.scale_coords(goal[i])
+ ax.scatter(goal_scaled[0], goal_scaled[1], marker='+', c='g')
+ ax.plot([pos_scaled[0], goal_scaled[0]], [pos_scaled[1], goal_scaled[1]], color='black', alpha=0.5)
+
+SMOOTHING_IDX_OFFSETS = jnp.array([[0, 1], [0, -1], [1, 0], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]])
+INFLATE_IDX_OFFSETS_3 = SMOOTHING_IDX_OFFSETS
+INFLATE_IDX_OFFSETS_5 = jnp.concatenate([INFLATE_IDX_OFFSETS_3, jnp.array([[2, 0], [2, 1], [2, 2], [1, 2], [0, 2], [-1, 2], [-2, 2], [-2, 1], [-2, 0], [-2, -1], [-2, -2], [-1, -2], [0, -2], [1, -2], [2, -2], [2, -1],])], axis=0)
+
+
+class GridMapBarn(GridMapPolygonAgents):
+
+ def __init__(self,
+ num_agents,
+ rad,
+ map_size,
+ smoothing_iters=4,
+ smoothing_upper_count=3,
+ smoothing_lower_count=1,
+ agent_coords=jackal_coords,
+ cell_size=0.15,
+ **map_kwargs):
+
+ super().__init__(num_agents, rad, map_size, agent_coords=agent_coords, cell_size=cell_size, **map_kwargs)
+
+ self.smoothing_iters = smoothing_iters
+ self.smoothing_upper_count = smoothing_upper_count
+ self.smoothing_lower_count = smoothing_lower_count
+ assert self.smoothing_upper_count > self.smoothing_lower_count, 'smoothing upper count must be greater than lower count'
+
+ self.inner_idx = jnp.array([[i, j] for i in range(1, self.height-1) for j in range(1, self.width-1)])
+ self.outer_idx = jnp.array([[i, j] for i in range(self.height) for j in range(self.width) if (i==0) or (i==self.height-1) or (j==0) or (j==self.width-1)])
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_test_case(self, rng: chex.PRNGKey):
+
+ return self.sample_barn_test_case(rng)
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_barn_test_case(self, rng):
+
+ def _cond_fun(val):
+ valid, test_case, rng = val
+ # jax.debug.print('valid {v}', v=valid)
+ return ~valid
+
+ def _body_fun(val):
+ valid, test_case, rng = val
+ rng, _rng = jax.random.split(rng)
+ valid, test_case = self.barn_test_case(_rng)
+ return (valid, test_case, rng)
+
+ init_test_case = (jnp.zeros((2,)), jnp.zeros((2,)), jnp.zeros((self.height, self.width), dtype=jnp.int32))
+
+ test_case = jax.lax.while_loop(
+ _cond_fun, _body_fun, (False, init_test_case, rng)
+ )
+ valid, test_case, rng = test_case
+
+ start, goal, smoothed_map = test_case
+
+ theta = (jnp.pi/2) * jax.random.choice(rng, jnp.arange(4), (2,))
+ test_case = jnp.vstack([start, goal])
+ test_case = jnp.concatenate([test_case, theta[:,None]], axis=1)
+
+ return smoothed_map, test_case[None]
+
+ def barn_test_case(self, rng):
+ # from matplotlib import pyplot as plt
+ rng, _rng = jax.random.split(rng)
+ base_map = self.sample_map(_rng)
+
+ def _smooth_fn(map_data, unused):
+ inner_map = jax.vmap(self._smooth, in_axes=(0, None))(self.inner_idx, map_data)
+ inner_map = inner_map.reshape((self.height-2, self.width-2))
+ map_data = map_data.at[1:-1, 1:-1].set(inner_map)
+ return map_data, None
+
+ smoothed_map, _ = jax.lax.scan(
+ _smooth_fn, base_map, None, length=self.smoothing_iters
+ )
+
+ def _inflate_obs(idx, map):
+ value = map.at[idx[0], idx[1]].get()
+ around = map.at[INFLATE_IDX_OFFSETS_5[:,0]+idx[0], INFLATE_IDX_OFFSETS_5[:,1]+idx[1]].set(value)
+ return around
+
+ inner_inflated_map = jax.vmap(_inflate_obs, in_axes=(0, None))(self.inner_idx, smoothed_map).any(axis=0)
+ outer_inflated_map = jax.vmap(_inflate_obs, in_axes=(0, None))(self.outer_idx, smoothed_map).any(axis=0)
+ inflated_map = inner_inflated_map | outer_inflated_map
+ # print('inflated map\n', inflated_map)
+ # print('valid cells', (1-inflated_map).sum())
+
+ rng, _rng = jax.random.split(rng)
+ start = jax.random.choice(rng, jnp.arange(self.height*self.width), p=~inflated_map.flatten())
+ start = jnp.array([start % self.width, start // self.width]) # [x, y]
+ # print('start:', start)
+ # fig, ax = plt.subplots()
+ # inflated_to_plot = inflated_map * 0.1 + smoothed_map * 0.9
+ # ax.imshow(inflated_to_plot, cmap='binary')
+ # plt.savefig('barn-inflated-map.png')
+
+
+ # print('start_idx:', start)
+ # with jax.disable_jit(False):
+ connected_region = _graph_utils.component_mask_with_pos(inflated_map, start)
+
+
+ min_dist = 19
+ # empty = jnp.zeros((self.height, self.width), dtype=jnp.int32)
+ x_lim = jnp.clip(jnp.array([start[0]-min_dist, start[0]+min_dist+1]), 0, self.width)
+ y_lim = jnp.clip(jnp.array([start[1]-min_dist, start[1]+min_dist+1]), 0, self.height)
+ print('x_lim:', x_lim, 'y_lim:', y_lim)
+
+ too_close_mask = jnp.ones((self.height, self.width))
+ xrange = jnp.arange(self.width)
+ too_close_mask_x = jnp.where((xrange >= x_lim[0]) & (xrange < x_lim[1]), 1, 0)
+ yrange = jnp.arange(self.width)
+ too_close_mask_y = jnp.where((yrange >= y_lim[0]) & (yrange < y_lim[1]), 1, 0)
+
+ too_close_mask = jnp.meshgrid(too_close_mask_x, too_close_mask_y)
+ valid_mask = 1-jnp.dstack(too_close_mask).all(axis=-1)
+ # print('too_close_mask:', valid_mask, valid_mask.shape)
+
+ masked_connected_region = connected_region * valid_mask
+
+ # fig, ax = plt.subplots()
+ # ax.imshow(masked_connected_region, cmap='binary')
+ # plt.savefig('barn-masked-map.png')
+
+ goal = jax.random.choice(rng, jnp.arange(self.height*self.width), p=masked_connected_region.flatten())
+ goal = jnp.array([goal % self.width, goal // self.width]) # [x, y]
+
+
+ # fig, ax = plt.subplots()
+ # ax.imshow(inflated_to_plot, cmap='binary')
+ # ax.plot([start[0], goal[0]], [start[1], goal[1]], c='blue', linestyle='--')
+ # ax.scatter(start[0], start[1], c='red', marker='x')
+ # # ax.scatter(goal % cols, goal // cols, c='green', marker='x')
+ # ax.scatter(goal[0], goal[1], c='green', marker='x')
+ # plt.savefig('barn-final-map.png')
+
+ valid = ((1-inflated_map).sum() > 0) & (masked_connected_region.sum() > 0)
+ # print('test case', (start, goal))
+ return valid, (start * self.cell_size, goal * self.cell_size, smoothed_map)
+
+
+ def _smooth(self, idx, map_data):
+
+ idx_offsets = SMOOTHING_IDX_OFFSETS + idx
+ valid = (idx_offsets[:, 0] > 0) & (idx_offsets[:, 0] < self.height-1) & (idx_offsets[:, 1] > 0) & (idx_offsets[:, 1] < self.width-1)
+ n_full = map_data.at[idx_offsets[:,0], idx_offsets[:,1]].get() * valid
+ n_full = jnp.sum(n_full)
+
+ fill = ((n_full>self.smoothing_upper_count) | map_data.at[idx[0], idx[1]].get()) & (n_full>self.smoothing_lower_count)
+ return jax.lax.select(fill, 1, 0)
+
+class GridMapPolygonAgentsSingleMap(GridMapPolygonAgents):
+
+ def __init__(self,
+ num_agents: int,
+ rad,
+ map_data: List,
+ agent_coords=default_coords,
+ middle_line=middle_line,
+ **map_kwargs):
+
+ self._map_data = jnp.array(
+ [[int(x) for x in row.split()] for row in map_data],
+ dtype=jnp.int32
+ )
+ height, width = self._map_data.shape
+ map_size = (width, height)
+ super().__init__(num_agents=num_agents,
+ rad=rad,
+ map_size=map_size,
+ agent_coords=agent_coords,
+ middle_line=middle_line,
+ **map_kwargs)
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_map(self, key):
+ return self._map_data
+
+class GridMapFromBuffer(GridMapCircleAgents):
+
+ def __init__(self,
+ num_agents,
+ rad,
+ map_size,
+ map_grids=None,
+ dir_path=None,
+ file_prefix="map_buffer_"):
+ """
+ saved map buffers expected in format: (map_data, starts, theta, goals)
+ """
+ print('** Super old code beware **')
+ super().__init__(num_agents, rad, map_size, fill=0.1)
+ if map_grids is None and dir_path is None:
+ raise ValueError("Must specify either map_grids or dir_path")
+ if map_grids is not None and dir_path is not None:
+ raise ValueError("Cannot specify both map_grids and dir_path")
+
+ if dir_path is not None:
+ # list files in dir_path
+ files = [filename for filename in os.listdir(dir_path) if filename.startswith(file_prefix)]
+ print('files', files)
+ test_cases = (
+ jnp.empty((0, self.height, self.width), dtype=jnp.int32),
+ jnp.empty((0, 2), dtype=jnp.float32),
+ jnp.empty((0, 1), dtype=jnp.float32),
+ jnp.empty((0, 2), dtype=jnp.float32),
+ )
+ for filename in files:
+ # load pkls
+ filepath = os.path.join(dir_path, filename)
+ with open(filepath, "rb") as f:
+ tc = pickle.load(f)
+ print('tc c', tc)
+ test_cases = jax.tree_map(lambda x, y: jnp.concatenate((x, y), axis=0), test_cases, tc)
+ self.test_cases = test_cases
+ self.num_test_cases = test_cases[0].shape[0]
+ print('test cases', test_cases)
+
+ if map_grids is not None:
+ raise NotImplementedError("map_grids not implemented yet")
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_scenario(self, key):
+ print('-- sampling scenarios -- ')
+ idx = jax.random.randint(key, (1,), minval=0, maxval=self.num_test_cases)[0]
+ tc = jax.tree_map(lambda x: x[idx], self.test_cases)
+ print('tc ', tc)
+ map_data = tc[0]
+ print('map data', map_data.shape)
+
+ theta = jnp.array([tc[2], 0])
+ print('theta', theta)
+
+ case = jnp.array([tc[1], tc[3]])
+ print('case', case)
+ test_case = jnp.concatenate([case, theta], axis=-1)
+
+ return map_data, test_case
+
+ '''@partial(jax.jit, static_argnums=[0])
+ def sample_map(self, key):
+ """ Sample map grid from pre-specified map grids list """
+ if self.map_grids.shape[0]>1:
+ map_idx = jax.random.randint(key, (1,), minval=0, maxval=len(self.map_grids))[0]
+ map_grid = self.map_grids[map_idx]
+ else:
+ map_grid = self.map_grids[0]
+ return map_grid'''
+
+def rrt_reward(new_pos, pos, goal):
+ goal_reached = jnp.linalg.norm(new_pos - goal) <= 0.3
+ #if goal_reached: print('goal reached')
+ weight_g = 0.2
+ goal_rew = 1
+ rga = weight_g * (jnp.linalg.norm(pos - goal) - jnp.linalg.norm(new_pos - goal))
+ rg = jnp.where(goal_reached, goal_rew, rga)
+ return rg
+
+if __name__ == "__main__":
+ import matplotlib.pyplot as plt
+ import numpy as np
+
+ key = jax.random.PRNGKey(3) # 3, 7, 9
+
+ file_path = "/home/alex/repos/jax-multirobsim/failure_maps/cosmic-waterfall-17"
+ map_gen = GridMapFromBuffer(1, 0.3, (7, 7), dir_path=file_path)
+
+
+ s = map_gen.sample_scenario(key)
+ raise
+ key, key_rrt = jax.random.split(key)
+
+ map_gen = GridMapCircleAgents(1, 0.3, (10, 10), 0.5)
+ map_data, case = map_gen.sample_scenario(key)
+
+ start = case[:, 0, :2].flatten()
+ goal = case[:, 1, :2].flatten()
+ print('case', case, 'start', start, 'goal', goal)
+
+ '''gr, parent = map_gen.a_star(map_data, start, goal)
+ print('parent', parent)
+
+ fig, ax = plt.subplots()
+
+ ax.imshow(map_data, extent=(0, map_data.shape[1], 0, map_data.shape[0]), origin="lower", cmap='binary', alpha=0.8)
+
+ zero_grid = np.zeros((10, 10))
+ x, y = jnp.meshgrid(jnp.arange(map_data.shape[0]), jnp.arange(map_data.shape[1]))#.reshape(-1, 2)
+ coords = jnp.dstack((y.flatten(), x.flatten())).squeeze()
+ for i in range(parent.shape[0]):
+ if parent[i] == -1: continue
+ node = coords[parent[i]]
+ print('node', node)
+ zero_grid[node[0], node[1]] = 1
+
+ ax.imshow(zero_grid, extent=(0, 10, 0, 10), origin="lower", cmap='binary', alpha=0.2)
+
+ map_gen.plot_pose(ax, case)
+ plt.show()
+
+ raise'''
+ #print('map_data', map_data)
+ tree, goalr = map_gen.rrt_star(key_rrt, map_data, start, goal)
+ print('tree', tree, 'goalr', goalr)
+ print('case', case)
+
+ fig, ax = plt.subplots()
+
+ map_gen.plot(ax, map_data)
+
+ #ax.scatter(case[:, 0, 0], case[:, 0, 1], c='r')
+ for n in range(tree.shape[0]):
+ if tree[n, 0] == 0.0: break
+ ax.scatter(tree[n, 0], tree[n, 1], c='gray')
+ pi = tree[n, 2]
+ if pi == -1: continue
+ pi = int(pi)
+ ax.plot([tree[n, 0], tree[pi, 0]], [tree[n, 1], tree[pi, 1]], c='gray', marker='+', alpha=0.75)
+
+
+
+ if goalr:
+ goal_idx = jnp.argwhere(tree[:,-1]==1)
+ print('goal_idx', goal_idx)
+
+ for g_idx in goal_idx:
+ c_idx = g_idx[0]
+ path_length = 0.0
+ rew = 0.0
+ while c_idx != 0:
+ print('cidx', c_idx, 'tree row', tree[c_idx])
+ c_pos = tree[c_idx, :2]
+ p_idx = int(tree[c_idx, 2])
+ p_pos = tree[p_idx, :2]
+ path_length += jnp.linalg.norm(c_pos - p_pos)
+ rew += rrt_reward(c_pos, p_pos, goal)
+ ax.plot([c_pos[0], p_pos[0]], [c_pos[1], p_pos[1]], c='r', alpha=0.25)
+ print('p_pos', p_pos, 'c_pos', c_pos, 'rew', rew)
+ c_idx = p_idx
+ print('path_length', path_length, 'rew', rew)
+
+ map_gen.plot_pose(ax, case)
+ plt.show()
+ '''raise
+ p_idx = parent[goal_idx]
+ print('p_idx', p_idx)
+ while p_idx != -1:
+ print('corrds', coords[p_idx])
+ ax.plot( [coords[p_idx, 1]+0.5, coords[goal_idx, 1]+0.5], [coords[p_idx, 0]+0.5, coords[goal_idx, 0]+0.5], c='r')
+ goal_idx = p_idx
+ p_idx = parent[goal_idx]
+
+
+ plt.show()
+
+ raise
+ pos = jnp.array([1.5, 1.5])
+ x = map_gen.check_agent_collision(pos, map_data)
+ print(x)
+ print(map_data)'''
+
+
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/maps/map.py b/jaxmarl/environments/jaxnav/maps/map.py
new file mode 100644
index 00000000..f1db3b8f
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/maps/map.py
@@ -0,0 +1,435 @@
+import jax
+from jax import numpy as jnp
+import numpy as np
+from functools import partial
+import matplotlib.axes._axes as axes
+import chex
+
+class Map(object):
+ """ Base class for a map """
+
+ def __init__(
+ self,
+ num_agents,
+ rad,
+ map_size,
+ start_pad=1.5,
+ valid_path_check=False,
+ ):
+ assert start_pad>=1.0, 'start_pad must be greater than or equal to 1.0'
+
+ self.num_agents = num_agents
+ self.rad = rad
+ self.map_size = map_size
+ self.start_pad=start_pad
+ self.valid_path_check = valid_path_check
+
+ # Test case sampling TODO fix this to be not hard coded
+ self.dist_to_goal = 50 # 5
+ self.rrt_samples = 1000
+ self.rrt_step_size = 0.25
+ self.goal_radius = 0.3
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_scenario(self, key):
+ """ Sample map grid and agent start/goal positions """
+
+ key_map, key_case = jax.random.split(key)
+ map_data = self.sample_map(key_map)
+ test_case = self.sample_test_case(key_case, map_data)
+ return map_data, test_case
+
+ @partial(jax.vmap, in_axes=[None, 0])
+ def sample_scenarios(self, key):
+ return self.sample_scenario(key)
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_map(self, key):
+ raise NotImplementedError
+
+ @partial(jax.jit, static_argnums=[0])
+ def sample_test_case(self, key):
+ """ Sample a test case for a given map """
+ key, _key = jax.random.split(key)
+ map_data = self.sample_map(_key)
+ radii = jnp.array([self.rad*self.start_pad, self.goal_radius])
+
+ def _sample_pair(key: chex.PRNGKey):
+ """ Sample a start and goal pose for an agent """
+ key_s, key_g, key_t = jax.random.split(key, 3)
+ low_lim = 1 + self.rad
+ high_lim = self.map_size[1] - 1 - self.rad
+ start = jax.random.uniform(key_s, (1, 2), minval=low_lim, maxval=high_lim)
+ g_low_lim = jnp.clip(start - self.dist_to_goal, low_lim, high_lim)
+ g_high_lim = jnp.clip(start + self.dist_to_goal, low_lim, high_lim)
+ goal = jax.random.uniform(key_g, (1, 2), minval=g_low_lim, maxval=g_high_lim)
+ theta = jax.random.uniform(key_t, (2, 1), minval=-jnp.pi, maxval=jnp.pi)
+ positions = jnp.concatenate([start, goal], axis=0)
+ poses = jnp.concatenate([positions, theta], axis=1)
+ return poses
+
+ def _agent_collision(pos, test_case, rad):
+ dists = jnp.linalg.norm(test_case-pos, axis=1) <= rad*2
+ return jnp.any(dists)
+
+ def _cond_idx(val):
+ """ true while i is less than the number of agents """
+ key, i, case = val
+ return i < case.shape[0]
+
+ def _body_idx(val, key):
+ """ samples a start and goal pair for an agent, taking into account pairs sampled previously """
+ i, case = val
+ def _cond_pos(val):
+ """ Check if the sampled pair is valid. Checks
+ 1. check start and goal do not collide with the map
+ 3. check the start is not within a radius with other agents' starts
+ 4. check the start is not within a radius with other agents' starts
+ 5. check if start and goal are too close
+ """
+
+ key, pair, case = val
+ temp_case = case.at[i].set(pair+self.rad*3) # ensure no conflict
+
+ map_collisions = jnp.any(jax.vmap(self.check_circle_map_collision, in_axes=[0, None, 0])(pair[:, :2], map_data, radii))
+ agent_collisions = jnp.any(jax.vmap(_agent_collision, in_axes=[0, 1, None])(pair[:, :2], temp_case[:, :, :2], self.rad*self.start_pad))
+
+ too_close = (jnp.linalg.norm(pair[0, :2] - pair[1, :2]) <= 2*self.rad).astype(jnp.bool_)
+
+ check = map_collisions | agent_collisions | too_close
+
+ if self.valid_path_check:
+ valid_path = self.passable_check(pair[0, :2], pair[1, :2], map_data) # WARNING can make code too slow
+ check = check | ~valid_path
+
+ return check
+
+ # return map_collisions | agent_collisions | too_close | ~valid_path
+
+ #print('p map', jnp.any(pmap_collision(pair, map_grid, rad)))
+ #jax.debug.print('p {pair} map {p}, s ag {s}, g ag {g}, dist f{d} dist {c}', pair=pair, p=jnp.any(pmap_collision(pair, map_grid, rad)), s=agent_collision(pair[0], temp_case[:, 0, :], rad), g=agent_collision(pair[1], temp_case[:, 1, :], rad), d=jnp.linalg.norm(pair[0] - pair[1]) >= dist_to_goal, c=(jnp.linalg.norm(pair[0] - pair[1]) <= 2*rad).astype(jnp.bool_))
+ """ true while pos is not valid """
+ # 1. check if start collides with map
+ # 2. check if goal collides with map
+ # 3. check if start collides with other agents
+ # 4. check if goal collides with other agents's goals
+ # 5. check if start and goal are too close
+
+ return self.check_agent_map_collision(pair[0], map_data, self.rad*self.start_pad) \
+ | self.check_agent_map_collision(pair[1], map_data) \
+ | _agent_collision(pair[0], temp_case[:, 0, :], self.rad) \
+ | _agent_collision(pair[1], temp_case[:, 1, :], self.rad) \
+ | (jnp.linalg.norm(pair[0] - pair[1]) >= self.dist_to_goal).astype(jnp.bool_) \
+ | (jnp.linalg.norm(pair[0] - pair[1]) <= 2*self.rad).astype(jnp.bool_) #\ | ~(self.rrt(key_rrt, map_data, pair[0], pair[1]))
+
+ def _body_pos(val):
+ """ Sample a start and goal pair """
+ key, pair, case = val
+ key, key_point = jax.random.split(key)
+ pair = _sample_pair(key_point)
+ case = case.at[i].set(pair)
+ #jax.debug.print('case {c}', c=case)
+ return key, pair, case
+
+ key, key_point = jax.random.split(key)
+ pair = _sample_pair(key_point)
+ case = case.at[i].set(pair)
+
+ key, pair, case = jax.lax.while_loop(
+ _cond_pos,
+ _body_pos,
+ (key, pair, case),
+ )
+
+ i += 1
+ return (i, case), None
+
+ fill_max = jnp.max(jnp.array(self.map_size)) + self.rad*2
+ case = jnp.full((self.num_agents, 2, 3), fill_max) # [num_agents, [start_pose, goal_pose]]
+
+ i = 0
+
+ key_scan = jax.random.split(key, self.num_agents)
+ (_, case), _ = jax.lax.scan(_body_idx, (i, case), key_scan)
+
+ # Add intial orientation
+ #theta = jax.random.uniform(key, (self.num_agents, 2, 1), minval=-jnp.pi, maxval=jnp.pi)
+ #case = jnp.concatenate([case, theta], axis=-1)
+ return map_data, case
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_circle_map_collision(self, pos, map_data, rad=None):
+ """ Check collision between a circle at position `pos` of radius `rad` and the map.
+ If rad is None, use the class rad. """
+ raise NotImplementedError
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_agent_map_collision(self, pos, theta, map_data, **agent_kwargs):
+ """ Check collision between an agent at position `pos` and the map"""
+ # NOTE should we switch these functions to use pose, i.e. [pos_x, pos_y, theta]?
+ raise NotImplementedError
+
+ def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, **agent_kwargs):
+ """ Check for intersection between a lidar beam and an agent. """
+ raise NotImplementedError
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_point_map_collision(self, pos, map_data):
+ """ Check collision between `pos` and the map"""
+ raise NotImplementedError
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_line_collision(self, pos1, pos2, map_data):
+ """ Check collision between line (pos1) -- (pos2) the map"""
+ raise NotImplementedError
+
+ def passable_check(self, pos1, pos2, map_data):
+ """ Check whether a path exists between pos1 and pos2.
+ Note, this does not return the path, only whether a path exists. """
+ raise NotImplementedError
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_agent_translation(self, start, end, map_data):
+ """ True for valid translation, False for invalid translation """
+
+ l_slope = (end[1] - start[1]) / (end[0] - start[0])
+ perpendicular_slope = -1 / l_slope
+ delta_x = self.rad / (1 + perpendicular_slope**2)**0.5
+ delta_y = perpendicular_slope * delta_x
+ #jax.debug.print('per slope {p}', p=perpendicular_slope)
+ # Calculate the coordinates of the two points
+ delta = jnp.array([delta_x, delta_y])
+ start_lower = start + delta
+ start_upper = start - delta
+ end_lower = end + delta
+ end_upper = end - delta
+
+ lower = self.check_line_collision(start_lower, end_lower, map_data)
+ upper = self.check_line_collision(start_upper, end_upper, map_data)
+ #jax.debug.print('lower {l}, upper {u}', l=lower, u=upper)
+ return ~lower & ~upper
+
+
+ @partial(jax.jit, static_argnums=[0])
+ def rrt(self, key, map_data, start, goal):
+ """ Run RRT algorithm to find a path between start and goal """
+
+ INF = 10000
+
+ print('key')
+ low_lim = 1 + self.rad
+ high_lim = self.map_size[1] - 1 - self.rad
+ goal_square = jnp.floor(goal).astype(jnp.int32).squeeze()
+ #print('goal square', goal_square)
+
+ tree = jnp.empty((self.rrt_samples, 3)) # [sample, [x, y, parent]]
+ tree = tree.at[0].set(jnp.append(start, -1))
+ rrt_idx = 1
+ gr = False
+
+ def _cond_fun(val):
+ i, gr = val[0], val[1]
+ return (i < self.rrt_samples) & ~gr
+
+ def _body_fun(val):
+ i, gr, rrt_idx, tree, key = val
+ key, key_s = jax.random.split(key)
+ # Sample position, find closest idx and increment towards sampled pos
+ sampled_pos = jax.random.uniform(key_s, (2,), minval=low_lim, maxval=high_lim)
+ #closest_idx = jnp.argmin(jnp.linalg.norm(tree[:, :2] - sampled_pos, axis=1))
+ distances = jnp.linalg.norm(tree[:, :2] - sampled_pos, axis=1)
+ distance_mask = jnp.where(jnp.arange(self.rrt_samples) < rrt_idx, 0, 1) * INF
+ closest_idx = jnp.argmin(distances + distance_mask)
+
+ tree_pos = tree[closest_idx, :2]
+ step_size = jnp.minimum(self.rrt_step_size, jnp.linalg.norm(sampled_pos - tree_pos))
+
+ angle = jnp.arctan2(sampled_pos[1] - tree_pos[1], sampled_pos[0] - tree_pos[0])
+ test_pos = tree_pos + jnp.array([jnp.cos(angle), jnp.sin(angle)]) * step_size
+
+ # Check free space, line collision
+ free_space = ~self.check_agent_map_collision(test_pos, map_data)
+ line_collision = self.check_agent_translation(tree_pos, test_pos, map_data)
+ valid = free_space & line_collision
+
+ #goal_square_reached = jnp.array_equal(jnp.floor(test_pos), goal_square)
+ goal_reached = jnp.linalg.norm(test_pos - goal) < self.goal_radius
+ new_node = jax.lax.select(valid, jnp.concatenate([test_pos, jnp.array([closest_idx])]), jnp.zeros((3,)))
+ tree = tree.at[rrt_idx].set(new_node)
+
+ rrt_idx += 1*valid
+ gr = gr | (goal_reached & valid)
+ return (i+1, gr, rrt_idx, tree, key)
+
+ val = jax.lax.while_loop(_cond_fun, _body_fun, (0, gr, rrt_idx, tree, key))
+ gr, tree = val[1], val[3]
+
+ return tree, gr
+
+ @partial(jax.jit, static_argnums=[0])
+ def rrt_star(self, key, map_data, start, goal):
+ """ Run RRT* algorithm to find an optimal path between start and goal """
+
+ INF = 10000
+
+ low_lim = 1 + self.rad
+ high_lim = self.map_size[1] - 1 - self.rad
+ goal_square = jnp.floor(goal).astype(jnp.int32).squeeze()
+
+ check_point_connections = jax.vmap(self.check_agent_translation, in_axes=(None, 0, None))
+
+ tree = jnp.empty((self.rrt_samples, 5)) # [sample, [x, y, parent, cost, goal]]
+ #print('start',start)
+ tree = tree.at[0].set(jnp.append(start, jnp.array([-1, 0.0, 0.0])))
+ rrt_idx = 1
+ goal_reached = False
+ goal_idx = jnp.full((30,), -1, dtype=jnp.int32)
+
+ rrt_star_neighbours = 20
+
+ def _cond_fun(val):
+ i, goal_reached = val[0], val[1]
+ return (i < self.rrt_samples) # & ~goal_reached
+
+ def _body_fun(val):
+ #jax.debug.print('body start')
+ i, goal_reached, rrt_idx, tree, key = val
+ key, key_s = jax.random.split(key)
+ #jax.debug.print('rrt idx {r}', r=rrt_idx)
+ # Sample position, find closest idx and increment towards sampled pos
+ sampled_pos = jax.random.uniform(key_s, (2,), minval=low_lim, maxval=high_lim)
+ distances = jnp.linalg.norm(tree[:, :2] - sampled_pos, axis=1)
+ distance_mask = jnp.where(jnp.arange(self.rrt_samples) < rrt_idx, 0, 1) * INF
+
+ closest_idx = jnp.argmin(distances + distance_mask)
+ #jax.debug.print('closest idx {c}, d {d}', c=closest_idx, d=distances + distance_mask)
+ tree_pos = tree[closest_idx, :2]
+ step_size = jnp.minimum(self.rrt_step_size, jnp.linalg.norm(sampled_pos - tree_pos))
+
+ angle = jnp.arctan2(sampled_pos[1] - tree_pos[1], sampled_pos[0] - tree_pos[0])
+ test_pos = tree_pos + jnp.array([jnp.cos(angle), jnp.sin(angle)]) * step_size
+
+ # Check free space, line collision
+ free_space = ~self.check_agent_map_collision(test_pos, map_data)
+ circle_trans = self.check_agent_translation(test_pos, tree_pos, map_data) # NOTE do we need this for our valid check?
+ valid = free_space & circle_trans
+ #goal_reached = jnp.array_equal(jnp.floor(test_pos), goal_square)
+ goal_just_reached = jnp.linalg.norm(test_pos - goal) < self.goal_radius
+ #jax.debug.print('valid {v}, free space {f}, circle trans {c}', v=valid, f=free_space, c=circle_trans)
+
+ # Find parent
+ tree_dist = jnp.linalg.norm(tree[:, :2] - test_pos, axis=1) # todo correct for zeros
+ distance_mask = jnp.where(jnp.arange(self.rrt_samples) < rrt_idx, 0, 1) * INF
+ #jax.debug.print('tree dist {t}', t=tree_dist.shape)
+ parent_poss = jnp.argsort(tree_dist+distance_mask)[:rrt_star_neighbours]
+
+ #jax.debug.print('parent poss {p}', p=parent_poss)
+ in_range = jnp.where(parent_poss < rrt_idx, True, False)
+ invalid_parent = ~check_point_connections(test_pos, tree[parent_poss, :2], map_data) | ~in_range
+ #print('invalid parent', invalid_parent)
+ #jax.debug.print('invalid parent {i}, not in range {r}', i=invalid_parent, r=~in_range)
+ invalid_parent = invalid_parent | ~in_range
+ parent_cost = tree[parent_poss, 3] + tree_dist[parent_poss] + INF*invalid_parent
+ parent_rel_idx = jnp.argmin(parent_cost)
+ parent_idx = parent_poss[parent_rel_idx]
+ cost = jnp.min(parent_cost)
+ valid = valid & ~invalid_parent[parent_rel_idx]
+ #jax.debug.print('valid {v}, parent valid {p}', v=valid, p=~invalid_parent[parent_rel_idx])
+ new_node = jax.lax.select(valid, jnp.concatenate([test_pos, jnp.array([parent_idx, cost, goal_just_reached.astype(jnp.int32)])]), jnp.zeros(tree.shape[1]))
+ #jax.debug.print('new node {n}, goal reached {g}', n=new_node, g=goal_just_reached)
+ tree = tree.at[rrt_idx].set(new_node)
+
+ # rewire
+ invalid_child = invalid_parent.at[parent_rel_idx].set(True)
+
+ new_cost = cost + tree_dist[parent_poss] + INF*invalid_child
+ rewire = (new_cost < tree[parent_poss, 3]) & valid
+
+ new_nodes = jnp.where(rewire[:, None], jnp.concatenate([tree[parent_poss, :2], jnp.full((rrt_star_neighbours, 1), rrt_idx), new_cost[:, None], tree[parent_poss, 4][:, None]], axis=1), tree[parent_poss])
+ #jax.debug.print('new nodes {n}', n=new_nodes)
+ tree = tree.at[parent_poss].set(new_nodes)
+
+ new_goal_node = goal_just_reached & valid
+
+ rrt_idx += 1*valid
+ goal_reached = goal_reached | new_goal_node
+ return (i+1, goal_reached, rrt_idx, tree, key)
+
+ val = jax.lax.while_loop(_cond_fun, _body_fun, (0, goal_reached, rrt_idx, tree, key))
+ goal_reached, tree = val[1], val[3]
+
+ return tree, goal_reached
+
+ def plot_rrt_tree(self, ax, tree, goal_reached=False, goal=None, rrt_reward=None, name="env"):
+
+ for n in range(tree.shape[0]):
+ if tree[n, 0] == 0.0: break
+ ax.scatter(tree[n, 0], tree[n, 1], c='gray')
+ parent_idx = tree[n, 2]
+ if parent_idx == -1: continue
+ parent_idx = int(parent_idx)
+ ax.plot([tree[n, 0], tree[parent_idx, 0]], [tree[n, 1], tree[parent_idx, 1]], c='gray', marker='+', alpha=0.75)
+
+
+ if goal_reached:
+ if rrt_reward is not None:
+ assert goal is not None
+ rewards = []
+
+ path_lengths = []
+ goal_idx = jnp.argwhere(tree[:,-1]==1)
+
+ for g_idx in goal_idx:
+ c_idx = g_idx[0]
+ path_length = 0.0
+ if rrt_reward is not None: rew = 0.0
+ while c_idx != 0:
+ #print('cidx', c_idx, 'tree row', tree[c_idx])
+ c_pos = tree[c_idx, :2]
+ p_idx = int(tree[c_idx, 2])
+ p_pos = tree[p_idx, :2]
+ path_length += jnp.linalg.norm(c_pos - p_pos)
+ if rrt_reward is not None: rew += rrt_reward(c_pos, p_pos, goal)
+ ax.plot([c_pos[0], p_pos[0]], [c_pos[1], p_pos[1]], c='r', alpha=0.25)
+ c_idx = p_idx
+ path_lengths.append(path_length)
+ #print('path_length:', path_length)
+ if rrt_reward is not None:
+ #print('reward:', rew)
+ rewards.append(rew)
+
+ if rrt_reward is not None:
+ max_rew_idx = jnp.argmax(jnp.array(rewards))
+ print(name, ' max reward:', rewards[max_rew_idx], 'path length:', path_lengths[max_rew_idx])
+ else:
+ print(name, ' min path length:', jnp.min(jnp.array(path_lengths)))
+
+ def plot_map(self,
+ ax: axes.Axes,
+ map_data: jnp.ndarray,) -> None:
+ raise NotImplementedError
+
+ def plot_agents(
+ self,
+ ax: axes.Axes,
+ pos: jnp.ndarray,
+ theta: jnp.ndarray,
+ goal: jnp.ndarray,
+ done: jnp.ndarray,
+ plot_line_to_goal=True,
+ colour_agents_by_idx=False,
+ ) -> None:
+ raise NotImplementedError
+
+ def plot_agent_path(
+ self,
+ ax: axes.Axes,
+ x_seq: jnp.ndarray,
+ y_seq: jnp.ndarray,
+ ) -> None:
+ raise NotImplementedError
+
+
+
+
+
diff --git a/jaxmarl/environments/jaxnav/maps/map_registration.py b/jaxmarl/environments/jaxnav/maps/map_registration.py
new file mode 100644
index 00000000..f181911a
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/maps/map_registration.py
@@ -0,0 +1,25 @@
+from .grid_map import GridMapCircleAgents, GridMapPolygonAgents, GridMapBarn, GridMapPolygonAgentsSingleMap, GridMapFromBuffer
+from .map import Map
+
+def make_map(map_id: str, num_agents: int, rad: float, **map_kwargs) -> GridMapCircleAgents: # note this type hint technically should be Map
+
+ if map_id not in registered_maps:
+ raise ValueError(f"Map: {map_id} not registered!")
+ if map_id == "Grid-Rand":
+ return GridMapCircleAgents(num_agents=num_agents, rad=rad, **map_kwargs)
+ if map_id == "Grid-Rand-Poly":
+ return GridMapPolygonAgents(num_agents=num_agents, rad=rad, **map_kwargs)
+ if map_id == "Grid-Rand-Barn":
+ return GridMapBarn(num_agents=num_agents, rad=rad, **map_kwargs)
+ if map_id == "Grid-Rand-Poly-Single":
+ return GridMapPolygonAgentsSingleMap(num_agents=num_agents, rad=rad, **map_kwargs)
+ if map_id == "Grid-Buffer":
+ return GridMapFromBuffer(num_agents=num_agents, rad=rad, **map_kwargs)
+
+registered_maps = [
+ "Grid-Rand",
+ "Grid-Rand-Poly",
+ "Grid-Rand-Barn",
+ "Grid-Rand-Poly-Single",
+ "Grid-Buffer",
+]
\ No newline at end of file
diff --git a/jaxmarl/environments/jaxnav/maps/polygon_map.py b/jaxmarl/environments/jaxnav/maps/polygon_map.py
new file mode 100644
index 00000000..7732933c
--- /dev/null
+++ b/jaxmarl/environments/jaxnav/maps/polygon_map.py
@@ -0,0 +1,224 @@
+"""
+NOT USED IN A LONG TIME
+
+collision code drawn from: https://www.jeffreythompson.org/collision-detection/poly-circle.php
+"""
+
+import jax
+import jax.numpy as jnp
+from functools import partial
+
+from .map import Map
+from jaxmarl.environments.jaxnav.jaxnav_utils import rot_mat
+
+class PolygonMap(Map):
+
+ def __init__(self,
+ num_agents,
+ rad,
+ map_size,
+ num_sides=4,
+ max_num_shapes=40,
+ min_edge_length=0.2,
+ max_edge_length=4.0):
+ super().__init__(num_agents, rad, map_size)
+ print('Super old code beware')
+ self.num_sides = num_sides
+ self.max_num_shapes = max_num_shapes
+ self.min_edge_length = min_edge_length
+ self.max_edge_length = max_edge_length
+
+ if self.num_sides != 4:
+ raise NotImplementedError('Only 4-sided polygons are currently supported')
+ self.shape_fn = self._sample_parrallelogram
+
+ ### === MAP GENERATION === ###
+ @partial(jax.jit, static_argnums=[0])
+ def sample_map(self, key):
+ """ Sample polygon map, returns coordinates of vertices """
+ key_num, key_coords = jax.random.split(key)
+
+ # rectangle for map bounds
+ bounds = jnp.array([
+ [[0.0, 0.15], [self.map_size[0], 0.15], [self.map_size[0], 0.0], [0.0, 0.0]],
+ [[0.15, 0.0], [0.15, self.map_size[1]], [0.0, self.map_size[1]], [0.0, 0.0]],
+ [[0.0, self.map_size[1]], [self.map_size[0], self.map_size[1]], [self.map_size[0], self.map_size[1]-0.15], [0.0, self.map_size[1]-0.15]],
+ [[self.map_size[0]-0.15, 0.0], [self.map_size[0]-0.15, self.map_size[1]], [self.map_size[0],self.map_size[1]], [self.map_size[0], 0.0]],
+ ])
+
+
+ # out of bounds constant
+ oobounds = jnp.full((self.num_sides, 2), 1.0, dtype=jnp.float32) + jnp.array([self.map_size[0], self.map_size[1]])
+
+ num_shapes = jax.random.randint(key_num, (1,), minval=0, maxval=self.max_num_shapes)
+ p_idx = jnp.arange(0, self.max_num_shapes)
+ p_mask = jnp.where(p_idx<=num_shapes, 0, 1)
+
+ key_coords = jax.random.split(key_coords, self.max_num_shapes)
+ coords = self.shape_fn(key_coords)
+
+ mask = p_mask[:, None, None] * jnp.repeat(oobounds[None], self.max_num_shapes, axis=0)
+ coords = coords + mask
+ return jnp.append(coords, bounds, axis=0)
+
+ ### === COLLISION DETECTION === ###
+ @partial(jax.jit, static_argnums=[0])
+ def check_agent_map_collision(self, pos, coords):
+ """ Check collision between a line and a circle subject to a map boundary """
+ #print('agent check coords shape', coords.shape)
+
+ @partial(jax.vmap, in_axes=(0, 0, None))
+ def _check_ac(p1, p2, pos):
+
+ def _in_map(p1, p2, pos, r):
+ return (p1[0] >= 0) & (p1[0] <= self.map_size[0]) & (p1[1] >= 0) & (p1[1] <= self.map_size[1]) & \
+ (p2[0] >= 0) & (p2[0] <= self.map_size[0]) & (p2[1] >= 0) & (p2[1] <= self.map_size[1])
+
+ return jax.lax.cond(_in_map(p1, p2, pos, self.rad), lambda _: line_circle_collision(p1, p2, pos, self.rad), lambda _: False, None)
+
+ edges = self.gen_edges(coords)
+ return jnp.any(_check_ac(edges[:, 0], edges[:, 1], pos))
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_map_collision(self, pos, coords, radius):
+ """ For a circle agent """
+ l = self.check_agent_map_collision(pos, coords)
+ i = jnp.any(self.check_point_map_collision(pos, coords))
+ return l | i
+
+ # NOTE need to add ray tracing for lidar - constrains size of usable polygon
+
+ @partial(jax.jit, static_argnums=[0])
+ def check_point_map_collision(self, pos, coords):
+ """ Check if a point is inside a polygon within a map,
+ NOTE: all map coords should be passed"""
+
+ @partial(jax.vmap, in_axes=[None, 0, 0])
+ def _vc(pos, vc, vn):
+ c = (((vc[1] > pos[1]) & (vn[1] < pos[1])) |\
+ ((vc[1] < pos[1]) & (vn[1] > pos[1]))) & \
+ (pos[0] < (vn[0] - vc[0]) * (pos[1] - vc[1]) / (vn[1] - vc[1]) + vc[0])
+
+ return c
+
+ coords = jnp.concatenate((coords, coords[:, 0][:, None]), axis=1)
+ return jnp.sum(_vc(pos, coords[:, :-1].reshape(-1, 2), coords[:, 1:].reshape(-1, 2))) % 2 != 0
+
+ ### === UTILS === ###
+ def gen_edges(self, coords):
+ """ Generate edges from coordinates and flatten array"""
+ coordsp = jnp.append(coords, coords[:, 0].reshape(-1, 1, 2), axis=1)
+ return jnp.column_stack((coordsp[:, :-1], coordsp[:, 1:])).reshape((-1, 2, 2))
+
+
+ @partial(jax.vmap, in_axes=[None, 0])
+ def _sample_parrallelogram(self, key):
+ """ Sample a parralelogram from a uniform distribution """
+
+ # Sample parralelogram parameters
+ key_c, key_o, key_s, key_a = jax.random.split(key, 4)
+ centre = jax.random.uniform(key_c, (1, 2), minval=0.0, maxval=self.map_size[1])
+ orientation = jax.random.uniform(key_o, (1,), minval=-jnp.pi, maxval=jnp.pi)
+ a, b = jax.random.uniform(key_s, (2,), minval=self.min_edge_length, maxval=self.max_edge_length)
+ alpha = jax.random.uniform(key_a, (1,), minval=0.1, maxval=jnp.pi)
+
+ # Calculate diagonals & angles
+ q = jnp.sqrt(a**2 + b**2 + 2*a*b*jnp.cos(alpha))
+ p = jnp.sqrt(a**2 + b**2 - 2*a*b*jnp.cos(alpha))
+ beta = jnp.arcsin(jnp.sin(alpha)*a/p)
+ gamma = jnp.arcsin(jnp.sin(beta)/(q/2)*(p/2))
+
+ # Calculate coordinates
+ d = - p/2 * jnp.array([jnp.cos(beta), jnp.sin(-beta)]).reshape((1,2))
+ b = + p/2 * jnp.array([jnp.cos(-beta), jnp.sin(-beta)]).reshape((1,2))
+ c = + q/2 * jnp.array([jnp.cos(gamma), jnp.sin(gamma)]).reshape((1,2))
+ a = - q/2 * jnp.array([jnp.cos(gamma), jnp.sin(gamma)]).reshape((1,2))
+
+ return jnp.dot(jnp.array([a, b, c, d]), rot_mat(orientation)).squeeze() + centre
+
+ ### === VISUALISATION === ###
+ def plot_map(self, ax, coord):
+ coord = jnp.append(coord, coord[:, 0].reshape(-1, 1, 2), axis=1)
+ for c in coord:
+ xs, ys = zip(*c)
+ ax.plot(xs, ys, color='black')
+
+
+### === COLLISION DETECTION UTILS === ### NOTE likely a better file for these
+@partial(jax.vmap, in_axes=(0, 0, None, None, None))
+def line_circle_map_collision(p1, p2, c, r, map_size):
+ """ Check collision between a line and a circle subject to a map boundary """
+
+ def _in_map(p1, p2, c, r):
+ return (p1[0] >= 0) & (p1[0] < map_size[0]) & (p1[1] >= 0) & (p1[1] < map_size[1]) & \
+ (p2[0] >= 0) & (p2[0] < map_size[0]) & (p2[1] >= 0) & (p2[1] < map_size[1])
+
+ return jax.lax.cond(_in_map(p1, p2, c, r), lambda _: line_circle_collision(p1, p2, c, r), lambda _: False, None)
+
+
+def line_circle_collision(p1, p2, c, r):
+ """ Check collision between a line and a circle """
+
+ cc = point_circle_collision(p1, c, r) | point_circle_collision(p2, c, r)
+
+ def _on_segement(p1, p2, c, r):
+ d1 = jnp.linalg.norm(p1 - c)
+ d2 = jnp.linalg.norm(p2 - c)
+ dl = jnp.linalg.norm(p1 - p2)
+
+ return (d1+d2 >= dl-r) & (d1+d2 <= dl+r)
+
+ def _line_circle(p2, p1, c, r):
+
+ dx, dy = p2 - p1
+ len = jnp.sqrt(dx**2 + dy**2)
+ dot = ((c[0] - p1[0]) * dx + (c[1] - p1[1]) * dy) / len**2
+ closest = jnp.array([p1[0] + dot * dx, p1[1] + dot * dy])
+
+ # check if closest on line segment
+ on_seg = _on_segement(p1, p2, c, r)
+
+ dist = closest - c
+ return (jnp.sqrt(dist[0]**2 + dist[1]**2) <= r) & (on_seg)
+
+ return jax.lax.cond(cc, lambda _: True, lambda _: _line_circle(p1, p2, c, r), None)
+
+def point_circle_collision(p, c, r):
+ """ Check collision between a point and a circle """
+ dist = p - c
+ return jnp.sqrt(dist[0]**2 + dist[1]**2) <= r
+
+
+if __name__=="__main__":
+ key = jax.random.PRNGKey(10)
+
+ map_size = (10, 10)
+ rad = 0.3
+
+ map_gen = PolygonMap(2, rad, map_size)
+ map_coords = map_gen.sample_map(key)
+
+ print('map coords', map_coords[0])
+
+ pos = jnp.array([5.3, 6.5])
+ #c = map_gen.check_point_collision(pos, map_coords[0])
+ c = map_gen.check_agent_map_collision(pos, map_coords)
+ print('c', c)
+ c = map_gen.check_map_collision(pos, map_coords, None)
+ print('c', c)
+
+ import matplotlib.pyplot as plt
+ # from jax_multirobsim.env.sample_cases.create_sample_cases import jax_sample_case
+
+ # case = jax_sample_case(key, 2, 0.3, map_size, map_coords, map_fn=map_gen.check_map_collision)
+
+ # fig, ax = plt.subplots()
+
+ # map_gen.plot_map(ax, map_coords)
+ # #plot_sample_case()
+ # ax.scatter(pos[0], pos[1], color='red')
+
+ # plt.show()
+
+
+
\ No newline at end of file
diff --git a/jaxmarl/environments/multi_agent_env.py b/jaxmarl/environments/multi_agent_env.py
index a4940171..e7bc19be 100644
--- a/jaxmarl/environments/multi_agent_env.py
+++ b/jaxmarl/environments/multi_agent_env.py
@@ -44,13 +44,19 @@ def step(
key: chex.PRNGKey,
state: State,
actions: Dict[str, chex.Array],
+ reset_state: Optional[State] = None,
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
- """Performs step transitions in the environment."""
+ """Performs step transitions in the environment. Resets the environment if done.
+ To control the reset state, pass `reset_state`. Otherwise, the environment will reset randomly."""
key, key_reset = jax.random.split(key)
obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions)
- obs_re, states_re = self.reset(key_reset)
+ if reset_state is None:
+ obs_re, states_re = self.reset(key_reset)
+ else:
+ states_re = reset_state
+ obs_re = self.get_obs(states_re)
# Auto-reset environment based on termination
states = jax.tree_map(
@@ -79,6 +85,11 @@ def action_space(self, agent: str):
"""Action space for a given agent."""
return self.action_spaces[agent]
+ @partial(jax.jit, static_argnums=(0,))
+ def get_avail_actions(self, state: State) -> Dict[str, chex.Array]:
+ """Returns the available actions for each agent."""
+ raise NotImplementedError
+
@property
def name(self) -> str:
"""Environment name."""
diff --git a/jaxmarl/registration.py b/jaxmarl/registration.py
index eeeafd97..c96c6b70 100644
--- a/jaxmarl/registration.py
+++ b/jaxmarl/registration.py
@@ -26,6 +26,7 @@
Hanabi,
Overcooked,
CoinGame,
+ JaxNav,
)
@@ -105,6 +106,10 @@ def make(env_id: str, **env_kwargs):
# 8. Coin Game
elif env_id == "coin_game":
env = CoinGame(**env_kwargs)
+
+ # 9. JaxNav
+ elif env_id == "jaxnav":
+ env = JaxNav(**env_kwargs)
return env
@@ -136,4 +141,5 @@ def make(env_id: str, **env_kwargs):
"hanabi",
"overcooked",
"coin_game",
+ "jaxnav",
]
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 9f9c12ae..2e0fd4b8 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -23,3 +23,4 @@ matplotlib>=3.8.3
pillow>=10.2.0
pettingzoo>=1.24.3
tqdm>=4.66.0
+scipy<=1.12
diff --git a/tests/baselines/test_ippo_mabrax.py b/tests/baselines/test_ippo_mabrax.py
new file mode 100644
index 00000000..10624055
--- /dev/null
+++ b/tests/baselines/test_ippo_mabrax.py
@@ -0,0 +1,15 @@
+import subprocess
+import sys
+import os
+
+def run_script(script_path, *args):
+ result = subprocess.run([sys.executable, script_path, *args], capture_output=True, text=True)
+ return result
+
+def test_script_with_arguments():
+ script_path = os.path.join('baselines/IPPO/ippo_ff_mabrax.py')
+ result = run_script(script_path, 'TOTAL_TIMESTEPS=1e4', 'WANDB_MODE=disabled')
+ assert result.returncode == 0
+
+
+test_script_with_arguments()
\ No newline at end of file
diff --git a/tests/baselines/test_mappo_smax.py b/tests/baselines/test_mappo_smax.py
new file mode 100644
index 00000000..19872876
--- /dev/null
+++ b/tests/baselines/test_mappo_smax.py
@@ -0,0 +1,14 @@
+import subprocess
+import sys
+import os
+
+def run_script(script_path, *args):
+ result = subprocess.run([sys.executable, script_path, *args], capture_output=True, text=True)
+ return result
+
+def test_script_with_arguments():
+ script_path = os.path.join('baselines/MAPPO/mappo_rnn_smax.py')
+ result = run_script(script_path, 'TOTAL_TIMESTEPS=1e4', 'WANDB_MODE=disabled')
+ assert result.returncode == 0
+
+test_script_with_arguments()
\ No newline at end of file
diff --git a/tests/baselines/test_qmix_mpe.py b/tests/baselines/test_qmix_mpe.py
new file mode 100644
index 00000000..cb933ab1
--- /dev/null
+++ b/tests/baselines/test_qmix_mpe.py
@@ -0,0 +1,15 @@
+import subprocess
+import sys
+import os
+
+def run_script(script_path, *args):
+ result = subprocess.run([sys.executable, script_path, *args], capture_output=True, text=True)
+ return result
+
+def test_script_with_arguments():
+ script_path = os.path.join('baselines/QLearning/qmix.py')
+ result = run_script(script_path, '+alg=qmix_mpe', '+env=mpe_spread','alg.TOTAL_TIMESTEPS=1e4', 'WANDB_MODE=disabled')
+ assert result.returncode == 0
+
+
+test_script_with_arguments()
\ No newline at end of file
diff --git a/tests/brax/test_brax_rand_acts.py b/tests/brax/test_brax_rand_acts.py
new file mode 100644
index 00000000..43671532
--- /dev/null
+++ b/tests/brax/test_brax_rand_acts.py
@@ -0,0 +1,25 @@
+"""
+Check that the environment can be reset and stepped with random actions.
+TODO: replace this with proper unit tests.
+"""
+import jax
+# import pytest
+
+from jaxmarl.environments.mabrax import MABraxEnv
+
+env = MABraxEnv("ant_4x2")
+
+def test_random_rollout():
+
+
+
+ rng = jax.random.PRNGKey(0)
+ rng, rng_reset = jax.random.split(rng)
+
+ _, state = env.reset(rng_reset)
+
+ for _ in range(10):
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, state, _, _, _ = env.step(rng, state, actions)
diff --git a/tests/coin_game/test_coin_game_rand_acts.py b/tests/coin_game/test_coin_game_rand_acts.py
new file mode 100644
index 00000000..d64476c6
--- /dev/null
+++ b/tests/coin_game/test_coin_game_rand_acts.py
@@ -0,0 +1,29 @@
+"""
+Check that the environment can be reset and stepped with random actions.
+TODO: replace this with proper unit tests.
+"""
+import jax
+# import pytest
+
+from jaxmarl.environments.coin_game import CoinGame
+
+env = CoinGame()
+
+def test_random_rollout():
+
+
+
+ rng = jax.random.PRNGKey(0)
+ rng, rng_reset = jax.random.split(rng)
+
+ _, state = env.reset(rng_reset)
+
+ for _ in range(10):
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, state, _, _, _ = env.step(rng, state, actions)
+
+
+
+
\ No newline at end of file
diff --git a/tests/jaxnav/test_jaxnav_gridpolymap.py b/tests/jaxnav/test_jaxnav_gridpolymap.py
new file mode 100644
index 00000000..841fd2ca
--- /dev/null
+++ b/tests/jaxnav/test_jaxnav_gridpolymap.py
@@ -0,0 +1,272 @@
+import jax
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+import pytest
+
+from jaxmarl.environments.jaxnav.maps.grid_map import GridMapPolygonAgents
+
+@pytest.mark.parametrize(
+ ("num_agents", "pos", "theta", "map_data", "cell_size", "disable_jit", "outcome"),
+ [
+ (
+ 1,
+ jnp.array([[1.5, 3.1]]),
+ jnp.array([jnp.pi/4]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ False,
+ True,
+ ),
+ (
+ 1,
+ jnp.array([[1.5, 3.1]]),
+ jnp.array([jnp.pi/4]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ True,
+ True,
+ ),
+ (
+ 1,
+ jnp.array([[3.1, 1.5]]),
+ jnp.array([jnp.pi/4]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ False,
+ False,
+ ),
+ (
+ 1,
+ jnp.array([[3.1, 1.5]]),
+ jnp.array([jnp.pi/4]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ True,
+ False,
+ ),
+ (
+ 1,
+ jnp.array([[1.5, 2.5]]),
+ jnp.array([0.0]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ True,
+ True,
+ ),
+ (
+ 1,
+ jnp.array([[1.5, 2.5]]),
+ jnp.array([0.0]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ False,
+ True,
+ ),
+ (
+ 2,
+ jnp.array([[3.1, 1.5],
+ [1.5, 3.1]]),
+ jnp.array([jnp.pi/4, 0]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ True,
+ jnp.array([False, True]),
+ ),
+ (
+ 2,
+ jnp.array([[3.1, 1.5],
+ [1.5, 3.1]]),
+ jnp.array([jnp.pi/4, 0]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ False,
+ jnp.array([False, True]),
+ ),
+ ]
+)
+def test_square_agent_grid_map_collisions(
+ num_agents,
+ pos,
+ theta,
+ map_data,
+ cell_size,
+ disable_jit: bool,
+ outcome: bool,
+):
+ with jax.disable_jit(disable_jit):
+ map_obj = GridMapPolygonAgents(
+ num_agents=num_agents,
+ rad=0.3,
+ map_size=map_data.shape,
+ cell_size=cell_size,
+ )
+
+ c = jax.vmap(
+ map_obj.check_agent_map_collision,
+ in_axes=(0, 0, None))(
+ pos,
+ theta,
+ map_data,
+ )
+ assert jnp.all(c == outcome)
+
+
+@pytest.mark.parametrize(
+ ("num_agents", "agent_coords", "pos", "theta", "map_data", "cell_size", "disable_jit", "outcome"),
+ [
+ (
+ 1,
+ jnp.array([
+ [-0.25, -0.25],
+ [-0.25, 0.25],
+ [0.25, 0.25],
+ [0.25, -0.25],
+ ]),
+ jnp.array([[1.5, 3.1]]),
+ jnp.array([jnp.pi/4]),
+ jnp.array([
+ [1, 1, 1, 1, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1]
+ ]),
+ 1.0,
+ False,
+ jnp.array([
+ [1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0]
+ ]),
+ ),
+ ]
+)
+def test_square_agent_grid_map_occupancy_mask(
+ num_agents,
+ agent_coords,
+ pos,
+ theta,
+ map_data,
+ cell_size,
+ disable_jit: bool,
+ outcome: jnp.ndarray,
+):
+ with jax.disable_jit(disable_jit):
+ map_obj = GridMapPolygonAgents(
+ num_agents=num_agents,
+ rad=0.3,
+ map_size=map_data.shape,
+ cell_size=cell_size,
+ agent_coords=agent_coords,
+ )
+
+ c = jax.vmap(
+ map_obj.get_agent_map_occupancy_mask,
+ in_axes=(0, 0, None))(
+ pos,
+ theta,
+ map_data,
+ )
+ assert jnp.all(c == outcome)
+
+if __name__=="__main__":
+
+ rng = jax.random.PRNGKey(0)
+
+ num_agents = 1
+ rad = 0.3
+ map_params = {
+ "map_size": (10, 10),
+ "fill": 0.4
+ }
+ pos = jnp.array([[1.5, 3.1]])
+ theta = jnp.array([-jnp.pi/4])
+ goal = jnp.array([[9.5, 9.5]])
+ done = jnp.array([False])
+
+ map_obj = GridMapPolygonAgents(
+ num_agents=num_agents,
+ rad=rad,
+ grid_size=1.0,
+ **map_params
+ )
+
+ map_data = map_obj.sample_map(rng)
+ print('map_data: ', map_data)
+
+ c = map_obj.check_agent_map_collision(
+ pos,
+ theta,
+ map_data,
+ )
+ print('c', c)
+
+ with jax.disable_jit(False):
+ c = map_obj.get_agent_map_occupancy_mask(
+ pos,
+ theta,
+ map_data
+ )
+ print('c', c)
+
+ plt, ax = plt.subplots()
+
+ map_obj.plot_map(ax, map_data)
+ map_obj.plot_agents(ax,
+ pos,
+ theta,
+ goal,
+ done=done,
+ plot_line_to_goal=False)
+
+ plt.savefig('test_map.png')
\ No newline at end of file
diff --git a/tests/jaxnav/test_jaxnav_rand_acts.py b/tests/jaxnav/test_jaxnav_rand_acts.py
new file mode 100644
index 00000000..ce8f350b
--- /dev/null
+++ b/tests/jaxnav/test_jaxnav_rand_acts.py
@@ -0,0 +1,27 @@
+"""
+Check that the environment can be reset and stepped with random actions.
+TODO: replace this with proper unit tests.
+"""
+import jax
+# import pytest
+
+from jaxmarl.environments.jaxnav import JaxNav
+
+env = JaxNav(4)
+
+def test_random_rollout():
+
+ rng = jax.random.PRNGKey(0)
+ rng, rng_reset = jax.random.split(rng)
+
+ _, state = env.reset(rng_reset)
+
+ for _ in range(10):
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, state, _, _, _ = env.step(rng, state, actions)
+
+test_random_rollout()
+
+
\ No newline at end of file
diff --git a/tests/mpe/mpe_policy_transfer.py b/tests/mpe/mpe_policy_transfer.py
index 31176278..6572f69c 100644
--- a/tests/mpe/mpe_policy_transfer.py
+++ b/tests/mpe/mpe_policy_transfer.py
@@ -14,7 +14,7 @@
from jax import numpy as jnp
from jaxmarl import make
from baselines.QLearning.utils import load_params, get_space_dim
-from jaxmarl.baselines.QLearning.iql import AgentRNN, ScannedRNN
+from baselines.QLearning.iql import AgentRNN, ScannedRNN
from pettingzoo.mpe import simple_speaker_listener_v4, simple_spread_v3, simple_adversary_v3
import tqdm
diff --git a/tests/overcooked/test_overcooked_rand_acts.py b/tests/overcooked/test_overcooked_rand_acts.py
new file mode 100644
index 00000000..60514111
--- /dev/null
+++ b/tests/overcooked/test_overcooked_rand_acts.py
@@ -0,0 +1,29 @@
+"""
+Check that the environment can be reset and stepped with random actions.
+TODO: replace this with proper unit tests.
+"""
+import jax
+# import pytest
+
+from jaxmarl.environments.overcooked import Overcooked
+
+env = Overcooked()
+
+def test_random_rollout():
+
+
+
+ rng = jax.random.PRNGKey(0)
+ rng, rng_reset = jax.random.split(rng)
+
+ _, state = env.reset(rng_reset)
+
+ for _ in range(10):
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, state, _, _, _ = env.step(rng, state, actions)
+
+
+
+
\ No newline at end of file
diff --git a/tests/storm/test_storm_rand_acts.py b/tests/storm/test_storm_rand_acts.py
new file mode 100644
index 00000000..65e3a402
--- /dev/null
+++ b/tests/storm/test_storm_rand_acts.py
@@ -0,0 +1,29 @@
+"""
+Check that the environment can be reset and stepped with random actions.
+TODO: replace this with proper unit tests.
+"""
+import jax
+# import pytest
+
+from jaxmarl.environments.storm import InTheGrid
+
+env = InTheGrid()
+
+def test_random_rollout():
+
+
+
+ rng = jax.random.PRNGKey(0)
+ rng, rng_reset = jax.random.split(rng)
+
+ _, state = env.reset(rng_reset)
+
+ for _ in range(10):
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, state, _, _, _ = env.step(rng, state, actions)
+
+
+
+
\ No newline at end of file
diff --git a/tests/test_jaxmarl_api.py b/tests/test_jaxmarl_api.py
new file mode 100644
index 00000000..c19d8105
--- /dev/null
+++ b/tests/test_jaxmarl_api.py
@@ -0,0 +1,38 @@
+"""
+Test auto reseting works as expected
+"""
+import jax
+import jax.numpy as jnp
+from jaxmarl import make
+
+def test_auto_reset_to_specific_state():
+
+ def _test_leaf(x, y, outcome=True):
+ x = jnp.array_equal(x, y)
+ assert x==outcome
+
+ env = make("MPE_simple_spread_v3")
+
+ rng = jax.random.PRNGKey(0)
+ rng, rng_reset1, rng_reset2 = jax.random.split(rng, 3)
+
+ _, state1 = env.reset(rng_reset1)
+ _, state2 = env.reset(rng_reset2)
+ # normal step
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, next_state, _, dones, _ = env.step(rng, state1, actions, reset_state=state2)
+ assert not dones["__all__"]
+ assert not jnp.array_equal(state2.p_pos, next_state.p_pos)
+
+ # auto reset to specific state
+ state1 = state1.replace(
+ step = env.max_steps,
+ )
+ rng, rng_act = jax.random.split(rng)
+ rng_act = jax.random.split(rng_act, env.num_agents)
+ actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
+ _, next_state, _, dones, _ = env.step(rng, state1, actions, reset_state=state2)
+ assert dones["__all__"]
+ jax.tree_map(_test_leaf, state2, next_state)