加速化学反应模型推理:投机解码技术在工业应用中的突破

Accelerating the inference of string generation-based chemical reaction models for industrial applications

摘要

本文介绍了一种加速基于字符串生成化学反应模型的推理方法,特别适用于工业应用中的计算机辅助合成规划系统。该方法通过投机解码技术,将查询字符串的子序列复制到目标字符串中,从而加速自回归SMILES生成器的推理速度。研究团队在Pytorch Lightning中重新实现了分子变换器,并在反应预测和单步逆合成中实现了超过3倍的推理加速,且未损失准确性。

原理

本文提出的方法基于投机解码(speculative decoding)技术,这是一种用于加速大型语言模型推理的通用技术。在每个生成步骤中,模型会附加一些草稿序列到当前生成的序列,并检查模型是否“接受”这些草稿令牌。如果草稿序列长度为N,在最优情况下,模型在一次前向传递中可以添加N+1个令牌到生成的序列,而在最差情况下,它只添加一个令牌,类似于标准的自回归生成。投机解码不改变预测序列的内容,与令牌逐个解码相比,没有任何影响。

流程

投机解码的工作流程包括以下步骤:

  1. 准备草稿序列:从源序列(反应物令牌)中提取多个子序列,使用滑动窗口和步长为1的方法。
  2. 生成目标字符串:在每个生成步骤中,模型尝试所有草稿序列,以查看是否可以从其中一个复制最多4个令牌。
  3. 验证接受率:计算草稿令牌的接受率,确保加速的同时不影响模型的准确性。

例如,在产品预测中,模型在生成目标字符串之前,会准备一个包含源序列子序列的列表,然后在每个生成步骤中,模型可以尝试所有草稿序列,以查看是否可以从其中一个复制最多4个令牌。

应用

该方法的应用前景广泛,特别是在计算机辅助合成规划系统中。通过加速模型的推理过程,可以显著提高化学反应预测和单步逆合成的效率,从而加速药物发现和其他化学工业应用的进程。此外,该方法的通用性意味着它也可以应用于其他需要快速推理的领域。