DMTG:革命性的多任务学习分组方法,一次训练实现高效任务分组与模型优化

DMTG: One-Shot Differentiable Multi-Task Grouping

摘要

本文介绍了一种名为DMTG(One-Shot Differentiable Multi-Task Grouping)的创新方法,旨在解决多任务学习(MTL)中任务数量庞大的问题。DMTG通过同时识别最佳任务组并训练模型权重,充分利用高阶任务亲和性,从而提高训练效率并减少由顺序处理可能导致的次优解。该方法将多任务分组(MTG)问题形式化为一个可微分的网络剪枝问题,通过学习一个连续放松的可微分分类分布,确保每个任务被唯一且独占地分类到一个分支中。实验结果表明,DMTG在CelebA和Taskonomy数据集上表现出色,具有高效性和准确性。

原理

DMTG的核心在于将多任务分组问题转化为网络剪枝问题。该方法首先设置K个分支,每个分支连接到所有N个任务头,以探索高阶任务亲和性。随后,通过学习一个放松的可微分分类分布,逐步将KN个头剪枝至N个,确保每个任务被唯一且独占地分类到一个分支中。这一过程通过Gumbel softmax优化,使得分类分布在训练过程中连续放松,从而实现高效的联合优化。DMTG的关键创新在于其能够在一次训练中同时完成任务分组和模型权重训练,避免了传统方法中分组和训练的顺序依赖,从而提高了效率和性能。

流程

DMTG的工作流程包括初始化、剪枝和收敛三个阶段。在初始化阶段,每个任务组(分支)连接到所有任务头,确保高阶任务亲和性的全面探索。在剪枝阶段,通过学习可微分的分类分布,逐步减少任务头的数量,直到每个任务被唯一分类到一个分支。在收敛阶段,每个任务独占地属于一个任务组,实现了任务的优化分组。具体示例中,如将4个任务分类到3个组中,初始阶段每个组连接到所有任务头,训练过程中逐步剪枝,最终每个任务独占一个组。

应用

DMTG方法在多任务学习领域具有广泛的应用前景,特别是在需要同时处理大量任务的复杂系统中,如自动驾驶、图像识别和自然语言处理等。其高效的任务分组和模型训练能力,使得DMTG能够显著提升多任务处理的效率和性能,为未来的智能系统提供了强大的技术支持。