Skip to content

Commit

Permalink
add batch and multihead support for non-causal flash attention, for @…
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 21, 2022
1 parent f000c8a commit 874af93
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 55 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ from flash_attention_jax import flash_attention

rng_key = random.PRNGKey(42)

q = random.normal(rng_key, (131072, 512))
k = random.normal(rng_key, (131072, 512))
v = random.normal(rng_key, (131072, 512))
mask = random.randint(rng_key, (131072,), 0, 2)
q = random.normal(rng_key, (1, 2, 131072, 512)) # (batch, heads, seq, dim)
k = random.normal(rng_key, (1, 2, 131072, 512))
v = random.normal(rng_key, (1, 2, 131072, 512))
mask = random.randint(rng_key, (1, 131072,), 0, 2) # (batch, seq)

out, _ = flash_attention(q, k, v, mask)

Expand Down Expand Up @@ -98,6 +98,7 @@ out.shape # (131072, 512)
- [ ] figure out issue with jit and static argnums
- [ ] comment with references to paper algorithms and explanations
- [ ] make sure it can work one-headed key / values, as in PaLM
- [ ] leading dimensions for causal flash attention variant

## Citations

Expand Down
17 changes: 11 additions & 6 deletions flash_attention_jax/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import jax
from jax import nn
from jax import jit, numpy as jnp
from jax.numpy import einsum

from einops import rearrange

EPSILON = 1e-10
MASK_VALUE = -1e10
Expand All @@ -12,26 +15,27 @@ def attention(q, k, v, key_mask):
scale = 1 / jnp.sqrt(dim)

q = q * scale
sim = q @ k.transpose()
sim = einsum('... i d, ... j d -> ... i j', q, k)

key_mask = rearrange(key_mask, 'b j -> b 1 1 j')
sim = jnp.where(key_mask, sim, MASK_VALUE)

attn = nn.softmax(sim, axis = -1)
return attn @ v

@jit
def causal_attention(q, k, v):
q_len, dim, k_len = *q.shape, k.shape[-2]
q_len, dim, k_len = *q.shape[-2:], k.shape[-2]
scale = 1 / jnp.sqrt(dim)

q = q * scale
sim = q @ k.transpose()
sim = einsum('... i d, ... j d -> ... i j', q, k)

causal_mask = jnp.triu(jnp.ones((q_len, k_len)), k_len - q_len + 1)
sim = jnp.where(causal_mask, MASK_VALUE, sim)

attn = nn.softmax(sim, axis = -1)
return attn @ v
return einsum('... i j, ... j d -> ... i d', attn, v)

# cosine sim attention

Expand All @@ -44,9 +48,10 @@ def cosine_sim_attention(q, k, v, key_mask):
dim, k_len = q.shape[-1], k.shape[-2]
q, k = map(l2norm, (q, k))

sim = q @ k.transpose() * COSINE_SIM_SCALE
sim = einsum('... i d, ... j d -> ... i j', q, k) * COSINE_SIM_SCALE

key_mask = rearrange(key_mask, 'b j -> b 1 1 j')
sim = jnp.where(key_mask, sim, MASK_VALUE)

attn = nn.softmax(sim, axis = -1)
return attn @ v
return einsum('... i j, ... j d -> ... i d', attn, v)
89 changes: 49 additions & 40 deletions flash_attention_jax/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from jax import nn
from jax import custom_vjp
from jax import numpy as jnp, lax, jit
from jax.numpy import einsum

from einops import rearrange

# constants

Expand All @@ -16,19 +19,21 @@
# flash attention

def _query_chunk_flash_attention(chunk_idx, q, k, v, key_mask):
q_len, k_len, dim, v_dim = q.shape[-2], *k.shape, v.shape[-1]
q_len, batch, heads, dim, k_len, v_dim = *q.shape, k.shape[0], v.shape[-1]
scale = 1 / jnp.sqrt(dim)
q_scaled = q * scale

def chunk_scanner(carries, _):
chunk_idx, out, row_sum, row_max = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)

k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx,), slice_sizes=(k_chunk_sizes,))
k_chunk = lax.dynamic_slice(k, (chunk_idx, 0, 0, 0), slice_sizes=(k_chunk_sizes, batch, heads, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0, 0, 0), slice_sizes=(k_chunk_sizes, batch, heads, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, batch))

attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)

attn_weights = q_scaled @ k_chunk.transpose()
key_mask_chunk = rearrange(key_mask_chunk, 'j b -> 1 b 1 j')
attn_weights = jnp.where(key_mask_chunk, attn_weights, MASK_VALUE)

block_row_max = jnp.max(attn_weights, axis = -1, keepdims = True)
Expand All @@ -38,7 +43,7 @@ def chunk_scanner(carries, _):
exp_weights = jnp.where(key_mask_chunk, exp_weights, 0.)
block_row_sum = jnp.sum(exp_weights, axis = -1, keepdims = True) + EPSILON

exp_values = exp_weights @ v_chunk
exp_values = einsum('i ... j, j ... d -> i ... d', exp_weights, v_chunk)

new_row_max = jnp.maximum(block_row_max, row_max)

Expand All @@ -52,35 +57,37 @@ def chunk_scanner(carries, _):

return (chunk_idx + k_chunk_sizes, out, new_row_sum, new_row_max), None

out = jnp.zeros((q_len, dim))
row_sum = jnp.zeros((q_len, 1))
row_max = jnp.ones((q_len, 1)) * -1e6
out = jnp.zeros((q_len, batch, heads, dim))
row_sum = jnp.zeros((q_len, batch, heads, 1))
row_max = jnp.ones((q_len, batch, heads, 1)) * -1e6

(_, out, row_sum, row_max), _ = lax.scan(chunk_scanner, init = (0, out, row_sum, row_max), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
row_max = row_max.reshape(q_len)
row_sum = rearrange(row_sum, 'n ... 1 -> n ...')
row_max = rearrange(row_max, 'n ... 1 -> n ...')

return out, row_sum, row_max

@custom_vjp
@jit
def flash_attention(q, k, v, key_mask):
q_len, dim, v_dim = *q.shape, v.shape[-1]
batch, heads, q_len, dim, v_dim = *q.shape, v.shape[-1]

def chunk_scanner(chunk_idx, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, dim))
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0, 0, 0), slice_sizes = (chunk_sizes, batch, heads, dim))

return (chunk_idx + chunk_sizes, _query_chunk_flash_attention(chunk_idx, q_chunk, k, v, key_mask))

q, k, v = map(lambda t: rearrange(t, 'b h n d -> n b h d'), (q, k, v))
key_mask = rearrange(key_mask, 'b j -> j b')

_, (out, row_sum, row_max) = lax.scan(chunk_scanner, init = 0, xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
row_max = row_max.reshape(q_len)
out = rearrange(out, 'c n b h d -> b h (c n) d')
row_sum = rearrange(row_sum, 'c n b h -> b h (c n)')
row_max = rearrange(row_max, 'c n b h -> b h (c n)')

return out, (row_sum, row_max)

Expand All @@ -90,7 +97,7 @@ def flash_attention_forward(q, k, v, key_mask):
return out, (q, k, v, key_mask, out, row_sum, row_max)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m):
q_len, dim, k_len, v_dim = *q.shape, *v.shape
q_len, batch, heads, dim, k_len, v_dim = *q.shape, v.shape[0], v.shape[-1]

scale = 1 / jnp.sqrt(dim)
q_scaled = q * scale
Expand All @@ -99,68 +106,70 @@ def chunk_scanner(carries, _):
chunk_idx, dq = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)

k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx,), slice_sizes=(k_chunk_sizes,))
k_chunk = lax.dynamic_slice(k, (chunk_idx, batch, heads, 0), slice_sizes=(k_chunk_sizes, batch, heads, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, batch, heads, 0), slice_sizes=(k_chunk_sizes, batch, heads, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx, batch), slice_sizes=(k_chunk_sizes, batch))

attn_weights = q_scaled @ k_chunk.transpose()
attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)

exp_attn_weights = jnp.exp(attn_weights - m)

key_mask_chunk = rearrange(key_mask_chunk, 'j b -> 1 b 1 j')
exp_attn_weights = jnp.where(key_mask_chunk, exp_attn_weights, 0.)

p = exp_attn_weights / l

dv_chunk = p.transpose() @ do
dp = do @ v_chunk.transpose()
dv_chunk = einsum('i ... j, i ... d -> j ... d', p, do)
dp = einsum('i ... d, j ... d -> i ... j', do, v_chunk)

D = jnp.sum(do * o, axis = -1, keepdims = True)
ds = p * scale * (dp - D)

dq_chunk = ds @ k_chunk
dk_chunk = ds.transpose() @ q
dq_chunk = einsum('i ... j, j ... d -> i ... d', ds, k_chunk)
dk_chunk = einsum('i ... j, i ... d -> j ... d', ds, q)

return (chunk_idx + k_chunk_sizes, dq + dq_chunk), (dk_chunk, dv_chunk)

dq = jnp.zeros_like(q)

(_, dq), (dk, dv) = lax.scan(chunk_scanner, init = (0, dq), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))

dq = dq.reshape(q_len, dim)
dk = dk.reshape(k_len, v_dim)
dv = dv.reshape(k_len, v_dim)

dk = rearrange(dk, 'c n ... -> (c n) ...')
dv = rearrange(dv, 'c n ... -> (c n) ...')
return dq, dk, dv

@jit
def flash_attention_backward(res, do):
q, k, v, key_mask, o, l, m = res

q_len, dim = q.shape
batch, heads, q_len, dim = q.shape

m, l = map(lambda t: rearrange(t, 'b h n -> n b h 1'), (m, l))

q, k, v, o, do = map(lambda t: rearrange(t, 'b h n d -> n b h d'), (q, k, v, o, do))
key_mask = rearrange(key_mask, 'b j -> j b')

dk = jnp.zeros_like(k)
dv = jnp.zeros_like(v)

m = m.reshape(q_len, 1)
l = l.reshape(q_len, 1)

def chunk_scanner(carries, _):
chunk_idx, dk, dv = carries

chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, q.shape[-1]))
m_chunk = lax.dynamic_slice(m, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes = (chunk_sizes, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes = (chunk_sizes, do.shape[-1]))
q_chunk = lax.dynamic_slice(q, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, q.shape[-1]))
m_chunk = lax.dynamic_slice(m, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))

dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

(_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))

dq = dq.reshape(q_len, dim)
dq = rearrange(dq, 'c n b h d -> b h (c n) d')
dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))

return dq, dk, dv, None

Expand Down
10 changes: 6 additions & 4 deletions flash_attention_jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@ def value_and_grad_difference(
fn1,
fn2,
seed = 42,
batch = 2,
heads = 4,
q_seq_len = 4096,
k_seq_len = 8192,
add_key_mask = True,
dim = 512
):
key_gen = PRNGKeyGenerator(seed)

q = random.normal(next(key_gen), (q_seq_len, dim))
k = random.normal(next(key_gen), (k_seq_len, dim))
v = random.normal(next(key_gen), (k_seq_len, dim))
q = random.normal(next(key_gen), (batch, heads, q_seq_len, dim))
k = random.normal(next(key_gen), (batch, heads, k_seq_len, dim))
v = random.normal(next(key_gen), (batch, heads, k_seq_len, dim))

key_mask = random.randint(next(key_gen), (k_seq_len,), 0, 2) == 1
key_mask = random.randint(next(key_gen), (batch, k_seq_len), 0, 2) == 1

fn1_value_and_grad, fn2_value_and_grad = map(partial(value_and_grad_wrapper, argnums = (0, 1, 2)), (fn1, fn2))

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flash-attention-jax',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Flash Attention - in Jax',
author = 'Phil Wang',
Expand All @@ -18,6 +18,7 @@
'jax'
],
install_requires=[
'einops',
'jax>=0.2.20'
],
classifiers=[
Expand Down

0 comments on commit 874af93

Please sign in to comment.