LSTM、RNN
LSTM::只需研究一次
作者:elfin 资料来源:torch.nn.LSTM
1、简述RNN
在传统的统计学中,有一门专门介绍时间序列的课程。其主要研究事件的发生与时间(可以是广义的)有较强的关联,这时传统机器学习算法并不能很好地解决这种带有时序的数据预测、特征挖掘。随着深度学习的火爆,从AR到RNN,再到LSTM及其变种,关于序列特征的挖掘取得了很大的进步。
RNN的结构如下:

注意这里的参数是共享的,即W,U,V只有一份,不会随着单元数的增加而增加。
关于隐藏层与输出层的公式为:
从上式可以发现输出只与当前的隐藏层及其输出权值矩阵V有关;但是隐藏层不仅与当前输入X有关,还与上一个单元的隐藏层有关。这与AR:Xt=λXt−1+εt,非常类似,传统的AR模型在RNN中主要对应了隐状态S的传递,而RNN模型的隐含层相对于传统自回归模型添加了当前输入的调制,以及当前输出是对历史信号+当前调制信号的再调制。有一点我们会发现:传统的自回归模型,它能够调制q阶信号(最近的q个历史节点),但是RNN只是当前与最后一个节点,这对于长距离依赖肯定是不好的。
关于RNN有如下的问题:
- 无法挖掘长距离依赖;
- W,U,V对不同时刻的信号调制是无差别的;
- 梯度计算非常麻烦!它需要累加各个单元节点的梯度,繁琐的函数递归,梯度爆炸与梯度消失的风险较大;
- 相对于最新的技术,RNN-style的模型有致命的缺陷:无法并行计算。
2、LSTM 长短时记忆神经网络
为什么要开发这类 LSTM-style 模型?如上述RNN缺点:RNN模型无法解决长期依赖(长距离依赖)问题,但是序列模型的一个重要特点就是具有长期依赖。如文字序列的上下文、最近一段时间的天气、股票等等。LSTM的提出就是为了解决这个问题!
LSTM解决问题的关键点在于门限(gates,亦称门)技术。它有三个门,分别为:遗忘门、输入门、输出门。
三种门的简单介绍:
-
输入门
输入门的作用是对输入的信号进行调制,那么它是如何调制的呢?调制对象、调制量分别是什么?
首先,调制对象是当前的输入信号与上一个节点单元的输出信号:[ht−1,xt];
其次,调制量是Wi,WC,由如下的公式我们可以获得更新量、更新对象:
it=σ(Wi⋅[ht−1,xt]+bi)˜Ct=tanh(WC⋅[ht−1,xt]+bC)输入门的可视化:
注:图片来源于http://colah.github.io/posts/2015-08-Understanding-LSTMs/
将更新值 it 与候选值 ˜Ct 相乘即得到输入门的 “输入”。
-
遗忘门
遗忘门的可视化:
遗忘门是对历史信息的调制。那么它又是如何调制信号的呢?
首先,调制的对象是细胞状态Ct−1(历史信息);
其次,调制量是Wf,由调制量得到当前节点的遗忘值:
ft=σ(Wf⋅[ht−1,xt]+bf)细胞状态更新:
将遗忘值乘以上一时刻的细胞状态得到当前细胞单元要使用的历史信息,再加上输出门的调制信息,得到当前的细胞状态Ct,即:
Ct=ft∗Ct−1+it∗˜Ct状态更新可视化:
-
输出门
输出门是调制信号得到输出信息的门。
首先,由输入对象加上一时刻输出组合 [ht−1,xt] 获取输出候选;
其次,由当前细胞状态Ct使用tanh激活获取输出更新权值,具体公式如下:
ot=σ(Wo⋅[ht−1,xt]+bo)ht=ot∗tanh(Ct)输出门可视化:
小结:LSTM的计算公式:
3、LSTM的优缺点
LSTM加入了门对历史信息、上个节点的输出、当前节点的输入信息进行调制,相当于对历史信息、当前输入都是有选择地接收处理而不是全盘吸收,有点局部注意力的感觉。注意这里的参数仍然是不随节点的变化而变换的,但是由于门的存在,抵消了硬值对信息的损坏。
因为对历史信息的保留及调制,一定程度上可以将其看作在长期依赖上的改进,但是LSTM仍然有一些缺点:
- 相较于RNN模型,结构复杂,参数量大;
- 同RNN类似,LSTM有梯度消失、梯度爆炸的风险,相比于RNN模型肯定是减小了很多;
- RNN-style的模型,无法并行计算;
4、torch.nn.LSTM
这个类是对输入序列应用多层长短时记忆(LSTM)RNN。
根据第2章,我们已经知道了单层LSTM的参数计算公式,在pytorch中,类似的计算公式为:
其中:
- ht为隐状态,也即t时刻的输出;
- ct是t时刻的细胞状态;
对于一个多层LSTM,第l(l≥2)层的输入x(l)t是前一层对应节点的隐状态h(l−1)t乘以dropout变量:δ(l−1)t,其中δ(l−1)t是Bernoulli随机变量,这里的dropout的概率为0。
4.1 参数
- input_size:输入x的期望特征数;
- hidden_size:隐状态h的特征数;
- num_layers:LSTM的层数,默认为1,否则即为简单的LSTM结构堆叠;
- bias:默认为True;否则在在σ激活时,不添加bhi、bhh;
- batch_first:默认False;如果为真,则输入输出的tensors是(batch, seq, feature);
- dropout:默认为0;不为0时,标识Bernoulli随机变量δ(l−1)t 的dropout概率;
- bidirectional:默认False;如果为真,则变成双向LSTM(BiLSTM)
注意:bidirectional=True 并不是等价于 num_layers=2
4.2 LSTM的输入
LSTM的输入有 input 、h_0、c_0
- input的shape为 (seq_len, batch, input_size) ,可以指定batch_first=True进行调整;输入也可以是压缩可变长度序列,可参考
torch.nn.utils.rnn.pack_padded_sequence()
或torch.nn.utils.rnn.pack_sequence()
- h_0的shape为 (num_layers * num_directions, batch, hidden_size) ,hidden_size是特征数量,不是cell单元数;如果是双向LSTM,因为num_layers=1,所以要设置num_directions=2。
- c_0的shape为 (num_layers * num_directions, batch, hidden_size) 。
如果(h_0, c_0)没有提供,默认为0。
4.3 LSTM的输出
LSTM的输出有output、h_n、c_n
-
output的shape为 (seq_len, batch, num_directions * hidden_size) ,即最后一层每个时刻t的隐状态h_t,即(h_1, **h_2 **, ⋯, h_n);
如果input中给了
torch.nn.utils.rnn.PackedSequence
,output也将成为packed sequence。 -
h_n的shape为 (num_layers * num_directions, batch, hidden_size) ;
-
c_n的shape为 (num_layers * num_directions, batch, hidden_size) 。
4.4 LSTM对象的变量
LSTM.weight_ih_l[k]
第k层输入x对应的权值 (W_ii|W_if|W_ig|W_io) ,当k = 0时shape为 (4*hidden_size, input_size),否则为 (4*hidden_size, num_directions * hidden_size)。
LSTM.weight_hh_l[k]
第k层隐状态ht对应的权值 (W_hi|W_hf|W_hg|W_ho),shape为 (4*hidden_size, hidden_size)。
LSTM.bias_ih_l[k]
第k层输入的偏置 (b_ii|b_if|b_ig|b_io) ,shape为 (4*hidden_size)。
LSTM.bias_hh_l[k]
第k层隐状态的偏置 (b_hi|b_hf|b_hg|b_ho),shape为 (4*hidden_size)。
4.5 NOTE
所有权值、偏置的初始化都是使用均匀分布:U(−√k,√k),其中k=1hidden_size
如果下面的条件都满足:
- cudnn可以使用;
- 输入的数据在GPU上;
- 输入张量的数据类型是
torch.float16
; - V100 GPU可以使用;
- 输入数据不是PackedSequence格式
则,persistent算法可用于提高性能。
4.6 案例
# 官方案例
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> type(output)
Out[9]: torch.Tensor
>>> output.shape
Out[10]: torch.Size([5, 3, 20])
>>> input0 = torch.randn(15, 3, 10)
>>> output, (hn, cn) = rnn(input0, (h0, c0))
>>> output.shape
Out[11]: torch.Size([15, 3, 20])
# 注意这里的3是batch,15是seq的长度。也即序列是可变长的,输出的序列单元数是和输入进行匹配的。
参考文献:
- LSTM为何如此有效?
- LSTM原理及实现(一)
- 一文搞懂RNN(循环神经网络)基础篇
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/
- torch.nn.LSTM
待更……
【推荐】100%开源!大型工业跨平台软件C++源码提供,建模,组态!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】Flutter适配HarmonyOS 5知识地图,实战解析+高频避坑指南
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 为 Java 虚拟机分配堆内存大于机器物理内存会怎么样?
· .NET程序启动就报错,如何截获初期化时的问题json
· 理解 C# 中的各类指针
· C#多线程编程精要:从用户线程到线程池的效能进化论
· 如何反向绘制出 .NET程序 异步方法调用栈
· 换个方式用C#开发微信小程序
· 实现远程磁盘:像访问自己的电脑硬盘一样访问对方的电脑硬盘 (附Demo源码)
· 【.NET必读】RabbitMQ 4.0+重大变更!C#开发者必须掌握的6大升级要点
· .NET 10 Preview 4中ASP.NET Core 改进
· C#网络编程(四)----HttpClient