Skip to content

Commit

Permalink
removed unnecessary state_len flag and placed shape checking in if-cl…
Browse files Browse the repository at this point in the history
…ause
  • Loading branch information
Artur-Galstyan committed Dec 2, 2023
1 parent d4d3b60 commit f31eae5
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class MultiheadAttention(Module):
use_key_bias: bool = field(static=True)
use_value_bias: bool = field(static=True)
use_output_bias: bool = field(static=True)
state_length: Optional[int] = field(static=True)

def __init__(
self,
Expand Down Expand Up @@ -169,8 +168,6 @@ def __init__(
- `use_key_bias`: Whether to use a bias term in the key projections.
- `use_value_bias`: Whether to use a bias term in the value projections.
- `use_output_bias`: Whether to use a bias term in the output projection.
- `state_length`: Used when RoPE embeddings should be applied. This is the size
of the key and value buffers that are updated each time the module is called
- `dropout_p`: Dropout probability on attention weights.
- `inference`: Whether to actually apply dropout at all. If `True` then dropout
is not applied. If `False` then dropout is applied. This may be toggled
Expand Down Expand Up @@ -291,23 +288,24 @@ def __call__(
key_heads = self._project(self.key_proj, key_)
value_heads = self._project(self.value_proj, value)

q_shape, k_shape, v_shape = (
query_heads.shape,
key_heads.shape,
value_heads.shape,
)

if process_heads is not None:
q_shape, k_shape, v_shape = (
query_heads.shape,
key_heads.shape,
value_heads.shape,
)
query_heads, key_heads, value_heads = process_heads(
query_heads, key_heads, value_heads
)

if (
query_heads.shape != q_shape
or key_heads.shape != k_shape
or value_heads.shape != v_shape
):
raise ValueError("process_heads must not change the shape of the heads.")
if (
query_heads.shape != q_shape
or key_heads.shape != k_shape
or value_heads.shape != v_shape
):
raise ValueError(
"process_heads must not change the shape of the heads."
)

attn_fn = partial(
dot_product_attention, dropout=self.dropout, inference=inference
Expand Down

0 comments on commit f31eae5

Please sign in to comment.