Skip to content

Commit

Permalink
[BugFix] do not save temporary decoding strategy #123
Browse files Browse the repository at this point in the history
Co-authored-by: Chuanbo Hua <cbhua@kaist.ac.kr>
  • Loading branch information
fedebotu and cbhua committed Mar 2, 2024
1 parent a86ddfc commit be60f56
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions rl4co/models/zoo/common/autoregressive/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AutoregressiveDecoder(nn.Module):
We suppose environments in the `done` state are still available for sampling. This is because in NCO we need to
wait for all the environments to reach a terminal state before we can stop the decoding process. This is in
contrast with the TorchRL framework (at the moment) where the `env.rollout` function automatically resets.
You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72.
You may follow tighter integration with TorchRL here: https://github.com/ai4co/rl4co/issues/72.
Args:
env_name: environment name to solve
Expand All @@ -60,7 +60,7 @@ class AutoregressiveDecoder(nn.Module):

def __init__(
self,
env_name: [str, RL4COEnvBase],
env_name: Union[str, RL4COEnvBase],
embedding_dim: int,
num_heads: int,
use_graph_context: bool = True,
Expand Down Expand Up @@ -111,8 +111,6 @@ def __init__(

self.select_start_nodes_fn = select_start_nodes_fn

self.decode_strategy = None

def forward(
self,
td: TensorDict,
Expand Down Expand Up @@ -154,19 +152,26 @@ def forward(
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
cached_embeds = self._precompute_cache(embeddings, td=td)

# setup decoding strategy
self.decode_strategy: DecodingStrategy = get_decoding_strategy(
# If `select_start_nodes_fn` is not being passed, we use the class attribute
if "select_start_nodes_fn" not in strategy_kwargs:
strategy_kwargs["select_start_nodes_fn"] = self.select_start_nodes_fn

# Setup decoding strategy
decode_strategy: DecodingStrategy = get_decoding_strategy(
decode_type, **strategy_kwargs
)
td, env, num_starts = self.decode_strategy.pre_decoder_hook(td, env)

# Pre-decoding hook: used for the initial step(s) of the decoding strategy
td, env, num_starts = decode_strategy.pre_decoder_hook(td, env)

# Main decoding: loop until all sequences are done
while not td["done"].all():
log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts)
td = self.decode_strategy.step(log_p, mask, td)
td = decode_strategy.step(log_p, mask, td)
td = env.step(td)["next"]

outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env)
# Post-decoding hook: used for the final step(s) of the decoding strategy
outputs, actions, td, env = decode_strategy.post_decoder_hook(td, env)

if calc_reward:
td.set("reward", env.get_reward(td, actions))
Expand Down

0 comments on commit be60f56

Please sign in to comment.