Skip to content

Commit

Permalink
Correct a few issues for running externally, and push TF directory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691010806
  • Loading branch information
psc-g committed Oct 29, 2024
1 parent a804d3c commit 02a3547
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 69 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,10 @@ environments you intend to use before you install Dopamine:

**Atari**

1. Install the atari roms following the instructions from
[atari-py](https://github.com/openai/atari-py#roms).
2. `pip install ale-py` (we recommend using a [virtual environment](virtualenv)):
3. `unzip $ROM_DIR/ROMS.zip -d $ROM_DIR && ale-import-roms $ROM_DIR/ROMS`
(replace $ROM_DIR with the directory you extracted the ROMs to).
1. These should now come packaged with
[ale_py](https://github.com/Farama-Foundation/Arcade-Learning-Environment).
1. You may need to manually run some steps to properly install `baselines`, see
[instructions](https://github.com/openai/baselines).

**Mujoco**

Expand Down
1 change: 0 additions & 1 deletion dopamine/discrete_domains/atari_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def create_atari_environment(
full_game_name = f'{game_name}NoFrameskip-{game_version}'
env = legacy_gym.make(full_game_name)
else:
gym.register_envs(ale_py)
gym.register_envs(ale_py)
full_game_name = f'ALE/{game_name}-v5'
repeat_action_probability = 0.25 if sticky_actions else 0.0
Expand Down
15 changes: 0 additions & 15 deletions dopamine/discrete_domains/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from dopamine.jax.agents.ppo import ppo_agent
from dopamine.jax.agents.quantile import quantile_agent as jax_quantile_agent
from dopamine.jax.agents.rainbow import rainbow_agent as jax_rainbow_agent
from dopamine.labs.moes.agents import dqn_moe_agent
from dopamine.labs.moes.agents import full_rainbow_moe_agent
from dopamine.labs.moes.agents import rainbow_100k_moe_agent
from dopamine.metrics import collector_dispatcher
from dopamine.metrics import statistics_instance
from dopamine.tf.agents.dqn import dqn_agent
Expand Down Expand Up @@ -124,18 +121,6 @@ def create_agent(
return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
num_actions=environment.action_space.n, summary_writer=summary_writer
)
elif agent_name == 'moe_dqn':
return dqn_moe_agent.DQNMoEAgent(
num_actions=environment.action_space.n, summary_writer=summary_writer
)
elif agent_name == 'moe_full_rainbow':
return full_rainbow_moe_agent.JaxFullRainbowMoEAgent(
num_actions=environment.action_space.n, summary_writer=summary_writer
)
elif agent_name == 'moe_der':
return rainbow_100k_moe_agent.Atari100kRainbowMoEAgent(
num_actions=environment.action_space.n, summary_writer=summary_writer
)
elif agent_name == 'ppo':
return ppo_agent.PPOAgent(
action_shape=environment.action_space.n,
Expand Down
2 changes: 1 addition & 1 deletion dopamine/jax/replay_memory/accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
maxlen=self._update_horizon + self._stack_size
)

def _make_replay_element(self) -> elements.ReplayElement | None:
def _make_replay_element(self) -> 'elements.ReplayElement | None':
trajectory_len = len(self._trajectory)

last_transition = self._trajectory[-1]
Expand Down
8 changes: 4 additions & 4 deletions dopamine/jax/replay_memory/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ class ReplayElement(ReplayElementProtocol, struct.PyTreeNode):
"""A single replay transition element supporting compression."""

state: npt.NDArray[np.float64]
action: npt.NDArray[np.int_] | npt.NDArray[np.float64] | int
reward: npt.NDArray[np.float64] | float
action: 'npt.NDArray[np.int_] | npt.NDArray[np.float64] | int'
reward: 'npt.NDArray[np.float64] | float'
next_state: npt.NDArray[np.float64]
is_terminal: npt.NDArray[np.bool_] | bool
episode_end: npt.NDArray[np.bool_] | bool
is_terminal: 'npt.NDArray[np.bool_] | bool'
episode_end: 'npt.NDArray[np.bool_] | bool'

def pack(self) -> 'ReplayElement':
# NOTE: pytype has a problem subclassing generics.
Expand Down
8 changes: 4 additions & 4 deletions dopamine/jax/replay_memory/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def add(self, transition: elements.TransitionElement, **kwargs: Any) -> None:
@typing.overload
def sample(
self,
size: int | None = None,
size: 'int | None' = None,
*,
with_sample_metadata: Literal[False] = False,
) -> ReplayElementT:
Expand All @@ -125,10 +125,10 @@ def sample(

def sample(
self,
size: int | None = None,
size: 'int | None' = None,
*,
with_sample_metadata: bool = False,
) -> ReplayElementT | tuple[ReplayElementT, samplers.SampleMetadata]:
) -> 'ReplayElementT | tuple[ReplayElementT, samplers.SampleMetadata]':
"""Sample a batch of elements from the replay buffer."""
if self.add_count < 1:
raise ValueError('No samples in replay buffer!')
Expand Down Expand Up @@ -161,7 +161,7 @@ def update(self, keys: ReplayItemID, **kwargs: Any) -> None:

def update(
self,
keys: npt.NDArray[ReplayItemID] | ReplayItemID,
keys: 'npt.NDArray[ReplayItemID] | ReplayItemID',
**kwargs: Any,
) -> None:
self._sampling_distribution.update(keys, **kwargs)
Expand Down
18 changes: 9 additions & 9 deletions dopamine/jax/replay_memory/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ReplayItemID = elements.ReplayItemID


@dataclasses.dataclass(frozen=True, kw_only=True)
@dataclasses.dataclass(frozen=True)
class SampleMetadata:
keys: npt.NDArray[ReplayItemID]

Expand All @@ -50,7 +50,7 @@ def update(self, key: ReplayItemID, **kwargs: Any) -> None:
...

def update(
self, keys: npt.NDArray[ReplayItemID] | ReplayItemID, **kwargs: Any
self, keys: 'npt.NDArray[ReplayItemID] | ReplayItemID', **kwargs: Any
) -> None:
...

Expand All @@ -68,7 +68,7 @@ class UniformSamplingDistribution(SamplingDistribution):
"""A uniform sampling distribution."""

def __init__(
self, seed: np.random.Generator | np.random.SeedSequence | int | None
self, seed: 'np.random.Generator | np.random.SeedSequence | int | None'
) -> None:
# RNG generator
self._rng = np.random.default_rng(seed)
Expand Down Expand Up @@ -123,7 +123,7 @@ def update(self, key: ReplayItemID, *args: Any, **kwargs: Any) -> None:

def update(
self,
keys: npt.NDArray[ReplayItemID] | ReplayItemID,
keys: 'npt.NDArray[ReplayItemID] | ReplayItemID',
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -180,7 +180,7 @@ def from_state_dict(self, state_dict: dict[str, Any]) -> None:
self._rng.bit_generator.state = state_dict['rng_state']


@dataclasses.dataclass(frozen=True, kw_only=True)
@dataclasses.dataclass(frozen=True)
class PrioritizedSampleMetadata(SampleMetadata):
probabilities: npt.NDArray[np.float64]

Expand All @@ -191,7 +191,7 @@ class PrioritizedSamplingDistribution(UniformSamplingDistribution):

def __init__(
self,
seed: np.random.SeedSequence | int | None,
seed: 'np.random.SeedSequence | int | None',
*,
priority_exponent: float = 1.0,
max_capacity: int,
Expand Down Expand Up @@ -227,9 +227,9 @@ def update(self, keys: ReplayItemID, *, priorities: float) -> None:

def update(
self,
keys: npt.NDArray[ReplayItemID] | ReplayItemID,
keys: 'npt.NDArray[ReplayItemID] | ReplayItemID',
*,
priorities: npt.NDArray[np.float64] | float,
priorities: 'npt.NDArray[np.float64] | float',
) -> None:
if not isinstance(keys, np.ndarray):
keys = np.asarray([keys], dtype=np.int32)
Expand Down Expand Up @@ -297,7 +297,7 @@ class SequentialSamplingDistribution(UniformSamplingDistribution):

def __init__(
self,
seed: np.random.Generator | np.random.SeedSequence | int,
seed: 'np.random.Generator | np.random.SeedSequence | int',
sort_samples: bool = True,
):
super().__init__(seed)
Expand Down
12 changes: 6 additions & 6 deletions dopamine/jax/replay_memory/sum_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def set(

def set(
self,
indices: npt.NDArray[np.int_] | int,
values: npt.NDArray[np.float64] | float,
indices: 'npt.NDArray[np.int_] | int',
values: 'npt.NDArray[np.float64] | float',
) -> None:
"""Set the value at a given leaf node index."""
if isinstance(indices, (int, np.integer)):
Expand Down Expand Up @@ -86,8 +86,8 @@ def get(self, index: npt.NDArray[np.int_]) -> npt.NDArray[np.float64]:
...

def get(
self, index: npt.NDArray[np.int_] | int
) -> npt.NDArray[np.float64] | float:
self, index: 'npt.NDArray[np.int_] | int'
) -> 'npt.NDArray[np.float64] | float':
"""Get the value at a given leaf node index."""
return self._nodes[self._first_leaf_offset + index]

Expand All @@ -105,8 +105,8 @@ def query(self, target: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
...

def query(
self, targets: npt.NDArray[np.float64] | float
) -> npt.NDArray[np.int_] | int:
self, targets: 'npt.NDArray[np.float64] | float'
) -> 'npt.NDArray[np.int_] | int':
"""Find the smallest index where target < cumulative value up to index.
This functions like the CDF for a multi-nomial distribution allowing us
Expand Down
2 changes: 1 addition & 1 deletion dopamine/labs/moes/architectures/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _maybe_create_moe_module(
rng_key: jax.Array,
expert_type: str = 'SMALL',
encoder_type: str = 'IMPALA',
) -> nn.Module | None:
) -> 'nn.Module | None':
"""Try to create an MoE module, or return None."""
del routing_type
moe_type = MoEType[moe_type]
Expand Down
30 changes: 15 additions & 15 deletions dopamine/labs/moes/architectures/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,26 @@
class RandomRouter(nn.Module):
"""Route tokens randomly."""

num_experts: int | None = None
num_experts: 'int | None' = None
k: int = 1

def setup(self):
logging.info("Creating a %s", self.__class__.__name__)
logging.info('Creating a %s', self.__class__.__name__)

@nn.compact
def __call__(
self,
x: jax.Array,
*,
num_experts: int | None = None,
k: int | None = None,
route_tokens: int | None = None,
key: jax.Array | None = None
num_experts: 'int | None' = None,
k: 'int | None' = None,
route_tokens: 'int | None' = None,
key: 'jax.Array | None' = None
) -> types.RouterReturn:
chex.assert_rank(x, 2)

num_experts = nn.merge_param("num_experts", num_experts, self.num_experts)
k = nn.merge_param("k", k, self.k)
num_experts = nn.merge_param('num_experts', num_experts, self.num_experts)
k = nn.merge_param('k', k, self.k)
sequence_length = x.shape[0]

# probs are set randomly.
Expand All @@ -65,27 +65,27 @@ class TopKRouter(nn.Module):
"""A simple router that linearly projects assignments."""

k: int
num_experts: int | None = None
num_experts: 'int | None' = None
noisy_routing: bool = False
noise_std: float = 1.0

def setup(self):
logging.info("Creating a %s", self.__class__.__name__)
logging.info('Creating a %s', self.__class__.__name__)

@nn.compact
def __call__(
self,
x: jax.Array,
*,
num_experts: int | None = None,
k: int | None = None,
key: jax.Array | None = None,
num_experts: 'int | None' = None,
k: 'int | None' = None,
key: 'jax.Array | None' = None,
**kwargs
) -> types.RouterReturn:
chex.assert_rank(x, 2)

num_experts = nn.merge_param("num_experts", num_experts, self.num_experts)
k = nn.merge_param("k", k, self.k)
num_experts = nn.merge_param('num_experts', num_experts, self.num_experts)
k = nn.merge_param('k', k, self.k)
sequence_length = x.shape[0]

x = nn.Dense(num_experts, use_bias=False)(x)
Expand Down
16 changes: 9 additions & 7 deletions dopamine/labs/moes/architectures/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def router_unflatten(aux_data, children):
class MoEModuleReturn:
values: jax.Array
router_out: RouterReturn
experts_hidden: jax.Array | None = None
experts_hidden: 'jax.Array | None' = None


def module_flatten(v):
Expand All @@ -73,9 +73,9 @@ def module_unflatten(aux_data, children):
class MoENetworkReturn:
q_values: jax.Array
moe_out: MoEModuleReturn
logits: jax.Array | None = None
probabilities: jax.Array | None = None
hidden_act: jax.Array | None = None
logits: 'jax.Array | None' = None
probabilities: 'jax.Array | None' = None
hidden_act: 'jax.Array | None' = None


def network_flatten(v):
Expand All @@ -100,8 +100,8 @@ def network_unflatten(aux_data, children):
class BaselineNetworkReturn:
q_values: jax.Array
hidden_act: jax.Array
logits: jax.Array | None = None
probabilities: jax.Array | None = None
logits: 'jax.Array | None' = None
probabilities: 'jax.Array | None' = None


def baseline_network_flatten(v):
Expand All @@ -122,4 +122,6 @@ def baseline_network_unflatten(aux_data, children):
)


NetworkReturn = MoENetworkReturn | BaselineNetworkReturn
# pylint: disable=invalid-name
NetworkReturn = 'MoENetworkReturn | BaselineNetworkReturn'
# pylint: enable=invalid-name
14 changes: 14 additions & 0 deletions dopamine/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2024 The Dopamine Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,21 @@
'tensorflow >= 2.2.0',
'gin-config >= 0.3.0',
'absl-py >= 0.9.0',
'ale_py >= 0.10.1',
'opencv-python >= 3.4.8.29',
'gym <= 0.25.2',
'gymnasium >= 1.0.0',
'flax >= 0.2.0',
'jax >= 0.1.72',
'jaxlib >= 0.1.51',
'Pillow >= 7.0.0',
'numpy >= 1.16.4',
'pygame >= 1.9.2',
'pandas >= 0.24.2',
'python-snappy >= 0.7.3',
'tf_slim >= 1.0',
'tensorflow-probability >= 0.13.0',
'tf-keras >= 2.18.0',
'tqdm >= 4.64.1',
]

Expand All @@ -51,7 +55,7 @@

setup(
name='dopamine_rl',
version='4.1.0',
version='4.1.1',
description=dopamine_description,
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 02a3547

Please sign in to comment.