Skip to content

Commit

Permalink
[Feat,Refactor] nucleus sampling; default returning logits; move samp…
Browse files Browse the repository at this point in the history
…ling strategies in decoding strategies
  • Loading branch information
fedebotu committed Apr 5, 2024
1 parent ada0d2f commit af131a8
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 86 deletions.
2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0dev0"
__version__ = "0.4.0dev1"
73 changes: 49 additions & 24 deletions rl4co/models/zoo/common/autoregressive/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from torch import Tensor

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.dec_strategies import DecodingStrategy, get_decoding_strategy
from rl4co.models.nn.attention import PointerAttention
from rl4co.models.nn.dec_strategies import (
DecodingStrategy,
get_decoding_strategy,
logits_to_probs,
)
from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding
from rl4co.models.nn.env_embeddings.dynamic import StaticEmbedding
from rl4co.models.nn.utils import get_log_likelihood
Expand Down Expand Up @@ -55,6 +59,9 @@ class AutoregressiveDecoder(nn.Module):
linear_bias: Whether to use a bias in the linear projection of the embeddings
context_embedding: Module to compute the context embedding. If None, the default is used
dynamic_embedding: Module to compute the dynamic embedding. If None, the default is used
temperature: Temperature for the softmax in the decoder
tanh_clipping: Clipping value for the tanh in the decoder
mask_logits: Whether to mask the logits in the decoder
"""

def __init__(
Expand All @@ -66,7 +73,10 @@ def __init__(
linear_bias: bool = False,
context_embedding: nn.Module = None,
dynamic_embedding: nn.Module = None,
**logit_attn_kwargs,
temperature: float = 1.0,
tanh_clipping: float = 10.0,
mask_logits: bool = True,
**pointer_attn_kwargs,
):
super().__init__()

Expand All @@ -75,6 +85,9 @@ def __init__(
self.env_name = env_name
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.temperature = temperature
self.tanh_clipping = tanh_clipping
self.mask_logits = mask_logits

assert embedding_dim % num_heads == 0

Expand Down Expand Up @@ -103,16 +116,18 @@ def __init__(
)

# MHA with Pointer mechanism (https://arxiv.org/abs/1506.03134)
self.pointer = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs)
self.pointer = PointerAttention(embedding_dim, num_heads, **pointer_attn_kwargs)

def forward(
self,
td: TensorDict,
embeddings: Tensor,
env: Union[str, RL4COEnvBase] = None,
decode_type: str = "sampling",
softmax_temp: float = None,
calc_reward: bool = True,
temperature: float = None,
tanh_clipping: float = None,
mask_logits: bool = None,
**strategy_kwargs,
) -> Tuple[Tensor, Tensor, TensorDict]:
"""Forward pass of the decoder
Expand All @@ -129,7 +144,6 @@ def forward(
- "multistart_sampling": sample as sampling, but with multi-start decoding
- "multistart_greedy": sample as greedy, but with multi-start decoding
- "beam_search": perform beam search
softmax_temp: Temperature for the softmax. If None, default softmax is used from the `LogitAttention` module
calc_reward: Whether to calculate the reward for the decoded sequence
strategy_kwargs: Keyword arguments for the decoding strategy. See :class:`rl4co.models.nn.dec_strategies.DecodingStrategy`
Expand All @@ -146,18 +160,27 @@ def forward(
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
cached_embeds = self._precompute_cache(embeddings, td=td)

# Get default values if not provided
temperature = temperature if temperature is not None else self.temperature
tanh_clipping = tanh_clipping if tanh_clipping is not None else self.tanh_clipping
mask_logits = mask_logits if mask_logits is not None else self.mask_logits

# Setup decoding strategy
decode_strategy: DecodingStrategy = get_decoding_strategy(
decode_type, **strategy_kwargs
decode_type,
temperature=temperature,
tanh_clipping=tanh_clipping,
mask_logits=mask_logits,
**strategy_kwargs,
)

# Pre-decoding hook: used for the initial step(s) of the decoding strategy
td, env, num_starts = decode_strategy.pre_decoder_hook(td, env)

# Main decoding: loop until all sequences are done
while not td["done"].all():
log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts)
td = decode_strategy.step(log_p, mask, td)
logits, mask = self._get_logits(cached_embeds, td, num_starts)
td = decode_strategy.step(logits, mask, td)
td = env.step(td)["next"]

# Post-decoding hook: used for the final step(s) of the decoding strategy
Expand Down Expand Up @@ -205,19 +228,17 @@ def _precompute_cache(

return cached_embeds

def _get_log_p(
def _get_logits(
self,
cached: PrecomputedCache,
td: TensorDict,
softmax_temp: float = None,
num_starts: int = 0,
):
"""Compute the log probabilities of the next actions given the current state
"""Compute the logits of the next actions given the current state.
Args:
cache: Precomputed embeddings
td: TensorDict with the current environment state
softmax_temp: Temperature for the softmax
num_starts: Number of starts for the multi-start decoding
"""

Expand Down Expand Up @@ -264,18 +285,16 @@ def _get_log_p(
glimpse_v = glimpse_v_stat + glimpse_v_dyn
logit_k = logit_k_stat + logit_k_dyn

# Get the mask
mask = ~td["action_mask"]

# Compute logits
log_p = self.pointer(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp)
mask = td["action_mask"]
logits = self.pointer(glimpse_q, glimpse_k, glimpse_v, logit_k, mask)

# Now we need to reshape the logits and log_p to [B*S,N,...] is num_starts > 1 without dynamic embeddings
# Now we need to reshape the logits and logits to [B*S,N,...] is num_starts > 1 without dynamic embeddings
# note that rearranging order is important here
if num_starts > 1 and not has_dyn_emb_multi_start:
log_p = rearrange(log_p, "b s l -> (s b) l", s=num_starts)
logits = rearrange(logits, "b s l -> (s b) l", s=num_starts)
mask = rearrange(mask, "b s l -> (s b) l", s=num_starts)
return log_p, mask
return logits, mask

def evaluate_action(
self,
Expand Down Expand Up @@ -305,23 +324,29 @@ def evaluate_action(
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
cached_embeds = self._precompute_cache(embeddings)

log_p = []
probs = []
decode_step = 0
while not td["done"].all():
log_p_, _ = self._get_log_p(cached_embeds, td)
logits, _ = self._get_logits(cached_embeds, td)
probs_ = logits_to_probs(
logits,
td["action_mask"],
tanh_clipping=self.tanh_clipping,
mask_logits=self.mask_logits,
)
action_ = action[..., decode_step]

td.set("action", action_)
td = env.step(td)["next"]
log_p.append(log_p_)
probs.append(probs_)

decode_step += 1

# Note that the decoding steps may not be equal to the decoding steps of actions
# due to the padded zeros in the actions

# Compute log likelihood of the actions
log_p = torch.stack(log_p, 1) # [batch_size, decoding steps, num_nodes]
log_p = torch.stack(probs, 1).log() # [batch_size, decoding steps, num_nodes]
ll = get_log_likelihood(
log_p, action[..., :decode_step], mask=None, return_sum=False
) # [batch_size, decoding steps]
Expand Down
9 changes: 9 additions & 0 deletions rl4co/models/zoo/common/autoregressive/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class AutoregressivePolicy(nn.Module):
train_decode_type: Type of decoding during training
val_decode_type: Type of decoding during validation
test_decode_type: Type of decoding during testing
temperature: Temperature for the softmax in the decoder
tanh_clipping: Clipping value for the tanh in the decoder
mask_logits: Whether to mask the logits in the decoder
**unused_kw: Unused keyword arguments
"""

Expand All @@ -65,6 +68,9 @@ def __init__(
train_decode_type: str = "sampling",
val_decode_type: str = "greedy",
test_decode_type: str = "greedy",
temperature: float = 1.0,
tanh_clipping: float = 10.0,
mask_logits: bool = True,
**unused_kw,
):
super(AutoregressivePolicy, self).__init__()
Expand Down Expand Up @@ -100,6 +106,9 @@ def __init__(
mask_inner=mask_inner,
context_embedding=context_embedding,
dynamic_embedding=dynamic_embedding,
temperature=temperature,
tanh_clipping=tanh_clipping,
mask_logits=mask_logits,
)
else:
self.decoder = decoder
Expand Down
33 changes: 15 additions & 18 deletions rl4co/models/zoo/common/nonautoregressive/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,29 @@ def forward(self, graph: Batch) -> Tensor: # type: ignore
edge_attr = self.act(layer(edge_attr))
graph.edge_attr = torch.sigmoid(self.output(edge_attr)) * 10 # type: ignore

heatmaps_logp = self._make_heatmaps(graph)
return heatmaps_logp
heatmaps_logits = self._make_heatmaps(graph)
return heatmaps_logits

def _make_heatmaps(self, batch_graph: Batch) -> Tensor: # type: ignore
graphs = batch_graph.to_data_list()
device = graphs[0].edge_attr.device
batch_size = len(graphs)
num_nodes = graphs[0].x.shape[0]

heatmaps_logp = torch.zeros(
heatmaps_logits = torch.zeros(
(batch_size, num_nodes, num_nodes),
device=device,
dtype=graphs[0].edge_attr.dtype,
)

for index, graph in enumerate(graphs):
edge_index, edge_attr = graph.edge_index, graph.edge_attr
heatmaps_logp[index, edge_index[0], edge_index[1]] = edge_attr.flatten()
heatmaps_logits[index, edge_index[0], edge_index[1]] = edge_attr.flatten()

if self.undirected_graph:
heatmaps_logp = (heatmaps_logp + heatmaps_logp.transpose(1, 2)) * 0.5
heatmaps_logits = (heatmaps_logits + heatmaps_logits.transpose(1, 2)) * 0.5

return heatmaps_logp
return heatmaps_logits


class NonAutoregressiveDecoder(nn.Module):
Expand Down Expand Up @@ -146,7 +146,7 @@ def forward(
env = get_env(env_name)

# calculate heatmap
heatmaps_logp = self.heatmap_generator(graph)
heatmaps_logits = self.heatmap_generator(graph)

# setup decoding strategy
self.decode_strategy: DecodingStrategy = get_decoding_strategy(
Expand All @@ -156,8 +156,8 @@ def forward(

# Main decoding: loop until all sequences are done
while not td["done"].all():
log_p, mask = self._get_log_p(td, heatmaps_logp, num_starts)
td = self.decode_strategy.step(log_p, mask, td)
logits, mask = self._get_logits(td, heatmaps_logits, num_starts)
td = self.decode_strategy.step(logits, mask, td)
td = env.step(td)["next"]

outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env)
Expand All @@ -168,21 +168,18 @@ def forward(
return outputs, actions, td

@classmethod
def _get_log_p(cls, td: TensorDict, heatmaps_logp: Tensor, num_starts: int):
def _get_logits(cls, td: TensorDict, heatmaps_logits: Tensor, num_starts: int):
# Get the mask
mask = ~td["action_mask"]
action_mask = td["action_mask"]

current_action = td.get("action", None)
if current_action is None:
log_p = heatmaps_logp.mean(-1)
logits = heatmaps_logits.mean(-1)
else:
batch_size = heatmaps_logp.shape[0]
batch_size = heatmaps_logits.shape[0]
_indexer = cls._multistart_batched_index(batch_size, num_starts)
log_p = heatmaps_logp[_indexer, current_action, :]

log_p[mask] = -torch.inf
log_p = nn.functional.log_softmax(log_p, -1)
return log_p, mask
logits = heatmaps_logits[_indexer, current_action, :]
return logits, action_mask

@staticmethod
@lru_cache(10)
Expand Down
7 changes: 4 additions & 3 deletions rl4co/models/zoo/deepaco/antsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ def _sampling(
):
# Sample from heatmaps
# p = phe**alpha * heu**beta <==> log(p) = alpha*log(phe) + beta*log(heu)
heatmaps_logp = (
heatmaps_logits = (
self.alpha * torch.log(self.pheromone) + self.beta * self.log_heuristic
)
self.decode_strategy = Sampling(multistart=True, num_starts=self.n_ants)
td, env, num_starts = self.decode_strategy.pre_decoder_hook(td, env)
while not td["done"].all():
log_p, mask = NARDecoder._get_log_p(td, heatmaps_logp, num_starts)
td = self.decode_strategy.step(log_p, mask, td)
logits, mask = NARDecoder._get_logits(td, heatmaps_logits, num_starts)
td = self.decode_strategy.step(logits, mask, td)
# TODO: check, do we need to run the logits normalization via step here?
td = env.step(td)["next"]

outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env)
Expand Down
Loading

0 comments on commit af131a8

Please sign in to comment.