Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Decoding refactoring #152

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/experiment/routing/am-xl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ logger:
wandb:
project: "rl4co"
tags: ["am", "${env.name}"]
group: ${env.name}${env.num_loc}"
group: "${env.name}${env.num_loc}"
name: "am-xl-${env.name}${env.num_loc}"

model:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/routing/pomo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ logger:
wandb:
project: "rl4co"
tags: ["pomo", "${env.name}"]
group: ${env.name}${env.num_loc}"
group: "${env.name}${env.num_loc}"
name: "pomo-${env.name}${env.num_loc}"

model:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/routing/symnco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ logger:
wandb:
project: "rl4co"
tags: ["symnco", "${env.name}"]
group: ${env.name}${env.num_loc}"
group: "${env.name}${env.num_loc}"
name: "symnco-${env.name}${env.num_loc}"

model:
Expand Down
2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0dev0"
__version__ = "0.4.0dev1"
90 changes: 49 additions & 41 deletions rl4co/models/nn/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings

from typing import Callable, Optional

Expand Down Expand Up @@ -129,97 +130,104 @@ def forward(self, x, key_padding_mask=None):
return self.out_proj(rearrange(out, "b h s d -> b s (h d)"))


class LogitAttention(nn.Module):
class PointerAttention(nn.Module):
"""Calculate logits given query, key and value and logit key.
This follows the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134).

Note:
With Flash Attention, masking is not supported

Perform the following:
Performs the following:
1. Apply cross attention to get the heads
2. Project heads to get glimpse
3. Compute attention score between glimpse and logit key
4. Normalize and mask

Args:
embed_dim: total dimension of the model
num_heads: number of heads
tanh_clipping: tanh clipping value
mask_inner: whether to mask inner attention
mask_logits: whether to mask logits
normalize: whether to normalize logits
softmax_temp: softmax temperature
linear_bias: whether to use bias in linear projection
sdp_fn: scaled dot product attention function (SDPA)
check_nan: whether to check for NaNs in logits
"""

def __init__(
self,
embed_dim: int,
num_heads: int,
tanh_clipping: float = 10.0,
mask_inner: bool = True,
mask_logits: bool = True,
normalize: bool = True,
softmax_temp: float = 1.0,
out_bias: bool = False,
sdp_fn=scaled_dot_product_attention,
check_nan: bool = True,
**unused_kwargs,
):
super(LogitAttention, self).__init__()
super(PointerAttention, self).__init__()
self.num_heads = num_heads
self.mask_logits = mask_logits
self.mask_inner = mask_inner
self.tanh_clipping = tanh_clipping
self.normalize = normalize
self.softmax_temp = softmax_temp

# Projection - query, key, value already include projections
self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias)
self.sdp_fn = sdp_fn

def forward(self, query, key, value, logit_key, mask, softmax_temp=None):
self.check_nan = check_nan

# Check unused kwargs
if unused_kwargs:
log.warning(f"Unused kwargs: {unused_kwargs}")

def forward(self, query, key, value, logit_key, attn_mask=None):
"""Compute attention logits given query, key, value, logit key and attention mask.

Args:
query: query tensor of shape [B, ..., L, E]
key: key tensor of shape [B, ..., S, E]
value: value tensor of shape [B, ..., S, E]
logit_key: logit key tensor of shape [B, ..., S, E]
attn_mask: attention mask tensor of shape [B, ..., S]. Note that `True` means that the value _should_ take part in attention
as described in the [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
"""
# Compute inner multi-head attention with no projections.
heads = self._inner_mha(query, key, value, mask)
heads = self._inner_mha(query, key, value, attn_mask)
glimpse = self.project_out(heads)

# Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
# bmm is slightly faster than einsum and matmul
logits = (
torch.bmm(glimpse, logit_key.squeeze(1).transpose(-2, -1))
/ math.sqrt(glimpse.size(-1))
).squeeze(1)

# From the logits compute the probabilities by clipping, masking and softmax
if self.tanh_clipping > 0:
logits = torch.tanh(logits) * self.tanh_clipping

if self.mask_logits:
logits[mask] = float("-inf")

# Normalize with softmax and apply temperature
if self.normalize:
softmax_temp = softmax_temp if softmax_temp is not None else self.softmax_temp
logits = torch.log_softmax(logits / softmax_temp, dim=-1)
logits = (torch.bmm(glimpse, logit_key.squeeze(-2).transpose(-2, -1))).squeeze(
-2
) / math.sqrt(glimpse.size(-1))

assert not torch.isnan(logits).any(), "Logits contain NaNs"
if self.check_nan:
assert not torch.isnan(logits).any(), "Logits contain NaNs"

return logits

def _inner_mha(self, query, key, value, mask):
def _inner_mha(self, query, key, value, attn_mask):
q = self._make_heads(query)
k = self._make_heads(key)
v = self._make_heads(value)

if self.mask_inner:
# need to invert mask: (N L S) -> (N 1 L S)
# make mask the same number of dimensions as q
attn_mask = (
~mask.unsqueeze(1) if mask.ndim == 3 else ~mask.unsqueeze(1).unsqueeze(2)
attn_mask.unsqueeze(1)
if attn_mask.ndim == 3
else attn_mask.unsqueeze(1).unsqueeze(2)
)
else:
attn_mask = None

heads = self.sdp_fn(q, k, v, attn_mask=attn_mask)
return rearrange(heads, "... h n g -> ... n (h g)", h=self.num_heads)

def _make_heads(self, v):
return rearrange(v, "... g (h s) -> ... h g s", h=self.num_heads)


# Deprecated
class LogitAttention(PointerAttention):
def __init__(self, *args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"LogitAttention is deprecated and will be removed in a future release. "
"Please use PointerAttention instead."
"Note that several components of the previous LogitAttention have moved to `rl4co.models.nn.dec_strategies`.",
category=DeprecationWarning,
)
super(LogitAttention, self).__init__(*args, **kwargs)
Loading
Loading