DisCo-Diff: 通过离散潜在变量增强连续扩散模型的新方法

DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents

摘要

本文介绍了一种名为DisCo-Diff的新型扩散模型,该模型通过引入离散潜在变量来增强连续扩散模型。DisCo-Diff通过引入可学习的离散潜在变量,简化了将复杂数据分布编码为单一高斯分布的学习问题。这些离散潜在变量通过编码器网络推断,并与扩散模型(DM)和编码器进行端到端的训练。DisCo-Diff不依赖于预训练网络,因此具有普遍适用性。实验证明,引入离散潜在变量可以显著简化DM的复杂噪声到数据映射学习,减少生成ODE的曲率,并在多个图像合成任务和分子对接任务中提高模型性能。

原理

DisCo-Diff的核心创新在于通过引入离散潜在变量来增强连续扩散模型。传统的扩散模型(DM)通过扩散过程将数据编码为简单的高斯分布,但将复杂、多模态的数据分布编码为单一连续高斯分布是一个挑战。DisCo-Diff通过引入离散潜在变量,将单一的噪声到数据映射分解为一组更简单的映射,每个映射的生成ODE曲率更小。此外,DisCo-Diff还包括一个自回归变换器,用于建模离散潜在变量的分布,这使得模型能够处理只有少数离散变量和小码本的情况。

流程

DisCo-Diff的工作流程分为两个阶段:首先,通过编码器网络推断离散潜在变量,并将其与扩散模型(DM)和编码器进行端到端的训练。其次,训练一个自回归模型来捕捉离散潜在变量的分布。在推理时,首先从自回归模型中采样离散潜在变量,然后使用ODE或SDE求解器进行采样。具体步骤包括:1) 使用编码器推断离散潜在变量;2) 将离散潜在变量输入到DM中进行训练;3) 训练自回归模型来捕捉离散潜在变量的分布;4) 在推理时,先采样离散潜在变量,再使用DM生成新样本。

应用

DisCo-Diff在多个领域具有广泛的应用前景,特别是在图像合成和分子对接任务中。在图像合成方面,DisCo-Diff能够生成高质量的图像,并在多个ImageNet生成基准上达到最先进的性能。在分子对接任务中,DisCo-Diff通过学习指示关键原子的离散潜在变量,提高了对接性能。此外,DisCo-Diff的框架还可以应用于其他“迭代”生成模型,如Poisson Flow生成模型,显示出类似的性能提升。