From 16b1cb0817532580adebbaa8af2006f653297fe4 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Tue, 19 Sep 2023 19:16:03 +0900 Subject: [PATCH] [Feat] FlashAttention2; remove unused kwarg --- rl4co/models/nn/graph/attnnet.py | 18 ++++------ rl4co/models/rl/common/critic.py | 8 ++--- .../zoo/common/autoregressive/decoder.py | 33 +++++++++++++------ .../zoo/common/autoregressive/encoder.py | 11 ++++--- .../zoo/common/autoregressive/policy.py | 9 +++-- rl4co/models/zoo/ham/encoder.py | 2 +- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/rl4co/models/nn/graph/attnnet.py b/rl4co/models/nn/graph/attnnet.py index 0373e768..0371d883 100644 --- a/rl4co/models/nn/graph/attnnet.py +++ b/rl4co/models/nn/graph/attnnet.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch.nn as nn @@ -19,7 +19,7 @@ class MultiHeadAttentionLayer(nn.Sequential): embed_dim: dimension of the embeddings feed_forward_hidden: dimension of the hidden layer in the feed-forward layer normalization: type of normalization to use (batch, layer, none) - force_flash_attn: whether to force FlashAttention (move to half precision) + sdpa_fn: scaled dot product attention function (SDPA) """ def __init__( @@ -28,14 +28,10 @@ def __init__( embed_dim: int, feed_forward_hidden: int = 512, normalization: Optional[str] = "batch", - force_flash_attn: bool = False, + sdpa_fn: Optional[Callable] = None, ): super(MultiHeadAttentionLayer, self).__init__( - SkipConnection( - MultiHeadAttention( - embed_dim, num_heads, force_flash_attn=force_flash_attn - ) - ), + SkipConnection(MultiHeadAttention(embed_dim, num_heads, sdpa_fn=sdpa_fn)), Normalization(embed_dim, normalization), SkipConnection( nn.Sequential( @@ -60,7 +56,7 @@ class GraphAttentionNetwork(nn.Module): num_layers: number of MHA layers normalization: type of normalization to use (batch, layer, none) feed_forward_hidden: dimension of the hidden layer in the feed-forward layer - force_flash_attn: whether to force FlashAttention (move to half precision) + sdpa_fn: scaled dot product attention function (SDPA) """ def __init__( @@ -70,7 +66,7 @@ def __init__( num_layers: int, normalization: str = "batch", feed_forward_hidden: int = 512, - force_flash_attn: bool = False, + sdpa_fn: Optional[Callable] = None, ): super(GraphAttentionNetwork, self).__init__() @@ -81,7 +77,7 @@ def __init__( embedding_dim, feed_forward_hidden=feed_forward_hidden, normalization=normalization, - force_flash_attn=force_flash_attn, + sdpa_fn=sdpa_fn, ) for _ in range(num_layers) ) diff --git a/rl4co/models/rl/common/critic.py b/rl4co/models/rl/common/critic.py index 105c229e..96045674 100644 --- a/rl4co/models/rl/common/critic.py +++ b/rl4co/models/rl/common/critic.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Callable, Optional, Union from tensordict import TensorDict from torch import Tensor, nn @@ -20,7 +20,7 @@ class CriticNetwork(nn.Module): num_layers: Number of layers for the encoder num_heads: Number of heads for the attention normalization: Normalization to use for the attention - force_flash_attn: Whether to force the use of flash attention. If True, cast to fp16 + sdpa_fn: Scaled dot product function to use for the attention """ def __init__( @@ -32,7 +32,7 @@ def __init__( num_layers: int = 3, num_heads: int = 8, normalization: str = "batch", - force_flash_attn: bool = False, + sdpa_fn: Optional[Callable] = None, **unused_kwargs, ): super(CriticNetwork, self).__init__() @@ -51,7 +51,7 @@ def __init__( num_layers=num_layers, normalization=normalization, feed_forward_hidden=hidden_dim, - force_flash_attn=force_flash_attn, + sdpa_fn=sdpa_fn, ) if encoder is None else encoder diff --git a/rl4co/models/zoo/common/autoregressive/decoder.py b/rl4co/models/zoo/common/autoregressive/decoder.py index e9450016..ba0d9efe 100644 --- a/rl4co/models/zoo/common/autoregressive/decoder.py +++ b/rl4co/models/zoo/common/autoregressive/decoder.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn + from einops import rearrange from tensordict import TensorDict from torch import Tensor @@ -57,7 +58,7 @@ class AutoregressiveDecoder(nn.Module): def __init__( self, - env_name: str, + env_name: [str, RL4COEnvBase], embedding_dim: int, num_heads: int, use_graph_context: bool = True, @@ -69,6 +70,8 @@ def __init__( ): super().__init__() + if isinstance(env_name, RL4COEnvBase): + env_name = env_name.name self.env_name = env_name self.embedding_dim = embedding_dim self.num_heads = num_heads @@ -88,11 +91,17 @@ def __init__( self.use_graph_context = use_graph_context # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim - self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=linear_bias) - self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=linear_bias) + self.project_node_embeddings = nn.Linear( + embedding_dim, 3 * embedding_dim, bias=linear_bias + ) + self.project_fixed_context = nn.Linear( + embedding_dim, embedding_dim, bias=linear_bias + ) # MHA - self.logit_attention = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs) + self.logit_attention = LogitAttention( + embedding_dim, num_heads, **logit_attn_kwargs + ) self.select_start_nodes_fn = select_start_nodes_fn @@ -136,7 +145,9 @@ def forward( else: if num_starts is not None: if num_starts > 1: - log.warn(f"num_starts={num_starts} is ignored for decode_type={decode_type}") + log.warn( + f"num_starts={num_starts} is ignored for decode_type={decode_type}" + ) num_starts = 0 # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step @@ -187,7 +198,9 @@ def forward( return outputs, actions, td - def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None): + def _precompute_cache( + self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None + ): """Compute the cached embeddings for the attention Args: @@ -202,9 +215,7 @@ def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorD glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed, - ) = self.project_node_embeddings( - embeddings - ).chunk(3, dim=-1) + ) = self.project_node_embeddings(embeddings).chunk(3, dim=-1) # Optionally disable the graph context from the initial embedding as done in POMO if self.use_graph_context: @@ -262,7 +273,9 @@ def _get_log_p( mask = ~td_unbatch["action_mask"] # Compute logits - log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp) + log_p = self.logit_attention( + glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp + ) # Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes] # Note that rearranging order is important here diff --git a/rl4co/models/zoo/common/autoregressive/encoder.py b/rl4co/models/zoo/common/autoregressive/encoder.py index 0e125aaf..6dce3913 100644 --- a/rl4co/models/zoo/common/autoregressive/encoder.py +++ b/rl4co/models/zoo/common/autoregressive/encoder.py @@ -5,6 +5,7 @@ from tensordict import TensorDict from torch import Tensor +from rl4co.envs import RL4COEnvBase from rl4co.models.nn.env_embeddings import env_init_embedding from rl4co.models.nn.graph.attnnet import GraphAttentionNetwork @@ -19,23 +20,25 @@ class GraphAttentionEncoder(nn.Module): num_layers: Number of layers for the encoder normalization: Normalization to use for the attention feed_forward_hidden: Hidden dimension for the feed-forward network - force_flash_attn: Whether to force the use of flash attention. If True, cast to fp16 init_embedding: Model to use for the initial embedding. If None, use the default embedding for the environment + sdpa_fn: Scaled dot product function to use for the attention """ def __init__( self, - env_name: str, + env_name: [str, RL4COEnvBase], num_heads: int, embedding_dim: int, num_layers: int, normalization: str = "batch", feed_forward_hidden: int = 512, - force_flash_attn: bool = False, init_embedding: nn.Module = None, + sdpa_fn=None, ): super(GraphAttentionEncoder, self).__init__() + if isinstance(env_name, RL4COEnvBase): + env_name = env_name.name self.env_name = env_name self.init_embedding = ( @@ -50,7 +53,7 @@ def __init__( num_layers, normalization, feed_forward_hidden, - force_flash_attn, + sdpa_fn=sdpa_fn, ) def forward( diff --git a/rl4co/models/zoo/common/autoregressive/policy.py b/rl4co/models/zoo/common/autoregressive/policy.py index 25a60a34..0a40201f 100644 --- a/rl4co/models/zoo/common/autoregressive/policy.py +++ b/rl4co/models/zoo/common/autoregressive/policy.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Callable, Optional, Union import torch.nn as nn @@ -39,7 +39,7 @@ class AutoregressivePolicy(nn.Module): normalization: Normalization type in the attention layers mask_inner: Whether to mask the inner diagonal in the attention layers use_graph_context: Whether to use the initial graph context to modify the query - force_flash_attn: Whether to force the use of flash attention in the attention layers + sdpa_fn: Scaled dot product function to use for the attention train_decode_type: Type of decoding during training val_decode_type: Type of decoding during validation test_decode_type: Type of decoding during testing @@ -60,7 +60,7 @@ def __init__( normalization: str = "batch", mask_inner: bool = True, use_graph_context: bool = True, - force_flash_attn: bool = False, + sdpa_fn: Optional[Callable] = None, train_decode_type: str = "sampling", val_decode_type: str = "greedy", test_decode_type: str = "greedy", @@ -83,8 +83,8 @@ def __init__( embedding_dim=embedding_dim, num_layers=num_encoder_layers, normalization=normalization, - force_flash_attn=force_flash_attn, init_embedding=init_embedding, + sdpa_fn=sdpa_fn, ) else: self.encoder = encoder @@ -97,7 +97,6 @@ def __init__( num_heads=num_heads, use_graph_context=use_graph_context, mask_inner=mask_inner, - force_flash_attn=force_flash_attn, context_embedding=context_embedding, dynamic_embedding=dynamic_embedding, ) diff --git a/rl4co/models/zoo/ham/encoder.py b/rl4co/models/zoo/ham/encoder.py index 736ed9a6..b7e756b4 100644 --- a/rl4co/models/zoo/ham/encoder.py +++ b/rl4co/models/zoo/ham/encoder.py @@ -38,7 +38,7 @@ def __init__( env_name=None, normalization="batch", feed_forward_hidden=512, - force_flash_attn=False, + sdpa_fn=None, ): super(GraphHeterogeneousAttentionEncoder, self).__init__()