From 02a3547d9fb4f6977e491c3bed1c82859a95b3e3 Mon Sep 17 00:00:00 2001 From: Pablo Samuel Castro Date: Tue, 29 Oct 2024 14:27:31 +0000 Subject: [PATCH] Correct a few issues for running externally, and push TF directory. PiperOrigin-RevId: 691010806 --- README.md | 9 +++--- dopamine/discrete_domains/atari_lib.py | 1 - dopamine/discrete_domains/run_experiment.py | 15 ---------- dopamine/jax/replay_memory/accumulator.py | 2 +- dopamine/jax/replay_memory/elements.py | 8 +++--- dopamine/jax/replay_memory/replay_buffer.py | 8 +++--- dopamine/jax/replay_memory/samplers.py | 18 ++++++------ dopamine/jax/replay_memory/sum_tree.py | 12 ++++---- dopamine/labs/moes/architectures/networks.py | 2 +- dopamine/labs/moes/architectures/routers.py | 30 ++++++++++---------- dopamine/labs/moes/architectures/types.py | 16 ++++++----- dopamine/tf/__init__.py | 14 +++++++++ setup.py | 6 +++- 13 files changed, 72 insertions(+), 69 deletions(-) create mode 100644 dopamine/tf/__init__.py diff --git a/README.md b/README.md index 4548b7b5..581a9714 100644 --- a/README.md +++ b/README.md @@ -55,11 +55,10 @@ environments you intend to use before you install Dopamine: **Atari** -1. Install the atari roms following the instructions from -[atari-py](https://github.com/openai/atari-py#roms). -2. `pip install ale-py` (we recommend using a [virtual environment](virtualenv)): -3. `unzip $ROM_DIR/ROMS.zip -d $ROM_DIR && ale-import-roms $ROM_DIR/ROMS` -(replace $ROM_DIR with the directory you extracted the ROMs to). +1. These should now come packaged with + [ale_py](https://github.com/Farama-Foundation/Arcade-Learning-Environment). +1. You may need to manually run some steps to properly install `baselines`, see + [instructions](https://github.com/openai/baselines). **Mujoco** diff --git a/dopamine/discrete_domains/atari_lib.py b/dopamine/discrete_domains/atari_lib.py index f9287692..4781cc26 100644 --- a/dopamine/discrete_domains/atari_lib.py +++ b/dopamine/discrete_domains/atari_lib.py @@ -108,7 +108,6 @@ def create_atari_environment( full_game_name = f'{game_name}NoFrameskip-{game_version}' env = legacy_gym.make(full_game_name) else: - gym.register_envs(ale_py) gym.register_envs(ale_py) full_game_name = f'ALE/{game_name}-v5' repeat_action_probability = 0.25 if sticky_actions else 0.0 diff --git a/dopamine/discrete_domains/run_experiment.py b/dopamine/discrete_domains/run_experiment.py index 9ee29bdd..c0c8ef6f 100644 --- a/dopamine/discrete_domains/run_experiment.py +++ b/dopamine/discrete_domains/run_experiment.py @@ -33,9 +33,6 @@ from dopamine.jax.agents.ppo import ppo_agent from dopamine.jax.agents.quantile import quantile_agent as jax_quantile_agent from dopamine.jax.agents.rainbow import rainbow_agent as jax_rainbow_agent -from dopamine.labs.moes.agents import dqn_moe_agent -from dopamine.labs.moes.agents import full_rainbow_moe_agent -from dopamine.labs.moes.agents import rainbow_100k_moe_agent from dopamine.metrics import collector_dispatcher from dopamine.metrics import statistics_instance from dopamine.tf.agents.dqn import dqn_agent @@ -124,18 +121,6 @@ def create_agent( return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer ) - elif agent_name == 'moe_dqn': - return dqn_moe_agent.DQNMoEAgent( - num_actions=environment.action_space.n, summary_writer=summary_writer - ) - elif agent_name == 'moe_full_rainbow': - return full_rainbow_moe_agent.JaxFullRainbowMoEAgent( - num_actions=environment.action_space.n, summary_writer=summary_writer - ) - elif agent_name == 'moe_der': - return rainbow_100k_moe_agent.Atari100kRainbowMoEAgent( - num_actions=environment.action_space.n, summary_writer=summary_writer - ) elif agent_name == 'ppo': return ppo_agent.PPOAgent( action_shape=environment.action_space.n, diff --git a/dopamine/jax/replay_memory/accumulator.py b/dopamine/jax/replay_memory/accumulator.py index 76207c06..3009bb01 100644 --- a/dopamine/jax/replay_memory/accumulator.py +++ b/dopamine/jax/replay_memory/accumulator.py @@ -67,7 +67,7 @@ def __init__( maxlen=self._update_horizon + self._stack_size ) - def _make_replay_element(self) -> elements.ReplayElement | None: + def _make_replay_element(self) -> 'elements.ReplayElement | None': trajectory_len = len(self._trajectory) last_transition = self._trajectory[-1] diff --git a/dopamine/jax/replay_memory/elements.py b/dopamine/jax/replay_memory/elements.py index 89dd44ee..e313f984 100644 --- a/dopamine/jax/replay_memory/elements.py +++ b/dopamine/jax/replay_memory/elements.py @@ -98,11 +98,11 @@ class ReplayElement(ReplayElementProtocol, struct.PyTreeNode): """A single replay transition element supporting compression.""" state: npt.NDArray[np.float64] - action: npt.NDArray[np.int_] | npt.NDArray[np.float64] | int - reward: npt.NDArray[np.float64] | float + action: 'npt.NDArray[np.int_] | npt.NDArray[np.float64] | int' + reward: 'npt.NDArray[np.float64] | float' next_state: npt.NDArray[np.float64] - is_terminal: npt.NDArray[np.bool_] | bool - episode_end: npt.NDArray[np.bool_] | bool + is_terminal: 'npt.NDArray[np.bool_] | bool' + episode_end: 'npt.NDArray[np.bool_] | bool' def pack(self) -> 'ReplayElement': # NOTE: pytype has a problem subclassing generics. diff --git a/dopamine/jax/replay_memory/replay_buffer.py b/dopamine/jax/replay_memory/replay_buffer.py index 5a6a9bc7..91a95276 100644 --- a/dopamine/jax/replay_memory/replay_buffer.py +++ b/dopamine/jax/replay_memory/replay_buffer.py @@ -108,7 +108,7 @@ def add(self, transition: elements.TransitionElement, **kwargs: Any) -> None: @typing.overload def sample( self, - size: int | None = None, + size: 'int | None' = None, *, with_sample_metadata: Literal[False] = False, ) -> ReplayElementT: @@ -125,10 +125,10 @@ def sample( def sample( self, - size: int | None = None, + size: 'int | None' = None, *, with_sample_metadata: bool = False, - ) -> ReplayElementT | tuple[ReplayElementT, samplers.SampleMetadata]: + ) -> 'ReplayElementT | tuple[ReplayElementT, samplers.SampleMetadata]': """Sample a batch of elements from the replay buffer.""" if self.add_count < 1: raise ValueError('No samples in replay buffer!') @@ -161,7 +161,7 @@ def update(self, keys: ReplayItemID, **kwargs: Any) -> None: def update( self, - keys: npt.NDArray[ReplayItemID] | ReplayItemID, + keys: 'npt.NDArray[ReplayItemID] | ReplayItemID', **kwargs: Any, ) -> None: self._sampling_distribution.update(keys, **kwargs) diff --git a/dopamine/jax/replay_memory/samplers.py b/dopamine/jax/replay_memory/samplers.py index 0951a294..960d8a81 100644 --- a/dopamine/jax/replay_memory/samplers.py +++ b/dopamine/jax/replay_memory/samplers.py @@ -29,7 +29,7 @@ ReplayItemID = elements.ReplayItemID -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(frozen=True) class SampleMetadata: keys: npt.NDArray[ReplayItemID] @@ -50,7 +50,7 @@ def update(self, key: ReplayItemID, **kwargs: Any) -> None: ... def update( - self, keys: npt.NDArray[ReplayItemID] | ReplayItemID, **kwargs: Any + self, keys: 'npt.NDArray[ReplayItemID] | ReplayItemID', **kwargs: Any ) -> None: ... @@ -68,7 +68,7 @@ class UniformSamplingDistribution(SamplingDistribution): """A uniform sampling distribution.""" def __init__( - self, seed: np.random.Generator | np.random.SeedSequence | int | None + self, seed: 'np.random.Generator | np.random.SeedSequence | int | None' ) -> None: # RNG generator self._rng = np.random.default_rng(seed) @@ -123,7 +123,7 @@ def update(self, key: ReplayItemID, *args: Any, **kwargs: Any) -> None: def update( self, - keys: npt.NDArray[ReplayItemID] | ReplayItemID, + keys: 'npt.NDArray[ReplayItemID] | ReplayItemID', *args: Any, **kwargs: Any, ) -> None: @@ -180,7 +180,7 @@ def from_state_dict(self, state_dict: dict[str, Any]) -> None: self._rng.bit_generator.state = state_dict['rng_state'] -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(frozen=True) class PrioritizedSampleMetadata(SampleMetadata): probabilities: npt.NDArray[np.float64] @@ -191,7 +191,7 @@ class PrioritizedSamplingDistribution(UniformSamplingDistribution): def __init__( self, - seed: np.random.SeedSequence | int | None, + seed: 'np.random.SeedSequence | int | None', *, priority_exponent: float = 1.0, max_capacity: int, @@ -227,9 +227,9 @@ def update(self, keys: ReplayItemID, *, priorities: float) -> None: def update( self, - keys: npt.NDArray[ReplayItemID] | ReplayItemID, + keys: 'npt.NDArray[ReplayItemID] | ReplayItemID', *, - priorities: npt.NDArray[np.float64] | float, + priorities: 'npt.NDArray[np.float64] | float', ) -> None: if not isinstance(keys, np.ndarray): keys = np.asarray([keys], dtype=np.int32) @@ -297,7 +297,7 @@ class SequentialSamplingDistribution(UniformSamplingDistribution): def __init__( self, - seed: np.random.Generator | np.random.SeedSequence | int, + seed: 'np.random.Generator | np.random.SeedSequence | int', sort_samples: bool = True, ): super().__init__(seed) diff --git a/dopamine/jax/replay_memory/sum_tree.py b/dopamine/jax/replay_memory/sum_tree.py index 649682fb..a9de361b 100644 --- a/dopamine/jax/replay_memory/sum_tree.py +++ b/dopamine/jax/replay_memory/sum_tree.py @@ -46,8 +46,8 @@ def set( def set( self, - indices: npt.NDArray[np.int_] | int, - values: npt.NDArray[np.float64] | float, + indices: 'npt.NDArray[np.int_] | int', + values: 'npt.NDArray[np.float64] | float', ) -> None: """Set the value at a given leaf node index.""" if isinstance(indices, (int, np.integer)): @@ -86,8 +86,8 @@ def get(self, index: npt.NDArray[np.int_]) -> npt.NDArray[np.float64]: ... def get( - self, index: npt.NDArray[np.int_] | int - ) -> npt.NDArray[np.float64] | float: + self, index: 'npt.NDArray[np.int_] | int' + ) -> 'npt.NDArray[np.float64] | float': """Get the value at a given leaf node index.""" return self._nodes[self._first_leaf_offset + index] @@ -105,8 +105,8 @@ def query(self, target: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]: ... def query( - self, targets: npt.NDArray[np.float64] | float - ) -> npt.NDArray[np.int_] | int: + self, targets: 'npt.NDArray[np.float64] | float' + ) -> 'npt.NDArray[np.int_] | int': """Find the smallest index where target < cumulative value up to index. This functions like the CDF for a multi-nomial distribution allowing us diff --git a/dopamine/labs/moes/architectures/networks.py b/dopamine/labs/moes/architectures/networks.py index 68f3a26f..ba076bb3 100644 --- a/dopamine/labs/moes/architectures/networks.py +++ b/dopamine/labs/moes/architectures/networks.py @@ -155,7 +155,7 @@ def _maybe_create_moe_module( rng_key: jax.Array, expert_type: str = 'SMALL', encoder_type: str = 'IMPALA', -) -> nn.Module | None: +) -> 'nn.Module | None': """Try to create an MoE module, or return None.""" del routing_type moe_type = MoEType[moe_type] diff --git a/dopamine/labs/moes/architectures/routers.py b/dopamine/labs/moes/architectures/routers.py index f3ec2263..9d21beb2 100644 --- a/dopamine/labs/moes/architectures/routers.py +++ b/dopamine/labs/moes/architectures/routers.py @@ -26,26 +26,26 @@ class RandomRouter(nn.Module): """Route tokens randomly.""" - num_experts: int | None = None + num_experts: 'int | None' = None k: int = 1 def setup(self): - logging.info("Creating a %s", self.__class__.__name__) + logging.info('Creating a %s', self.__class__.__name__) @nn.compact def __call__( self, x: jax.Array, *, - num_experts: int | None = None, - k: int | None = None, - route_tokens: int | None = None, - key: jax.Array | None = None + num_experts: 'int | None' = None, + k: 'int | None' = None, + route_tokens: 'int | None' = None, + key: 'jax.Array | None' = None ) -> types.RouterReturn: chex.assert_rank(x, 2) - num_experts = nn.merge_param("num_experts", num_experts, self.num_experts) - k = nn.merge_param("k", k, self.k) + num_experts = nn.merge_param('num_experts', num_experts, self.num_experts) + k = nn.merge_param('k', k, self.k) sequence_length = x.shape[0] # probs are set randomly. @@ -65,27 +65,27 @@ class TopKRouter(nn.Module): """A simple router that linearly projects assignments.""" k: int - num_experts: int | None = None + num_experts: 'int | None' = None noisy_routing: bool = False noise_std: float = 1.0 def setup(self): - logging.info("Creating a %s", self.__class__.__name__) + logging.info('Creating a %s', self.__class__.__name__) @nn.compact def __call__( self, x: jax.Array, *, - num_experts: int | None = None, - k: int | None = None, - key: jax.Array | None = None, + num_experts: 'int | None' = None, + k: 'int | None' = None, + key: 'jax.Array | None' = None, **kwargs ) -> types.RouterReturn: chex.assert_rank(x, 2) - num_experts = nn.merge_param("num_experts", num_experts, self.num_experts) - k = nn.merge_param("k", k, self.k) + num_experts = nn.merge_param('num_experts', num_experts, self.num_experts) + k = nn.merge_param('k', k, self.k) sequence_length = x.shape[0] x = nn.Dense(num_experts, use_bias=False)(x) diff --git a/dopamine/labs/moes/architectures/types.py b/dopamine/labs/moes/architectures/types.py index f6d99486..ce668df1 100644 --- a/dopamine/labs/moes/architectures/types.py +++ b/dopamine/labs/moes/architectures/types.py @@ -48,7 +48,7 @@ def router_unflatten(aux_data, children): class MoEModuleReturn: values: jax.Array router_out: RouterReturn - experts_hidden: jax.Array | None = None + experts_hidden: 'jax.Array | None' = None def module_flatten(v): @@ -73,9 +73,9 @@ def module_unflatten(aux_data, children): class MoENetworkReturn: q_values: jax.Array moe_out: MoEModuleReturn - logits: jax.Array | None = None - probabilities: jax.Array | None = None - hidden_act: jax.Array | None = None + logits: 'jax.Array | None' = None + probabilities: 'jax.Array | None' = None + hidden_act: 'jax.Array | None' = None def network_flatten(v): @@ -100,8 +100,8 @@ def network_unflatten(aux_data, children): class BaselineNetworkReturn: q_values: jax.Array hidden_act: jax.Array - logits: jax.Array | None = None - probabilities: jax.Array | None = None + logits: 'jax.Array | None' = None + probabilities: 'jax.Array | None' = None def baseline_network_flatten(v): @@ -122,4 +122,6 @@ def baseline_network_unflatten(aux_data, children): ) -NetworkReturn = MoENetworkReturn | BaselineNetworkReturn +# pylint: disable=invalid-name +NetworkReturn = 'MoENetworkReturn | BaselineNetworkReturn' +# pylint: enable=invalid-name diff --git a/dopamine/tf/__init__.py b/dopamine/tf/__init__.py new file mode 100644 index 00000000..a98a65d0 --- /dev/null +++ b/dopamine/tf/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/setup.py b/setup.py index ec481956..56e7a89c 100644 --- a/setup.py +++ b/setup.py @@ -31,8 +31,10 @@ 'tensorflow >= 2.2.0', 'gin-config >= 0.3.0', 'absl-py >= 0.9.0', + 'ale_py >= 0.10.1', 'opencv-python >= 3.4.8.29', 'gym <= 0.25.2', + 'gymnasium >= 1.0.0', 'flax >= 0.2.0', 'jax >= 0.1.72', 'jaxlib >= 0.1.51', @@ -40,8 +42,10 @@ 'numpy >= 1.16.4', 'pygame >= 1.9.2', 'pandas >= 0.24.2', + 'python-snappy >= 0.7.3', 'tf_slim >= 1.0', 'tensorflow-probability >= 0.13.0', + 'tf-keras >= 2.18.0', 'tqdm >= 4.64.1', ] @@ -51,7 +55,7 @@ setup( name='dopamine_rl', - version='4.1.0', + version='4.1.1', description=dopamine_description, long_description=long_description, long_description_content_type='text/markdown',