Skip to content

Commit

Permalink
Merge pull request #6 from instadeepai/fix/reset-step-reward-shape
Browse files Browse the repository at this point in the history
Fix: Keep reward shape and dtype the same when resetting and stepping
  • Loading branch information
arnupretorius authored Jan 16, 2024
2 parents 5da1342 + 6ca1c9a commit 4c5d8aa
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
14 changes: 8 additions & 6 deletions matrax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
agent_obs=agent_obs,
step_count=state.step_count,
)
timestep = restart(observation=observation)
timestep = restart(observation=observation, shape=self.num_agents)
return state, timestep

def step(
Expand All @@ -122,7 +122,7 @@ def compute_reward(
actions: chex.Array, payoff_matrix_per_agent: chex.Array
) -> chex.Array:
reward_idx = tuple(actions)
return payoff_matrix_per_agent[reward_idx]
return payoff_matrix_per_agent[reward_idx].astype(float)

rewards = jax.vmap(functools.partial(compute_reward, actions))(
self.payoff_matrix
Expand All @@ -143,10 +143,12 @@ def compute_reward(

timestep = jax.lax.cond(
done,
termination,
transition,
rewards,
next_observation,
lambda: termination(
reward=rewards, observation=next_observation, shape=self.num_agents
),
lambda: transition(
reward=rewards, observation=next_observation, shape=self.num_agents
),
)

# create environment state
Expand Down
20 changes: 15 additions & 5 deletions matrax/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_matrix_game__reset(matrix_game_env: MatrixGame) -> None:

key1, key2 = random.PRNGKey(0), random.PRNGKey(1)
state1, timestep1 = reset_fn(key1)
state2, timestep2 = reset_fn(key2)
state2, _ = reset_fn(key2)

assert isinstance(timestep1, TimeStep)
assert isinstance(state1, State)
Expand Down Expand Up @@ -111,10 +111,10 @@ def test_matrix_game__step(matrix_game_env_with_state: MatrixGame) -> None:

# Check that rewards have the correct number of dimensions
assert jnp.ndim(timestep1.reward) == 1
assert jnp.ndim(timestep.reward) == 0
assert jnp.ndim(timestep.reward) == 1
# Check that discounts have the correct number of dimensions
assert jnp.ndim(timestep1.discount) == 0
assert jnp.ndim(timestep.discount) == 0
assert jnp.ndim(timestep1.discount) == 1
assert jnp.ndim(timestep.discount) == 1
# Check that the state is made of DeviceArrays, this is false for the non-jitted
# step function since unpacking random.split returns numpy arrays and not device arrays.
assert_is_jax_array_tree(new_state1)
Expand Down Expand Up @@ -157,7 +157,6 @@ def test_matrix_game__reward(matrix_game_env: MatrixGame) -> None:
state, timestep = matrix_game_env.reset(state_key)

state, timestep = step_fn(state, jnp.array([0, 0]))
jax.debug.print("rewards: {r}", r=timestep.reward)
assert jnp.array_equal(timestep.reward, jnp.array([11, 11]))

state, timestep = step_fn(state, jnp.array([1, 0]))
Expand All @@ -174,3 +173,14 @@ def test_matrix_game__reward(matrix_game_env: MatrixGame) -> None:

state, timestep = step_fn(state, jnp.array([2, 2]))
assert jnp.array_equal(timestep.reward, jnp.array([5, 5]))


def test_matrix_game__timesteps_equivalent(matrix_game_env: MatrixGame) -> None:
"""Validate that all timestep attributes have the same dtype and shape over reset and step."""
step_fn = jax.jit(matrix_game_env.step)
state_key = random.PRNGKey(10)
state, init_timestep = matrix_game_env.reset(state_key)

state, new_timestep = step_fn(state, jnp.array([0, 0]))

chex.assert_trees_all_equal_shapes_and_dtypes(init_timestep, new_timestep)

0 comments on commit 4c5d8aa

Please sign in to comment.