探索非线性网络的线性近似:一种新颖的泛化边界方法

A Generalization Bound for Nearly-Linear Networks

摘要

本文探讨了非线性网络作为线性网络的扰动问题,并提出了一种新颖的泛化边界,该边界对于接近线性的网络变得非空洞。与先前提出非空洞泛化边界的工作相比,本文的边界具有先验性,即在实际训练之前就可以评估这些边界,而不需要依赖于尚未完全理解的隐式偏差现象。本文的主要贡献是一个泛化边界,适用于全连接网络、梯度下降(梯度流)以及使用均方误差损失的二分类问题。实验验证了该边界在降采样的MNIST数据集上的有效性。

原理

本文的核心思想是将非线性网络视为线性网络的扰动。通过这种方法,可以利用线性网络的参数来构建一个代理模型,该模型的泛化差距可以通过经典方法有意义地界定。具体来说,如果初始权重固定,通过最小化数据集上的平方损失来学习线性网络的结果仅由YX⊤和XX⊤决定,这两个矩阵的参数数量远少于网络中的总参数数量,使得基于计数的经典方法变得有意义。

流程

论文中提出的工作流程包括以下步骤:

  1. 模型定义:定义一个全连接的LeakyReLU网络,具有L层。
  2. 数据准备:数据点(x, y)来自分布D,假设所有x在单位球内,所有y等于±1。
  3. 训练过程:使用梯度流在白化数据上训练模型,优化平方损失。
  4. 推理过程:为了与训练过程一致,模型输出在点x上的结果为fϵ θ(Σ−1/2x),其中Σ是特征相关矩阵。
  5. 性能度量:采用误分类损失作为风险函数,定义了经验风险和分布风险。

应用

本文提出的泛化边界适用于全连接网络、梯度下降(梯度流)以及使用均方误差损失的二分类问题。这些边界在接近线性的网络中变得非空洞,因此具有广泛的应用前景,特别是在需要理解深度学习模型泛化能力的场景中。