From a86ddfcb0aadbec15f4ce03966ebfe816160c56c Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 17:13:42 +0900 Subject: [PATCH 1/6] [Minor] align SVRP specs with future TorchRL Co-authored-by: ngastzepeda --- rl4co/envs/routing/svrp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rl4co/envs/routing/svrp.py b/rl4co/envs/routing/svrp.py index 73a713a6..cab0bc8c 100644 --- a/rl4co/envs/routing/svrp.py +++ b/rl4co/envs/routing/svrp.py @@ -64,8 +64,8 @@ def _make_spec(self, td_params: TensorDict = None): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - minimum=self.min_loc, - maximum=self.max_loc, + low=self.min_loc, + high=self.max_loc, shape=(self.num_loc + 1, 2), dtype=torch.float32, ), @@ -74,8 +74,8 @@ def _make_spec(self, td_params: TensorDict = None): dtype=torch.int64, ), skills=BoundedTensorSpec( - minimum=self.min_skill, - maximum=self.max_skill, + low=self.min_skill, + high=self.max_skill, shape=(self.num_loc, 1), dtype=torch.float32, ), @@ -88,8 +88,8 @@ def _make_spec(self, td_params: TensorDict = None): self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, - minimum=0, - maximum=self.num_loc + 1, + low=0, + high=self.num_loc + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,), dtype=torch.float32) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) From be60f56a2ee06290b9887349281bd83c047ac8cb Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 17:55:14 +0900 Subject: [PATCH 2/6] [BugFix] do not save temporary decoding strategy #123 Co-authored-by: Chuanbo Hua --- .../zoo/common/autoregressive/decoder.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/rl4co/models/zoo/common/autoregressive/decoder.py b/rl4co/models/zoo/common/autoregressive/decoder.py index ca3e9022..c8dabfbd 100644 --- a/rl4co/models/zoo/common/autoregressive/decoder.py +++ b/rl4co/models/zoo/common/autoregressive/decoder.py @@ -45,7 +45,7 @@ class AutoregressiveDecoder(nn.Module): We suppose environments in the `done` state are still available for sampling. This is because in NCO we need to wait for all the environments to reach a terminal state before we can stop the decoding process. This is in contrast with the TorchRL framework (at the moment) where the `env.rollout` function automatically resets. - You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72. + You may follow tighter integration with TorchRL here: https://github.com/ai4co/rl4co/issues/72. Args: env_name: environment name to solve @@ -60,7 +60,7 @@ class AutoregressiveDecoder(nn.Module): def __init__( self, - env_name: [str, RL4COEnvBase], + env_name: Union[str, RL4COEnvBase], embedding_dim: int, num_heads: int, use_graph_context: bool = True, @@ -111,8 +111,6 @@ def __init__( self.select_start_nodes_fn = select_start_nodes_fn - self.decode_strategy = None - def forward( self, td: TensorDict, @@ -154,19 +152,26 @@ def forward( # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings, td=td) - # setup decoding strategy - self.decode_strategy: DecodingStrategy = get_decoding_strategy( + # If `select_start_nodes_fn` is not being passed, we use the class attribute + if "select_start_nodes_fn" not in strategy_kwargs: + strategy_kwargs["select_start_nodes_fn"] = self.select_start_nodes_fn + + # Setup decoding strategy + decode_strategy: DecodingStrategy = get_decoding_strategy( decode_type, **strategy_kwargs ) - td, env, num_starts = self.decode_strategy.pre_decoder_hook(td, env) + + # Pre-decoding hook: used for the initial step(s) of the decoding strategy + td, env, num_starts = decode_strategy.pre_decoder_hook(td, env) # Main decoding: loop until all sequences are done while not td["done"].all(): log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts) - td = self.decode_strategy.step(log_p, mask, td) + td = decode_strategy.step(log_p, mask, td) td = env.step(td)["next"] - outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env) + # Post-decoding hook: used for the final step(s) of the decoding strategy + outputs, actions, td, env = decode_strategy.post_decoder_hook(td, env) if calc_reward: td.set("reward", env.get_reward(td, actions)) From 37eaa40f2ff460bc776f5bea6b351436e61f0d8f Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 18:07:31 +0900 Subject: [PATCH 3/6] [Feat, Doc] enable passing kwargs, select start nodes function in `dec_strategies` --- rl4co/models/nn/dec_strategies.py | 33 ++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/rl4co/models/nn/dec_strategies.py b/rl4co/models/nn/dec_strategies.py index 2562b2f8..5a1a2fc3 100644 --- a/rl4co/models/nn/dec_strategies.py +++ b/rl4co/models/nn/dec_strategies.py @@ -32,20 +32,39 @@ def get_decoding_strategy(decoding_strategy, **config): class DecodingStrategy: + """Base class for decoding strategies. Subclasses should implement the :meth:`_step` method. + Includes hooks for pre and post main decoding operations. + + Args: + multistart (bool, optional): Whether to use multistart decoding. Defaults to False. + num_starts (int, optional): Number of starts for multistart decoding. Defaults to None. + select_start_nodes_fn (Callable, optional): Function to select start nodes. Defaults to select_start_nodes. + """ + name = "base" - def __init__(self, multistart=False, num_starts=None, **kwargs) -> None: + def __init__( + self, + multistart=False, + num_starts=None, + select_start_nodes_fn=select_start_nodes, + **kwargs, + ) -> None: + self.actions = [] self.logp = [] self.multistart = multistart self.num_starts = num_starts + self.select_start_nodes_fn = select_start_nodes_fn def _step( self, logp: torch.Tensor, td: TensorDict, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: + """Main decoding operation. This method should be implemented by subclasses.""" raise NotImplementedError("Must be implemented by subclass") def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase): + """Pre decoding hook. This method is called before the main decoding operation.""" # Multi-start decoding. If num_starts is None, we use the number of actions in the action mask if self.multistart: if self.num_starts is None: @@ -61,7 +80,7 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase): # Multi-start decoding: first action is chosen by ad-hoc node selection if self.num_starts > 1: - action = select_start_nodes(td, env, self.num_starts) + action = self.select_start_nodes_fn(td, env, self.num_starts) # Expand td to batch_size * num_starts td = batchify(td, self.num_starts) @@ -78,6 +97,7 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase): return td, env, self.num_starts def post_decoder_hook(self, td, env): + """Post decoding hook. This method is called after the main decoding operation.""" assert ( len(self.logp) > 0 ), "No outputs were collected because all environments were done. Check your initial state" @@ -87,6 +107,7 @@ def post_decoder_hook(self, td, env): def step( self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """Main decoding operation. This method calls the :meth:`_step` method and collects the outputs.""" assert not logp.isinf().all(1).any() logp, selected_actions, td = self._step(logp, mask, td, **kwargs) @@ -103,11 +124,12 @@ class Greedy(DecodingStrategy): name = "greedy" def __init__(self, multistart=False, num_starts=None, **kwargs) -> None: - super().__init__(multistart=multistart, num_starts=num_starts) + super().__init__(multistart=multistart, num_starts=num_starts, **kwargs) def _step( self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """Select the action with the highest log probability.""" # [BS], [BS] _, selected = logp.max(1) @@ -122,11 +144,12 @@ class Sampling(DecodingStrategy): name = "sampling" def __init__(self, multistart=False, num_starts=None, **kwargs) -> None: - super().__init__(multistart=multistart, num_starts=num_starts) + super().__init__(multistart=multistart, num_starts=num_starts, **kwargs) def _step( self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """Sample an action with a multinomial distribution given by the log probabilities.""" probs = logp.exp() selected = torch.multinomial(probs, 1).squeeze(1) @@ -171,7 +194,7 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase, **kwargs): self.beam_width = get_num_starts(td, env.name) # select start nodes. TODO: include first step in beam search as well - action = select_start_nodes(td, env, self.beam_width) + action = self.select_start_nodes_fn(td, env, self.beam_width) # Expand td to batch_size * beam_width td = batchify(td, self.beam_width) From 627ed4784ffd1183f776ff5b38a68ae6ac924280 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 18:08:49 +0900 Subject: [PATCH 4/6] [Test] pass `select_start_nodes_fn` too --- tests/test_policy.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_policy.py b/tests/test_policy.py index f387b19a..43067697 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -2,6 +2,7 @@ from rl4co.models import AutoregressivePolicy, PointerNetworkPolicy from rl4co.utils.test_utils import generate_env_data +from rl4co.utils.ops import select_start_nodes # Main autorergressive policy: rollout over multiple envs since it is the base @@ -25,7 +26,13 @@ def test_base_policy_multistart(env_name, size=20, batch_size=2): td = env.reset(x) policy = AutoregressivePolicy(env.name) num_starts = size // 2 if env.name in ["pdp"] else size - out = policy(td, env, decode_type="multistart_greedy", num_starts=num_starts) + out = policy( + td, + env, + decode_type="multistart_greedy", + num_starts=num_starts, + select_start_nodes_fn=select_start_nodes, + ) assert out["reward"].shape == ( batch_size * num_starts, ) # to evaluate, we could just unbatchify From 99454fd5f03411015ac44c4f8cc960f95a6ccc6a Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 18:09:35 +0900 Subject: [PATCH 5/6] [Minor] linting --- rl4co/models/zoo/common/autoregressive/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rl4co/models/zoo/common/autoregressive/policy.py b/rl4co/models/zoo/common/autoregressive/policy.py index 1de2dbb4..4c100d7e 100644 --- a/rl4co/models/zoo/common/autoregressive/policy.py +++ b/rl4co/models/zoo/common/autoregressive/policy.py @@ -51,7 +51,7 @@ class AutoregressivePolicy(nn.Module): def __init__( self, - env_name: [str, RL4COEnvBase] = "tsp", + env_name: Union[str, RL4COEnvBase] = "tsp", encoder: nn.Module = None, decoder: nn.Module = None, init_embedding: nn.Module = None, From 91dcc1fd953fa895011483a92c087b20a1ee52d2 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 18:10:05 +0900 Subject: [PATCH 6/6] [Version] 0.3.3 --- rl4co/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rl4co/__init__.py b/rl4co/__init__.py index f9aa3e11..e19434e2 100644 --- a/rl4co/__init__.py +++ b/rl4co/__init__.py @@ -1 +1 @@ -__version__ = "0.3.2" +__version__ = "0.3.3"