Before learning GRU and LSTM, you can look at the basic code of RNN - PyTorch this blog
1. Gated Recurrent Unit (GRU )
1. Reset gate and update gate
The reset gate allows us to control the number of past states that we "may still want to remember"; the update gate will allow us to control how many of the new states are copies of the old ones.
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) \pmb{R}_t=\sigma(\pmb{X}_t\pmb{W}_{xr}+\pmb{H}_{t-1}\pmb{W}_{hr}+\pmb{b}_r) \\ \pmb{Z}_t=\sigma(\pmb{X}_t\pmb{W}_{xz}+\pmb{H}_{t-1}\pmb{W}_{hz}+\pmb{b}_z) RRRt=σ(XXXtWWWxr+HHHt−1WWWhr+bbbr)ZZZt=σ(XXXtWWWxz+HHHt−1WWWhz+bbbz) The outputs of the two are activated by using sigmoid The two fully connected layers of the function are given.
2. Candidate hidden states
candidate hidden states at time step t
H
~
t
∈
R
n
×
h
\pmb{\tilde H}_t \in \mathbb R^{n×h}
H~H~H~t∈Rn×h, the calculation is as follows:
H ~ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \pmb{\tilde H}_t=tanh(\pmb{X}_t\pmb{W}_{xh}+(\pmb{R}_{t}\odot \pmb{H}_{t-1})\pmb{W}_{hh}+\pmb{b}_h) H~H~H~t=tanh(XXXtWWWxh+(RRRt⊙HHHt−1)WWWhh+bbbh) where the symbol ⊙ is the Hadamard product (element-wise product) operator, R t \pmb{R}_{t} RRRtand H t − 1 \pmb{H}_{t-1} Element-wise multiplication of HHHt−1 reduces the influence of previous states. with a nonlinear activation function t a n h tanh tanh to ensure that the values in the candidate hidden state remain in the interval (−1, 1).
3. Hidden state
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \pmb{H}_t=\pmb{Z}_t\odot\pmb{H}_{t-1}+(1-\pmb{Z}_{t})\odot\pmb{\tilde H}_t HHHt=ZZZt⊙HHHt−1+(1−ZZZt)⊙H~H~H~t
whenever the portal is updated Z t \pmb{Z}_t When ZZZt is close to 1, the model tends to keep only the old state. At this time, from X t \pmb{X}_t XXXt information is basically ignored. Instead, when Z t \pmb{Z}_t When ZZZt is close to 0, the new hidden state H t \pmb{H}_t HHHt will be close to the candidate hidden state H ~ t \pmb{\tilde H}_t H~H~H~t .
summary:
These designs can help deal with the vanishing gradient problem in RNN s and better capture the dependencies of sequences with long time-step distances.
For example, if the update length for all time steps of the entire subsequence is close to 1, then regardless of the length of the sequence, the old hidden state at the time step at the beginning of the sequence will be easily preserved and passed on to the end of the sequence.
The control loop unit has the following two salient features:
• Reset gates help capture short-term dependencies in a sequence.
• Update portals help capture long-term dependencies in sequences.
4.PyTorch code
- Realize from scratch
import torch from torch import nn from d2l import torch as d2l batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) def get_params(vocab_size, num_hiddens, device): """Initialize model parameters""" num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device)*0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xz, W_hz, b_z = three() # Update gate parameters W_xr, W_hr, b_r = three() # reset gate parameters W_xh, W_hh, b_h = three() # Candidate hidden state parameters # output layer parameters W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) # additional gradient params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params def init_gru_state(batch_size, num_hiddens, device): """The initialization function of the hidden state""" return (torch.zeros((batch_size, num_hiddens), device=device), ) def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z) R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r) H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h) H = Z * H + (1 - Z) * H_tilda Y = H @ W_hq + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H,)
- train
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() num_epochs, lr = 500, 1 model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
- Simple implementation
# Simple implementation num_inputs = vocab_size gru_layer = nn.GRU(num_inputs, num_hiddens) model = d2l.RNNModel(gru_layer, len(vocab)) model = model.to(device) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
2. Long Short-Term Memory Network (LSTM)
1. Input gate, forget gate and output gate
I
t
=
σ
(
X
t
W
x
i
+
H
t
−
1
W
h
i
+
b
i
)
F
t
=
σ
(
X
t
W
x
f
+
H
t
−
1
W
h
f
+
b
f
)
O
t
=
σ
(
X
t
W
x
o
+
H
t
−
1
W
h
o
+
b
o
)
\pmb{I}_t=\sigma(\pmb{X}_t\pmb{W}_{xi}+\pmb{H}_{t-1}\pmb{W}_{hi}+\pmb{b}_i) \\ \pmb{F}_t=\sigma(\pmb{X}_t\pmb{W}_{xf}+\pmb{H}_{t-1}\pmb{W}_{hf}+\pmb{b}_f)\\ \pmb{O}_t=\sigma(\pmb{X}_t\pmb{W}_{xo}+\pmb{H}_{t-1}\pmb{W}_{ho}+\pmb{b}_o)
IIIt = σ (XXXt WWxi +HHHt_1 WWWhi +bbbi) FFFt = σ (XXXt WWxf + HHHt_1 WWWhf + bbbf) OOOt = σ (xxxt WWWxo + HHHHT_1 WWWho +bbbo) is processed by three full connection layers with sigmoid activation functions to calculate the values * of input forgotten and output all of which are within the range of (0, 1)
2. Memories
Candidate memory:
C
~
t
=
t
a
n
h
(
X
t
W
x
c
+
H
t
−
1
W
h
c
+
b
c
)
\pmb{\tilde C}_t=tanh(\pmb{X}_t\pmb{W}_{xc}+\pmb{H}_{t-1}\pmb{W}_{hc}+\pmb{b}_c)
C~C~C~t=tanh(XXXtWWWxc+HHHt−1WWWhc+bbbc)
Memories:
C
t
=
F
t
⊙
C
t
−
1
+
I
t
⊙
C
~
t
\pmb{C}_t=\pmb{F}_t\odot\pmb{C}_{t-1}+\pmb{I}_{t}\odot\pmb{\tilde C}_t
CCCt=FFFt⊙CCCt−1+IIIt⊙C~C~C~t
If the forget gate is always 1 and the input gate is always 0, the past mnemonic
C
t
−
1
\pmb{C}_{t-1}
CCCt−1 will be saved over time and passed to the current time step.
This design is introduced to alleviate the vanishing gradient problem and better capture long distance dependencies in sequences.
3. Hidden state
H t = O t ⊙ t a n h ( C t ) \pmb{H}_t=\pmb{O}_t\odot tanh(\pmb{C}_t) HHHt=OOOt⊙tanh(CCCt)
H
t
\pmb{H}_t
The value of HHHt is always in the interval (−1, 1), as long as the output gate is close to 1, we can effectively pass all the memory information to the prediction part, and for the output gate close to 0, we only keep the memory cell. All information without updating the hidden state.
summary
- There are three types of gates in long short-term memory networks: input gates, forgetting gates, and output gates.
- The hidden layer output of a long short-term memory network consists of "hidden states" and "memory cells". Only the hidden state is passed to the output layer, and the memory is entirely internal information.
- A long short-term memory network can mitigate vanishing gradients and exploding gradients.
4.PyTorch code
- Implement LSTM from scratch
import torch from torch import nn from d2l import torch as d2l batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) def get_lstm_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device)*0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xi, W_hi, b_i = three() # input gate parameters W_xf, W_hf, b_f = three() # forget gate parameters W_xo, W_ho, b_o = three() # Output gate parameters W_xc, W_hc, b_c = three() # Candidate memory parameters # output layer parameters W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) # additional gradient params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] for param in params: param.requires_grad_(True) return params def init_lstm_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), torch.zeros((batch_size, num_hiddens), device=device)) def lstm(inputs, state, params): [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params (H, C) = state outputs = [] for X in inputs: I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i) F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f) O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o) C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c) C = F * C + I * C_tilda H = O * torch.tanh(C) Y = (H @ W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H, C) vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() num_epochs, lr = 500, 1 model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
- Simple implementation
# Simple implementation num_inputs = vocab_size lstm_layer = nn.LSTM(num_inputs, num_hiddens) model = d2l.RNNModel(lstm_layer, len(vocab)) model = model.to(device) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
3. Deep Recurrent Neural Networks
H
t
(
l
)
=
ϕ
l
(
H
t
(
l
−
1
)
W
x
h
(
l
)
+
H
t
−
1
(
l
)
W
h
h
(
l
)
+
b
h
(
l
)
)
\pmb{H}_t^{(l)}=\phi_l(\pmb{H}_t^{(l-1)}\pmb{W}_{xh}^{(l)}+\pmb{H}_{t-1}^{(l)}\pmb{W}_{hh}^{(l)}+\pmb{b}_h^{(l)})
HHHt(l)=ϕl(HHHt(l−1)WWWxh(l)+HHHt−1(l)WWWhh(l)+bbbh(l))
O t = H t ( l ) W h q + b q \pmb{O}_t=\pmb{H}_t^{(l)}\pmb{W}_{hq}+\pmb{b}_q OOOt=HHHt(l)WWWhq+bbbq
import torch from torch import nn from d2l import torch as d2l batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) vocab_size, num_hiddens, num_layers = len(vocab), 256, 2 num_inputs = vocab_size device = d2l.try_gpu() lstm_layer = nn.LSTM(num_inputs, num_hiddens, num_layers) model = d2l.RNNModel(lstm_layer, len(vocab)) model = model.to(device) num_epochs, lr = 500, 2 d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
4. Bidirectional Recurrent Neural Network
H → t = ϕ ( X t W x h ( f ) + H → t − 1 W h h ( f ) + b h ( f ) ) H ← t = ϕ ( X t W x h ( b ) + H ← t − 1 W h h ( b ) + b h ( b ) ) \overrightarrow{\pmb H}_t=\phi(\pmb{X}_t\pmb{W}_{xh}^{(f)}+\overrightarrow{\pmb H}_{t-1}\pmb{W}_{hh}^{(f)}+\pmb{b}_h^{(f)})\\ \overleftarrow{\pmb H}_t=\phi(\pmb{X}_t\pmb{W}_{xh}^{(b)}+\overleftarrow{\pmb H}_{t-1}\pmb{W}_{hh}^{(b)}+\pmb{b}_h^{(b)}) HHH t=ϕ(XXXtWWWxh(f)+HHH t−1WWWhh(f)+bbbh(f))HHH t=ϕ(XXXtWWWxh(b)+HHH t−1WWWhh(b)+bbbh(b))
forward hidden state
H
→
t
\overrightarrow{\pmb H}_t
HHH
tand the reverse hidden state
H
←
t
\overleftarrow{\pmb H}_t
HHH
tSpliced together to get the hidden state of the output layer
H
t
∈
R
n
×
2
h
\pmb H_t\in\mathbb R^{n \times2h}
HHHt∈Rn×2h
H
t
=
H
→
t
⊕
H
←
t
O
t
=
H
t
W
h
q
+
b
q
\pmb{H}_t=\overrightarrow{\pmb H}_t\oplus\overleftarrow{\pmb H}_t\\ \pmb{O}_t=\pmb{H}_t\pmb{W}_{hq}+\pmb{b}_q
HHHt=HHH
t⊕HHH
tOOOt=HHHtWWWhq+bbbq
-
Notice:
Bidirectional recurrent neural networks are generally not used for prediction, because the latter cannot be seen when predicting