diff --git a/equinox/nn/_attention.py b/equinox/nn/_attention.py index b39e612f..448405e7 100644 --- a/equinox/nn/_attention.py +++ b/equinox/nn/_attention.py @@ -132,7 +132,6 @@ class MultiheadAttention(Module): use_key_bias: bool = field(static=True) use_value_bias: bool = field(static=True) use_output_bias: bool = field(static=True) - state_length: Optional[int] = field(static=True) def __init__( self, @@ -169,8 +168,6 @@ def __init__( - `use_key_bias`: Whether to use a bias term in the key projections. - `use_value_bias`: Whether to use a bias term in the value projections. - `use_output_bias`: Whether to use a bias term in the output projection. - - `state_length`: Used when RoPE embeddings should be applied. This is the size - of the key and value buffers that are updated each time the module is called - `dropout_p`: Dropout probability on attention weights. - `inference`: Whether to actually apply dropout at all. If `True` then dropout is not applied. If `False` then dropout is applied. This may be toggled @@ -291,23 +288,24 @@ def __call__( key_heads = self._project(self.key_proj, key_) value_heads = self._project(self.value_proj, value) - q_shape, k_shape, v_shape = ( - query_heads.shape, - key_heads.shape, - value_heads.shape, - ) - if process_heads is not None: + q_shape, k_shape, v_shape = ( + query_heads.shape, + key_heads.shape, + value_heads.shape, + ) query_heads, key_heads, value_heads = process_heads( query_heads, key_heads, value_heads ) - if ( - query_heads.shape != q_shape - or key_heads.shape != k_shape - or value_heads.shape != v_shape - ): - raise ValueError("process_heads must not change the shape of the heads.") + if ( + query_heads.shape != q_shape + or key_heads.shape != k_shape + or value_heads.shape != v_shape + ): + raise ValueError( + "process_heads must not change the shape of the heads." + ) attn_fn = partial( dot_product_attention, dropout=self.dropout, inference=inference