Deep learning practice part (10) -- TensorFlow learning path

Window of knowledge

PyTorch is an open source Python machine learning library based on Torch, which is used for applications such as natural language processing.

In january2017, the Facebook Artificial Intelligence Institute (FAIR) launched PyTorch based on Torch. It is a Python based renewable computing package that provides two advanced functions: 1. Powerful GPU accelerated tensor computing (such as NumPy). 2. Deep neural network with automatic derivation system.

The predecessor of PyTorch is Torch. Its bottom layer is the same as the Torch framework, but it uses Python to rewrite a lot of content. It is not only more flexible, supports dynamic graphs, but also provides Python interfaces. Developed by Torch7 team, it is a deep learning framework with Python as the priority. It can not only realize powerful GPU acceleration, but also support dynamic neural network.

PyTorch can be regarded as numpy with GPU support, and also as a powerful deep neural network with automatic derivation function. In addition to Facebook, it has been adopted by Twitter, CMU and Salesforce.


In last week's article, we learned to integrate all the codes (data preprocessing, network model, training code), and then conducted actual training. We must know that the training results of neural network are small. In addition to knowing the quality and effectiveness of the model, we also need to consider the actual test of the trained model, and it also needs to be deployed as an application in the future. Of course, we will not directly deploy it, Optimization, compression and pruning should also be considered.

1, Model prediction

Implementation steps:

1. save the model during training

2. write test code (data processing, model call, data test)

4. output model results and map them to real labels

1. save the model during training

#Add before training
# Generate a saver to store the trained model
saver = tf.train.Saver()

After each batch is trained, the test of the whole verification set is started (generally, an epoch is trained before verification). After the verification set is tested, if the accuracy is greater than the last test and greater than 80%, the model is considered to be saved, that is, the best model is finally saved.

 if avg_test_acc > pre_test_acc and avg_test_acc > 0.80:
checkpoint_path = os.path.join(logs_checkpoint,

2. test code

1. data preprocessing:

This place is the same as when training

# Get a picture
def get_one_image(img_dir):
    # Input parameter: train, path of training picture
    # Return parameter: image, randomly select a picture from the training picture
    #print("train", train)
    #n = len(train)
    #ind = np.random.randint(0, n)
    #img_dir = train[ind]  # Randomly select pictures for the test
    # img_dir = train

    img =
    #imag = img.resize([150, 150])  # Since the picture is in the preprocessing stage and resize d, this command can be omitted
    imge = tf.image.resize_images(img, (150, 150))
    image = tf.reshape(imge, [1, 150, 150, 3])
    #image = np.array(imge)

    image = image/255
    image = tf.cast(image, tf.float32)

    return image

2. model call

In fact, it is to import the saved model parameters into the current network for testing.

The current network only carries out forward propagation, not backward propagation.

saver = tf.train.Saver()

with tf.Session() as sess:
img_array =

print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('Loading success, global_step is %s' % global_step)

3. data test

# Test picture
def evaluate_one_image(image_array):
global graph
graph = tf.get_default_graph()
with graph.as_default():
  #image = tf.cast(image_array, tf.float32)

  x = tf.placeholder(tf.float32, shape=[1,150, 150, 3])

  logit = model.inference(x, BATCH_SIZE, N_CLASSES,1)

  logit = tf.nn.softmax(logit)

4. output results:

prediction =,feed_dict={x: img_array})
max_index = np.argmax(prediction)
# print(max_index)
# A dictionary or list can be selected for label mapping
label_dict = {0: 'cat', 1: 'dog'}
label_list = ['cat','dog']
print("The output of the model is{},The corresponding real label is:{}".format(max_index,label_list[max_index]))

All test codes:

Actual forecast display

We can see that what we read is the picture of the dog under test. Then the prediction tag of the network is 1. The tag given to the dog was 1, that is, the actual tag mapped is dog, and the prediction is correct.


This sharing is over. It is a complete process of the image classification project. From data processing to network construction, to training, to calling the model for prediction, we have shared the details and annotated the code details. I am sure you can understand it. If you have any doubts, please feel free to go back to the background.

Although this project is over, I believe that there are some places that we don't quite understand. Whether data processing or network construction, it may not be so simple. It doesn't matter. Next time, the editor will make a summary of the vulnerabilities in this project, which can be regarded as the summary of the image classification project. At the same time, you are welcome to ask more questions to promote our progress together.

Have a nice weekend. See you next time!

Editor: Yue yijushi reviewed by: Xiaoquan Jushi

Posted by gobbles on Wed, 01 Jun 2022 17:57:09 +0530