FlashDecoding 原理 | 深度学习算法

FlashDecoding 在 FlashAttention 2 的基础上针对 LLM 推理时的 decoding 步骤进行了进一步的性能优化,其计算结果仍然是严格对齐的。

背景

在 LLM 推理场景中,主要包含以下操作:

  • linear projection:1、5
  • attention:2、3、4
  • feedforward network:6

LLM 推理时对 prompt 的处理过程称为 prefill phase,第二阶段预测过程称为 decode phase。这两个阶段的算子基本一致,主要区别在于是输入数据的形状。由于 decode phase 一次只处理一个 token,因此输入矩阵变成了一维矩阵(向量),如下图 decode phase 部分中和 KV cache 拼接的红色向量:

原理

正如上一篇博客中中所提到的,FlashAttention 2 通过在序列长度维度进行并行来缓解 batch size 比较小时并行化程度不高的问题。在 LLM decode 步骤中,batch size 直接降为 1,这进一步导致了 GPU 率下降。

FlashDecoding 通过对 FlashAttention 2 的内层循环进行并行拆分进一步提高了推理场景的并行化程度。

先来回顾一下 FlashAttention 2 的推理流程:

在 decode 步骤中,外层循环次数直接变为 1,只剩下串行的内层循环:

为了保证并行度,只能继续切分内层循环。但是内层循环需要做规约操作,不同 block 任务之间需要通信。因此,就有了下图的执行逻辑:

内层循环被分成 5 份,每一块被调度到一个 block 上计算局部结果,最后 5 个 block 还需要进行汇总得到最终结果。

参考

Flash-Decoding for long-context inference

大模型推理加速之Flash Decoding:更小子任务提升并行度

【FlashAttention-V4,非官方】FlashDecoding++

FlashDecoding 原理 | 深度学习算法

http://www.zh0ngtian.tech/posts/b2e55160.html

作者

zhongtian

发布于

2024-01-27

更新于

2024-02-25

许可协议

评论