GRU, LSTM, Bidirectional Recurrent Neural Network

Before learning GRU and LSTM, you can look at the basic code of RNN - PyTorch this blog

1. Gated Recurrent Unit (GRU )

1. Reset gate and update gate

The reset gate allows us to control the number of past states that we "may still want to remember"; the update gate will allow us to control how many of the new states are copies of the old ones.

R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) \pmb{R}_t=\sigma(\pmb{X}_t\pmb{W}_{xr}+\pmb{H}_{t-1}\pmb{W}_{hr}+\pmb{b}_r) \\ \pmb{Z}_t=\sigma(\pmb{X}_t\pmb{W}_{xz}+\pmb{H}_{t-1}\pmb{W}_{hz}+\pmb{b}_z) RRRt​=σ(XXXt​WWWxr​+HHHt−1​WWWhr​+bbbr​)ZZZt​=σ(XXXt​WWWxz​+HHHt−1​WWWhz​+bbbz​) The outputs of the two are activated by using sigmoid The two fully connected layers of the function are given.

2. Candidate hidden states

candidate hidden states at time step t H ~ t ∈ R n × h \pmb{\tilde H}_t \in \mathbb R^{n×h} H~H~H~t​∈Rn×h, the calculation is as follows:

H ~ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \pmb{\tilde H}_t=tanh(\pmb{X}_t\pmb{W}_{xh}+(\pmb{R}_{t}\odot \pmb{H}_{t-1})\pmb{W}_{hh}+\pmb{b}_h) H~H~H~t​=tanh(XXXt​WWWxh​+(RRRt​⊙HHHt−1​)WWWhh​+bbbh​) where the symbol ⊙ is the Hadamard product (element-wise product) operator, R t \pmb{R}_{t} RRRt​and H t − 1 \pmb{H}_{t-1} Element-wise multiplication of HHHt−1​ reduces the influence of previous states. with a nonlinear activation function t a n h tanh tanh to ensure that the values ​​in the candidate hidden state remain in the interval (−1, 1).

3. Hidden state

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \pmb{H}_t=\pmb{Z}_t\odot\pmb{H}_{t-1}+(1-\pmb{Z}_{t})\odot\pmb{\tilde H}_t HHHt​=ZZZt​⊙HHHt−1​+(1−ZZZt​)⊙H~H~H~t​

whenever the portal is updated Z t \pmb{Z}_t When ZZZt​ is close to 1, the model tends to keep only the old state. At this time, from X t \pmb{X}_t XXXt​ information is basically ignored. Instead, when Z t \pmb{Z}_t When ZZZt​ is close to 0, the new hidden state H t \pmb{H}_t HHHt​ will be close to the candidate hidden state H ~ t \pmb{\tilde H}_t H~H~H~t​ .

summary:

These designs can help deal with the vanishing gradient problem in RNN s and better capture the dependencies of sequences with long time-step distances.

For example, if the update length for all time steps of the entire subsequence is close to 1, then regardless of the length of the sequence, the old hidden state at the time step at the beginning of the sequence will be easily preserved and passed on to the end of the sequence.

The control loop unit has the following two salient features:

• Reset gates help capture short-term dependencies in a sequence.

• Update portals help capture long-term dependencies in sequences.

4.PyTorch code

  • Realize from scratch
import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)


def get_params(vocab_size, num_hiddens, device):
    """Initialize model parameters"""
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three()  # Update gate parameters
    W_xr, W_hr, b_r = three()  # reset gate parameters
    W_xh, W_hh, b_h = three()  # Candidate hidden state parameters
    # output layer parameters
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # additional gradient
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

def init_gru_state(batch_size, num_hiddens, device):
    """The initialization function of the hidden state"""
    return (torch.zeros((batch_size, num_hiddens), device=device), )


def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)
  • train
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

  • Simple implementation
# Simple implementation
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

2. Long Short-Term Memory Network (LSTM)

1. Input gate, forget gate and output gate


I t = σ ( X t W x i + H t − 1 W h i + b i ) F t = σ ( X t W x f + H t − 1 W h f + b f ) O t = σ ( X t W x o + H t − 1 W h o + b o ) \pmb{I}_t=\sigma(\pmb{X}_t\pmb{W}_{xi}+\pmb{H}_{t-1}\pmb{W}_{hi}+\pmb{b}_i) \\ \pmb{F}_t=\sigma(\pmb{X}_t\pmb{W}_{xf}+\pmb{H}_{t-1}\pmb{W}_{hf}+\pmb{b}_f)\\ \pmb{O}_t=\sigma(\pmb{X}_t\pmb{W}_{xo}+\pmb{H}_{t-1}\pmb{W}_{ho}+\pmb{b}_o) IIIt = σ (XXXt WWxi +HHHt_1 WWWhi +bbbi) FFFt = σ (XXXt WWxf + HHHt_1 WWWhf + bbbf) OOOt = σ (xxxt WWWxo + HHHHT_1 WWWho +bbbo) is processed by three full connection layers with sigmoid activation functions to calculate the values * of input forgotten and output all of which are within the range of (0, 1)

2. Memories

Candidate memory:
C ~ t = t a n h ( X t W x c + H t − 1 W h c + b c ) \pmb{\tilde C}_t=tanh(\pmb{X}_t\pmb{W}_{xc}+\pmb{H}_{t-1}\pmb{W}_{hc}+\pmb{b}_c) C~C~C~t​=tanh(XXXt​WWWxc​+HHHt−1​WWWhc​+bbbc​)
Memories:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \pmb{C}_t=\pmb{F}_t\odot\pmb{C}_{t-1}+\pmb{I}_{t}\odot\pmb{\tilde C}_t CCCt​=FFFt​⊙CCCt−1​+IIIt​⊙C~C~C~t​
If the forget gate is always 1 and the input gate is always 0, the past mnemonic C t − 1 \pmb{C}_{t-1} CCCt−1​ will be saved over time and passed to the current time step.

This design is introduced to alleviate the vanishing gradient problem and better capture long distance dependencies in sequences.

3. Hidden state

H t = O t ⊙ t a n h ( C t ) \pmb{H}_t=\pmb{O}_t\odot tanh(\pmb{C}_t) HHHt​=OOOt​⊙tanh(CCCt​)

H t \pmb{H}_t The value of HHHt​​ is always in the interval (−1, 1), as long as the output gate is close to 1, we can effectively pass all the memory information to the prediction part, and for the output gate close to 0, we only keep the memory cell. All information without updating the hidden state.

summary

  • There are three types of gates in long short-term memory networks: input gates, forgetting gates, and output gates.
  • The hidden layer output of a long short-term memory network consists of "hidden states" and "memory cells". Only the hidden state is passed to the output layer, and the memory is entirely internal information.
  • A long short-term memory network can mitigate vanishing gradients and exploding gradients.

4.PyTorch code

  • Implement LSTM from scratch
import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # input gate parameters
    W_xf, W_hf, b_f = three()  # forget gate parameters
    W_xo, W_ho, b_o = three()  # Output gate parameters
    W_xc, W_hc, b_c = three()  # Candidate memory parameters
    # output layer parameters
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # additional gradient
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)


vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
  • Simple implementation
# Simple implementation
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

3. Deep Recurrent Neural Networks

H t ( l ) = ϕ l ( H t ( l − 1 ) W x h ( l ) + H t − 1 ( l ) W h h ( l ) + b h ( l ) ) \pmb{H}_t^{(l)}=\phi_l(\pmb{H}_t^{(l-1)}\pmb{W}_{xh}^{(l)}+\pmb{H}_{t-1}^{(l)}\pmb{W}_{hh}^{(l)}+\pmb{b}_h^{(l)}) HHHt(l)​=ϕl​(HHHt(l−1)​WWWxh(l)​+HHHt−1(l)​WWWhh(l)​+bbbh(l)​)

O t = H t ( l ) W h q + b q \pmb{O}_t=\pmb{H}_t^{(l)}\pmb{W}_{hq}+\pmb{b}_q OOOt​=HHHt(l)​WWWhq​+bbbq​

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, num_layers = len(vocab), 256, 2
num_inputs = vocab_size
device = d2l.try_gpu()
lstm_layer = nn.LSTM(num_inputs, num_hiddens, num_layers)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)

num_epochs, lr = 500, 2
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

4. Bidirectional Recurrent Neural Network

H → t = ϕ ( X t W x h ( f ) + H → t − 1 W h h ( f ) + b h ( f ) ) H ← t = ϕ ( X t W x h ( b ) + H ← t − 1 W h h ( b ) + b h ( b ) ) \overrightarrow{\pmb H}_t=\phi(\pmb{X}_t\pmb{W}_{xh}^{(f)}+\overrightarrow{\pmb H}_{t-1}\pmb{W}_{hh}^{(f)}+\pmb{b}_h^{(f)})\\ \overleftarrow{\pmb H}_t=\phi(\pmb{X}_t\pmb{W}_{xh}^{(b)}+\overleftarrow{\pmb H}_{t-1}\pmb{W}_{hh}^{(b)}+\pmb{b}_h^{(b)}) HHH t​=ϕ(XXXt​WWWxh(f)​+HHH t−1​WWWhh(f)​+bbbh(f)​)HHH t​=ϕ(XXXt​WWWxh(b)​+HHH t−1​WWWhh(b)​+bbbh(b)​)

forward hidden state H → t \overrightarrow{\pmb H}_t HHH t​and the reverse hidden state H ← t \overleftarrow{\pmb H}_t HHH t​Spliced ​​together to get the hidden state of the output layer H t ∈ R n × 2 h \pmb H_t\in\mathbb R^{n \times2h} HHHt​∈Rn×2h
H t = H → t ⊕ H ← t O t = H t W h q + b q \pmb{H}_t=\overrightarrow{\pmb H}_t\oplus\overleftarrow{\pmb H}_t\\ \pmb{O}_t=\pmb{H}_t\pmb{W}_{hq}+\pmb{b}_q HHHt​=HHH t​⊕HHH t​OOOt​=HHHt​WWWhq​+bbbq​

  • Notice:

    Bidirectional recurrent neural networks are generally not used for prediction, because the latter cannot be seen when predicting

Tags: Pytorch rnn lstm gru

Posted by MockY on Sat, 13 Aug 2022 21:35:12 +0530