"测试时训练:RNN的新纪元 - 通过自监督学习实现高效序列建模"

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

摘要

本文提出了一种新型的序列建模层,称为Test-Time Training (TTT)层,其核心思想是将隐藏状态本身作为一个机器学习模型,并且更新规则为自监督学习的一个步骤。TTT层在测试序列上的更新过程等同于在测试时训练模型。文章介绍了两种实例化:TTT-Linear和TTT-MLP,它们的隐藏状态分别是线性模型和两层多层感知机(MLP)。实验表明,这两种实例化在125M到1.3B参数的规模上,与强大的Transformer和现代RNN模型Mamba相比,都能匹配或超越基准性能。特别是在长上下文处理中,TTT-Linear和TTT-MLP显示出更大的潜力,为未来的研究指明了一个有希望的方向。

原理

TTT层的工作原理是将隐藏状态设计为一个模型,更新规则为自监督学习的一个步骤。具体来说,隐藏状态被视为一个模型f,其权重为W,更新规则是基于自监督损失ℓ的梯度下降步骤。在测试序列上更新隐藏状态的过程等同于在测试时训练模型f。这种设计使得TTT层能够在保持线性复杂度的同时,增强隐藏状态的表达能力,从而在长上下文处理中表现出色。

流程

TTT层的工作流程包括初始化隐藏状态为模型f的参数W0,然后在每个输入令牌上执行以下步骤:

  1. 使用当前的隐藏状态Wt-1和输入令牌xt计算输出令牌zt。
  2. 根据自监督损失ℓ计算梯度,并更新隐藏状态Wt。
  3. 重复上述步骤,直到处理完所有输入令牌。 具体实例化TTT-Linear和TTT-MLP在处理长序列时,能够通过条件更多的令牌来持续降低困惑度,而Mamba在16k上下文后则无法做到这一点。

应用

TTT层的关键内容在序列建模领域具有广泛的应用前景。由于其能够在测试时进行训练,TTT层特别适合处理动态和变化的数据流,如实时语音识别、在线文本分析和动态环境下的机器人控制等。此外,TTT层的线性复杂度使其在处理长序列时比传统的Transformer模型更高效,因此在需要处理大量数据的场景中具有显著优势。