NN & BP Algorithm
神经网络 & 反向传播算法
神经网络基本的表示方法网上有许多资料,我就不再赘述了。对我来说最难的部分是反向传播算法,我花了许多时间去理解反向传播算法。
链式求导
反向传播算法就是复合函数链式求导的一个应用,什么是链式求导呢?
以求 $ e=(a+b)*(b+1) $ 的偏导数为例子
根据链式法制,
$\frac{\delta e}{\delta a} = \frac{\delta e}{\delta c} \frac{\delta c}{\delta a}$
$\frac{\delta e}{\delta b} = \frac{\delta e}{\delta c} \frac{\delta c}{\delta b} + \frac{\delta e}{\delta d} \frac{\delta d}{\delta b}$
不难发现, $\frac{\delta e}{\delta a}$ 和 $\frac{\delta e}{\delta b}$ 的计算路径有重复项,反向传播算法避免了重复计算,有点类似动态规划的思想。反向传播算法反向逐层计算偏导数,实现只通过一次计算,得出代价函数对所有参数的梯度。
反向传播算法过程
编程作业4给的神经网络模型如下
如何利用反向传播算法求 $\theta^{(1)}$ 和 $\theta^{(2)}$ 呢?
首先假设代价函数的符号为 E,并且有如下表达式
下面内容就是我自己根据链式求导法则推的,因为很多文章直接引入一个叫“误差项”的概念,非常难理解。
根据链式求导法则,可以得出:
$\frac{\delta E}{\delta \theta^{(2)}} = \frac{\delta E}{\delta a^{(3)}} \frac{\delta a^{(3)}}{\delta z^{(3)}} \frac{\delta z^{(3)}}{\delta \theta^{(2)}}$
$\frac{\delta E}{\delta \theta^{(1)}} = \frac{\delta E}{\delta a^{(3)}} \frac{\delta a^{(3)}}{\delta z^{(3)}} \frac{\delta z^{(3)}}{\delta a^{(2)}} \frac{\delta a^{(2)}}{\delta z^{(2)}} \frac{\delta a^{(2)}}{\delta \theta^{(1)}}$
令
$\delta^{(3)} = \frac{\delta E}{\delta a^{(3)}}\frac{\delta a^{(3)}}{\delta z^{(3)}} = a^{(3)} - y$
$\delta^{(2)} = \frac{\delta E}{\delta a^{(3)}} \frac{\delta a^{(3)}}{\delta z^{(3)}} \frac{\delta z^{(3)}}{\delta a^{(2)}} \frac{\delta a^{(2)}}{\delta z^{(2)}} = (\theta^{(2)})^T\delta^{(3)}g’(z^{(2)})$
代入偏导数公式合并
$\frac{\delta E}{\delta \theta^{(2)}} = \delta^{(3)}g’(z^{3})$
$\frac{\delta E}{\delta \theta^{(1)}} = \delta^{(2)}g’(z^{2})$
用上面这两个公式就可以算梯度了,每一轮迭代算出所有训练集的梯度和,最后用平均梯度来代替当前次迭代的梯度。
$\delta^{(3)}$ 的求导可以看参考文献[4],交叉商的求导。看起来很像平方误差的导数….其实并不是。
lab4 代码
向量化实现 nnCostFunction 代码如下
1 |
|
参考
[1] https://zhuanlan.zhihu.com/p/40378224 “Back Propagation(梯度反向传播)实例讲解”
[2] https://www.zhihu.com/question/27239198/answer/89853077 “如何直观地解释 backpropagation 算法?”
[3] http://neuralnetworksanddeeplearning.com/chap2.html “How the backpropagation algorithm works”
[4] https://blog.csdn.net/Jerry_Lu_ruc/article/details/107974072 “关于交叉熵下softmax和sigmoid的求导”
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!