首页 - 设备 - DepGraph:任意架构的结构化剪枝,适用于CNN、Transformer、GNN等!

DepGraph:任意架构的结构化剪枝,适用于CNN、Transformer、GNN等!

2023-10-04 12:40

1. 内容概要

这项工作提出了一种非深度图算法DepGraph,它以通用架构实现了结构化剪枝,适用于CNN、Transfmers、RNN, GNN 和其他网络。该算法可以自动分析复杂结构耦合,从而正确去除参数,实现网络加速。基于DepGraph算法,我们开发了PyTorch结构化剪枝框架Torch-Pruning。与依赖Masking的“SimulationPruning”不同,该框架实际上可以去除参数和通道,降低模型推理成本。借助DepGraph,研究人员和工程师不再需要纠结于复杂的网络结构,一键轻松完成复杂的模型剪枝。

论文标题:DepGraph:走向任何结构修剪
论文链接:https://www.gsm-guard.net/abs/2301.12900
项目地址:https://www.gsm-guard.net/VainF/Torch-Pruning

2.背景介绍

结构剪枝是一种重要的模型压缩算法。它通过去除神经网络中的冗余结构来减少参数量,从而降低模型推理的时间和空间成本。过去几年,结构化剪枝技术被广泛用于加速各种神经网络,涵盖了ResNet、VGG、Transformer等流行架构。然而,现有的剪枝技术仍然存在一个棘手的问题,即算法实现与网络结构之间的强绑定性,这导致需要针对不同的模型开发专用且复杂的剪枝程序。

那么这种强绑定从何而来呢?在网络中,每个神经元上通常有多个参数连接。如下图1(a)所示,当我们想要通过剪枝神经元(粗体突出显示)来实现加速时,需要同时删除与该神经元连接的多组参数。这些参数构成了化学修剪的最小单位,通常称为组。然而,不同网络架构中参数的分组方式通常差异很大。图1(b)-(d)分别可视化了由残差结构、拼接结构和降维结构引起的参数分组。这些结构甚至可以相互嵌套,从而产生更复杂的分组模式。因此,参数分组也是结构化剪枝算法实现的一个难题。

图1:各种结构中的参数耦合,其中突出显示的神经元和参数连接需要同时修剪

3.本文的方法

3.1 参数分组

在没有自动参数分组的情况下,本文提出了一种称为 DepGraph 的(非深度)图算法来对任意网络中的参数依赖关系进行建模。在结构化剪枝中,同一组中的参数是成对耦合的。当我们想要删除其中之一时,需要删除属于该组的所有参数,以确保结构的正确性。理想情况下,我们是否可以直接构造一个二元分组矩阵G来记录所有参数对之间的耦合关系?如果第i层参数和第j层参数是耦合的,我们用它们来表示。这样,参数分组就可以简单地建模为一个查询问题:

然而,参数是否相互依赖不仅由参数本身决定,还受到它们之间的中间层的影响。然而中间层的结构存在无限的可能性,这使得我们很难根据规则直接判断参数的耦合情况。在分析参数依赖关系的过程中,我们发现了一个重要的现象,即相邻层之间的依赖关系是可以递归的。例如,相邻层A和B之间存在依赖关系,相邻层B和C之间也存在依赖关系。那么我们可以递归地得出A和C之间也存在依赖关系。虽然A,C不是直接连接。这就引出了本文算法的核心,就是“利用相邻层的局部依赖关系,递归推导我们需要的分组矩阵”。这种相邻层之间的局部依赖关系称为依赖图(Dependency Graph),记为。依赖图是稀疏的局部关系图,因为它仅对直接连接的层上的依赖关系进行建模。由此,分组问题可以简化为路径搜索问题。当依赖图中节点 i 和节点 j 之间存在路径时,我们可以知道 i 和 j 属于同一组。

3.2 依赖图建模

然而,当我们将这个简单的想法应用到实际网络中时,我们发现了一个新问题。在结构化剪枝中,同一层可能有两种剪枝方法,即输入剪枝和输出剪枝。对于卷积层,我们可以独立地修剪参数的不同维度,以分别修剪输入通道或输出通道。然而,上述依赖图无法模拟这种现象。为此,我们提出了一种更细粒度的模型描述符,该描述符在逻辑上将每一层分解为输入和输出。基于此描述,一个简单的堆叠网络可以描述为:

这些符号代表网络连接。还记得依赖图模型的关系是什么吗?答案是相邻层的局部依赖!在新的模型描述方法中,“相邻层”的定义更广泛,我们将同一层的输入和输出视为相邻。尽管神经网络包含多种层和算子,但我们仍然从上面的公式中抽象出两种基本类型的依赖关系,即层间依赖关系和层内依赖关系。 )。

层间依赖:首先,我们考虑层间依赖。这种依赖性是由层之间的直接连接引起的,并且与层类型无关。由于一层的输出和下一层的输入对应相同的中间特征(Feature),因此两者需要同时进行剪枝。例如,在通道剪枝中,“某一层的输出通道剪枝”和“相邻后续层的输入通道剪枝”是等价的。

层内依赖:其次,我们分析层内依赖,这与层本身的性质有关。在神经网络中,我们可以将各个层分为两类:第一类层的输入和输出可以独立剪枝,具有不同的剪枝布局(pruning shcme),表示为 or 。例如,对于全连接层的2D参数矩阵,我们可以得到两种不同的布局。在这种情况下,输入和输出在依赖图中是独立且解耦的;而另一种类型的层的输入和输出之间存在耦合,例如element-wise操作、Batch Normalization等。它们的参数(如果有)只有一个剪枝布局,并且影响输入和输出维度。事实上,与复杂的参数分组模型相比,深度网络中的层类型非常有限。我们可以预先定义不同层的剪枝布局来确定图中的依赖关系。

综上所述,依赖图的构建可以基于两个简洁的规则来实现,其形式化描述为:

其中 和 分别表示逻辑“OR”和“AND”。我们在算法1和算法2中总结了依赖图构建和参数分组的过程,其中参数分组是一种递归连通分量(Connected Component)搜索,可以通过简单的深度或有限的宽度来完成搜索完成。

将上述算法应用于特定的残差结构块,我们可以得到以下可视化结果。在具体剪枝时,我们以任意一个节点为起点,例如作为起点,递归搜索所有其他可以访问的节点,并将它们归入同一组进行剪枝。值得注意的是,卷积网络对输入和输出使用不同的剪枝布局(),并且其输入和输出节点在深度图中不存在依赖关系。其他层(例如批量归一化)具有依赖性。

图2:残差结构的依赖图建模

3.3 使用依赖图进行剪枝

图 3:不同稀疏度的图示。该方法根据依赖关系对耦合​​参数进行同步和稀疏,从而保证剪枝后的参数一致且“冗余”

依赖图的一个重要作用是自动对参数进行分组,以实现任何架构的模型剪枝。事实上,依赖图的自动分组能力也可以帮助设计组级剪枝。在结构化剪枝中,属于同一组的参数将同时被删除。在这种情况下,我们需要确保这些被移除的参数是“一致冗余的”。然而,传统训练的网络显然无法满足这一要求。这就需要我们对稀疏参数引入稀疏学习方法。这里也有一个问题。传统的逐层独立稀疏技术实际上无法实现这一目标,因为逐层算法没有考虑层间依赖关系,导致出现图2(b)中的非均匀稀疏情况。为了解决这个问题,我们将参数按照依赖关系进行打包,如图2(c)所示,并进行一致的稀疏训练(将虚线框中的参数推为0),使得耦合后的参数呈现一贯的重要性。在具体技术方面,我们采用简单的L2正则化项,通过为参数组分配不同的正则化权重来进行组稀疏化。

其中,k用于可剪枝参数的切片,用于定位当前参数内的第k组参数子矩阵。上述稀疏算法将得到k组不同稀疏程度的耦合参数。我们选择总体 L2 范数最小的耦合参数。进行修剪。事实上,依赖图还可以用来设计各种更强大的组剪枝方法。但由于稀疏训练、重要性评估等技术不是本文的主要内容,因此这里不再详细讨论。

4 个实验

4.1 基准

本文的实验主要由两部分组成。第一部分测试流行的CIFAR数据集和ImageNet数据集。我们验证了各种模型的结构化剪枝效果。我们使用 DepGraph 和一致稀疏性构建了一个非常简单的剪枝处理器,可以在这两个数据集上取得良好的性能。

4.2 分析实验

一致稀疏性:在分析实验中,我们首先评估了一致稀疏性和逐层独立稀疏性之间的差异。结论与3.3中的分析一致,即逐层算法无法实现取决于参数的一致稀疏性。例如下图中的绿色直方图就代表了传统的逐层稀疏策略。与本文提出的一致稀疏性相比,其整体稀疏性性能较差。

分组策略/稀疏性分配:我们还评估了分组策略。我们考虑了无分组(No Grouping)、卷积分组(Conv-only)和完全分组(Full Grouping)。在该策略中,没有分组执行独立的稀疏参数,卷积分组考虑卷积层并忽略其他参数化层,最终的完整分组在所有参数化层上执行一致的稀疏性。实验表明,完全稀疏可以取得更好的效果,同时剪枝的稳定性更高,不容易出现过度剪枝(性能明显下降)。

另外,如何分配剪枝的稀疏度也是一个重要的问题。我们逐层测试了算法在相同稀疏度(Uniform Sparsity)和可学习稀疏度(Learned Sparsity)下的性能。根据稀疏参数L2 Norm对可学习的稀疏度进行全局排序,以确定稀疏度。该方法可以假设参数冗余不是均匀分布在所有层中,因此可以获得更好的性能。但与此同时,可学习的稀疏性存在过度剪枝的风险,即在某一层中删除过多的参数。

依赖图可视化:下图中,我们可视化了DenseNet-121、ResNet-18和ViT-Base的依赖图以及递归推导得到的分组矩阵。可以发现,不同网络的参数依赖关系复杂且不同。

非图像模型的结构化剪枝:深度模型不仅仅是CNN和Transformer。我们还对其他架构的深度模型进行了初步验证,包括用于文本分类的LSTM、用于3D点云分类的DGCNN和用于图像分类的DGCNN。对于GAT数据,我们的方法取得了令人满意的结果。

5 说话很便宜

5.1 一个最小的例子

在本节中,我们将展示一个最简单的 DepGraph 示例。这里我们要在标准 ResNet-18 的第一层上执行通道剪枝:

通过调用DG.get_pruning_group我们可以得到包含model.conv1的最小剪枝单元pruning_group,然后通过调用prune_group.prune()实现基于组的剪枝。通过打印这个分组,我们可以看到对 model.conv1 进行简单操作所导致的复杂耦合:

这时候如果我们不依赖DepGraph,就需要手动逐层剪枝。然而,这通常需要开发人员非常熟悉网络结构,还需要手动分析和分组依赖关系。

5.2 高级修剪器

基于DepGraph,我们在项目中支持更简单的剪枝器,可以对任何架构进行一键剪枝。目前,我们已经支持传统的权重剪枝(MagnitudePruner)、BN剪枝(BNScalePruner)以及本文中使用的剪枝。分组剪枝(GroupNormPruner)、随机剪枝(RandomPruner)等。使用DepGraph,这些剪枝器可以快速应用于不同的模型,降低开发成本。

6 总结

本文提出了DepGraph,一种适用于任何架构的结构化剪枝技术,极大地简化了剪枝过程。目前,我们的框架已经覆盖了Torchvision模型库中85%的模型,涵盖了分类、分割、检测等任务。总体而言,本文的工作只能作为“任意节点架构剪枝”问题的初步探索性工作。无论是工程还是算法设计都有很大的改进空间。此外,当前大多数剪枝算法都是针对单层设计的,我们的工作为未来“组级剪枝”的研究提供了一些有用的基础资源。

审稿编辑:李谦

-->