This repository has been archived on 2020-04-25. You can view files and clone it, but cannot push or open issues or pull requests.
ml/nn/ResNet.md
2020-02-23 22:14:06 +08:00

69 lines
3.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ResNet 详解
## 背景简介
深度网络随着层数不断加深,可能会引起梯度消失/梯度爆炸的问题:
1. “梯度消失”指的是即当梯度小于1.0)在被反向传播到前面的层时,重复的相乘可能会使梯度变得无限小。
2. “梯度爆炸”指的是即当梯度大于1.0)在被反向传播到前面的层时,重复的相乘可能会使梯度变得非常大甚至无限大导致溢出。
随着网络深度的不断增加,常常会出现以下两个问题:
1. 长时间训练但是网络收敛变得非常困难甚至不收敛
2. 网络性能会逐渐趋于饱和甚至还会开始下降可以观察到下图中56层的误差比20层的更多故这种现象并不是由于过拟合造成的。
这种现象称为深度网络的退化问题。
![pic](http://www.zeekling.cn/gogsPics/ml/nn/13.png)
ResNet深度残差网络成功解决了此类问题使得即使在网络层数很深(甚至在1000多层)的情况下,网络依然可以得到很好的性能与效
率。
## 参差网络
ResNet引入残差网络结构residual network即在输入与输出之间称为堆积层引入一个前向反馈的shortcut connection这有
点类似与电路中的“短路”也是文中提到identity mapping恒等映射y=x。原来的网络是学习输入到输出的映射H(x),而残差网络学
习的是$$ F(x)=H(x)x $$。残差学习的结构如下图所示:
![pic](http://www.zeekling.cn/gogsPics/ml/nn/14.png)
另外我们可以从数学的角度来分析这个问题,首先残差单元可以表示为:
$$
y_l=h(x_l) + F(x_l, W-l) \\
x_{l+1} = f(y_l)
$$
其中$$ x_{l} $$ 和$$ x_{l+1} $$ 分别表示的是第 l 个残差单元的输入和输出,注意每个残差单元一般包含多层结构。 F 是残差函
表示学习到的残差而h表示恒等映射 f 是ReLU激活函数。基于上式我们求得从浅层 l 到深层 L 的学习特征为:
$$ x_L = x_l + \sum_{i=l}^{L-1} F(x_i, W_i) $$
利用链式规则,可以求得反向过程的梯度:
$$
\frac{\alpha{loss}}{\alpha{x_l}} = \frac{\alpha{loss}}{\alpha{x_L}} * \frac{\alpha{x_L}}{\alpha{x_l}}
= \frac{\alpha{loss}}{\alpha{x_L}} * (1 + \frac{\alpha}{\alpha{x_L}} \sum_{i=l}^{L-1} F(x_i, W_i) )
$$
式子的第一个因子表示的损失函数到达 L 的梯度小括号中的1表明短路机制可以无损地传播梯度而另外一项残差梯度则需要经过带
有weights的层梯度不是直接传递过来的。残差梯度不会那么巧全为-1而且就算其比较小有1的存在也不会导致梯度消失。所以残
差学习会更容易。
实线、虚线就是为了区分这两种情况的:
1. 实线的Connection部分表示通道相同如上图的第一个粉色矩形和第三个粉色矩形都是3x3x64的特征图由于通道相同所以采
用计算方式为$$ H(x)=F(x)+x $$
2. 虚线的的Connection部分表示通道不同如上图的第一个绿色矩形和第三个绿色矩形分别是3x3x64和3x3x128的特征图通道不
同,采用的计算方式为$$ H(x)=F(x)+Wx $$其中W是卷积操作用来调整x维度的。
![pic](http://www.zeekling.cn/gogsPics/ml/nn/16.png)
## 残差学习的本质
![pic](http://www.zeekling.cn/gogsPics/ml/nn/15.png)
残差网络的确解决了退化的问题,在训练集和校验集上,都证明了的更深的网络错误率越小:
![pic](http://www.zeekling.cn/gogsPics/ml/nn/17.png)
下面是resnet的成绩单, 在imagenet2015夺得冠军:
![pic](http://www.zeekling.cn/gogsPics/ml/nn/18.png)