diff --git a/docs/17_flash_attn/02_flash_attn_v1_part2.md b/docs/17_flash_attn/02_flash_attn_v1_part2.md index d91b05c..b793234 100644 --- a/docs/17_flash_attn/02_flash_attn_v1_part2.md +++ b/docs/17_flash_attn/02_flash_attn_v1_part2.md @@ -32,6 +32,17 @@ printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, s 如果请求的 sram_size 超过 max_sram_size,那么内核启动时将会失败,这时候我们需要调整 Bc、D、Br 的参数,找到平衡点,既能保证算法所需内存,又不会超过硬件限制。 +这里为了简单起见,在代码中直接将 Bc 和 Br 写成了固定值。值得注意的是,这个 Br 和 Bc 的值是可以不一样的,并且一定有$Br \leq Bc$。 + +```cpp +const int Bc = 32; +const int Br = 16; +``` + +至于为什么一定会有$Br \leq Bc$,则可以回到在 Flash Attention V1 的论文里,其计算方式为 $Bc=\lceil \frac{M}{4d} \rceil$,$Br= min(\lceil \frac{M}{4d} \rceil, d)$。其中$M$是设备每个 SM 所能使用的最大共享内存空间大小,$d$是每个向量的维度。$4d$表示的是 Q,K,V,S 使用共享内存的子块大小之和。这里会发现当$\lceil \frac{M}{4d} \rceil > d$时,$Br = d < \lceil \frac{M}{4d} \rceil = Bc$。当$\lceil \frac{M}{4d} \rceil \leq d$时,$Br = \lceil \frac{M}{4d} \rceil = Bc$。所以一定会有$Br \leq Bc$。 + +根据这个性质,当 Br 与 Bc 不相等时时,也可以只用简单的 if 语句就可以完成 Q 子块的加载,但设置 Bc 和 Br 的时候最好是相等的,可以提高 GPU 线程的利用率。 + 接下来,我们需要设置 CUDA 内核的执行维度(也就是 gridDim 和 blockDim): ```cpp @@ -90,17 +101,18 @@ int lm_offset = (bx * gridDim.y * N) + (by * N); ```cpp extern __shared__ float sram[]; -int tile_size = Bc * d; // size of Qi, Kj, Vj -float* Qi = sram; -float* Kj = &sram[tile_size]; -float* Vj = &sram[tile_size * 2]; -float* S = &sram[tile_size * 3]; +const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj +const int Q_TILE_SIZE = Br * d; // size of Qi +float *Qi = sram; +float *Kj = &sram[Q_TILE_SIZE]; +float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE]; +float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2]; ``` 这里我们可以逐一拆解每个部分的作用: -- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `tile_size`,即 Bc 个向量,每个向量的维度为 d。Bc 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。 -- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 tile_size 大小(Bc * d 的数据量)。 +- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `Q_TILE_SIZE`,即 Br 个向量,每个向量的维度为 d。Br 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。 +- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 `KV_TILE_SIZE` 大小(Bc * d 的数据量)。 - **Vj 区域**: Vj 与 Kj 类似,不过它存储的是值(Value)张量的一部分。运算中配合 softmax 后的注意力权重对每个线程所对应的值进行加权求和,最终生成输出。 - **S 区域**: S 区域专门用来存储计算结果——也就是 $QK^T$ 相乘得到的分数 Matrix S。在执行 softmax 操作之前,每个线程对自己对应的输出行内的所有元素,将点乘结果保存到 S 里。 @@ -115,12 +127,12 @@ float* S = &sram[tile_size * 3]; ```cpp // 整个 K、V 张量被分成 Tc 个 tile -// 每个 tile 大小为 tile_size(定义为 Bc * d) +// 每个 KV tile 大小为 KV_TILE_SIZE(定义为 Bc * d) for (int j = 0; j < Tc; j++) { // Load Kj, Vj from HBM to SRAM for (int x = 0; x < d; x++) { - Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; - Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; + Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x]; + Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x]; } __syncthreads(); ``` @@ -137,7 +149,10 @@ for (int j = 0; j < Tc; j++) { ```cpp for (int i = 0; i < Tr; i++) { - ... // 内部代码 + // 这个就是处理Br和Bc不相等的情况 + if (tx < Br){ + ... // 内部代码 + } } ``` @@ -147,7 +162,7 @@ for (int j = 0; j < Tc; j++) { ```cpp for (int x = 0; x < d; x++) { - Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; + Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]; } ``` @@ -250,8 +265,8 @@ for (int j = 0; j < Tc; j++) { for (int y = 0; y < Bc; y++) { pv += S[(Bc * tx) + y] * Vj[(y * d) + x]; } - O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \ - * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \ + O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) \ + * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) \ + (__expf(row_m - row_m_new) * pv)); } ``` diff --git a/docs/17_flash_attn/flash_attn_v1.cu b/docs/17_flash_attn/flash_attn_v1.cu index ed62a92..56c4567 100644 --- a/docs/17_flash_attn/flash_attn_v1.cu +++ b/docs/17_flash_attn/flash_attn_v1.cu @@ -60,11 +60,13 @@ __global__ void flash_attn_v1_kernel(const float *Q, // Define SRAM for Q,K,V,S extern __shared__ float sram[]; - int tile_size = Bc * d; // size of Qi, Kj, Vj + const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj + const int Q_TILE_SIZE = Br * d; // size of Qi + // const int S_TILE_SIZE = Br * Bc; // size of Sij = softmax(Qi * Kj^T * softmax_scale) float *Qi = sram; - float *Kj = &sram[tile_size]; - float *Vj = &sram[tile_size * 2]; - float *S = &sram[tile_size * 3]; + float *Kj = &sram[Q_TILE_SIZE]; + float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE]; + float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2]; // outer loop for (int j = 0; j < Tc; j++) @@ -72,61 +74,64 @@ __global__ void flash_attn_v1_kernel(const float *Q, // Load Kj, Vj from HBM to SRAM for (int x = 0; x < d; x++) { - Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; - Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; + Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x]; + Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x]; } __syncthreads(); for (int i = 0; i < Tr; i++) { - // Load Qi to SRAM, l and m to registers - for (int x = 0; x < d; x++) + if (tx < Br) { - Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; - } - float row_m_prev = m[lm_offset + (Br * i) + tx]; - float row_l_prev = l[lm_offset + (Br * i) + tx]; - - // S = QK^T, row_m = rowmax(S) - float row_m = -INFINITY; - for (int y = 0; y < Bc; y++) - { - float sum = 0; + // Load Qi to SRAM, l and m to registers for (int x = 0; x < d; x++) { - sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; + Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]; } - sum *= softmax_scale; - S[(Bc * tx) + y] = sum; - - if (sum > row_m) - row_m = sum; - } + float row_m_prev = m[lm_offset + (Br * i) + tx]; + float row_l_prev = l[lm_offset + (Br * i) + tx]; - // P = exp(S - row_m), row_l = rowsum(P) - float row_l = 0; - for (int y = 0; y < Bc; y++) - { - S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m); - row_l += S[(Bc * tx) + y]; - } + // S = QK^T, row_m = rowmax(S) + float row_m = -INFINITY; + for (int y = 0; y < Bc; y++) + { + float sum = 0; + for (int x = 0; x < d; x++) + { + sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; + } + sum *= softmax_scale; + S[(Bc * tx) + y] = sum; - // Compute new m and l - float row_m_new = max(row_m_prev, row_m); - float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l); + if (sum > row_m) + row_m = sum; + } - // Write O, l, m to HBM - for (int x = 0; x < d; x++) - { - float pv = 0; // Pij * Vj + // P = exp(S - row_m), row_l = rowsum(P) + float row_l = 0; for (int y = 0; y < Bc; y++) { - pv += S[(Bc * tx) + y] * Vj[(y * d) + x]; + S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m); + row_l += S[(Bc * tx) + y]; + } + + // Compute new m and l + float row_m_new = max(row_m_prev, row_m); + float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l); + + // Write O, l, m to HBM + for (int x = 0; x < d; x++) + { + float pv = 0; // Pij * Vj + for (int y = 0; y < Bc; y++) + { + pv += S[(Bc * tx) + y] * Vj[(y * d) + x]; + } + O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv)); } - O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv)); + m[lm_offset + (Br * i) + tx] = row_m_new; + l[lm_offset + (Br * i) + tx] = row_l_new; } - m[lm_offset + (Br * i) + tx] = row_m_new; - l[lm_offset + (Br * i) + tx] = row_l_new; } __syncthreads(); } @@ -234,7 +239,8 @@ int main() // split kv seq_len to Tc and Q seq_len to Tr const int Bc = 32; - const int Br = 32; + // const int Br = 32; + const int Br = 16; const int Tc = ceil((float)N / Bc); const int Tr = ceil((float)N / Br); @@ -305,7 +311,7 @@ int main() if (max_diff < 0.0001) { - printf("Results are correct! "); + printf("Results are correct! \n"); } else {