"解释蒸馏:提升深度学习模型泛化能力的新策略"

Explanation is All You Need in Distillation: Mitigating Bias and Shortcut Learning

摘要

本文探讨了深度神经网络(DNNs)中的偏差和虚假相关性导致的快捷学习问题,特别是在分布外(OOD)泛化方面。传统的解决方法通常需要在训练过程中使用无偏数据或进行超参数调整以对抗快捷学习。本文提出了一种新的方法——解释蒸馏(Explanation Distillation),该方法不依赖于无偏数据,允许任意大小的学生网络学习无偏教师网络(如视觉-语言模型或处理去偏图像的网络)的决策原因。实验结果表明,仅通过解释(例如通过层相关传播LRP)蒸馏训练的神经网络能够高度抵抗快捷学习,超越了组不变学习、解释背景最小化和替代蒸馏技术。在COLOURED MNIST数据集中,LRP蒸馏达到了98.2%的OOD准确率,而深度特征蒸馏和IRM分别达到了92.1%和60.2%。在COCO-on-Places数据集中,LRP蒸馏在分布内和OOD准确率之间的不良泛化差距仅为4.4%,而其他两种技术的差距分别为15.1%和52.1%。

原理

解释蒸馏的核心在于利用一个已经训练好的、具有高度分布鲁棒性的教师模型来指导学生模型的训练,从而避免学生在训练过程中学习到数据中的偏差或快捷路径。具体来说,教师模型在训练过程中不会受到数据偏差的影响,因此其决策是基于正确的特征而非虚假的相关性。学生模型通过学习教师模型的决策解释(如LRP生成的相关性热图),可以模仿教师的决策过程,同时避免学习到数据中的偏差。这种方法的关键在于,学生模型不仅学习教师的输出结果,更重要的是学习教师是如何得出这些结果的,从而在面对OOD数据时能够保持良好的泛化能力。

流程

解释蒸馏的工作流程包括以下几个步骤:

  1. 选择一个已经训练好的、无偏的教师模型,该模型能够准确分类训练数据而不受其虚假相关性的影响。
  2. 定义一个学生模型,该模型的大小和结构可以根据需要进行调整。
  3. 在训练过程中,教师模型提供其对输入样本的决策解释(如LRP热图),学生模型则尝试生成与教师模型相似的解释热图。
  4. 使用一个损失函数来衡量教师和学生模型解释热图之间的差异,并通过反向传播调整学生模型的参数,使其解释热图尽可能接近教师模型的解释热图。
  5. 重复上述过程,直到学生模型能够在不依赖数据偏差的情况下,准确地模仿教师模型的决策过程。

例如,在COLOURED MNIST数据集中,教师模型是一个在随机颜色版本的MNIST上训练的ResNet34,学生模型是一个在COLOURED MNIST 100%上训练的ResNet18。通过解释蒸馏,学生模型能够学习到教师模型如何基于数字形状而非颜色进行分类,从而在OOD数据上保持高准确率。

应用

解释蒸馏技术具有广泛的应用前景,特别是在需要高度泛化能力和抵抗数据偏差的场景中。例如,在医疗图像分析、自动驾驶、金融风险评估等领域,模型需要能够准确识别和处理各种未见过的数据分布。解释蒸馏通过确保模型学习到正确的决策逻辑,而非仅仅依赖于训练数据中的虚假相关性,从而提高了模型在这些领域的可靠性和安全性。此外,由于解释蒸馏不依赖于无偏数据,它可以在数据稀缺或难以获取无偏数据的情况下发挥重要作用。