simplest online-softmax notebook for explain Flash Attention
Blog link: 手撕Online-Softmax
run online_softmax_torch.ipynb
we show the block online softmax result
X = torch.tensor([-0.3, 0.2, 0.5, 0.7, 0.1, 0.8])
X_softmax = F.softmax(X, dim = 0)
print(X_softmax)
X_block = torch.split(X, split_size_or_sections = 3 , dim = 0)
# we parallel compute different block max & sum
X_block_0_max = X_block[0].max()
X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum()
X_block_1_max = X_block[1].max()
X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum()
# online block update max & sum
X_block_1_max_update = torch.max(X_block_0_max, X_block_1_max) # X[-1] is new data
X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) \
+ torch.exp(X_block[1] - X_block_1_max_update).sum() # block sum
X_block_online_softmax = torch.exp(X - X_block_1_max_update) / X_block_1_sum_update
print(X_block_online_softmax)
output is
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])
$$ \tilde{x}i=\frac{e^{x_i-\max(x{:N})}}{\sum_j^Ne^{x_j-\max(x_{:N})}} $$
Note
-
We first compute
1:N
element maximum value$\max(x_{:N})$ and softmax denominator$l_N$ -
We add a new element
$x_{N+1}$ , we update$\max(x_{:N+1})$ and update$l_{N+1}$ as follow.
we cannot use
- We can apply the softmax function using the adjusted numerator and denominator values.
$$ \tilde{x}{i}=\frac{e^{x_i-\max(x{:N+1})}}{l_{N+1}} $$
online softmax make cumulative sum
- we seperate compute different block
$l^{(t)}$ and$m^{(t)}$
-
it’s easy to update global
$m,l$ $$ \begin{align} m=\max({x_{:2N}})&=\max(\max({x_{:N}}),\max(x_{N+1:2N}))\ &=max(m^{(1)},m^{(2)}) \end{align} $$ but the$l$ NOT update as follow: $$ l=l_{:2N}\neq l^{(1)}+l^{(2)} $$ -
So we based block sum
$l^{(t)}$ and max$m^{(t)}$ to online update global$l$
- update block softmax like:
we do multi block online softmax by for-loop :
$$
l_\text{new}= l_\text{old} (e^{m_\text{old}-m}) +l_\text{new}(e^{m_{\text{new}}-m})
$$
noted current block max/sum as
In attention machine, we need softmax for attention score matrix
$$
S=QK^T,S\in\mathbb{R}^{N\times N}
$$
the query is row-wise matrix
and we need softmax attention score:
$$
P_{i,:}=\text{softmax}(S_{i,:})
$$
when we use online-softmax, we could parallel update k-row max