小白学大模型,揭秘FlashAttention加速秘籍,大模型高效加速三步走

FlashAttention 是一种用于加速大型语言模型(LLM)训练和推理的算法。它通过减少内存访问次数和优化计算过程来提高效率。FlashAttention 的主要思想是将注意力机制的计算分解为多个小步骤,从而减少对内存的依赖。以下是 FlashAttention 1/2/3 的简要介绍:
### FlashAttention 1 FlashAttention 1 是最初的版本,主要针对序列自注意力机制进行了优化。它的核心思想是将注意力计算分解为多个小步骤,从而减少内存访问次数。FlashAttention 1 通过以下方式实现加速: 1. "块状计算":将输入序列分成多个块,并在每个块上进行注意力计算。 2. "内存优化":通过减少对内存的访问次数,降低内存带宽的瓶颈。
### FlashAttention 2 FlashAttention 2 是对 FlashAttention 1 的改进版本,主要针对更复杂的注意力机制进行了优化。FlashAttention 2 的主要改进包括: 1. "支持多头注意力":FlashAttention 2 可以处理多头注意力机制,从而适用于更复杂的模型。 2. "进一步优化内存访问":通过更精细的内存管理,进一步减少内存访问次数。
### FlashAttention 3 FlashAttention 3 是最新的版本,进一步优化了计算效率和内存访问。FlashAttention 3 的主要改进包括: 1. "动态块大小":根据输入序列的长度动态调整块的大小,以实现更好的

相关阅读延伸:小白学大模型:大模型加速的秘密 FlashAttention 1/2/3

在 Transformer 架构中,注意力机制的计算复杂度与序列长度(即文本长度)呈 平方关系 。这意味着,当模型需要处理更长的文本时(比如从几千个词到几万个词),计算时间和所需的内存会急剧增加。最开始的标准注意力机制存在两个主要问题:

  1. 内存占用高 :模型需要生成一个巨大的注意力矩阵 (N×N)。这个矩阵需要被保存在 高带宽内存 (HBM) 中。对于长序列,这很快就会超出 GPU 的内存容量。

  2. 计算效率低 :标准实现会将注意力计算分解成多个独立的步骤(矩阵乘法、softmax 等)。每一步都需要将数据从速度较慢的 HBM 中读取,计算后又写回 HBM。这种频繁的数据移动( 内存读写 )成为了性能瓶颈,导致 GPU 的计算单元(如 Tensor Cores)利用率低下。

什么是 FlashAttention?

FlashAttention 使得处理长达数万甚至数十万个 token 的超长文本成为可能。这解锁了新的应用场景,例如分析法律文档、总结长篇小说或处理整个代码库。

FlashAttention 使得模型的训练和推理速度更快,尤其是在长序列场景下。例如,FlashAttention-2 在长序列上比标准实现快 10 倍,使得训练成本更低,用户体验更好。

最新的 FlashAttention-3 利用了新硬件(如 NVIDIA H100)的 FP8 精度,进一步提升了性能,同时通过特殊的算法保持了计算的准确性,让模型训练更加高效。

FlashAttention v1

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

https://arxiv.org/abs/2205.14135

许多研究提出了 近似注意力 方法,试图通过减少计算量(FLOPs)来提高效率。然而,这些方法通常忽略了GPU不同层级内存(如高速的片上SRAM和相对较慢的高带宽HBM)之间的I/O开销,导致它们在实际运行时并没有带来显著的加速。

FlashAttention的核心思想是 I/O感知 ,即在设计算法时,将数据在不同层级内存之间的读写开销考虑在内。论文指出,在现代GPU上,计算速度已经远超内存访问速度,因此大多数操作都受限于内存访问。FlashAttention通过以下两个关键技术来解决这一问题:

  • Tiling (平铺) :将输入数据(Q、K、V矩阵)分割成小块,并在GPU的片上SRAM中进行计算。这样可以避免将庞大的 N×N 注意力矩阵完整地写入到速度较慢的HBM中。

  • 内存优化 :在反向传播时,FlashAttention 不存储巨大的中间注意力矩阵,而是只保存前向传播中计算出的 Softmax归一化因子 。这样,反向传播时可以利用这些因子在SRAM中快速地重新计算注意力矩阵,从而避免了从HBM读取大矩阵的开销。

GPU内存层级

  • HBM (高带宽内存) :容量大(如A100 GPU的40-80 GB),但速度相对较慢(带宽1.5-2.0 TB/s)。

  • 片上SRAM (静态随机存取存储器) :容量小(每个流式多处理器有192 KB),但速度极快(带宽估计达19 TB/s),比HBM快一个数量级以上。

由于GPU的计算速度增长快于内存速度,许多操作的性能瓶颈在于 内存访问 ,而不是计算本身。因此,如何高效利用快速的SRAM变得至关重要。

运算类型

根据 算术强度 (每字节内存访问的算术运算次数),操作可分为两类:

  • 计算密集型 (Compute-bound) :运算时间由算术操作数量决定,内存访问时间相对较小。例如,大规模矩阵乘法。

  • 内存密集型 (Memory-bound) :运算时间由内存访问次数决定,计算时间相对较小。例如,大多数元素级操作(如激活函数、Dropout)和归约操作(如Softmax、LayerNorm)。

注意力实现改进

给定查询 Q、键 K 和值 V 矩阵,注意力的计算分三步:

  1. 相似度计算

  2. Softmax归一化

  3. 加权求和

标准实现(如“Algorithm 0”所示)将每一步都作为一个独立的GPU核函数,并 物化 (materialize)中间矩阵 S 和 P 到HBM中。

这种实现方式导致了两个主要问题:

  • 巨大的内存占用 :中间矩阵 S 和 P 的大小为 N×N,其内存占用与序列长度 N 的平方成正比。

  • 大量的HBM访问 :由于每个步骤都需要读写HBM,导致I/O开销巨大。论文指出,这种方法对HBM的访问次数是 O(N2) 级别的,这在长序列(通常 N≫d)时会成为主要的性能瓶颈,导致运行时间慢。

FlashAttention旨在减少对GPU高带宽内存(HBM)的读写,实现对 确切注意力 (exact attention)的快速、内存高效的计算。为此,它采用了两种关键技术:

  1. Tiling(分块) :将输入的 Q,K,V 矩阵分成若干小块。然后,在计算过程中,每次只将一小块数据从慢速的HBM加载到快速的片上SRAM进行计算,而不是一次性加载整个大矩阵。

  2. Recomputation(重计算) :为了避免在反向传播时存储 O(N2) 的中间注意力矩阵 S 和 P,FlashAttention只存储 Softmax 的归一化统计量(即 m 和 ℓ)。在反向传播时,它会利用这些统计量, 按需在SRAM中重新计算 必要的注意力矩阵块。

通过Tiling和Recomputation,FlashAttention能够将所有计算步骤(矩阵乘法、Softmax、可选的遮蔽和Dropout)融合成 一个单一的CUDA核函数 。这避免了在每个步骤之间反复地将数据写入HBM。

实现效果

lashAttention在BERT-large模型上的训练速度超过了MLPerf 1.1的记录保持者。与Nvidia的实现相比,FlashAttention的训练时间缩短了 15% ,这证明了其在标准长序列任务上的卓越性能。

FlashAttention在训练GPT-2模型时,相比于流行的HuggingFace和Megatron-LM实现,实现了显著的端到端加速。

  • 与Huggingface相比,速度提升高达 3倍

  • 与Megatron-LM相比,速度提升高达 1.7倍

  • 重要的是,FlashAttention在不改变模型定义的情况下,实现了与基线模型相同的困惑度(perplexity),证明了其 数值稳定性

在Long-Range Arena基准测试中,FlashAttention相比于标准的Transformer实现,实现了 2.4倍 的加速。此外, 块稀疏FlashAttention 的表现甚至优于所有已测试的近似注意力方法,证明了其在处理超长序列时的优越性。

lashAttention的内存占用与序列长度呈 线性关系 ,而标准实现是平方关系。这使得FlashAttention的内存效率比标准方法高出 20倍

FlashAttention v2

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

https://arxiv.org/abs/2307.08691

第一代 FlashAttention 通过利用 GPU 内存层次结构的特性,显著降低了内存占用(从二次方降为 线性 )并实现了 2-4 倍的加速,且没有引入任何近似。

然而,FlashAttention 的效率仍然不如优化的矩阵乘法(GEMM)操作,其浮点运算性能(FLOPs/s)仅能达到理论峰值的 25-40%。这主要是因为 FlashAttention 存在 不优化的工作划分 (work partitioning),导致 GPU 线程块(thread blocks)和线程束(warps)之间的并行度不足、占用率低或产生不必要的共享内存读写。

为了解决这些问题,论文提出了 FlashAttention-2 ,通过以下改进实现了更好的工作划分:

  1. 减少非矩阵乘法(non-matmul)的浮点运算 :虽然这类操作占总 FLOPs 的比例小,但执行起来很慢。

  2. 在序列长度维度上并行化 :即使对于单个注意力头,也将其计算任务分配给不同的线程块,以提高 GPU 的占用率。

  3. 优化线程块内部的工作分配 :在每个线程块内,重新分配线程束之间的工作,以减少通过共享内存进行的通信。

前向传播改进

FlashAttention-2对在线 Softmax 技巧进行了两处微调:

  1. 延迟归一化 :在每个循环迭代中, 不立即 对输出进行归一化。相反,它维护一个 “未缩放” 的中间结果,并在 整个循环结束时 仅进行一次最终的归一化。这减少了每个块的缩放操作,从而减少了非 matmul 的 FLOPs。

  2. 简化统计量 :为反向传播存储数据时,只保存 logsumexp 统计量 L(j)=m(j)+log(ℓ(j)),而不是同时存储最大值 m(j) 和指数和 ℓ(j)。

并行化改进

第一代 FlashAttention 仅在批处理大小和注意力头数量上进行并行化。当序列长度很长时,批处理大小通常很小,导致 GPU 资源的利用率(occupancy)不高。FlashAttention-2 通过 在序列长度维度上增加并行化 来解决这个问题。

  • 前向传播 :FlashAttention-2 将注意力矩阵的行块任务分配给不同的线程块,这些线程块之间无需通信。通过在行维度上并行,当批次大小和注意力头数较小时,GPU 的 SM(流式多处理器)能够被更充分地利用,从而提高整体吞吐量。

  • 后向传播 :类似地,后向传播则在注意力矩阵的列块上进行并行。由于反向传播中的某些更新需要跨线程块通信,作者使用了 原子加法 (atomic adds)来更新共享的梯度 dK 和 dV,确保了线程安全。

除了线程块级别的并行,FlashAttention-2 还优化了 线程块内部 线程束之间的工作分配,以减少共享内存的读写。

  • 前向传播

    • FlashAttention :采用“split-K”方案,将 K 和 V 矩阵的计算任务分配给不同的线程束。这要求所有线程束将中间结果写入共享内存,再进行同步和求和,导致不必要的共享内存访问。

    • FlashAttention-2 :改为将 Q 矩阵的计算任务分配给不同的线程束。每个线程束负责计算 Q 的一个分片与完整的 K 的乘积。这样,每个线程束可以独立地完成其部分输出,而无需与其他线程束进行共享内存通信,从而显著提高了效率。

  • 后向传播 :后向传播的依赖关系更复杂,但 FlashAttention-2 仍然通过避免“split-K”方案来减少共享内存的读写,实现了性能提升。

实现效果

FlashAttention-2 比第一代 FlashAttention 快 1.7-3.0 倍 ,比 Triton 实现的 FlashAttention 快 1.3-2.5 倍

在 A100 GPU 上,FlashAttention-2 在 前向传播 中达到了 230 TFLOPs/s 的峰值,相当于理论最大吞吐量的 73% 。在 后向传播 中,它达到了理论最大吞吐量的 63%。

FlashAttention v3

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

https://arxiv.org/abs/2407.08608

虽然之前的 FlashAttention 通过减少内存读写来加速计算,但它未能充分利用现代硬件(如 Hopper GPU)的新特性。例如,FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。

与 FlashAttention-2 类似,FlashAttention-3 也将任务并行化到不同的线程块(CTA),但其创新之处在于 在单个线程块内部 ,将线程束(warps)划分为不同的角色。

  • 生产者(Producer) :负责将数据从 HBM(全局内存)异步加载到 SMEM(共享内存)。

  • 消费者(Consumer) :在数据加载完成后,从 SMEM 读取数据并执行计算。

生产者和消费者通过一个 循环缓冲区(circular buffer) 进行同步。生产者将数据放入缓冲区,消费者从中取出。当缓冲区中的一个“阶段”被消费后,生产者就可以继续向其中加载新数据。

线程内部的 GEMM 和 Softmax 重叠

在标准 FlashAttention 中,GEMM 和 Softmax 存在顺序依赖:Softmax 必须在第一个 GEMM 计算完成后才能开始,而第二个 GEMM 必须等待 Softmax 的结果。

FlashAttention-3 通过在 寄存器中 使用额外的缓冲区,打破了这种依赖关系。在每次循环中,它 异步 启动下一个 GEMM 的计算,而同时执行当前 GEMM 结果的 Softmax 和更新操作。这样,GEMM 和 Softmax 的执行就可以重叠,提高了效率。

FP8 低精度计算

FP8 的 WGMMA(Warp Group
Matrix-Multiply-Accumulate)指令要求输入矩阵具有特定的
k-major 布局 ,而输入张量通常是 mn-major 布局

FlashAttention-3 选择在 GPU 内核中(in-kernel)进行 转置 。它利用 LDSM/STSM 指令,这些指令能够高效地在 SMEM 和 RMEM(寄存器)之间进行数据传输,并在传输过程中完成布局转置,避免了代价高昂的 HBM 读写。

同于传统的逐张量(per-tensor)量化,FlashAttention-3 对每个 进行单独量化。这使得每个块可以有自己的缩放因子,从而更有效地处理离群值,减少量化误差。

实现效果

FlashAttention-3 的前向传播速度比 FlashAttention-2 快 1.5-2.0 倍 ,后向传播快 1.5-1.75 倍 。FP16 版本的 FlashAttention-3 达到了 740 TFLOPs/s 的峰值,相当于 H100 GPU 理论最大吞吐量的 **75%**。

在处理中长序列(1k 及以上)时,FlashAttention-3 的性能甚至超过了 NVIDIA 自家闭源、针对 H100 优化的 cuDNN 库。

# 学习大模型 & 讨论Kaggle #

△长按添加竞赛小助手

每天大模型、算法竞赛、干货资讯

36000+ 来自竞赛爱好者一起交流~

发布于 2025-12-16 12:33
收藏
1
上一篇:奥飞娱乐巨额减值风波,3家子公司涉嫌经营异常引关注 下一篇:第十域,英雄起源挑战风暴英雄与LOL,另辟蹊径,脑洞大开的新视角竞技盛宴