诊断与重学习:平衡多模态学习的新策略
摘要
本文由Yake Wei等人在中国人民大学高岭人工智能学院提出,针对多模态学习中的不平衡问题,即模型倾向于特定模态的训练,提出了一种名为“诊断与重学习”的方法。该方法通过评估每个模态的单模态表示空间的可分离性来估计学习状态,并据此软重初始化相应的单模态编码器,从而避免了对信息量少的模态的过度强调,同时增强了较差学习的模态的编码器,有效地平衡和增强了多模态学习。实验结果显示,该方法在多种模态和多模态框架中表现出色,源代码和数据集已公开。
原理
“诊断与重学习”方法的核心在于通过评估单模态表示空间的可分离性来诊断每个模态的学习状态。具体来说,通过聚类算法评估训练和验证集的表示纯度,计算两者之间的纯度差距,以此作为重初始化强度的依据。重初始化过程中,使用tanh函数将纯度差距映射到[0,1)区间,确保重初始化强度与纯度差距成正比。这种方法不仅避免了信息量少的模态的过度训练,还通过软重初始化保留了之前学习到的知识,确保了模态间的协作知识不被完全丢弃。
流程
- 数据输入:每个模态的数据首先通过相应的单模态编码器提取特征。
 - 特征融合:提取的单模态特征通过融合策略(如拼接或跨模态交互)融合成多模态特征。
 - 诊断阶段:通过聚类算法评估每个模态的训练和验证表示纯度,计算纯度差距。
 - 重学习阶段:根据纯度差距计算重初始化强度,对单模态编码器进行软重初始化。
 - 模型优化:融合特征输入最终的多模态分类器,使用交叉熵损失函数优化模型。
 
应用
该方法适用于各种多模态学习场景,包括但不限于多模态识别、音频-视觉场景理解等。其灵活性和有效性使其能够适应不同的多模态框架,包括多模态Transformer。随着多模态学习在各个领域的广泛应用,该方法有望进一步提升多模态模型的性能和鲁棒性。
