网站建设规划,做棋牌网站的步骤,织梦+和wordpress,权威解读当前经济热点问题目录 一、GRU提出的背景#xff1a;1.RNN存在的问题#xff1a;2.GRU的思想#xff1a; 二、更新门和重置门#xff1a;三、GRU网络架构#xff1a;1.更新门和重置门如何发挥作用#xff1a;1.1候选隐藏状态H~t#xff1a;1.2隐藏状态Ht#xff1a; 2.GRU: 四、底层源码… 目录 一、GRU提出的背景1.RNN存在的问题2.GRU的思想 二、更新门和重置门三、GRU网络架构1.更新门和重置门如何发挥作用1.1候选隐藏状态H~t1.2隐藏状态Ht 2.GRU: 四、底层源码五、Pytorch版代码 一、GRU提出的背景
1.RNN存在的问题
循环神经网络讲解文章
由于RNN的隐藏状态ht用于记录之前的所有序列信息而对于长序列问题来说ht会记录太多序列信息导致序列时序特征区分度很差最前面的序列特征因为进行了太多轮迭代往往不太好从ht中提取因此一些比较靠前但很重要的序列特征在ht中可能就不太被重视而一些比较靠后但没用的序列特征在ht中被过于关注。
2.GRU的思想
GRU的思想是如何将隐藏状态ht中一些重要的序列信息给予高的关注而一些不重要的序列信息给予低的关注。
对于需要关注的序列信息使用更新门来提高关注度对于需要遗忘的序列信息使用遗忘门来降低关注度
二、更新门和重置门
GRU提出更新门和重置门的思想来改变隐藏状态ht中不同序列信息的关注度。 更新门和重置门可以分别看做一个全连接层的隐藏层这样的话上图就等价于两个并排的隐藏层其中
每个隐藏层都接收之前时间步的隐藏状态Ht-1和当前时间步的输入batch。更新门和重置门有各自的可学习权重参数和偏置值公式含义类似传统RNN。Rt 和 Zt 都是根据过去的隐藏状态 Ht-1 和当前输入 Xt 计算得到的 [0,1] 之间的量激活函数。
三、GRU网络架构
1.更新门和重置门如何发挥作用
重置门对过去t个时间步的序列信息Ht-1进行选择更新门对当前一个时间步的序列信息Xt进行选择。具体原理如下
1.1候选隐藏状态H~t
候选隐藏状态既保留了之前的隐藏状态Ht-1又保留了当前一个时间步的序列信息Xt。 因为Rt是一个[0,1] 之间的量所以Rt×Ht-1是对之前的隐藏状态Ht-1进行一次选择Rt 在某个位置的值越趋近于0则表示Ht-1这个位置的序列信息越倾向于被丢弃反之保留。
综上重置门的作用是对过去的序列信息Ht-1进行选择Ht-1中哪些序列信息当前的输出是有用的应该被保存下来而哪些序列信息是不重要的应该被遗忘。
1.2隐藏状态Ht 因为Zt是一个[0,1] 之间的量如果Zt全为0则当前隐藏状态Ht为当前候选隐藏状态该候选隐藏状态不仅保留了之前的序列信息还保留了当前时间步batch的序列信息如果Zt全为1则当前隐藏状态Ht为上一个时间步的隐藏状态。
综上更新门的作用是决定当前一个时间步的序列信息是否保留如果Zt全为0则说明当前时间步batch的序列信息是有用的候选隐藏状态包含之前的序列信息和当前一个时间步的序列信息保留下来加入到隐藏状态Ht中如果Zt全为1则说明当前时间步batch的序列信息是没有用的丢弃当前batch的序列信息直接使用上一个时间步的隐藏状态Ht-1作为当前的隐藏状态Ht。Ht-1仅包含之前的序列信息不包含当前一个时间步的序列信息
2.GRU:
GRU网络架构如下可以看做是三个隐藏层并排的架构。
四、底层源码
代码中num_hiddens表示隐藏层神经元个数由于重置门、更新门的输出维度相同所以重置门和更新门两个隐藏层的神经元个数也是一样的num_hiddens。
import torch
from torch import nn
from d2l import torch as d2l# 数据预处理获取datalodaer和字典
batch_size, num_steps 32, 35
train_iter, vocab d2l.load_data_time_machine(batch_size, num_steps)# 初始化可学习参数
def get_params(vocab_size, num_hiddens, device):num_inputs num_outputs vocab_sizedef normal(shape):return torch.randn(sizeshape, devicedevice) * 0.01def three():return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, devicedevice))W_xz, W_hz, b_z three()W_xr, W_hr, b_r three()W_xh, W_hh, b_h three()W_hq normal((num_hiddens, num_outputs))b_q torch.zeros(num_outputs, devicedevice)params [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), devicedevice),)# 定义门控循环单元模型
def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q paramsH, stateoutputs []for X in inputs:Z torch.sigmoid((X W_xz) (H W_hz) b_z)R torch.sigmoid((X W_xr) (H W_hr) b_r)H_tilda torch.tanh((X W_xh) ((R * H) W_hh) b_h)H Z * H (1 - Z) * H_tildaY H W_hq b_qoutputs.append(Y)return torch.cat(outputs, dim0), (H,)# 训练
vocab_size, num_hiddens, device len(vocab), 256, d2l.try_gpu()
num_epochs, lr 500, 1
model d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)五、Pytorch版代码
num_inputs vocab_size
# 调用pytorch构建网络结构
gru_layer nn.GRU(num_inputs, num_hiddens)
model d2l.RNNModel(gru_layer, len(vocab))
model model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)