Skip to content

Commit

Permalink
Refine the algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanUkhov committed Jan 16, 2024
1 parent f411e4e commit a73fff4
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions _drafts/2024-02-01-relative-positional-embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,36 @@ More generally, the algorithm can be summarized as follows:

$$
S = \text{transpose}\left(
\text{stack-diagonals}\left(
QE, \, 0, \, n_{t_3} - 1
\text{diagonal}\left(
QE, \, \text{lower}=0, \, \text{upper}=n_{t_3} - 1
\right)
\right)
$$

where $$\text{stack-diagonals}$$ is a function taking a tensor and stacking its
diagonals specified by a range with two offsets relative to the main diagonal
from bottom up, and $$\text{transpose}$$ is a function taking a tensor and
permuting its last two dimensions.
where $$\text{diagonal}$$ is a function taking a tensor and stacking its
diagonals—specified by a range with two offsets relative to the main
diagonal—from bottom up, and $$\text{transpose}$$ is a function taking a tensor
and transposing it. Both functions operators on the last two dimensions of the
given tensor. This resulting matrix can then be plugged into Equation (1) to
complete the calculation.

The matrix can then be plugged into Equation (1) to complete the calculation. In
case the queries are shorter than the keys and values, which is what happens in
the prediction mode, $$S$$ will have the right amount for rows but the last
columns will be excessive and hence have to be discarded.
In case the keys and values are shorter than the maximum allowed relative
position, that is, $$t_1 < t_3$$, $$S$$ should be truncated to shape
$$n_s \times n_h \times n_{t_2} \times n_{t_1}$$:

$$
S = \text{truncate}\left(
\text{transpose}\left(
\text{diagonal}\left(
QE, \, \text{lower}=0, \, \text{upper}=n_{t_3} - 1
\right)
\right),
\text{keep} = n_{t_1}
\right)
$$

where $$\text{truncate}$$ is a function taking a tensor and keeping only the
specified number of first elements in the last dimension, discarding the rest.

# References

Expand Down

0 comments on commit a73fff4

Please sign in to comment.