diff --git a/README.md b/README.md index 814a9b2..449ee93 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ ### LLM 推理技术 🤖 ++ [Flash Attention v1 - 原理篇](./docs/17_flash_attn/01_flash_attn_v1_part1.md) ++ [Flash Attention v1 - 实现篇](./docs/17_flash_attn/02_flash_attn_v1_part2.md) + [连续批处理](./docs/13_continuous_batch/README.md) + [Page Attention - 原理篇](./docs/14_page_attention/README.md) + [Page Attention - 源码解析](./docs/15_vllm_page_attention/README.md) diff --git a/docs/17_flash_attn/01_flash_attn_v1_part1.md b/docs/17_flash_attn/01_flash_attn_v1_part1.md new file mode 100644 index 0000000..62940ef --- /dev/null +++ b/docs/17_flash_attn/01_flash_attn_v1_part1.md @@ -0,0 +1,167 @@ +# Flash Attention v1 理论篇 + +Transformer 模型的时候处理速度较慢,且会占用大量的显存。自注意力的时间和内存复杂度是与序列长度的平方成正比。在 GPU 上,计算速度已超过内存速度,并且 +Transformers 中的**大多数操作都受到内存访问的瓶颈**。因此,内存访问模式的优化是加速 Transformer 模型的关键。Flash Attention 是一种高效的自注意力实现,它通过将内存访问模式与计算结合起来,减少了内存带宽的使用,从而提高了性能。 + +在这个系列中,我们将介绍 Flash Attention 系列的原理和实现。 + +## 1. GPU 的层次结构 + +老规矩,我们还是先回顾一下 GPU 的层次结构。GPU 内存层次结构由不同大小和速度的多种形式的内存组成,较小的内存速度较快。例如,A100 GPU 具有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0TB/s,并且每个 108 个流处理器有 192KB 的片上 SRAM,其带宽估计约为 19TB/s [44, 45]。片上 SRAM 的速度比 HBM 快一个数量级,但其大小小了多个数量级。 + + +![picture 0](images/9626ab9b79ea64fb08e8f204c67d0e588f6ef384ab788cad38030846a21314c5.png) + +## 2. 标准 Attention + +给定输入序列 $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}$,其中 $N$ 是序列长度,而 $d$ 是头部维度,我们希望计算注意力输出 $\mathbf{O} \in \mathbb{R}^{N \times d}$: + +$$ +\mathbf{S}=\mathbf{Q} K^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d}, +$$ + +![picture 4](images/865f289b12429f1cf8de42cf2b6b019ecaef55ab09cb04c590ca66ac9b9f9ce7.png) + + +在标准的注意力机制实现中,矩阵 $\mathbf{S}$ 和 $\mathbf{P}$ 需要被显式地存储在高速但容量有限的高带宽内存(HBM)中。这种存储方式带来了 $O(N^2)$ 的内存开销,这在处理大规模输入时尤其值得关注。 + +以一个具体实例来看,在 GPT-2 模型中,序列长度 $N$ 为 1024,而每个特征的维度 $d$ 仅为 64,即 $N \gg d$。由于注意力机制的核心操作(如 softmax 函数)大多受限于内存访问速度,对 HBM 的高频访问不仅增加了内存带宽压力,还显著延长了计算的整体墙钟时间(wall-clock time),从而降低了模型的运行效率。 + + +## 3. Flash Attention + +FlashAttention 的核心思想可以用两个关键词来概括:**分块计算** 和 **动态重计算**。这两种技术的结合,使得注意力机制在保持高效的同时,显著减少了内存占用。 + +### 3.1 分块计算:化整为零 + +传统的 Softmax 需要一次性加载整个输入数据,才能计算全局的最大值和归一化系数。而 FlashAttention 采用了 **增量式计算** 的方式,将输入数据分成小块,依次加载到 GPU 的片上缓存(SRAM)中。 + +我们首先定义一些变量方便后续的讨论: + +| **变量** | **尺寸(shape)** | **说明** | +|-------------------------|--------------------------|-----------------------------| +| $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ | $N \times d$ | 输入矩阵 | +| $\mathbf{Q}_i$ | $B_r \times d$ | $\mathbf{Q}$ 的第 $i$ 个行分块 | +| $\mathbf{K}_j, \mathbf{V}_j$ | $B_c \times d$ | $\mathbf{K}, \mathbf{V}$ 的第 $j$ 个行分块 | +| $\mathbf{S}_{ij}$ | $B_r \times B_c$ | 局部注意力分数矩阵 | +| $\tilde{m}_{ij}$ | $B_r$ | 局部行最大值向量 | +| $\tilde{\mathbf{P}}_{ij}$ | $B_r \times B_c$ | 局部未归一化的注意力权重 | +| $\tilde{\ell}_{ij}$ | $B_r$ | 局部行和向量 | +| $m_i^{\mathrm{new}}$ | $B_r$ | 更新后的全局行最大值 | +| $\ell_i^{\mathrm{new}}$ | $B_r$ | 更新后的全局行和 | +| $\mathbf{O}_i$ | $B_r \times d$ | 输出的第 $i$ 个分块 | + +首先,FlashAttention 将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 划分为若干小块。假设片上缓存的大小为 $M$,则 $\mathbf{Q}$ 被划分为 $T_r = \lceil N/B_r \rceil$ 个块,每块大小为 $B_r \times d$;$\mathbf{K}$ 和 $\mathbf{V}$ 被划分为 $T_c = \lceil N/B_c \rceil$ 个块,每块大小为 $B_c \times d$。这里 $B_r$ 和 $B_c$ 的选择基于缓存的大小和特征维度 $d$。 + +对于每一块 $\mathbf{K}_j$ 和 $\mathbf{V}_j$,FlashAttention 将其从 HBM 加载到 SRAM,然后与每一块 $\mathbf{Q}_i$ 计算局部注意力分数 $\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^T$。$\mathbf{S}_{ij}$ 的大小为 $B_r \times B_c$,远小于全局矩阵 $N \times N$。 + +为了在分块计算中保持数值稳定性,FlashAttention 维护两个全局统计量:行最大值 $m_i \in \mathbb{R}^{B_r}$ 和行和 $\ell_i \in \mathbb{R}^{B_r}$。对于每一块 $\mathbf{S}_{ij}$,计算局部最大值 $\tilde{m}_{ij}$ 和局部归一化系数 $\tilde{\ell}_{ij}$,并根据这些值动态更新全局统计量。 + +在更新输出矩阵 $\mathbf{O}_i$ 时,FlashAttention 采用增量式的方法,将每一块的计算结果逐步累加。具体公式为: + +$$ +\mathbf{O}_i \leftarrow \text{diag}(\ell_i^{\text{new}})^{-1} \left( \text{diag}(\ell_i) e^{m_i - m_i^{\text{new}}} \mathbf{O}_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\mathbf{P}}_{ij} \mathbf{V}_j \right) +$$ + +这一公式确保了输出结果与全局计算等价。在每一块计算完成后,更新后的 $\mathbf{O}_i$、$\ell_i$ 和 $m_i$ 被写回 HBM,供后续计算使用。 + +:::note + +本文不探讨公式的具体推导过程,感兴趣的读者可以参考 [2] + +::: + +![picture 6](images/5dfce77ba3c57779bce60c8ebc552aa40304c6c5f36bb2d207d6b102a4d8026e.png) + +上图是 FlashAttention 的分块计算的示意图,外层循环中会对 $\mathbf{K}$ 和 $\mathbf{V}$ 进行分块,而内层循环中会对 $\mathbf{Q}$ 进行分块。每个外层循环中都会计算得到 $\mathbf{O_{i,j}}$,并将其根据公式更新到 $\mathbf{O}$ 中。 + +这里我们以一个最简单的例子来说明更新的过程。 + + +我们以 **序列长度 $ N = 4 $**、**特征维度 $ d = 2 $** 为例,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设: +- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_r = 2 $)。 +- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_c = 2 $)。 + +初始状态下: +- 输出矩阵 $\mathbf{O} = \begin{bmatrix} 0 & 0 \\ 0 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}$。 +- 全局统计量:$\ell = [0, 0, 0, 0]^T$, $m = [-\infty, -\infty, -\infty, -\infty]^T$。 + +--- + +**步骤 1:外层循环 $ j=1 $,处理块 $\mathbf{K}_1, \mathbf{V}_1$** + +1. **加载 $\mathbf{K}_1, \mathbf{V}_1$ 到 SRAM**: + + $$ + \mathbf{K}_1 = \begin{bmatrix} k_{11} & k_{12} \\ k_{21} & k_{22} \end{bmatrix}, \quad \mathbf{V}_1 = \begin{bmatrix} v_{11} & v_{12} \\ v_{21} & v_{22} \end{bmatrix} + $$ + +2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$**: + - **加载数据**: + $$ + \mathbf{Q}_1 = \begin{bmatrix} q_{11} & q_{12} \\ q_{21} & q_{22} \end{bmatrix}, \quad \mathbf{O}_1 = \begin{bmatrix} 0 & 0 \\ 0 & 0 \end{bmatrix}, \quad \ell_1 = [0, 0]^T, \quad m_1 = [-\infty, -\infty]^T + $$ + - **计算局部注意力分数**: + $$ + \mathbf{S}_{11} = \mathbf{Q}_1 \mathbf{K}_1^T = \begin{bmatrix} q_{11}k_{11} + q_{12}k_{12} & q_{11}k_{21} + q_{12}k_{22} \\ q_{21}k_{11} + q_{22}k_{12} & q_{21}k_{21} + q_{22}k_{22} \end{bmatrix} \in \mathbb{R}^{2 \times 2} + $$ + - **局部统计量**: + - 逐行最大值 $\tilde{m}_{11} = [\max(\mathbf{S}_{11}[1,:]), \max(\mathbf{S}_{11}[2,:])]^T$。 + - 未归一化注意力权重 $\tilde{\mathbf{P}}_{11} = \exp(\mathbf{S}_{11} - \tilde{m}_{11})$。 + - 逐行和 $\tilde{\ell}_{11} = [\text{sum}(\tilde{\mathbf{P}}_{11}[1,:]), \text{sum}(\tilde{\mathbf{P}}_{11}[2,:])]^T$。 + - **更新全局统计量**: + - 全局最大值 $m_1^{\text{new}} = \max(m_1, \tilde{m}_{11})$。 + - 全局行和 $\ell_1^{\text{new}} = e^{m_1 - m_1^{\text{new}}} \ell_1 + e^{\tilde{m}_{11} - m_1^{\text{new}}} \tilde{\ell}_{11}$。 + - **更新输出**: + $$ + \mathbf{O}_1 \leftarrow \text{diag}(\ell_1^{\text{new}})^{-1} \left( \text{diag}(\ell_1) e^{m_1 - m_1^{\text{new}}} \mathbf{O}_1 + e^{\tilde{m}_{11} - m_1^{\text{new}}} \tilde{\mathbf{P}}_{11} \mathbf{V}_1 \right) + $$ + - **写回 HBM**:更新后的 $\mathbf{O}_1$ 对应前两行,$\ell_1$ 和 $m_1$ 同步更新。 + +3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$**: + - 类似地,加载 $\mathbf{Q}_2 = \begin{bmatrix} q_{31} & q_{32} \\ q_{41} & q_{42} \end{bmatrix}$,计算 $\mathbf{S}_{21} = \mathbf{Q}_2 \mathbf{K}_1^T$,更新后两行 $\mathbf{O}_2$。 + + +**步骤 2:外层循环 $ j=2 $,处理块 $\mathbf{K}_2, \mathbf{V}_2$** + +1. **加载 $\mathbf{K}_2, \mathbf{V}_2$ 到 SRAM**: + $$ + \mathbf{K}_2 = \begin{bmatrix} k_{31} & k_{32} \\ k_{41} & k_{42} \end{bmatrix}, \quad \mathbf{V}_2 = \begin{bmatrix} v_{31} & v_{32} \\ v_{41} & v_{42} \end{bmatrix} + $$ + +2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$**: + - **加载数据**:当前 $\mathbf{O}_1$ 已包含来自 $\mathbf{V}_1$ 的贡献。 + - **计算局部注意力分数**: + $$ + \mathbf{S}_{12} = \mathbf{Q}_1 \mathbf{K}_2^T = \begin{bmatrix} q_{11}k_{31} + q_{12}k_{32} & q_{11}k_{41} + q_{12}k_{42} \\ q_{21}k_{31} + q_{22}k_{32} & q_{21}k_{41} + q_{22}k_{42} \end{bmatrix} \in \mathbb{R}^{2 \times 2} + $$ + - **更新统计量**:根据 $\mathbf{S}_{12}$ 的局部最大值和行和,更新 $m_1^{\text{new}}$ 和 $\ell_1^{\text{new}}$。 + - **更新输出**: + $$ + \mathbf{O}_1 \leftarrow \text{diag}(\ell_1^{\text{new}})^{-1} \left( \text{diag}(\ell_1) e^{m_1 - m_1^{\text{new}}} \mathbf{O}_1 + e^{\tilde{m}_{11} - m_1^{\text{new}}} \tilde{\mathbf{P}}_{12} \mathbf{V}_2 \right) + $$ + - **结果等价于全局 Softmax**:最终 $\mathbf{O}_1$ 为前两行注意力结果的加权和。 + +3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$**: + - 类似地,计算 $\mathbf{S}_{22} = \mathbf{Q}_2 \mathbf{K}_2^T$,更新后两行 $\mathbf{O}_2$。 + + + + +通过这种分阶段、分块处理的方式,FlashAttention 在不牺牲计算精度的前提下,显著提升了注意力机制的效率,成为处理长序列任务的利器。 + + +### 3.2 动态重计算:用时间换空间 + +在反向传播阶段,传统的注意力机制需要存储前向传播生成的完整注意力矩阵,这进一步加剧了内存压力。FlashAttention 采用了 **动态重计算** 的策略:在前向传播中,只存储必要的中间结果(如最大值和归一化系数),而在反向传播时,按需重新计算注意力矩阵。 + +我们的文章里面展示只实现前向传播的计算,反向传播的详细过程可以参考 [2]。 + + + +## 参考文献 + +[1] Andrei Ivanov, Nikoli Dryden, Tal Ben-Nun, Shigang Li, and Torsten Hoefler. Data movement is all you need: A case study on optimizing transformers. Proceedings of Machine Learning and Systems, 3:711–732, 2021 +[2] https://zhuanlan.zhihu.com/p/669926191 +[3] http://www.zh0ngtian.tech/posts/49b73eba.html + diff --git a/docs/17_flash_attn/02_flash_attn_v1_part2.md b/docs/17_flash_attn/02_flash_attn_v1_part2.md new file mode 100644 index 0000000..d91b05c --- /dev/null +++ b/docs/17_flash_attn/02_flash_attn_v1_part2.md @@ -0,0 +1,285 @@ +# Flash Attention v1 实现篇 + +在上一篇中,我们介绍了 Flash Attention v1 的基本原理和分块计算的思想。本文将深入探讨 Flash Attention v1 的实现细节。 + +## 1. 整体设计思路 + +的是 Flash Attention 的 CUDA 内核接收三个输入张量 $\mathbf{Q}$(Q)、$\mathbf{K}$(键)、$\mathbf{V}$(值),计算 $\mathbf{QK^T}$,然后经过 softmax 归一化得到注意力权重,最终与 $\mathbf{V}$ 相乘得到输出 $\mathbf{O}$。在这个过程中,还维护了中间状态 $\mathbf{l}$ 和 $\mathbf{m}$(分别代表累积指数和归一化系数),以便分块累积计算长序列时保持数值稳定性。总体上,每个线程块(block)负责处理一个 batch 的一个 head 内的一部分数据,利用共享内存(sram)来减少全局内存访问延迟。整个内核通过嵌套两层循环来实现对大矩阵分块计算的过程。 + +## 2. 核函数外的主机代码与内核启动 + +核函数在启动前,主机代码首先要确定每个线程块(block)需要使用的共享内存大小。这里的计算公式为: + +```cpp +const int sram_size = (3 * Bc * D * sizeof(float)) + (Bc * Br * sizeof(float)); +``` + +这段代码由两部分组成: + +`(3 * Bc * D * sizeof(float))` 此处 3 代表共享内存中划分出来的三个区域:Qi、Kj 和 Vj。Bc 表示每个 block 中需要加载的元素个数(其实和线程数有关,每个线程负责加载 1 组 d 元素),D 就是每个向量的维度,也就是每个线程加载的数据条数。 sizeof(float) 是每个浮点数的字节数(通常为 4 字节)。 整体来看,这部分计算出的是存储 Qi、Kj 和 Vj 这三个数据块所需要的共享内存总字节数。 + +`(Bc * Br * sizeof(float))` 此部分对应共享内存中 S 区域,用于存储中间计算结果。 Br 则通常代表内层循环中维度的大小,同样乘上 Bc 与 float 的字节数,得到对应的共享内存所需大小。 + +将两个部分相加,就得到了每个 block 所需的共享内存总量 `sram_size`。如此计算可以确保在调用内核时把共享内存传递进去,从而保证内核中的动态共享内存可以正确使用。 + +为了避免请求的共享内存超过设备的最大允许值,程序调用了: + +```cpp +int max_sram_size; +cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); +printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size); +``` + +如果请求的 sram_size 超过 max_sram_size,那么内核启动时将会失败,这时候我们需要调整 Bc、D、Br 的参数,找到平衡点,既能保证算法所需内存,又不会超过硬件限制。 + +接下来,我们需要设置 CUDA 内核的执行维度(也就是 gridDim 和 blockDim): + +```cpp +dim3 grid_dim(B, nh);  // B: batch 大小,nh: head 数 +dim3 block_dim(Bc);   // 每个块内有 Bc 个线程 +``` + +网格(grid)的维度设置为 dim3(B, nh),第一个维度 B 对应批量(batch)大小,每个 batch 分配一个 block 行。第二个维度 nh 则对应多头注意力中的 head 数,每个 head 分配到不同的 block 列。这样就保证了在一个 kernel 启动中,不同 batch 与不同 head 可以并行执行而互不影响。 + +线程块(block)设置为 dim3(Bc),意味着每个 block 中一维有 Bc 个线程,每个线程负责计算指定数据片段。 + +最后,内核调用语句为: + +```cpp +flash_attn_v1_kernel<<>>(d_Q, d_K, d_V, N, D, Tc, Tr, Bc, Br, softmax_scale, d_l, d_m, d_O); +``` + +## 3. 核函数内的计算逻辑 + +### 3.1. 线程与数据布局 + +Kernel 开头定义了以下变量: + +```cpp +int tx = threadIdx.x; +int bx = blockIdx.x; +int by = blockIdx.y;  // batch 和 head 索引 +``` + +这几行代码直接利用 CUDA 内置变量获取当前线程在线程块内的位置(tx)以及线程块在整个网格中的位置(bx 和 by)。其中: + +- tx 表示线程在当前 block 内部的索引 +- bx 用来标识当前线程块所属的 batch +- by 表示当前线程块所属的 head(多头注意力中每个 head 单独处理注意力计算) + +这样的编号方式使得每个线程块都能精确知道自己应当处理哪一部分数据。 + +紧接着,代码计算了两个偏移量: + +```cpp +int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); +int lm_offset = (bx * gridDim.y * N) + (by * N); +``` + +`qkv_offset` 这一偏移量用于确定当前线程块在全局内存中 Q、K、V 这三个张量的起始位置。`bx * gridDim.y` 表示每个 batch 内有 `gridDim.y` 个 head(也就是每个 batch 中的线程块数)。乘以 N * d 后,可以理解为一个 batch 内所有 head 占据的存储空间大小。于是 `bx * gridDim.y * N * d`就跳过了前面所有 batch 的数据。`(by * N * d)` 则是在当前 batch 内,按 head 顺序找到对应 head 的起始位置。 + +`lm_offset` 这一偏移量用于 l 和 m 两个中间状态数据的定位。`(bx * gridDim.y * N) + (by * N)` 与 `qkv_offset` 类似,区别在于与 Q、K、V 数据相比,l 和 m 的数据不涉及向量维度 d。同样 `bx * gridDim.y` 表示前面所有 batch 内所有 head 的数据总大小,乘以 N 后跳过了这些数据。 + +借助这两段偏移量的计算,每个线程块都能“知道”自己在全局数据结构中的准确位置。 + +### 3.2 共享内存的划分与加载 + +在 CUDA 内核中,我们经常利用共享内存(Shared Memory)来减少对全局内存的访问,以提高数据的访问速度。动态共享内存大小在内核启动时由主机代码传入(即 sram_size 参数),从而保证内核需要多少共享内存就申请多少内存空间。 + +为了更好地利用这块共享内存,代码通过手动划分来存储不同数据,其代码如下: + +```cpp +extern __shared__ float sram[]; +int tile_size = Bc * d; // size of Qi, Kj, Vj +float* Qi = sram; +float* Kj = &sram[tile_size]; +float* Vj = &sram[tile_size * 2]; +float* S = &sram[tile_size * 3]; +``` + +这里我们可以逐一拆解每个部分的作用: + +- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `tile_size`,即 Bc 个向量,每个向量的维度为 d。Bc 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。 +- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 tile_size 大小(Bc * d 的数据量)。 +- **Vj 区域**: Vj 与 Kj 类似,不过它存储的是值(Value)张量的一部分。运算中配合 softmax 后的注意力权重对每个线程所对应的值进行加权求和,最终生成输出。 +- **S 区域**: S 区域专门用来存储计算结果——也就是 $QK^T$ 相乘得到的分数 Matrix S。在执行 softmax 操作之前,每个线程对自己对应的输出行内的所有元素,将点乘结果保存到 S 里。 + +### 3.3 外层循环:分块加载键和值 + +下面我们详细解析外层循环中分块加载键(K)和值(V)的代码,看看它是如何利用每个线程的协作,将全局内存中的 K、V 张量按照分块(tile)的方式加载到共享内存中,从而实现更高效的数据复用和计算。 + +在实现多头注意力计算时,整体的 K 和 V 张量往往规模较大,所以一次性把全部数据加载到共享内存是不可能的。为了解决这个问题,程序采用了分块加载的方式: + +- Tc 表示总共需要加载多少个“tile”块,每个 tile 包含一部分连续的键和值数据 +- 每个线程块在每次外层循环迭代过程中,只加载一个 tile 的数据,用于后续计算 + +```cpp +// 整个 K、V 张量被分成 Tc 个 tile +// 每个 tile 大小为 tile_size(定义为 Bc * d) +for (int j = 0; j < Tc; j++) { + // Load Kj, Vj from HBM to SRAM + for (int x = 0; x < d; x++) { + Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; + Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; + } + __syncthreads(); +``` + +这一段代码通过外层循环分块来加载 K 和 V 数据,每次循环由每个线程加载连续 d 个数据,并计算出正确的全局偏移地址,确保加载到共享内存中的 Kj 和 Vj 是当前 tile 对应的数据。通过调用 `__syncthreads()`,所有线程在继续下一步前等待该 tile 加载完成。 + +![picture 0](images/c397679ab7ad09d15680161bc1e2344c62a97161daf94c08cd9a1766036d05f2.png) + +### 3.4 内层循环:计算注意力分数 + +下面我们来看看内层循环所实现的 Q 加载和 softmax 计算过程,看看这一部分代码是如何在共享内存中加载 Q 向量 Qi,并利用加载好的键 Kj 进行内积计算,然后对计算结果执行 softmax 预处理,从而为后续注意力加权做好准备。 + +代码中对每个 Q 分块(tile),利用一个内层循环来处理,每次循环处理一部分 Q 数据以及对应的 softmax 计算: + +```cpp + for (int i = 0; i < Tr; i++) { + ... // 内部代码 + } +``` + +其中,Tr 表示 Q 被分成 Tr 个块,每个块依次处理,既可以应对长序列,也能在共享内存有限的情况下循环利用已经加载的数据。 + +内层循环开始时,每个线程首先将自己需要的查询数据从全局内存复制到共享内存区 Qi 中: + +```cpp + for (int x = 0; x < d; x++) { + Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; + } +``` + +这里的流程与之前加载 Kj、Vj 的步骤类似,每个线程把对应的 d 维查询向量数据加载到共享内存 Qi 区域。 + +加载完查询向量后,每个线程从全局内存读取之前累计的中间状态值,这些状态值用于 softmax 的数值稳定更新: + +```cpp + // 上一阶段累积中行内最大值(用于防止指数溢出) + float row_m_prev = m[lm_offset + (Br * i) + tx]; + // 上一阶段累积的归一化因子(行内所有 softmax 权重的总和) + float row_l_prev = l[lm_offset + (Br * i) + tx]; +``` + +接下来,每个线程计算查询向量 Qi 与所有加载在共享内存中键 Kj 的内积,从而获得每个注意力分数。代码实现如下: + +```cpp + // 计算 S = QK^T, 并找到当前行的最大值 row_m + float row_m = -INFINITY; + for (int y = 0; y < Bc; y++) { + float sum = 0; + for (int x = 0; x < d; x++) { + sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; + } + sum *= softmax_scale; + S[(Bc * tx) + y] = sum; + + if (sum > row_m) + row_m = sum; + } +``` + +对于每个线程而言,它负责一行的计算,假设该行对应一个查询向量 Qi。外层循环(变量 y)遍历当前 tile 中加载的所有键数据,共有 Bc 个键向量。 内层循环(变量 x)计算 Qi 与第 y 个键向量之间的点积。由于每个向量有 d 个分量,所以内层进行 d 次乘法加法累加,得到 sum。计算完点积之后,乘以 softmax_scale 参数。这一步通常用于缩放内积结果,防止数值过大或过小。 + +将这个结果存储到共享内存的 S 数组中,对应的存储位置为 S[(Bc * tx) + y],保证每个线程负责的查询行中对应的所有注意力分数都存放在连续内存中。 + +同时,通过比较更新 row_m 的值,找到当前计算这一行中所有注意力分数的最大值。这个最大值在后续 softmax 算子中用于减值,保证数值稳定性,防止指数计算溢出。 + +完成点积计算后,接下来进行 softmax 的核心操作,对每个注意力分数,都先做一个指数计算,但在计算之前先减去 row_m,以实现数值稳定性。 + +```cpp + // P = exp(S - row_m), row_l = rowsum(P) + float row_l = 0; + for (int y = 0; y < Bc; y++) { + S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m); + row_l += S[(Bc * tx) + y]; + } +``` + +对 S 数组中存储的每个注意力分数,统一减去当前行的最大值 row_m。减去 row_m 后,再用 `__expf` 计算指数函数。每个线程遍历自己计算的这一行中所有的元素,**计算出的指数值重新覆盖 S 中对应位置**,同时累**加得到整行的归一化因子** row_l,也就是后续 softmax 中用来归一化的分母。 + +整个内层循环实现的关键步骤如下: + +1. 以每个分块查询为单位,通过内层循环加载 Q 相关数据到共享内存 Qi,确保快速访问 +2. 读取之前累积的状态(row_m_prev 和 row_l_prev),为跨块累计计算做准备 +3. 对于共享内存中加载的键 Kj,计算 Qi 与 Kj 的内积,得到未经归一化的注意力分数 S +4. 在计算过程中减去每行的最大值(row_m),利用 `__expf` 函数计算指数,同时累加这些指数值,得到 softmax 归一化所需的总和 row_l + +![picture 3](images/9bcbef4156d32c8ff9836bafd4132191a1c62eca89b4316c8f78b7c98e22f84b.png) + + + +### 3.5 状态更新与输出计算 + +在分块计算长序列的注意力时,由于不能一次性处理整个序列,所以将查询 Q、键 K、值 V 分成多个块。对于每个分块计算 softmax 部分后,我们需要将当前块的新计算结果与之前累积的结果合并。这就需要设计一种数值稳定的融合策略,既要保证计算结果正确,又要防止由于指数计算产生数值溢出或下溢。 + +在本段代码中,用两个中间状态: + +- row_m:当前块中求得的最大值(用于 softmax 数值平移) +- row_l:当前块完指数处理后的行和(用于归一化) + +同时,前一块累积保存的状态为 row_m_prev 与 row_l_prev。下面的步骤正是将这两部分状态融合,得到新的累计状态 row_m_new 与 row_l_new,再结合当前计算结果更新输出 O + +首先,我们计算新的最大值和归一化系数。代码如下: + +```cpp + float row_m_new = max(row_m_prev, row_m); + float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + + (__expf(row_m - row_m_new) * row_l); +``` + +计算新的最大值 row_m_new 时,由于要融合两部分的 softmax 结果,新的最大值取的是前一块和当前块两者的较大值。这样保证在进行指数变换时不会因过大的差距导致数值的不稳定。 + +更新归一化系数 row_l_new 这一式子体现了两块归一化部分如何合并: + +- `__expf(row_m_prev - row_m_new)` 乘以前块的累积归一化系数 row_l_prev,得到前块在新归一化系数中所占的贡献; +- `__expf(row_m - row_m_new)` 乘以当前块的归一化系数 row_l,同理得到当前块的贡献; + +将两部分相加,结果就是新的归一化因子 row_l_new,这种更新方案也是流式 softmax 计算中的一种常见技巧,通过分块计算后在物理意义上“拼接”所有块的结果。关键在于: + +1. 乘以相应的指数因子,可以将两个块的非归一化权重能够在同一数值域下叠加求和 +2. 取新最大值后,对前后状态进行归一化,确保整体输出后续只需再除以这个新归一化系数即可 + +接下来,代码进入输出计算阶段。目的是将 softmax 计算得到的权重(经过指数处理后的 S)与对应 V 值进行加权求和,再融合前一块累积的部分。代码如下: + +```cpp + // Write O, l, m to HBM + for (int x = 0; x < d; x++) { + float pv = 0; // Pij * Vj + for (int y = 0; y < Bc; y++) { + pv += S[(Bc * tx) + y] * Vj[(y * d) + x]; + } + O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \ + * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \ + + (__expf(row_m - row_m_new) * pv)); + } +``` + +for 循环遍历向量维度 x,对于每个查询向量的每个元素,都需要进行加权求和。 + +循环内,pv 初始化为 0,然后通过遍历 y(即遍历当前块内所有键和值数据),累加 S 中 softmax 权重与 Vj 中对应的元素之积。这里 S 中存储的已做指数处理的加权权重,与 Vj 相乘,得到当前块对该查询向量在 x 维产生的部分输出。 + +接下来融合累积结果与当前计算的结果。前一块的累积结果已存储在全局内存的 O 中,对应部分在当前块中仍然可用。由于前一块的累积是在 row_l_prev 归一化下完成的,因此需要通过乘以 `__expf(row_m_prev - row_m_new)` 来转换到以新状态 row_m_new 为基准的数值域;同理,当前块的新计算 pv 也要乘以 `__expf(row_m - row_m_new)`。将这两部分累加后,再乘以 (1 / row_l_new) 得到最终的归一化输出。 + +在计算完当前块的输出 O 后,还要将新的中间状态写回全局内存,供后续块继续融合使用: + +```cpp + m[lm_offset + (Br * i) + tx] = row_m_new; + l[lm_offset + (Br * i) + tx] = row_l_new; +``` + +## 4. 编译运行 + +完整代码在仓库的同级目录下的 `flash_attn_v1.cu` 文件中。可以通过以下命令编译运行: + +```bash +nvcc flash_attn_v1.cu -o flash_attn_v1 +./flash_attn_v1 +``` + +## 参考 + +- [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2006.04533) +- https://github.com/tspeterkim/flash-attention-minimal + diff --git a/docs/17_flash_attn/drawio/draw.drawio b/docs/17_flash_attn/drawio/draw.drawio new file mode 100644 index 0000000..da1540c --- /dev/null +++ b/docs/17_flash_attn/drawio/draw.drawio @@ -0,0 +1 @@ +7V1tc5s4Hv80nmlvph4hISReNkl7d7Pb3e5lb2/vVYbYOKF1TI6QJtlPf+JB2EjCgI0ENqSd1gGD8f/54ae/Zujy4fXvkfd4/yVc+usZBMvXGbqaQWhBCtl/yZG37IjrouzAXRQss0Nge+A6+MvPr+RHn4Ol/5Qfyw7FYbiOg8fywUW42fiLuHTMi6Lwpfy2Vbhelg48ene+dOB64a3lo/8JlvF9dhQDALYn/uEHd/exeObB4+/O7/F07y3Dl53vjD7N0GUUhnH26uH10l8n1CsT5nPF2eLJIn8TN7lg+fWflxfP1ub7n38+B1/Jtz/+WtEPlGS3+eGtn/OvnD9t/MZpEIXPm6Wf3MWaoYuX+yD2rx+9RXL2hXGdHbuPH9b56eJbguSX7368uOdn4ij87l+G6zBK74su0x92xosWOecZUS7yB/Kj2H+t/KpWQUAmen744MfRG3sLvwDBOc4uyuXOwjTny8uWjZZFCH/ffZmJuQTl0nNXfMKWvuxFTuI25Ib9kdv6nPzphsCuA8rkJURBXq7au7SFVBttLWO0XYWbeIey2U9yPFiv+fFNuPF1swFBgQ2OJbOhoPguGwredM4GC8hk95fMqOa/hlF8H96FG2/9aXv0YvEc/Sj4suWSRPnt1T+H4WN+8Jsfx2+5BfGe47DMRkbL6O3P/GbpL/9NfmGKn/969bp78uqNf67IuYv0T8G55Evt5xujQfgcLfw91CJu7tS86M6P95K1QhQif+3FwY/yo3TOVf6cJeVy1nFCp0dvU2K387/nxKelWvJh5T0Ea0bSj7NE4hzv4TElIEJ2wvXwIVgk13qbp8RlPknvSK5K/01v9pTyOLmVhR9ft5/EXt0l/8/nc/5U7FtmD5ad0WUJtq7LNq7sKpurVHZ9Nteh5xk+CJQGqEXwYOsitn2AXS2b0j6tp00p/vRZ5GqHppSnB3WWFPVpR1UxSmawEgNXbUZ3LB9RWb4ZtJO/+DJJO25XM3Lx24xc5Ue3JjH7kC5M4m6ow7i79Hy6Wqj47iyozx5on2p3obK4rLJQFe4jo7bROiQO6jb0qVbCDvUOnoLeqfIuPXr30wx/+n1UuicFJgpvqU5CLG26Zx3iLDUmIdWsqWJmhxqKGmoo6VNDkTENvU61E6S3ApLLVKpy9iC3kXDTESq3yrEaTjpUJUs9kvK1SlJkifBfH9mb34ly9n7MsuPYguxsM6Z66cG6XAM4XnxcZfGhKUsZQeMy37x1cLdhrxeMsD5j0UVC9mDhrT/mJx6C5TKLD332GN5tequEZY9hsIlTEuGLGb5K7sV8T/aoXbHRFdJhGyhMgIqH2ry7OV/xx6giObGsDwmVOF2RRWlTV3w8r6mS1+jqXBWWxdXl7pcNiMRHJRu1aayjW2NXkbdIHPOO7n5MdDf5Wzq6zo+mDn2xDGPBtY9L4y1LUnnZQ5tWeVVRuabXsMgIlwhKdHf7DiRMhezjwc6r97NSM2FvZ+LeX//wE41u3pFI5bN0LnvK5OQmjB68dfn0S07M5LydPWd6cs0SST/6wL7qItjcKa9PDNKH3AYlp3MzVDodMJnc5LcHO4+Wnowjb/O0Yjflt9/4xRtewmhZ/vTdy2+9xfe7VOI/CDSHNi1oDW13+xrvUH4ZPD2uvZzqwWYd7Hzwah168e4D1bvntRCgPz0/3ASCQldE5jfBe0nN93aMci9Qo9+d6KRdNt9QEXAVQdiuTiJd9pse33s4toJSKs5I1tJNf2aH9ytmHRZEcy7U11tstRw0Lrikl36MIu9t5w15hLG989fkwI54ieU6Lkqfm15AHSxIVPYMW/kqvswx5XNwgMyZLOHpwAM0lh14rOwcV07vIIkeRh9rD8indaTWhe0XgVXQkW2/YWCVVUjvQDSxsPFWycbXWHgz+tu0G4adXvVX1Q/To0YdKIXUYAJu7zA3lw5UJ4aIcoN2Q61we21B8cecQG77QG46FFrl5Qx3laA5JHwHFJQA2BAr6jbKOMHVRUGkwoiKNnKz/Jgsy2C/Ldbe01OwEIjWBsJeb5tk+r3sruKQqcOPtUvBpJQJQUfwWESgemYz88v2JF+Y1t0pM6rSndqmhdiu+CCtWR4vIg7Pjc6BVQ4v5xjTGmea/vbVjwJGlqQtYM7D4qYeNivvmq45ICCYK47XqxJ76QKHEP3SiKsDgK7T2nc/3Xx7rx8q1ji9Xa1WcKFMb5fOrcOylW4cP2kQybtGA3ncQYOxMc8Zw5P/f79ZvE/BIWPjP7aGx//q1uT+oP8A4NkN++cb7z+26T30VtLahv5WR4srJBOgiv1VIqAv9MfViDItIsDUf+RCINmB/oWAx9mDi0TFKHQIBR3cdNkiBWo50NzisgTkoYXqelzSFUaaXFiVMmtCsmXOR4ajqi3R2YYgUuUEAQXITVU50VY44U9kSgoS/zNoOfCdCjkg7i0AZY9ENImFzQPD/sSierEg+wCPnVnce9FTbn8575/j1QeqEoXWCKocPpVjp2qAU8Vdqord/HkqC937gFPVqKk6yNQ+vFQ1WGo/UmoPTKoJRqoASBXoqEpolBIXlTLz8KZAMxxTFMbMQYcJltUF3aiYBQRYqoUU2Z+tUDFtuCang4LPuBcDWEi1IMgktthRJfBt2w6VnYaDVqv31XaQgN/bvKpt30G+lTTYoKLx0FVs7NTGxitvUVbNy9z7XGfe58v1rmK2XygAFoVAbA+iRfojq3z4nGFw10mGd6barxisBWxZ/y0VOFWbASCw3gCcwlAtCfs1hKFa1Vj8QSphsNmMUQmxIpAyq4S0i+b/eXjhIiOU9LitExbvBAE264NdczP19JpWaaTTAGyr28Ey1pEnH1QFBDVp9txDW0cTZLAd68VZAkOADLqHDkVsX7dlr+NgvfTzoRRTI5lJhOhm++8hFiFXPxIx9ZWp6CCGIBPVZfzJRZgZnSo5jwGMTi3qa52Hf9kVZxkAbvFinJOuYm202dEWoDoEXAY/lGw8oD3X0EvsL83sTjVQOph3v8zg5VJeGF8WogozURxOv/ZR3SdJCruQHWFEKMaKDFAlOmLO3aHoVFfXSjQchmwsmWz8cp6yIS5Iauoe9InGnjFlAxSNX8YjGiqzYVo2qiPKSTZMyoYYVw5BNuCpycaZhhsioGwA4caeuYmTbBgNRcVW7xCEw9xWHsdsn/TJuQJXnzpig9hwp1DBBqCqF2kDfVrA3HBTRQkxqKoqm8ABS8ym8BY5zkyuLWKfLu2OfLhYUVAJgVkRsMyNIhiKKkqrwlQ7mam4YGvbbofUQl+6XRU2PtWTmN676hFzC8B/uwlOZ72FMWdrdIEFz8ym1VfGpEBaZtO/FLjVGXPXUvBr73a+n6GGEvqqf64Tk2MfplEfLSI6jVvTwgYg1VPcrlLEAzMFk6eBEQKV+1VyiLeGbTqq69TTlNjjYimJ47bVdP6bPkQDUpUu9JC9CwMFJSJCu+H2RBptFGqwnG04RCSQinLYFFetr36CGtRPRjOIkEpCDoWNO5qPIqy/V+thhF2vWLCwgteDmOOyM7rloHHVxRYHTWWuwRwXC+XqW7//PK6QY0P7JiNz1bihxetmYgpEGrlDwxG7cqbs+KbzmREBbA1RBPhQzqEYdDOG2YaNDXO/G9rzB52KqMb0VA66HdQs6NYXc9sG9+Q8gRlW7Va/dDTDSpILhyMFe5SL6o7aNMRqGmJ1ckOsCEFlHcOKCMnoDCvLngYgFev/RO4UZdG2xQ7pTo7brNTRWSHDrt3nflijV8Yx/0hstLiKsrvRwSt5YWaUlRFTncxG3RbTabHhgHvQeZeh3Qvw8PKuPdtXmFxdPm113fVW1+1UGUKVKv8cesuEDmESE/zr45e2yygOCro7UDSHz0UvypBUkcg6eE6hItDWVuo6ZKuQaWP6UWirFCnlq5fazmJppmICah+kPx1FuNSZY4AwdG3LIo5NiZAGUXsOQHqKAEy5Au7GQY5CJ/X5vwMGcE06OU6dZL7w8mJxgkqJKGZKt1VKwTtiiIelk9W9+kknJ50s6eTF4kQdJeZqVmghcueKqoBK8/RFqOcJcBb6WiiVVJHQRAFupryI2j2plZv2mCl24OSPigVO+iPLfbGsswtmCEUQ5DRdy6EReazYKucvXdzogIY2EmiIgOyzTQOPHWhMoDsgoeuUS7KIKMIe08Bjl0oUOzPsqWBYPrMfZ1td7Qb6xDlbD30i/UKfnPGuITMDdLJpAztpuO+i3G5oQqTqEQBHjPyGIABkWEbeDB7VabxQIJuh0J9RNrdQYPh9UTNaKsViFDVcjKgvFKvfPWzCo5YFQceeqoJcuLT3fjlRpe0THnXCo16cJh7VAnzBbaFkfHxYb4BU5YZ8A0YujmPTOCyA2BDEiiqqJQuKvliNmBsY0AtgTMpfba4a/aUvyhFnZ0Rz4g4vGCUTcGhqiJ49cAiWMAoCcAhR6g4KpEAm4NCkk+cOHLLdfcAhZFvOsHRyAg5NOnnuwCFqCVroWlbvwCE+L3AoTY1t53q227fetrEH2bkmOWfrmyS032lKxFx1fDhzrk3NZhTqcY4jl1kM55/KQdeiuo9jhXixWn4LXztwhbh0JxcYXiHOF0WdaykH8c2shqRK1bXtKUSdQtRyiBqdaIiKkLMvbSS2PccDShtddwpejw9eeW++Pnh1Sa/BKz37rpHQwWBur2GC6CJtRK+e4zX5vcnvyX7vROulQsSZurpGqufY2lSvuns4GMjSVjUr1XCfDlYrYJ327VO9ar3br3R7NK6JuumAIjWuq3SvX0KwcVxgiQkLLIu4EdmCo7PA3HWLs1iBFuTj3c1ElrS6STip3mhVb+9inGEroA0TBYRFaieUsGzXYgqYZ30E8NO7CqhalahRARvsSjySymW6vpy42x+BdbSYUtd6Y4/9d8bUEu+su7LpTgXrnR1chJA0mdJPD+O0zT9/3820M7dV/g6Oyd8lTK7xfY2lEjaVbSrki1VKEFBHl011WyFAe6S/Brg8InROxPIKUA3ugFSR5GlkygFzHKf6yijrK9enF3Zi0e84liK3IyYjS7dV8/QoK1iFQ5Cso2QKWa7MuKDNFckLhZTdVI1WD/bGBL6PvcQE0etITLGvXLezgqMjMAVDRc5lmCmtdqDUFx80g+xoCBKwheQgQRG3WY5qQ11tjIF8f9Wpz3lMn9PNrV5tnzN3SL31Od0G25iecp/TcWxZzRQIHxfNoULPXF0NFxbWTrH4FIs3isXZ6zhYL3328mvj+Q/DjtUJQIJOqmJ1ZDJYh3zqz6SRk0aeL/rAIUByiAMAIEAwsMUhvcaPEDTFyUHgHBk/ppdW752urvETKvRvnDxSqmoKyBfkBr9ybzQkVhLKF7AX2UN32jdgJnFyApMTaOQE/NfHd9czfJEsEgR/S75d+qTwIhVazAQEfEgeL3y5eTDkKLoM0RxX6Nsrq0aGQ7TauZgVM3lYsrmSefp7QuMhzuLRyFZqiWzlc6132Mptb6nkRHVx1ZraUpPNbWZz/7j5dnqWlHCURjGmBANFvG3YlnIdH0rAfZKF3qKfWB+oQ/fIQP1II3vmhV5pJJfDp3D2to4T8iEMfTe5+gDBJKAvqc5AsQoEYxttcFkN0J6m28ESZyi8Ramt0sEZV3RHqoGBihmN+njC1fKUFCWdeNBRj0ocF6DC79nUbCMYmENQDgO70ogNhvXCHH5omCgJConIEwWeyDBEgjNhylenfLUmX/31FBfQURvNbbdyqYZjyfAJw7krbIAuH8nSDWn8swOE9fmNV+jwUT+Vd9K8aANCVSFwnGx1iIxbIhY4jLPEpjV30s7ZBnDPkXCWWqCAmhUZGMSHrqsrBoLtuZd27jaoLY2Fu1gANu3UetuyVrqXa0nLL7WztlUJ6/QmKUrbB/U//q2Im06jQqWFCaosj4edhnDwEz53SvIaJnmnl+JZAJC5O+gcD6lyvHqsR46V3N2nqfONoLBqKvG/H5denMhRHLJ/fh0brsQVgxUVottVSBBsL0Hs1yhMiL8NdZjy3H8Jl37yjv8D \ No newline at end of file diff --git a/docs/17_flash_attn/images/5dfce77ba3c57779bce60c8ebc552aa40304c6c5f36bb2d207d6b102a4d8026e.png b/docs/17_flash_attn/images/5dfce77ba3c57779bce60c8ebc552aa40304c6c5f36bb2d207d6b102a4d8026e.png new file mode 100644 index 0000000..3e78a21 Binary files /dev/null and b/docs/17_flash_attn/images/5dfce77ba3c57779bce60c8ebc552aa40304c6c5f36bb2d207d6b102a4d8026e.png differ diff --git a/docs/17_flash_attn/images/865f289b12429f1cf8de42cf2b6b019ecaef55ab09cb04c590ca66ac9b9f9ce7.png b/docs/17_flash_attn/images/865f289b12429f1cf8de42cf2b6b019ecaef55ab09cb04c590ca66ac9b9f9ce7.png new file mode 100644 index 0000000..4d1295a Binary files /dev/null and b/docs/17_flash_attn/images/865f289b12429f1cf8de42cf2b6b019ecaef55ab09cb04c590ca66ac9b9f9ce7.png differ diff --git a/docs/17_flash_attn/images/9626ab9b79ea64fb08e8f204c67d0e588f6ef384ab788cad38030846a21314c5.png b/docs/17_flash_attn/images/9626ab9b79ea64fb08e8f204c67d0e588f6ef384ab788cad38030846a21314c5.png new file mode 100644 index 0000000..f742dc0 Binary files /dev/null and b/docs/17_flash_attn/images/9626ab9b79ea64fb08e8f204c67d0e588f6ef384ab788cad38030846a21314c5.png differ diff --git a/docs/17_flash_attn/images/9bcbef4156d32c8ff9836bafd4132191a1c62eca89b4316c8f78b7c98e22f84b.png b/docs/17_flash_attn/images/9bcbef4156d32c8ff9836bafd4132191a1c62eca89b4316c8f78b7c98e22f84b.png new file mode 100644 index 0000000..055af6c Binary files /dev/null and b/docs/17_flash_attn/images/9bcbef4156d32c8ff9836bafd4132191a1c62eca89b4316c8f78b7c98e22f84b.png differ diff --git a/docs/17_flash_attn/images/c397679ab7ad09d15680161bc1e2344c62a97161daf94c08cd9a1766036d05f2.png b/docs/17_flash_attn/images/c397679ab7ad09d15680161bc1e2344c62a97161daf94c08cd9a1766036d05f2.png new file mode 100644 index 0000000..bfc2118 Binary files /dev/null and b/docs/17_flash_attn/images/c397679ab7ad09d15680161bc1e2344c62a97161daf94c08cd9a1766036d05f2.png differ diff --git a/docs/index.md b/docs/index.md index 447d463..72aed9c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -50,10 +50,13 @@ + ...(补充中) ### 大师系列 💡 + 我现在还不知道写啥,毕竟我现在还是菜鸡~~ ### LLM 推理技术 🤖 ++ [FlashAttention v1 - 原理篇](/flash_attn/flash_attn_v1_part1) ++ [FlashAttention v1 - 实现篇](/flash_attn/flash_attn_v1_part2) + [连续批处理](/continuous_batch) + [Page Attention - 原理篇](/page_attention) + [Page Attention - 源码解析](/vllm_page_attention) diff --git a/sidebars.js b/sidebars.js index 5180e8d..bbe0e53 100644 --- a/sidebars.js +++ b/sidebars.js @@ -172,6 +172,16 @@ const sidebars = { type: 'category', label: 'LLM 推理技术', items: [ + { + type: 'category', + label: 'Flash Attention 专题', + items: [ + { + type: 'autogenerated', + dirName: '17_flash_attn' + }, + ], + }, { type: 'autogenerated', dirName: '13_continuous_batch'