"WGQA:提升Transformer模型推理效率的新策略"

Weighted Grouped Query Attention in Transformers

摘要

本文由Sai Sena Chinnakonduru和Astarag Mohapatra提出,针对Transformer语言模型中的注意力机制进行了创新改进。随着模型规模的扩大,传统的多头注意力(MHA)在硬件内存限制下,推理成本显著增加。为了解决这一问题,作者提出了Weighted Grouped-Query Attention(WGQA),通过引入新的可学习参数,使得模型在微调过程中能够进行加权平均,从而在不增加额外推理开销的情况下,实现了与MHA相当的性能,平均提升了0.53%。此外,论文还探讨了不同参数聚合方式对模型性能的影响,并通过实验验证了WGQA在多个数据集上的有效性。

原理

WGQA的核心创新在于对Grouped-Query Attention(GQA)的改进。在GQA中,查询头被分组,从而减少了键值头的数量。WGQA在此基础上引入了可学习的权重参数,这些参数应用于T5解码器注意力块中的每个键和值头,使得模型在微调时能够根据这些权重进行加权平均。这种加权机制使得模型能够更好地学习分组机制,从而在训练过程中增强性能。具体来说,WGQA通过在键值头的聚合过程中引入额外的标量或向量参数,实现了对键值头的动态加权,这种加权方式在推理过程中不增加任何额外开销。

流程

WGQA的工作流程主要包括以下几个步骤:

  1. 初始化:使用预训练的T5模型权重进行初始化。
  2. 分组:在解码器的注意力块中,将查询头分组,并对键值头进行加权。
  3. 微调:使用AdamW优化器对模型进行微调,引入新的可学习参数,这些参数在微调过程中被优化。
  4. 评估:在多个数据集上评估模型的性能,包括CNN/Daily Mail、WMT 2014德英翻译和Multi-news数据集。

具体示例:在T5-base模型上,WGQA通过引入额外的权重参数,实现了在Multi-news数据集上ROUGE分数从43.5(GQA)提升到43.7(WGQA),在CNN/Daily Mail数据集上R1分数从41.7(GQA)提升到41.9(WGQA)。

应用

WGQA的提出为大型语言模型(LLM)的推理效率提供了新的解决方案。通过在现有模型基础上引入可学习的权重参数,WGQA不仅提高了模型的性能,还保持了较低的推理成本。这一技术有望在需要高效推理的语言模型应用场景中得到广泛应用,如实时翻译、自动摘要和对话系统等。随着模型规模的进一步扩大,WGQA的性能优势将更加明显,为未来语言模型的发展提供了新的方向。