From be60f56a2ee06290b9887349281bd83c047ac8cb Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 2 Mar 2024 17:55:14 +0900 Subject: [PATCH] [BugFix] do not save temporary decoding strategy #123 Co-authored-by: Chuanbo Hua --- .../zoo/common/autoregressive/decoder.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/rl4co/models/zoo/common/autoregressive/decoder.py b/rl4co/models/zoo/common/autoregressive/decoder.py index ca3e9022..c8dabfbd 100644 --- a/rl4co/models/zoo/common/autoregressive/decoder.py +++ b/rl4co/models/zoo/common/autoregressive/decoder.py @@ -45,7 +45,7 @@ class AutoregressiveDecoder(nn.Module): We suppose environments in the `done` state are still available for sampling. This is because in NCO we need to wait for all the environments to reach a terminal state before we can stop the decoding process. This is in contrast with the TorchRL framework (at the moment) where the `env.rollout` function automatically resets. - You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72. + You may follow tighter integration with TorchRL here: https://github.com/ai4co/rl4co/issues/72. Args: env_name: environment name to solve @@ -60,7 +60,7 @@ class AutoregressiveDecoder(nn.Module): def __init__( self, - env_name: [str, RL4COEnvBase], + env_name: Union[str, RL4COEnvBase], embedding_dim: int, num_heads: int, use_graph_context: bool = True, @@ -111,8 +111,6 @@ def __init__( self.select_start_nodes_fn = select_start_nodes_fn - self.decode_strategy = None - def forward( self, td: TensorDict, @@ -154,19 +152,26 @@ def forward( # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings, td=td) - # setup decoding strategy - self.decode_strategy: DecodingStrategy = get_decoding_strategy( + # If `select_start_nodes_fn` is not being passed, we use the class attribute + if "select_start_nodes_fn" not in strategy_kwargs: + strategy_kwargs["select_start_nodes_fn"] = self.select_start_nodes_fn + + # Setup decoding strategy + decode_strategy: DecodingStrategy = get_decoding_strategy( decode_type, **strategy_kwargs ) - td, env, num_starts = self.decode_strategy.pre_decoder_hook(td, env) + + # Pre-decoding hook: used for the initial step(s) of the decoding strategy + td, env, num_starts = decode_strategy.pre_decoder_hook(td, env) # Main decoding: loop until all sequences are done while not td["done"].all(): log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts) - td = self.decode_strategy.step(log_p, mask, td) + td = decode_strategy.step(log_p, mask, td) td = env.step(td)["next"] - outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env) + # Post-decoding hook: used for the final step(s) of the decoding strategy + outputs, actions, td, env = decode_strategy.post_decoder_hook(td, env) if calc_reward: td.set("reward", env.get_reward(td, actions))