"公理训练:教授Transformer因果推理的新范式"
摘要
本文探讨了文本型AI系统在现实世界中交互时因果推理的重要性。由于干预性数据生成成本高昂,研究了代理如何从被动数据中学习因果推理。具体而言,提出了一种基于公理训练的设置,其中代理通过多个因果公理(或规则)的演示来学习,而不是将公理作为归纳偏差或从数据值中推断。关键问题在于代理是否能从公理演示中泛化到新场景。例如,如果一个transformer模型在小型图上训练了因果传递性公理的演示,它是否能泛化到大型图上的传递性公理应用?基于一种新颖的公理训练方案的结果表明,这种泛化是可能的。考虑了在给定因果图结构的情况下推断一个变量是否导致另一个变量的任务。发现一个6700万参数的transformer模型,在训练了线性因果链(以及一些噪声变体)后,可以很好地泛化到新的图类型,包括更长的因果链、逆序的因果链和带有分支的图;即使在未明确针对这些设置进行训练的情况下也是如此。我们的模型性能与许多大型语言模型如GPT-4、Gemini Pro和Phi-3相当(甚至更好)。总体而言,我们的公理训练框架提供了一种从被动数据中学习因果推理的新范式,只要能够生成足够的演示,就可以用于学习任意公理。
原理
本文提出的公理训练方法的核心在于通过符号演示来教授transformer模型因果公理。具体来说,因果公理被表示为一个符号三元组⟨前提, 假设, 结果⟩,其中假设是一个因果声明,前提是决定该声明真伪的任何相关信息(结论)。结论可以简单地是“是”或“否”。例如,来自[16]的碰撞公理可以表示为:前提:“A ⊥⊥ B, B ̸⊥⊥ C, A ̸⊥⊥ C”;假设:“A是否导致C?”;结论为“是”。基于这个模板,可以通过改变变量名称、变量数量、顺序等生成大量合成三元组。关键问题是:如果一个模型在这样的数据上训练,它是否会学会将公理应用于新场景?
为了回答这个问题,本文训练了一个从头开始的transformer模型,使用因果无关公理[11]的符号演示。为了评估泛化能力,模型在大小为3-5节点的简单因果无关公理链上训练,并在多个不同的泛化方面进行测试,包括长度泛化(大小为7-15的链)、名称泛化(更长的变量名称)、顺序泛化(带有逆向边或节点乱序的链)和结构泛化(带有分支的图)。发现模型在简单链上训练后能够泛化到在更大的链上多次应用公理,但无法泛化到更复杂的场景,如顺序或结构泛化。然而,当模型在简单链和一些边随机逆向的链的组合数据集上训练时,发现模型在所有类型的评估场景中都能很好地泛化。扩展了在NLP任务中长度泛化的发现[17, 7, 13, 10],发现位置嵌入在确保因果泛化跨长度和其他方面起着关键作用。我们最好的模型没有位置编码,尽管我们发现正弦编码在某些场景中也工作得很好。
流程
- 数据生成:基于特定的公理,我们可以将给定前提的假设映射到其正确的标签(‘是’或‘否’)。为了创建训练数据集,我们枚举所有可能的三元组{(P, H, L)}N,其中P是前提,H是假设,L是标签(是/否)。
 - 模型训练:使用生成的数据集,定义基于每个三元组地面真实标签的损失函数,表示为E P,H,L∼Ptrain − log(P(L|P, H))。模型在不同的位置编码(如无位置编码、可学习位置编码和正弦位置编码)下进行训练。
 - 数据扰动:为了增强模型的泛化能力,引入多层次的训练数据扰动,包括节点名称、因果图拓扑和长度级别的扰动。
 - 评估:设计了多种评估集,涵盖了因果序列的长度、节点名称变化、顺序变化和分支复杂性,以评估模型在未见过的复杂结构上的泛化能力。
 
应用
本文提出的公理训练方法不仅限于因果推理任务,还可以应用于其他基于公理的正式系统,如逻辑推理任务。此外,该方法为训练语言模型提供了一个新的范式,可能有助于提高模型在因果推理等复杂任务上的性能。通过在预训练中融入因果公理演示,即使是小型语言模型也可能在因果任务上表现得与大型模型如GPT-4相当。
