SNAS论文阅读笔记
条评论从Google第一篇AutoDL的论文开始,在这方面的论文层出不穷,目测已经成为一个很活跃的研究方向,从一开始上千块TPU发展到现在一天内就能搜到SOTA的结果。NAS已经火了有段时间了,近期才有时间系统的看一下AutoDL相关的内容,SNAS是我最近看的一篇论文,这里做下简单的整理和总结,主要也是为自己留下一份知识储备,方便以后快速回顾。
简介
这篇论文主要有两个贡献点或者创新点,一个是使用Gumbel分布将离散的One-hot同连续可微空间联系起来,再一个是增加资源限制的计算,搜索时会同时考虑进计算量,参数和内存访问的代价进去。也是通过cell的方式搭建网络。
实现
由于没有官方的代码,很多实现也是看论文以及darts的代码结合理解的。论文附录里有个图画的非常好,顺带说一下基本的流程和结构。
这里i,j都表示的是node,也就是tensor,$O^n_{i,j}$表示的是连接i,j之间的op的第n个候选op,可以是conv或者pooling或者其他各种op的组合。$z^n_{i,j}$可以理解为各个候选op对于输出结果的权重,不同的实现这里的处理差异很大,本文是使用了gumbel softmax将离散的one-hot转变为连续的gumbel分布,从而可以进行梯度下降。gumbel softmax在很多论文中都有提到,是一个非常重要的连接离散和连续空间的桥梁。one-hot简单说就是一个0-1向量,只有一个值为1,其他都为0.
这样的一个从i到j的结构在darts里是被称为mixop,这里暂且也叫它mixop吧。mixop的候选可以有很多,但太多的候选会导致一个是训练时间变长,再一个就是显存不一定够用。这里通常都是延续enas里面对于这种候选op的设定。
搜索空间和网络采样
$O_{i,j}$代表连接node(i)和node(j)的候选op,对于$x_j$就可以表示成如下:
$$x_j=\sum_{i<j}Õ_{i,j}(x_i)$$
这里的$Õ_{i,j}$表示的是选中的operation。由于最终网络只会选定一个op来连接i和j,这里论文将上式改写成每个候选op的结果乘以一个one-hot的结构。
$$x_j=\sum_{i<j}Õ_{i,j}(x_i)=\sum_{i<j}Z^T_{i,j}O_{i,j}(x_i)$$
这个地方的$Z^T_{i,j}$可以看作是各个候选op的权重,权重的不同将会直接影响mixop的连接,也相当于改变了cell的结构。所以问题就由找到一个合适的网络结构变成找到一个合适的one-hot变量Z。
结构和op的参数学习
cell结构
由于论文里也提到snas的cell结构和mixop候选同enas,darts这些网络很接近,所以这里就假设是一样的,顺便也说一下cell结构是什么样的。
cell主要有网络结构的共享和简化网络结构抽象的功能,一般将cell抽象为normal和reduction两种,对于normal的结构,输入和输出的channel数保持不变,对于reduction的cell,相当于过了一次stride=2,输出的channel数也会加倍。这些都是简单的一些设置,完全可以走不同的设置看看效果。
enas的cell有两个输入,两个输入为他之前的cells,cell的输出为内部的各个mixop的concat。然后继续接下一个cell,以此反复。
在所有cell的最前面会有stem的结构,这个结构简单的话可以用一个大kernel的conv,通常会对输入的image的channel调整到一个合适的数值。像resnet等的第一个conv都是7x7的大kernel卷积。
最后的部分通常就是head了,head可以实现很多种,如针对不同问题的分类的head,或者可以接detection的head等等,不过backbone部分就是使用snas或者其他方法找到的网络结构了。
参数学习方法, Gumbel Softmax
继续前面的搜索最优one-hot变量$Z_{i,j}$的过程。由于$Z$是一个离散的变量,梯度到了这一层没有办法传播,为了能够将搜索$Z$的过程,这里让$Z$成为gumbel-softmax的输出,表达式如下
$$Z^k_{i,j}=f_{\alpha_{i,j}}(G^k_{i,j})
=\frac{exp((log\alpha^k_{i,j} + G^k_{i,j})/\lambda)}{\sum^n_{l=0}exp((log\alpha^l_{i,j} + G^l_{i,j})/\lambda)}$$
这一层很像softmax的表达式,多了$G^k_{i,j}$和$\lambda$。
$G^k_{i,j}=-log(-log(U^k_{i,j}))$是第k个gumbel随机变量,这里$G$就是服从gumbel分布的变量了,$U^k_{i,j}$就是第k个均匀分布的随机变量,$U$就是服从均匀分布。
$\lambda$被称为softmax的温度,这个值越小,输出的$Z^k_{i,j}$就越接近一个one-hot的输出,所以可以猜到一开始的温度是一个相对较大的值,随着搜索的进行,温度会逐渐变小,直到$Z$完全变为one-hot的结果。论文里没有提到如果对$\lambda$进行调整,我试了一些方法,大概$\lambda$从5.0或10.0开始下降,大概到1e-3基本可以认为就是one-hot的变量了。
这里面还有一个变量$\alpha$就是需要训练得到的结构参数了,这个值直接影响输出的one-hot的值,既然有两种cell,可以理解为也有两组$\alpha$的值了,一组是normal的,一组是reduction的。
credit assignment
文中提到这个地方感觉像是被reviewer问到了,为了同enas做对比,主要是抓住强化学习难以及时将reward分配到每个mixop上面,而snas因为将离散的选择问题抽象为连续的gumbel分布,所以每个mixop上的梯度就可以当作其reward,当然这是把这个过程同强化学习的过程做对比,但所有走stochastic的方法都有这方面的便利,这里也就不展开讨论了。
资源限制
文中使用了一些对op的计算量和资源使用的抽象,通过对这些资源花费cost的优化达到资源限制的效果,这里记录下论文里的方法。
论文中记录了三个主要的候选参数,分别是1)参数量,2)float操作数、3)memory访问。参数量相当于计算weights部分的数值,float操作数按照各个op计算的过程估计,memory访问文中叫MAC,是输入和输出的尺寸。有了这些数值后,加上一个系数就。之前的loss就直接加上这个cost了。
$$E_{Zp_\alpha(Z)}[L_\theta(Z)+C(Z)]=E_{Zp_\alpha(Z)}[L_\theta(Z)+E_{Z~p_\alpha(Z)}[C(Z)]]$$
这个cost也是跟前面的one-hot相关的一个值,因此对于loss函数整体的优化还是会落到对$Z$的优化上,就和前面接起来了。
$$C(Z)=\sum_{i,j}C(Z_{i,j})=\sum_{i,j}Z^T_{i,j}C(O_{i,j})$$
论文里处理的比较粗暴,也没太考虑不同op差异巨大的问题,这里可以使用一些方法做一定的normalize,可能效果会更好
效果
这里简单展示一下论文里提到的效果。
cifar10的实验结果
Architecture | Test error | Params(M) | Search Cost (GPU days) |
---|---|---|---|
snas(single-level) + mild constraint | 2.98 | 2.9 | 1.5 |
snas(single-level) + moderate constraint | 2.85 | 2.8 | 1.5 |
snas(single-level) + aggressive constraint | 3.10 | 2.3 | 1.5 |
imagenet的实验结果
Architecture | Test error(top1) | Test error(top5) | Params(M) | Search Cost (GPU days) |
---|---|---|---|---|
snas(single-level) + mild constraint | 27.3 | 9.2 | 4.3 | 1.5 |
思考
snas上能够看到enas,darts的影子,也能嗅到MNAS和更之后的FBNet的味道,这里重点也记录一下这几篇工作的一些要点,为进一步阅读抛砖引玉。
ENAS,DARTS
enas感觉是第一篇把nas推向实际一点的工作,从enas开始,网络搜索开始逐渐平民化,不再是用超多GPU或TPU怼出来的离实际很远的东西。enas里面的cell,和mixop的方法到现在都一直沿用,enas的代码也开源,对于推动nas发展起了很重要的作用,enas是使用强化学习的一种方法,有很多思想可以借鉴。
darts是这两年把网络搜索压缩到单卡单天的一个工作,代码也是开源的,使用的也是基于梯度的方法,沿用了很多enas的设置,也采用cell和mixop的方法对网络搜索建模,之前跑过他们的代码,对于想快速学习是非常不错的选择。
MNAS,FBNet
MNAS和FBNet还有很多这方面的工作,属于将实际落地中非常看中的延时和模型尺寸考虑进去,不仅注重网络搜索过程,也将实际手机上运行的延时结果反馈给网络搜索过程。
SNAS是一篇承上启下的论文,现在基于梯度的方法大有增长的趋势,到了darts和snas基本已经到了1个GPU day的水平,使用多目标优化的方法如把硬件和op的限制考虑进网络搜索中的也在慢慢兴起,未来这种类型的结合也肯定会层出不穷,如果这个时候想切入nas的领域,snas是一篇不错的切入点,缺点就是作者没有开出源码,如果有源码那就更好了。
至此,撒花~
本文标题:SNAS论文阅读笔记
文章作者:throneclay
发布时间:2019-07-18
最后更新:2022-08-03
原始链接:http://blog.throneclay.top/2019/07/18/SNAS/
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC-SA 3.0 CN 许可协议。转载请注明出处!