From af131a897d6d9b13dddc6767465bc11f51b6ca78 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sat, 6 Apr 2024 04:49:25 +0900 Subject: [PATCH] [Feat,Refactor] nucleus sampling; default returning logits; move sampling strategies in decoding strategies --- rl4co/__init__.py | 2 +- .../zoo/common/autoregressive/decoder.py | 73 +++++++++++++------ .../zoo/common/autoregressive/policy.py | 9 +++ .../zoo/common/nonautoregressive/decoder.py | 33 ++++----- rl4co/models/zoo/deepaco/antsystem.py | 7 +- rl4co/models/zoo/eas/decoder.py | 48 +++++------- rl4co/models/zoo/eas/search.py | 6 +- rl4co/models/zoo/matnet/decoder.py | 23 ++++-- rl4co/models/zoo/mdam/decoder.py | 4 +- rl4co/models/zoo/ptrnet/decoder.py | 3 +- 10 files changed, 122 insertions(+), 86 deletions(-) diff --git a/rl4co/__init__.py b/rl4co/__init__.py index d9b25cba..99ba9969 100644 --- a/rl4co/__init__.py +++ b/rl4co/__init__.py @@ -1 +1 @@ -__version__ = "0.4.0dev0" +__version__ = "0.4.0dev1" diff --git a/rl4co/models/zoo/common/autoregressive/decoder.py b/rl4co/models/zoo/common/autoregressive/decoder.py index 68131f7f..b9beec27 100644 --- a/rl4co/models/zoo/common/autoregressive/decoder.py +++ b/rl4co/models/zoo/common/autoregressive/decoder.py @@ -9,8 +9,12 @@ from torch import Tensor from rl4co.envs import RL4COEnvBase, get_env -from rl4co.models.nn.attention import LogitAttention -from rl4co.models.nn.dec_strategies import DecodingStrategy, get_decoding_strategy +from rl4co.models.nn.attention import PointerAttention +from rl4co.models.nn.dec_strategies import ( + DecodingStrategy, + get_decoding_strategy, + logits_to_probs, +) from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding from rl4co.models.nn.env_embeddings.dynamic import StaticEmbedding from rl4co.models.nn.utils import get_log_likelihood @@ -55,6 +59,9 @@ class AutoregressiveDecoder(nn.Module): linear_bias: Whether to use a bias in the linear projection of the embeddings context_embedding: Module to compute the context embedding. If None, the default is used dynamic_embedding: Module to compute the dynamic embedding. If None, the default is used + temperature: Temperature for the softmax in the decoder + tanh_clipping: Clipping value for the tanh in the decoder + mask_logits: Whether to mask the logits in the decoder """ def __init__( @@ -66,7 +73,10 @@ def __init__( linear_bias: bool = False, context_embedding: nn.Module = None, dynamic_embedding: nn.Module = None, - **logit_attn_kwargs, + temperature: float = 1.0, + tanh_clipping: float = 10.0, + mask_logits: bool = True, + **pointer_attn_kwargs, ): super().__init__() @@ -75,6 +85,9 @@ def __init__( self.env_name = env_name self.embedding_dim = embedding_dim self.num_heads = num_heads + self.temperature = temperature + self.tanh_clipping = tanh_clipping + self.mask_logits = mask_logits assert embedding_dim % num_heads == 0 @@ -103,7 +116,7 @@ def __init__( ) # MHA with Pointer mechanism (https://arxiv.org/abs/1506.03134) - self.pointer = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs) + self.pointer = PointerAttention(embedding_dim, num_heads, **pointer_attn_kwargs) def forward( self, @@ -111,8 +124,10 @@ def forward( embeddings: Tensor, env: Union[str, RL4COEnvBase] = None, decode_type: str = "sampling", - softmax_temp: float = None, calc_reward: bool = True, + temperature: float = None, + tanh_clipping: float = None, + mask_logits: bool = None, **strategy_kwargs, ) -> Tuple[Tensor, Tensor, TensorDict]: """Forward pass of the decoder @@ -129,7 +144,6 @@ def forward( - "multistart_sampling": sample as sampling, but with multi-start decoding - "multistart_greedy": sample as greedy, but with multi-start decoding - "beam_search": perform beam search - softmax_temp: Temperature for the softmax. If None, default softmax is used from the `LogitAttention` module calc_reward: Whether to calculate the reward for the decoded sequence strategy_kwargs: Keyword arguments for the decoding strategy. See :class:`rl4co.models.nn.dec_strategies.DecodingStrategy` @@ -146,9 +160,18 @@ def forward( # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings, td=td) + # Get default values if not provided + temperature = temperature if temperature is not None else self.temperature + tanh_clipping = tanh_clipping if tanh_clipping is not None else self.tanh_clipping + mask_logits = mask_logits if mask_logits is not None else self.mask_logits + # Setup decoding strategy decode_strategy: DecodingStrategy = get_decoding_strategy( - decode_type, **strategy_kwargs + decode_type, + temperature=temperature, + tanh_clipping=tanh_clipping, + mask_logits=mask_logits, + **strategy_kwargs, ) # Pre-decoding hook: used for the initial step(s) of the decoding strategy @@ -156,8 +179,8 @@ def forward( # Main decoding: loop until all sequences are done while not td["done"].all(): - log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts) - td = decode_strategy.step(log_p, mask, td) + logits, mask = self._get_logits(cached_embeds, td, num_starts) + td = decode_strategy.step(logits, mask, td) td = env.step(td)["next"] # Post-decoding hook: used for the final step(s) of the decoding strategy @@ -205,19 +228,17 @@ def _precompute_cache( return cached_embeds - def _get_log_p( + def _get_logits( self, cached: PrecomputedCache, td: TensorDict, - softmax_temp: float = None, num_starts: int = 0, ): - """Compute the log probabilities of the next actions given the current state + """Compute the logits of the next actions given the current state. Args: cache: Precomputed embeddings td: TensorDict with the current environment state - softmax_temp: Temperature for the softmax num_starts: Number of starts for the multi-start decoding """ @@ -264,18 +285,16 @@ def _get_log_p( glimpse_v = glimpse_v_stat + glimpse_v_dyn logit_k = logit_k_stat + logit_k_dyn - # Get the mask - mask = ~td["action_mask"] - # Compute logits - log_p = self.pointer(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp) + mask = td["action_mask"] + logits = self.pointer(glimpse_q, glimpse_k, glimpse_v, logit_k, mask) - # Now we need to reshape the logits and log_p to [B*S,N,...] is num_starts > 1 without dynamic embeddings + # Now we need to reshape the logits and logits to [B*S,N,...] is num_starts > 1 without dynamic embeddings # note that rearranging order is important here if num_starts > 1 and not has_dyn_emb_multi_start: - log_p = rearrange(log_p, "b s l -> (s b) l", s=num_starts) + logits = rearrange(logits, "b s l -> (s b) l", s=num_starts) mask = rearrange(mask, "b s l -> (s b) l", s=num_starts) - return log_p, mask + return logits, mask def evaluate_action( self, @@ -305,15 +324,21 @@ def evaluate_action( # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings) - log_p = [] + probs = [] decode_step = 0 while not td["done"].all(): - log_p_, _ = self._get_log_p(cached_embeds, td) + logits, _ = self._get_logits(cached_embeds, td) + probs_ = logits_to_probs( + logits, + td["action_mask"], + tanh_clipping=self.tanh_clipping, + mask_logits=self.mask_logits, + ) action_ = action[..., decode_step] td.set("action", action_) td = env.step(td)["next"] - log_p.append(log_p_) + probs.append(probs_) decode_step += 1 @@ -321,7 +346,7 @@ def evaluate_action( # due to the padded zeros in the actions # Compute log likelihood of the actions - log_p = torch.stack(log_p, 1) # [batch_size, decoding steps, num_nodes] + log_p = torch.stack(probs, 1).log() # [batch_size, decoding steps, num_nodes] ll = get_log_likelihood( log_p, action[..., :decode_step], mask=None, return_sum=False ) # [batch_size, decoding steps] diff --git a/rl4co/models/zoo/common/autoregressive/policy.py b/rl4co/models/zoo/common/autoregressive/policy.py index 6da9cd18..1bd497f6 100644 --- a/rl4co/models/zoo/common/autoregressive/policy.py +++ b/rl4co/models/zoo/common/autoregressive/policy.py @@ -44,6 +44,9 @@ class AutoregressivePolicy(nn.Module): train_decode_type: Type of decoding during training val_decode_type: Type of decoding during validation test_decode_type: Type of decoding during testing + temperature: Temperature for the softmax in the decoder + tanh_clipping: Clipping value for the tanh in the decoder + mask_logits: Whether to mask the logits in the decoder **unused_kw: Unused keyword arguments """ @@ -65,6 +68,9 @@ def __init__( train_decode_type: str = "sampling", val_decode_type: str = "greedy", test_decode_type: str = "greedy", + temperature: float = 1.0, + tanh_clipping: float = 10.0, + mask_logits: bool = True, **unused_kw, ): super(AutoregressivePolicy, self).__init__() @@ -100,6 +106,9 @@ def __init__( mask_inner=mask_inner, context_embedding=context_embedding, dynamic_embedding=dynamic_embedding, + temperature=temperature, + tanh_clipping=tanh_clipping, + mask_logits=mask_logits, ) else: self.decoder = decoder diff --git a/rl4co/models/zoo/common/nonautoregressive/decoder.py b/rl4co/models/zoo/common/nonautoregressive/decoder.py index 3c50392d..ad85fec3 100644 --- a/rl4co/models/zoo/common/nonautoregressive/decoder.py +++ b/rl4co/models/zoo/common/nonautoregressive/decoder.py @@ -60,8 +60,8 @@ def forward(self, graph: Batch) -> Tensor: # type: ignore edge_attr = self.act(layer(edge_attr)) graph.edge_attr = torch.sigmoid(self.output(edge_attr)) * 10 # type: ignore - heatmaps_logp = self._make_heatmaps(graph) - return heatmaps_logp + heatmaps_logits = self._make_heatmaps(graph) + return heatmaps_logits def _make_heatmaps(self, batch_graph: Batch) -> Tensor: # type: ignore graphs = batch_graph.to_data_list() @@ -69,7 +69,7 @@ def _make_heatmaps(self, batch_graph: Batch) -> Tensor: # type: ignore batch_size = len(graphs) num_nodes = graphs[0].x.shape[0] - heatmaps_logp = torch.zeros( + heatmaps_logits = torch.zeros( (batch_size, num_nodes, num_nodes), device=device, dtype=graphs[0].edge_attr.dtype, @@ -77,12 +77,12 @@ def _make_heatmaps(self, batch_graph: Batch) -> Tensor: # type: ignore for index, graph in enumerate(graphs): edge_index, edge_attr = graph.edge_index, graph.edge_attr - heatmaps_logp[index, edge_index[0], edge_index[1]] = edge_attr.flatten() + heatmaps_logits[index, edge_index[0], edge_index[1]] = edge_attr.flatten() if self.undirected_graph: - heatmaps_logp = (heatmaps_logp + heatmaps_logp.transpose(1, 2)) * 0.5 + heatmaps_logits = (heatmaps_logits + heatmaps_logits.transpose(1, 2)) * 0.5 - return heatmaps_logp + return heatmaps_logits class NonAutoregressiveDecoder(nn.Module): @@ -146,7 +146,7 @@ def forward( env = get_env(env_name) # calculate heatmap - heatmaps_logp = self.heatmap_generator(graph) + heatmaps_logits = self.heatmap_generator(graph) # setup decoding strategy self.decode_strategy: DecodingStrategy = get_decoding_strategy( @@ -156,8 +156,8 @@ def forward( # Main decoding: loop until all sequences are done while not td["done"].all(): - log_p, mask = self._get_log_p(td, heatmaps_logp, num_starts) - td = self.decode_strategy.step(log_p, mask, td) + logits, mask = self._get_logits(td, heatmaps_logits, num_starts) + td = self.decode_strategy.step(logits, mask, td) td = env.step(td)["next"] outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env) @@ -168,21 +168,18 @@ def forward( return outputs, actions, td @classmethod - def _get_log_p(cls, td: TensorDict, heatmaps_logp: Tensor, num_starts: int): + def _get_logits(cls, td: TensorDict, heatmaps_logits: Tensor, num_starts: int): # Get the mask - mask = ~td["action_mask"] + action_mask = td["action_mask"] current_action = td.get("action", None) if current_action is None: - log_p = heatmaps_logp.mean(-1) + logits = heatmaps_logits.mean(-1) else: - batch_size = heatmaps_logp.shape[0] + batch_size = heatmaps_logits.shape[0] _indexer = cls._multistart_batched_index(batch_size, num_starts) - log_p = heatmaps_logp[_indexer, current_action, :] - - log_p[mask] = -torch.inf - log_p = nn.functional.log_softmax(log_p, -1) - return log_p, mask + logits = heatmaps_logits[_indexer, current_action, :] + return logits, action_mask @staticmethod @lru_cache(10) diff --git a/rl4co/models/zoo/deepaco/antsystem.py b/rl4co/models/zoo/deepaco/antsystem.py index 2b6056d9..088b428a 100644 --- a/rl4co/models/zoo/deepaco/antsystem.py +++ b/rl4co/models/zoo/deepaco/antsystem.py @@ -118,14 +118,15 @@ def _sampling( ): # Sample from heatmaps # p = phe**alpha * heu**beta <==> log(p) = alpha*log(phe) + beta*log(heu) - heatmaps_logp = ( + heatmaps_logits = ( self.alpha * torch.log(self.pheromone) + self.beta * self.log_heuristic ) self.decode_strategy = Sampling(multistart=True, num_starts=self.n_ants) td, env, num_starts = self.decode_strategy.pre_decoder_hook(td, env) while not td["done"].all(): - log_p, mask = NARDecoder._get_log_p(td, heatmaps_logp, num_starts) - td = self.decode_strategy.step(log_p, mask, td) + logits, mask = NARDecoder._get_logits(td, heatmaps_logits, num_starts) + td = self.decode_strategy.step(logits, mask, td) + # TODO: check, do we need to run the logits normalization via step here? td = env.step(td)["next"] outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env) diff --git a/rl4co/models/zoo/eas/decoder.py b/rl4co/models/zoo/eas/decoder.py index 8a3deaec..e3ce5197 100644 --- a/rl4co/models/zoo/eas/decoder.py +++ b/rl4co/models/zoo/eas/decoder.py @@ -7,14 +7,13 @@ from tensordict import TensorDict from rl4co.envs import RL4COEnvBase +from rl4co.models.nn.dec_strategies import logits_to_probs from rl4co.models.nn.utils import decode_probs from rl4co.utils.ops import batchify, unbatchify -def forward_logit_attn_eas_lay( - self, query, key, value, logit_key, mask, softmax_temp=None -): - """Add layer to the forward pass of logit attention, i.e. +def forward_pointer_attn_eas_lay(self, query, key, value, logit_key, mask): + """Add layer to the forward pass of pointer attention, i.e. Single-head attention. """ # Compute inner multi-head attention with no projections. @@ -33,20 +32,6 @@ def forward_logit_attn_eas_lay( / math.sqrt(glimpse.size(-1)) ).squeeze(1) - # From the logits compute the probabilities by clipping, masking and softmax - if self.tanh_clipping > 0: - logits = torch.tanh(logits) * self.tanh_clipping - - if self.mask_logits: - logits[mask] = float("-inf") - - # Normalize with softmax and apply temperature - if self.normalize: - softmax_temp = softmax_temp if softmax_temp is not None else self.softmax_temp - logits = torch.log_softmax(logits / softmax_temp, dim=-1) - - assert not torch.isnan(logits).any(), "Logits contain NaNs" - return logits @@ -59,7 +44,7 @@ def forward_eas( env: Union[str, RL4COEnvBase] = None, decode_type: str = "multistart_sampling", num_starts: int = None, - softmax_temp: float = None, + temperature: float = None, **unused_kwargs, ): """Forward pass of the decoder @@ -76,7 +61,7 @@ def forward_eas( - "multistart_sampling": sample as sampling, but with multi-start decoding - "multistart_greedy": sample as greedy, but with multi-start decoding num_starts: Number of multi-starts to use. If None, will be calculated from the action mask - softmax_temp: Temperature for the softmax. If None, default softmax is used from the `LogitAttention` module + temperature: Temperature for the softmax. If None, default softmax is used from the `PointerAttention` module calc_reward: Whether to calculate the reward for the decoded sequence """ @@ -99,20 +84,26 @@ def forward_eas( td.set("action", action) td = env.step(td)["next"] - log_p = torch.zeros_like( - td["action_mask"], device=td.device - ) # first log_p is 0, so p = log_p.exp() = 1 + probs = torch.ones_like(td["action_mask"], device=td.device) - outputs.append(log_p) + outputs.append(probs) actions.append(action) # Main decoding: loop until all sequences are done while not td["done"].all(): decode_step += 1 - log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts + 1) + logits, _ = self._get_logits(cached_embeds, td, num_starts + 1) + temperature = temperature if temperature is not None else self.temperature + probs = logits_to_probs( + logits, + mask=td["action_mask"], + temperature=temperature, + tanh_clipping=self.tanh_clipping, + mask_logits=self.mask_logits, + ) # Select the indices of the next nodes in the sequences, result (batch_size) long - action = decode_probs(log_p.exp(), mask, decode_type=decode_type) + action = decode_probs(probs, td["action_mask"], decode_type=decode_type) if iter_count > 0: # append incumbent solutions init_shp = action.shape @@ -124,9 +115,10 @@ def forward_eas( td = env.step(td)["next"] # Collect output of step - outputs.append(log_p) + outputs.append(probs) actions.append(action) - outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1) + # Note: we convert outputs (probs) to log-probs here + outputs, actions = torch.stack(outputs, 1).log(), torch.stack(actions, 1) rewards = env.get_reward(td, actions) return outputs, actions, td, rewards diff --git a/rl4co/models/zoo/eas/search.py b/rl4co/models/zoo/eas/search.py index a1d4db67..cff87d0e 100644 --- a/rl4co/models/zoo/eas/search.py +++ b/rl4co/models/zoo/eas/search.py @@ -12,7 +12,7 @@ from rl4co.data.transforms import StateAugmentation from rl4co.models.nn.utils import get_log_likelihood from rl4co.models.zoo.common.search import SearchBase -from rl4co.models.zoo.eas.decoder import forward_eas, forward_logit_attn_eas_lay +from rl4co.models.zoo.eas.decoder import forward_eas, forward_pointer_attn_eas_lay from rl4co.models.zoo.eas.nn import EASLayerNet from rl4co.utils.ops import batchify, gather_by_index, unbatchify from rl4co.utils.pylogger import get_pylogger @@ -166,7 +166,9 @@ def training_step(self, batch, batch_idx): # EASLay: replace forward of logit attention computation. EASLayer eas_layer = EASLayerNet(num_instances, decoder.embedding_dim).to(batch.device) decoder.pointer.eas_layer = partial(eas_layer, decoder.pointer) - decoder.pointer.forward = partial(forward_logit_attn_eas_lay, decoder.pointer) + decoder.pointer.forward = partial( + forward_pointer_attn_eas_lay, decoder.pointer + ) for param in eas_layer.parameters(): opt_params.append(param) if self.hparams.use_eas_embedding: diff --git a/rl4co/models/zoo/matnet/decoder.py b/rl4co/models/zoo/matnet/decoder.py index 9bb64e9c..2ddf09d0 100644 --- a/rl4co/models/zoo/matnet/decoder.py +++ b/rl4co/models/zoo/matnet/decoder.py @@ -7,6 +7,7 @@ from tensordict import TensorDict from torch import Tensor +from rl4co.models.nn.dec_strategies import logits_to_probs from rl4co.models.nn.env_embeddings.context import FFSPContext from rl4co.models.zoo.common.autoregressive.decoder import AutoregressiveDecoder @@ -53,7 +54,7 @@ def __init__( embedding_dim: int, num_heads: int, use_graph_context: bool = False, - **logit_attn_kwargs, + **kwargs, ): context_embedding = FFSPContext(embedding_dim) @@ -63,7 +64,7 @@ def __init__( num_heads=num_heads, use_graph_context=use_graph_context, context_embedding=context_embedding, - **logit_attn_kwargs, + **kwargs, ) self.no_job_emb = nn.Parameter( @@ -117,13 +118,13 @@ def __init__( embedding_dim: int, num_heads: int, use_graph_context: bool = True, - **logit_attn_kwargs, + **kwargs, ): super().__init__( embedding_dim=embedding_dim, num_heads=num_heads, use_graph_context=use_graph_context, - **logit_attn_kwargs, + **kwargs, ) self.cached_embs: PrecomputedCache = None # self.encoded_wait_op = nn.Parameter(torch.rand((1, 1, embedding_dim))) @@ -136,13 +137,21 @@ def forward( td: TensorDict, decode_type="sampling", num_starts: int = 1, - softmax_temp: float = None, + temperature: float = None, ) -> Tuple[Tensor, Tensor, TensorDict]: device = td.device batch_size = td.size(0) - log_p, _ = self._get_log_p(self.cached_embs, td, softmax_temp, num_starts) - all_job_probs = log_p.exp() + # Get logits for each job and probabilities + logits, _ = self._get_logits(self.cached_embs, td, num_starts) + temperature = temperature if temperature is not None else self.temperature + all_job_probs = logits_to_probs( + logits, + mask=td["action_mask"], + temperature=temperature, + tanh_clipping=self.tanh_clipping, + mask_logits=self.mask_logits, + ) if "sampling" in decode_type: # to fix pytorch.multinomial bug on selecting 0 probability elements diff --git a/rl4co/models/zoo/mdam/decoder.py b/rl4co/models/zoo/mdam/decoder.py index f74a5788..8399e838 100644 --- a/rl4co/models/zoo/mdam/decoder.py +++ b/rl4co/models/zoo/mdam/decoder.py @@ -10,7 +10,7 @@ from tensordict import TensorDict from rl4co.envs import RL4COEnvBase -from rl4co.models.nn.attention import LogitAttention +from rl4co.models.nn.attention import PointerAttention from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding from rl4co.models.nn.utils import decode_probs, get_log_likelihood @@ -88,7 +88,7 @@ def __init__( # MHA with Pointer mechanism (https://arxiv.org/abs/1506.03134) self.pointer = [ - LogitAttention( + PointerAttention( embedding_dim, num_heads, mask_inner=mask_inner, diff --git a/rl4co/models/zoo/ptrnet/decoder.py b/rl4co/models/zoo/ptrnet/decoder.py index f2a69137..3b7cbbfb 100644 --- a/rl4co/models/zoo/ptrnet/decoder.py +++ b/rl4co/models/zoo/ptrnet/decoder.py @@ -157,7 +157,8 @@ def forward( ) # select the next inputs for the decoder [batch_size x hidden_dim] idxs = ( - decode_probs(probs, mask, decode_type=decode_type) + # note: mask here is the inverse of usual action mask + decode_probs(probs, ~mask, decode_type=decode_type) if eval_tours is None else eval_tours[:, i] )