Skip to content

Commit

Permalink
[Feat] FlashAttention2; remove unused kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Sep 19, 2023
1 parent 82ef231 commit 16b1cb0
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 35 deletions.
18 changes: 7 additions & 11 deletions rl4co/models/nn/graph/attnnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional

import torch.nn as nn

Expand All @@ -19,7 +19,7 @@ class MultiHeadAttentionLayer(nn.Sequential):
embed_dim: dimension of the embeddings
feed_forward_hidden: dimension of the hidden layer in the feed-forward layer
normalization: type of normalization to use (batch, layer, none)
force_flash_attn: whether to force FlashAttention (move to half precision)
sdpa_fn: scaled dot product attention function (SDPA)
"""

def __init__(
Expand All @@ -28,14 +28,10 @@ def __init__(
embed_dim: int,
feed_forward_hidden: int = 512,
normalization: Optional[str] = "batch",
force_flash_attn: bool = False,
sdpa_fn: Optional[Callable] = None,
):
super(MultiHeadAttentionLayer, self).__init__(
SkipConnection(
MultiHeadAttention(
embed_dim, num_heads, force_flash_attn=force_flash_attn
)
),
SkipConnection(MultiHeadAttention(embed_dim, num_heads, sdpa_fn=sdpa_fn)),
Normalization(embed_dim, normalization),
SkipConnection(
nn.Sequential(
Expand All @@ -60,7 +56,7 @@ class GraphAttentionNetwork(nn.Module):
num_layers: number of MHA layers
normalization: type of normalization to use (batch, layer, none)
feed_forward_hidden: dimension of the hidden layer in the feed-forward layer
force_flash_attn: whether to force FlashAttention (move to half precision)
sdpa_fn: scaled dot product attention function (SDPA)
"""

def __init__(
Expand All @@ -70,7 +66,7 @@ def __init__(
num_layers: int,
normalization: str = "batch",
feed_forward_hidden: int = 512,
force_flash_attn: bool = False,
sdpa_fn: Optional[Callable] = None,
):
super(GraphAttentionNetwork, self).__init__()

Expand All @@ -81,7 +77,7 @@ def __init__(
embedding_dim,
feed_forward_hidden=feed_forward_hidden,
normalization=normalization,
force_flash_attn=force_flash_attn,
sdpa_fn=sdpa_fn,
)
for _ in range(num_layers)
)
Expand Down
8 changes: 4 additions & 4 deletions rl4co/models/rl/common/critic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Callable, Optional, Union

from tensordict import TensorDict
from torch import Tensor, nn
Expand All @@ -20,7 +20,7 @@ class CriticNetwork(nn.Module):
num_layers: Number of layers for the encoder
num_heads: Number of heads for the attention
normalization: Normalization to use for the attention
force_flash_attn: Whether to force the use of flash attention. If True, cast to fp16
sdpa_fn: Scaled dot product function to use for the attention
"""

def __init__(
Expand All @@ -32,7 +32,7 @@ def __init__(
num_layers: int = 3,
num_heads: int = 8,
normalization: str = "batch",
force_flash_attn: bool = False,
sdpa_fn: Optional[Callable] = None,
**unused_kwargs,
):
super(CriticNetwork, self).__init__()
Expand All @@ -51,7 +51,7 @@ def __init__(
num_layers=num_layers,
normalization=normalization,
feed_forward_hidden=hidden_dim,
force_flash_attn=force_flash_attn,
sdpa_fn=sdpa_fn,
)
if encoder is None
else encoder
Expand Down
33 changes: 23 additions & 10 deletions rl4co/models/zoo/common/autoregressive/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn

from einops import rearrange
from tensordict import TensorDict
from torch import Tensor
Expand Down Expand Up @@ -57,7 +58,7 @@ class AutoregressiveDecoder(nn.Module):

def __init__(
self,
env_name: str,
env_name: [str, RL4COEnvBase],
embedding_dim: int,
num_heads: int,
use_graph_context: bool = True,
Expand All @@ -69,6 +70,8 @@ def __init__(
):
super().__init__()

if isinstance(env_name, RL4COEnvBase):
env_name = env_name.name
self.env_name = env_name
self.embedding_dim = embedding_dim
self.num_heads = num_heads
Expand All @@ -88,11 +91,17 @@ def __init__(
self.use_graph_context = use_graph_context

# For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=linear_bias)
self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=linear_bias)
self.project_node_embeddings = nn.Linear(
embedding_dim, 3 * embedding_dim, bias=linear_bias
)
self.project_fixed_context = nn.Linear(
embedding_dim, embedding_dim, bias=linear_bias
)

# MHA
self.logit_attention = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs)
self.logit_attention = LogitAttention(
embedding_dim, num_heads, **logit_attn_kwargs
)

self.select_start_nodes_fn = select_start_nodes_fn

Expand Down Expand Up @@ -136,7 +145,9 @@ def forward(
else:
if num_starts is not None:
if num_starts > 1:
log.warn(f"num_starts={num_starts} is ignored for decode_type={decode_type}")
log.warn(
f"num_starts={num_starts} is ignored for decode_type={decode_type}"
)
num_starts = 0

# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
Expand Down Expand Up @@ -187,7 +198,9 @@ def forward(

return outputs, actions, td

def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None):
def _precompute_cache(
self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None
):
"""Compute the cached embeddings for the attention
Args:
Expand All @@ -202,9 +215,7 @@ def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorD
glimpse_key_fixed,
glimpse_val_fixed,
logit_key_fixed,
) = self.project_node_embeddings(
embeddings
).chunk(3, dim=-1)
) = self.project_node_embeddings(embeddings).chunk(3, dim=-1)

# Optionally disable the graph context from the initial embedding as done in POMO
if self.use_graph_context:
Expand Down Expand Up @@ -262,7 +273,9 @@ def _get_log_p(
mask = ~td_unbatch["action_mask"]

# Compute logits
log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp)
log_p = self.logit_attention(
glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp
)

# Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes]
# Note that rearranging order is important here
Expand Down
11 changes: 7 additions & 4 deletions rl4co/models/zoo/common/autoregressive/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tensordict import TensorDict
from torch import Tensor

from rl4co.envs import RL4COEnvBase
from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.nn.graph.attnnet import GraphAttentionNetwork

Expand All @@ -19,23 +20,25 @@ class GraphAttentionEncoder(nn.Module):
num_layers: Number of layers for the encoder
normalization: Normalization to use for the attention
feed_forward_hidden: Hidden dimension for the feed-forward network
force_flash_attn: Whether to force the use of flash attention. If True, cast to fp16
init_embedding: Model to use for the initial embedding. If None, use the default embedding for the environment
sdpa_fn: Scaled dot product function to use for the attention
"""

def __init__(
self,
env_name: str,
env_name: [str, RL4COEnvBase],
num_heads: int,
embedding_dim: int,
num_layers: int,
normalization: str = "batch",
feed_forward_hidden: int = 512,
force_flash_attn: bool = False,
init_embedding: nn.Module = None,
sdpa_fn=None,
):
super(GraphAttentionEncoder, self).__init__()

if isinstance(env_name, RL4COEnvBase):
env_name = env_name.name
self.env_name = env_name

self.init_embedding = (
Expand All @@ -50,7 +53,7 @@ def __init__(
num_layers,
normalization,
feed_forward_hidden,
force_flash_attn,
sdpa_fn=sdpa_fn,
)

def forward(
Expand Down
9 changes: 4 additions & 5 deletions rl4co/models/zoo/common/autoregressive/policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Callable, Optional, Union

import torch.nn as nn

Expand Down Expand Up @@ -39,7 +39,7 @@ class AutoregressivePolicy(nn.Module):
normalization: Normalization type in the attention layers
mask_inner: Whether to mask the inner diagonal in the attention layers
use_graph_context: Whether to use the initial graph context to modify the query
force_flash_attn: Whether to force the use of flash attention in the attention layers
sdpa_fn: Scaled dot product function to use for the attention
train_decode_type: Type of decoding during training
val_decode_type: Type of decoding during validation
test_decode_type: Type of decoding during testing
Expand All @@ -60,7 +60,7 @@ def __init__(
normalization: str = "batch",
mask_inner: bool = True,
use_graph_context: bool = True,
force_flash_attn: bool = False,
sdpa_fn: Optional[Callable] = None,
train_decode_type: str = "sampling",
val_decode_type: str = "greedy",
test_decode_type: str = "greedy",
Expand All @@ -83,8 +83,8 @@ def __init__(
embedding_dim=embedding_dim,
num_layers=num_encoder_layers,
normalization=normalization,
force_flash_attn=force_flash_attn,
init_embedding=init_embedding,
sdpa_fn=sdpa_fn,
)
else:
self.encoder = encoder
Expand All @@ -97,7 +97,6 @@ def __init__(
num_heads=num_heads,
use_graph_context=use_graph_context,
mask_inner=mask_inner,
force_flash_attn=force_flash_attn,
context_embedding=context_embedding,
dynamic_embedding=dynamic_embedding,
)
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/ham/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
env_name=None,
normalization="batch",
feed_forward_hidden=512,
force_flash_attn=False,
sdpa_fn=None,
):
super(GraphHeterogeneousAttentionEncoder, self).__init__()

Expand Down

0 comments on commit 16b1cb0

Please sign in to comment.