【模型压缩】网络层与算子融合

news/2024/5/19 14:23:49 标签: 笔记, 人工智能, 深度学习, 边缘计算

由于深度学习网络层数深,结构复杂,生成的算子数量众多,带了巨大的计算资源在和时间的消耗。业界对于加速算子的计算展开了一定研究,比较经典的方法是将多个算子重新组合成一个新的算子,同时对生成的代码进行底层的性能优化,融合成新算子后计算相对于多个单算子分别计算的好处是可以实现内存复用,并提高GPU、CPU、寄存器等计算资源的利用率。

一个常规的例子:

一个原始的Inception Block,首先将神经网络的conv、BN、Relu三个层融合为了一个层,简称CBR,另外TensorRT还可以对网络做水平组合,水平组合是指将输入为相同张量和执行相同操作的层融合一起,下面的Figure3即是将三个相连的CBR为一个大的的CBR。最后,对于concat层,将contact层的输入直接送入下面的操作中,不用单独进行concat后在输入计算,相当于减少了一次传输吞吐,然后就获得了如Figure4所示的最终计算图。

优化前:

 

优化后:

 

正文开始,在讲算子融合之前,我们先看一下代数转换部分,详情见论文DNNFusion

1.代数转换

代数化简是一个在算子层加速中较通用的手段,即将某种特定的tensor计算转化为数学上等价的tensor计算,转化后的计算方式须较转化之前的计算方式计算量更小。

 作者将代数化简分为结合律、分配律和交换律三类,共包含45条结合律化简,38条分配律化简和66条交换律化简。下图具体列举了一些实例。

 

2.算子融合

算子融合时将多个算子融合成一个新算子,可以实现内存复用,并提高计算机资源的利用率。经典工具如TVM 采用相对固定的schedule模板,本文提出了将算子类型分类,并根据类别进行融合的方法,提高了处理过程中的灵活性和覆盖性。

对于算子类别,本文作者将其分为One-to-One,One-to-Many, Many-to-Many, Reorganize, Shuffle五类,具体如下:

 

这里面我们考虑input和output tensor中不同index上的数字的映射关系,One-to-One表示input上的i位与output上i位存在一对一映射关系;One-to-Many表示input上的i位与output上多位是映射关系;Many-to-Many表示input上多位与output上多位是映射关系;Reorganize和Shuffle同样也是一对一映射关系,但是二者相对One-to-One的不同是,从input到output的映射存在index的转换,比如input的i位数字可能映射到output的j位数字,其中Shuffle相对于Reorganize更为严格,其index映射关系须为permutation类。

2.1 融合策略

在指定融合策略的时候我们只需考虑相邻两个算子是否融合即可,因为当相邻两个算子形成一个融合算子后,我们将这个融合算子视为一个新的独立算子,并通过递归继续考虑其与前后算子是否可以继续融合。

2.2 融合后算子类型

如果A和B都是同一种类型x,那么融合后的AB依旧是x;如何A是x,B是y, 那么选取x和y中较复杂的类型来作为AB的类型。具体来说One-to-One的复杂度最低,Reorganize和Shuffle的复杂度居中,One-to-Many 和Many-to-Many的复杂度最高。具体转换关系见下图:

2.3 融合的时机

上图中的绿色区域代表一定可以融合的场景,红色区域代表(One-to-Many与Many-to-Many, Many-to-Many与Many-to-Many)代表一定不能融的场景,橙色区域部分代表不确定融合后是否有收益,作者用了ML的方式来根据具体场景来判断是否融合。

具体为什么这么做的原因作者给予了解释:

One-to-One与其它:One-to-One类型与其他类型融合可减少多余数据拷贝,占用寄存器少,融合有收益;

  1. Reorganize、Shuffle与其它:Reorganize、Shuffle仅在One-to-One上加了特殊点对点映射函数,本质没有变化,理由同a
  2. One-to-Many与Many-to-Many:Expand+Conv为例,Conv希望可以连续的读取内存数据,而Expand可能将该数据打散,性能劣化;
  3. Many-to-Many 与Many-to-Many: Conv+Conv为例,算子过于复杂,影响cache和寄存器的合理使用,性能劣化
  4. Many-to-Many 与One-to-Many:需要分情况讨论:例如Conv+Expand,如Expand只针对一维扩展,则影响不到conv的计算;而Conv+Resize,Resize会影响多个维度的数据,从而影响conv的计算;所以这种模式是否有收益待定。

3. 其他优化

主要包括两点,第一是一些变形算子的消除,比如在一些情况下变形算子的output只被一个算子用,变形后的data locality的好处并不能抵消拷贝数据带来的时间消耗,这种情况会对变形算子进行消除,下图为一个例子:

 第二点是从全局角度消除不必要format转化,以Transpose为例,从算子图全局考虑往往会有多个transpose等算子对tensor进行format转换,作者希望在保证结果一致的前提下,尽可能消除不必要的format转换。作者使用了贪心算法,从受format影响最大的复杂算子(Conv、GEMM、Softmax等)出发,为其选取最优format,延伸开来,统一其他算子的format并消除不必要format转换。

实验结果:

 

 

 


http://www.niftyadmin.cn/n/4961598.html

相关文章

CSS行内,内部,外部以及优先级

1.内联样式表&#xff1a; 将样式编写到style标签里 <style>.context {color: red;} </style> 2. 行内样式&#xff1a; 在 HTML 标签中使用 style 属性设置 CSS 样式 <div style"font-size: 18px;">行内样式</div> 3.外联样式&#xff1…

2023国赛数学建模思路 - 案例:粒子群算法

文章目录 1 什么是粒子群算法&#xff1f;2 举个例子3 还是一个例子算法流程算法实现建模资料 # 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 什么是粒子群算法&#xff1f; 粒子群算法&#xff08;Pa…

solidity0.8.0的应用案例14:空投合约

空投是币圈中一种营销策略,项目方将代币免费发放给特定用户群体。为了拿到空投资格,用户通常需要完成一些简单的任务,如测试产品、分享新闻、介绍朋友等。项目方通过空投可以获得种子用户,而用户可以获得一笔财富,两全其美。 因为每次接收空投的用户很多,项目方不可能一…

vue使用插件vue-seamless-scroll无限滚动列表

链接: vue-seamless-scroll插件文档 安装vue-seamless-scroll npm install vue-seamless-scroll --save引入 1、main.js全局引入 import scroll from vue-seamless-scroll Vue.use(scroll)2、局部引入 import vueSeamlessScroll from vue-seamless-scrollcomponents: {vueS…

Simulink仿真模块 - Clock

Clock&#xff1a;显示并提供仿真时间 库&#xff1a; Simulink / Sources 模型为&#xff1a; 说明 Clock 模块在每个仿真时间步输出当前仿真时间。此模块对需要仿真时间的其他模块非常有用。 当在离散系统中需要当前时间时&#xff0c;请使用Digital Clock模块。 实例 模块…

使用mysql:5.6和 owncloud 镜像,构建一个个人网盘

一.拉取镜像 docker pull mysql:5.7 docker pull owncloud 二.创建容器 1.MySQL容器 docker run -d --name db1 -p 3306:3306 -e MYSQL_ROOT_PASSWORD123456. -e MYSQL_DATABASEowncloud -e MYSQL_USERowncloud -e MYSQL_PASSWORDowncloud mysql:5.7 docker run: 创建和运行…

商业大厦烟感监控,效果出乎意料!

烟感监控是现代安全技术中至关重要的一环&#xff0c;其在预防火灾、保护生命和财产方面发挥着关键作用。通过使用先进的烟雾探测器和智能报警系统&#xff0c;烟感监控能够及早发现烟雾和火源&#xff0c;并在火灾爆发前提供必要的警示和警报。 通过其精密的传感技术和联动装置…

Java中使用MongoTemplate 简单操作MongoDB

Autowired private MongoTemplate mongoTemplate; User&#xff1a;封装的对象 插入&#xff1a;mongoTemplate.insert(user); 根据id查询&#xff1a;mongoTemplate.findById(id, User.class); 查询所有&#xff1a;mongoTemplate.findAll(User.class); 条件查询&#…