Skip to content

Commit

Permalink
offer cosine sim flash attention variant
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 27, 2022
1 parent 9b8e9c4 commit e429d05
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 2 deletions.
3 changes: 2 additions & 1 deletion flash_attention_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from flash_attention_jax.flash_attention import flash_attention
from flash_attention_jax.cosine_sim_flash_attention import cosine_sim_flash_attention
from flash_attention_jax.causal_flash_attention import causal_flash_attention
from flash_attention_jax.rabe_attention import rabe_attention
from flash_attention_jax.attention import attention, causal_attention
from flash_attention_jax.attention import attention, causal_attention, cosine_sim_attention

from flash_attention_jax.utils import value_and_grad_difference, PRNGKeyGenerator

Expand Down
20 changes: 20 additions & 0 deletions flash_attention_jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from jax import nn
from jax import jit, numpy as jnp

EPSILON = 1e-10
MASK_VALUE = -1e10
COSINE_SIM_SCALE = 16

@jit
def attention(q, k, v, key_mask):
Expand Down Expand Up @@ -30,3 +32,21 @@ def causal_attention(q, k, v):

attn = nn.softmax(sim, axis = -1)
return attn @ v

# cosine sim attention

@jit
def l2norm(t):
return t / (jnp.linalg.norm(t) + EPSILON)

@jit
def cosine_sim_attention(q, k, v, key_mask):
dim, k_len = q.shape[-1], k.shape[-2]
q, k = map(l2norm, (q, k))

sim = q @ k.transpose() * COSINE_SIM_SCALE

sim = jnp.where(key_mask, sim, MASK_VALUE)

attn = nn.softmax(sim, axis = -1)
return attn @ v
163 changes: 163 additions & 0 deletions flash_attention_jax/cosine_sim_flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import math
import jax
from functools import partial
from jax import nn
from jax import custom_vjp
from jax import numpy as jnp, lax, jit

# constants

EPSILON = 1e-10
MASK_VALUE = -1e10

Q_CHUNK_SIZE = 1024
K_CHUNK_SIZE = 1024
COSINE_SIM_SCALE = 16

# flash attention

def _query_chunk_flash_attention(chunk_idx, q, k, v, key_mask):
q_len, k_len, dim, v_dim = q.shape[-2], *k.shape, v.shape[-1]

def chunk_scanner(carries, _):
chunk_idx, out, row_sum = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)

k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx,), slice_sizes=(k_chunk_sizes,))

attn_weights = (q @ k_chunk.transpose() * COSINE_SIM_SCALE) - COSINE_SIM_SCALE # the output of this will range from [-2 * scale, 0], and the row sums are now bounded by key/value sequence length - you can also shift this more if you wish to tailor the normalization constant (in the case of extreme sequence lengths)

attn_weights = jnp.where(key_mask_chunk, attn_weights, MASK_VALUE)

exp_weights = jnp.exp(attn_weights)
exp_weights = jnp.where(key_mask_chunk, exp_weights, 0.)

block_row_sum = jnp.sum(exp_weights, axis = -1, keepdims = True)

exp_values = exp_weights @ v_chunk

chunk_out = exp_values / k_len

return (chunk_idx + k_chunk_sizes, out + chunk_out, row_sum + block_row_sum), None

out = jnp.zeros((q_len, dim))
row_sum = jnp.zeros((q_len, 1))

(_, out, row_sum), _ = lax.scan(chunk_scanner, init = (0, out, row_sum), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))

out = out * (k_len / (row_sum + EPSILON)) # renormalize after acquiring all the correct row sums

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)

return out, row_sum

@jit
def l2norm(t):
return t / (jnp.linalg.norm(t) + EPSILON)

@jit
def cosine_sim_flash_attention(q, k, v, key_mask):
q, k = map(l2norm, (q, k))
return cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)

@custom_vjp
def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
q_len, dim, v_dim = *q.shape, v.shape[-1]

def chunk_scanner(chunk_idx, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, dim))

return (chunk_idx + chunk_sizes, _query_chunk_flash_attention(chunk_idx, q_chunk, k, v, key_mask))

_, (out, row_sum) = lax.scan(chunk_scanner, init = 0, xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)

return out, (row_sum,)

@jit
def flash_attention_forward(q, k, v, key_mask):
out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m):
q_len, dim, k_len, v_dim = *q.shape, *v.shape

scale = 1 / jnp.sqrt(dim)

def chunk_scanner(carries, _):
chunk_idx, dq = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)

k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx,), slice_sizes=(k_chunk_sizes,))

attn_weights = q @ k_chunk.transpose() * COSINE_SIM_SCALE - COSINE_SIM_SCALE

exp_attn_weights = jnp.exp(attn_weights)

exp_attn_weights = jnp.where(key_mask_chunk, exp_attn_weights, 0.)

p = exp_attn_weights / l

dv_chunk = p.transpose() @ do
dp = do @ v_chunk.transpose()

D = jnp.sum(do * o, axis = -1, keepdims = True)
ds = p * scale * (dp - D)

dq_chunk = ds @ k_chunk
dk_chunk = ds.transpose() @ q

return (chunk_idx + k_chunk_sizes, dq + dq_chunk), (dk_chunk, dv_chunk)

dq = jnp.zeros_like(q)

(_, dq), (dk, dv) = lax.scan(chunk_scanner, init = (0, dq), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))

dq = dq.reshape(q_len, dim)
dk = dk.reshape(k_len, v_dim)
dv = dv.reshape(k_len, v_dim)

return dq, dk, dv

@jit
def flash_attention_backward(res, do):
q, k, v, key_mask, o, l, m = res

q_len, dim = q.shape

dk = jnp.zeros_like(k)
dv = jnp.zeros_like(v)

m = m.reshape(q_len, 1)
l = l.reshape(q_len, 1)

def chunk_scanner(carries, _):
chunk_idx, dk, dv = carries

chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, q.shape[-1]))
m_chunk = lax.dynamic_slice(m, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes = (chunk_sizes, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes = (chunk_sizes, do.shape[-1]))

dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

(_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))

dq = dq.reshape(q_len, dim)

return dq, dk, dv, None

cosine_sim_flash_attention_after_l2norm.defvjp(flash_attention_forward, flash_attention_backward)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flash-attention-jax',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Flash Attention - in Jax',
author = 'Phil Wang',
Expand Down

0 comments on commit e429d05

Please sign in to comment.