Introduction to GAN

introduce

GAN has always been a popular model for deep learning and has been applied in various fields. Whether it is in CV, NLP, AR, VR, the addition of GAN makes them more three-dimensional and vivid. Recently, the concept of meta universe has been heated up, in which GAN has played a great role. Let's take a look at some examples of GAN.

Yes, it is. Of course, it is for GAN. Let's take a look at what GAN is.

What is GAN

Gan (generative adversarial networks) is a deep learning model, which is one of the most promising methods for unsupervised learning on complex distribution in recent years.

In the GAN model, there are generally two modules: the generative model and the discriminative model. Their mutual game and learning make them make common progress and finally produce the goal of hard to distinguish between true and false.
The generator is like a liar who makes counterfeit money, and the discriminator is like a policeman who tests counterfeit money. In the beginning, the quality of the counterfeit money made by the liar is very poor, and it can be identified by the police at a glance. At this time, the liar needs to improve his counterfeiting technology to deceive the police. As the counterfeit money made by the liar becomes more and more like real, it is difficult for the police to identify the authenticity of the counterfeit money, At this time, the police also need to improve their identification ability through learning. In this way, the police and swindlers continue to cheat and learn, and eventually generate bills that are difficult to distinguish between true and false.

Network framework of GAN

Taking the generated picture as an example, the Generator is a generation network that receives a random noise noise, and generates a picture through this noise, which is recorded as Generator Data. The Discriminator is a discrimination network that determines whether a picture is true or not. Its output is a probability value. If it is 1, it means 100% of the picture is real, and if it is 0, it means 100% of the picture is false.

Generator implementation code

Taking the handwritten dataset MNIST as an example, the input of the generated network is a row of random numbers with normal distribution, so its input is a one-dimensional vector with length N, and its output is a (28,28,1) dimensional picture. The following code is based on the Keras framework.

def build_generator(self):
    # --------------------------------- #
    #   Generator, entering a string of random numbers
    # --------------------------------- #
    model = Sequential()

    model.add(Dense(256, input_dim=self.latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    model.add(Reshape(self.img_shape))

    noise = Input(shape=(self.latent_dim,))
    img = model(noise)

    return Model(noise, img)

Discriminator implementation code

The purpose of the discriminator is to judge whether the input picture is true or false. Therefore, it is a (28,28,1) dimensional picture, and the output is a number between 0 and 1. 1 means that the picture is true, and 0 means that the picture is false.

def build_discriminator(self):
    # ----------------------------------- #
    #   Evaluator, which evaluates the input image. Insert a code slice here`
    # ----------------------------------- #
    model = Sequential()
    # Enter a picture
    model.add(Flatten(input_shape=self.img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    # Judge the truth
    model.add(Dense(1, activation='sigmoid'))

    img = Input(shape=self.img_shape)
    validity = model(img)

    return Model(img, validity)

Optimization training of GAN model

In the training process, the goal of the generator is to generate real pictures as much as possible to deceive the discriminator. The goal of the discriminator is to distinguish the real pictures generated by the generator as far as possible. In this way, the generator and the discriminator constitute a dynamic game process. In order to deeply understand this game process, let's first understand what is Nash equilibrium.

Nash equilibrium refers to a situation in the game in which each participant cannot change his strategy as long as other participants do not change their decisions. Corresponding to the GAN, the generator produces data identical to the real data, and the discriminator can no longer distinguish the result. The accuracy rate is 50%. The more it is, the more it is a random guess. This means that both networks have maximized their interests and do not change their own strategies, that is, they do not update their own network weights.

The objective function of GAN model is as follows:

After such confrontation training, the effect may have several processes. The figure drawn in the original paper is as follows:

The black line represents the actual distribution of data x, the green line represents the generated distribution of data, and the blue line represents the distribution effect of the generated data in the discriminator
The overall algorithm of the original paper:

All codes

rom __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os
import numpy as np

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   Row 28, column 28, that is, the shape of mnist
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam optimizer
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # Do not train discriminator when training generate
        self.discriminator.trainable = False
        # Predict the generated fake picture
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):
        # --------------------------------- #
        #   Generator, entering a string of random numbers
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   Evaluator, which evaluates the input pictures
        # ----------------------------------- #
        model = Sequential()
        # Enter a picture
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # Judge the truth
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # Get data
        (X_train, _), (_, _) = mnist.load_data()

        # Standardize
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # create label
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   Randomly select batch_size pictures
            #   Train discriminator
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  Training generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=30000, batch_size=256, sample_interval=200)

Finally, the result of GAN generation is as follows

Tags: AI Deep Learning Computer Vision

Posted by pahunrepublic on Sat, 04 Jun 2022 00:50:52 +0530