From 9acbb3c57ed8c0c1eba05a19c93f542551fca3e9 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 24 Jan 2024 13:56:28 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E5=88=9D=E6=AD=A5=E6=B7=BB=E5=8A=A0=20?= =?UTF-8?q?attention=20=E7=AE=97=E5=AD=90=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/08-01llm/README.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/08-01llm/README.md b/src/08-01llm/README.md index e14641fa..969f05db 100644 --- a/src/08-01llm/README.md +++ b/src/08-01llm/README.md @@ -25,3 +25,39 @@ y = (x^2 + δ)^(-1/2) * w * x 1 Output: - **Y(heterogeneous) - T**: 输出张量。形状与 `X` 相同。 + +## Attention + +### Summary + +Multi-head Self Attention 的封装形式,用于 transformer 模型。 + +支持使用 kv cache,使用条件由输入和属性综合决定。有以下 种情况: + +| 序号 | 输入数量 | `max_seq_len` | 使用 kv cache | 输出数量 | cache s 维度 | 备注 +|:-:|:-:|:-----:|:-------:|:-:|:------------------------:|:- +| 0 | 3 | 0 | none | 1 | - | +| 1 | 3 | S > 0 | init | 3 | `S` | `assert(S >= seq_len)` +| 2 | 4 | 0 | inplace | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量 +| 3 | 4 | S > 0 | inplace | 3 | `S` | `assert(S >= past_seq_len + seq_len)` +| 4 | 6 | 0 | copy | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量 +| 5 | 6 | S > 0 | copy | 3 | `S` | `assert(S >= past_seq_len + seq_len)` + +### Attributes + +- **max_seq_len - INT** (default is `0`): 最大序列长度,用于初始化 kv cache。 + +### Inputs + +- **query(heterogeneous) - T**: 形状为 `N x n_head x seq_len x head_dim`。 +- **key(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`。 +- **value(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`。 +- **past_seq_len(optional) -int64**: 要连接的历史序列长度,必须为标量。不使用 kv cache 时留空。 +- **k_cache(optional, heterogeneous) -T**: k 缓存的初始值,形状为 `N x n_kv_head x s x head_dim`,`s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。 +- **v_cache(optional, heterogeneous) -T**: v 缓存的初始值,形状为 `N x n_kv_head x s x head_dim`,`s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。 + +### Outputs + +- **output(heterogeneous) - T**: 形状与 `query` 相同。 +- **k_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim`。`s` 的值根据 `Summary` 的描述计算。 +- **v_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim`。`s` 的值根据 `Summary` 的描述计算。