FlashAttention-3:突破GPU极限,加速大型语言模型的新纪元

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

摘要

本文介绍了一种名为FlashAttention-3的新型注意力机制,旨在通过利用GPU硬件的异步性和低精度特性,显著提高大型语言模型和长上下文应用的处理速度。FlashAttention-3通过三种主要技术实现了这一目标:利用Tensor Cores和Tensor Memory Accelerator(TMA)的异步性,通过warp-specialization重叠计算和数据移动;交错块矩阵乘法和softmax操作;以及利用FP8低精度的硬件支持进行块量化和不连贯处理。实验结果显示,FlashAttention-3在H100 GPU上实现了1.5-2.0倍的加速,FP16精度下达到740 TFLOPs/s(75%利用率),FP8精度下接近1.2 PFLOPs/s。此外,FP8版本的FlashAttention-3在数值误差方面比基准FP8注意力降低了2.6倍。

原理

FlashAttention-3的核心在于充分利用了NVIDIA Hopper架构中的Tensor Cores和Tensor Memory Accelerator(TMA)的异步特性,以及FP8低精度计算的优势。通过warp-specialization技术,FlashAttention-3能够将数据移动和计算操作分离到不同的warp中,从而实现计算和数据加载的重叠,有效隐藏了内存和指令发布的延迟。此外,FlashAttention-3通过交错块矩阵乘法(GEMM)和softmax操作,进一步优化了计算流程,特别是在处理softmax这类低吞吐量操作时,通过与GEMM操作的重叠执行,提高了整体的计算效率。最后,通过块量化和不连贯处理技术,FlashAttention-3在FP8精度下实现了更低的数值误差,同时保持了高吞吐量的计算性能。

流程

FlashAttention-3的工作流程主要包括三个阶段:首先,通过warp-specialization技术,将数据加载和计算任务分配到不同的warp中,实现数据移动和计算的重叠;其次,在计算阶段,通过交错执行块矩阵乘法和softmax操作,优化计算流程,提高硬件利用率;最后,在输出阶段,通过块量化和不连贯处理技术,减少FP8精度下的数值误差,同时保持高吞吐量的计算性能。具体实现中,FlashAttention-3通过算法1和算法2详细描述了前向传播的流程,包括数据加载、计算和输出的具体步骤。

应用

FlashAttention-3的优化技术不仅适用于当前的NVIDIA Hopper架构,也具有广泛的适用性,可以推广到其他支持异步执行和低精度计算的GPU架构。其高效的计算性能和低数值误差特性,使得FlashAttention-3在处理长上下文任务、高分辨率图像、音频和视频处理以及大型代码库分析等应用场景中具有巨大的潜力。随着大型语言模型和长上下文应用的不断发展,FlashAttention-3有望在这些领域发挥重要作用,推动相关技术的进步和应用的广泛部署。