From c5b90454004db8bf5b7eb3afe7c11be2d230d8b7 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Mon, 28 Oct 2024 17:37:54 +0900 Subject: [PATCH] [BugFix,Temp] quick fix for edge case of #228 --- rl4co/models/nn/attention.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/rl4co/models/nn/attention.py b/rl4co/models/nn/attention.py index 0dfa5973..088c7c75 100644 --- a/rl4co/models/nn/attention.py +++ b/rl4co/models/nn/attention.py @@ -249,7 +249,7 @@ def __init__( mask_inner: bool = True, out_bias: bool = False, check_nan: bool = True, - sdpa_fn: Optional[Callable] = None, + sdpa_fn: Optional[Union[Callable, str]] = "default", **kwargs, ): super(PointerAttention, self).__init__() @@ -258,9 +258,27 @@ def __init__( # Projection - query, key, value already include projections self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias) - self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention self.check_nan = check_nan + # Defaults for sdpa_fn implementation + # see https://github.com/ai4co/rl4co/issues/228 + if isinstance(sdpa_fn, str): + if sdpa_fn == "default": + sdpa_fn = scaled_dot_product_attention + elif sdpa_fn == "simple": + sdpa_fn = scaled_dot_product_attention_simple + else: + raise ValueError( + f"Unknown sdpa_fn: {sdpa_fn}. Available options: ['default', 'simple']" + ) + else: + if sdpa_fn is None: + sdpa_fn = scaled_dot_product_attention + log.info( + "Using default scaled_dot_product_attention for PointerAttention" + ) + self.sdpa_fn = sdpa_fn + def forward(self, query, key, value, logit_key, attn_mask=None): """Compute attention logits given query, key, value, logit key and attention mask.