6 min read

How to load a pretrained model in TensorFlow

At work I struggled for days to load a pretrained model into TensorFlow before giving up. Then a few weeks later I tried again and it worked fine. I’m going to try to do it again here to make sure I can get it to work.

Why you would use a pretrained model

Pretrained models are especially useful for image classification. The convolutional neural networks (CNNs) used for image classification often have eight or more layers and over a million parameters. To train this large a network you need a massive dataset and a lot of time to train the network. I don’t have these kinds of resources, but I can use a pretrained model and adapt to my needs. The lower levels of a CNN are generally just finding edges, lines, and basic shapes, regardless of what the images are that are given as input. Thus it would be a waste of my time to redo training of these basic concepts when I can use a pretrained network and just change the higher layers.

import tensorflow as tf
## C:\Users\cbe117\AppData\Local\CONTIN~1\ANACON~1\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
##   from ._conv import register_converters as _register_converters
print(tf.__version__)
## 1.9.0

How to save and load models in TensorFlow

TensorFlow has a guide on how to save and load models here, and a guide on how to export and import MetaGraphs here.

However, TensorFlow has terrible documentation on how to get pretrained models working. They have a list of pretrained models here. If you just have your images in folders for each label, then it looks like it should be pretty easy to use these models. However, I want to do a more hands-on approach.

Loading resnet

I’m going to follow the Stack Overflow question from here. First I download the inception_resnet_v2.py file. This file allows us to load the network structure into TF. If it’s not in the same path as your current path, you need to add its folder to your path.

import sys
sys.path.insert(0, 'C://Users//cbe117//Documents//GitHub//website-hugo//static//post//2018-08-04-how-to-load-a-pretrained-model-in-tensorflow')
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
print(inception_resnet_v2)
## <function inception_resnet_v2 at 0x0000000039CAB048>
slim = tf.contrib.slim

The model was trained with images that were 299 by 299 with three channels for colors, and to predict which class each images out of 1001 classes. We should get the original classes (and in the same order) since we set num_classes=1001. We can load the model structure with the following:

height = 299
width = 299
channels = 3
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_resnet_v2_arg_scope()):
     logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)

Next we can load the saved weights from the pretrained model. You can download the weights from here, picking inception_resnet_v2_2016_08_30.tar.gz.

This loads the network weights.

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "C://Users/cbe117/Documents/GitHub/website-hugo/static/post/2018-08-04-how-to-load-a-pretrained-model-in-tensorflow/inception_resnet_v2_2016_08_30.ckpt")

Predict on an image

Now I’ll to use a picture of a cat as input and see what the network outputs. I cropped the image to get mostly the head of the cat and resized it to be 299 by 299.

import sys
print(sys.path)
## ['C://Users//cbe117//Documents//GitHub//website-hugo//static//post//2018-08-04-how-to-load-a-pretrained-model-in-tensorflow', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\python36.zip', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\DLLs', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\lib', 'C:\\PROGRA~1\\R\\R-35~1.1\\bin\\x64', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\lib\\site-packages', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\lib\\site-packages\\win32', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\lib\\site-packages\\win32\\lib', 'C:\\Users\\cbe117\\AppData\\Local\\CONTIN~1\\ANACON~1\\lib\\site-packages\\Pythonwin', 'C:/Users/cbe117/Documents/R/win-library/3.5/reticulate/python']
import matplotlib.pyplot as plt
cat = plt.imread('C://Users/cbe117/Documents/GitHub/website-hugo/static//post//2018-08-04-how-to-load-a-pretrained-model-in-tensorflow//cat299.jpg')

This is what the cat looks like:

knitr::include_graphics('/post/2018-08-04-how-to-load-a-pretrained-model-in-tensorflow/cat299.jpg')

catlogits = sess.run(logits, feed_dict={X:cat.reshape(1,299,299,3)})
print(catlogits.shape)
## (1, 1001)
print(catlogits)
## [[ -51.25375   -19.919136   54.14626  ...  222.14777  -177.31622
##     34.811455]]
import numpy as np
print(np.sort(catlogits[0,:])[-5:])
## [564.11554 588.2493  615.571   616.497   738.8386 ]
print(np.argsort(catlogits[0,:])[-5:])
## [675 741 883 424 918]

Finding the class names

To figure out which classes these are, we need a lookup table for our network. It can be found here. Here I’ll create a list with the class names from this file. Note that the first class should be an unknown class, so we need to add one to the front.

with open("C://Users//cbe117//Documents//GitHub//website-hugo//static//post//2018-08-04-how-to-load-a-pretrained-model-in-tensorflow//imagenet1000_clsid_to_human.txt", "r") as file:
    lines = [line for line in file]
linesclean = [line.strip().replace("'",'').replace('{','').replace('}','') for line in lines]
linesclean2 = [line.split(":")[1].strip().replace("",'') for line in linesclean]
classes = ["unknown"] + linesclean2
print(len(classes))
## 1001
print(classes[0:10])
## ['unknown', 'tench, Tinca tinca,', 'goldfish, Carassius auratus,', 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias,', 'tiger shark, Galeocerdo cuvieri,', 'hammerhead, hammerhead shark,', 'electric ray, crampfish, numbfish, torpedo,', 'stingray,', 'cock,', 'hen,']

Now we can see what the predicted classes were.

# 5 most likely classes, most likely is last
print([classes[i] for i in np.argsort(catlogits[0,:])[-5:]])
## ['mousetrap,', 'power drill,', 'vacuum, vacuum cleaner,', 'barber chair,', 'comic book,']

The most likely class according to the prediction is comic book. This is really bad.

Preprocessing the images

After some more searching online, I figured out the problem. The inputs of Resnet should be preprocessed to be in range -1 to 1, not 0 to 255.

def preprocess_input(x): 
   x = np.divide(x, 255.0) 
   x = np.subtract(x, 0.5) 
   x = np.multiply(x, 2.0) 
   return x
   
catlogits2= sess.run(logits, feed_dict={X:preprocess_input(cat.reshape(1,299,299,3))})
print(sess.run(end_points, feed_dict={X:preprocess_input(cat.reshape(1,299,299,3))})['Predictions'][0,np.argsort(catlogits2[0,:])[-5:]])
## [0.00238993 0.03550951 0.20014726 0.20269492 0.49091274]
print([classes[i] for i in np.argsort(catlogits2[0,:])[-5:]])
## ['Siamese cat, Siamese,', 'lynx, catamount,', 'Egyptian cat,', 'tiger cat,', 'tabby, tabby cat,']

Now it looks right, the top 5 prediction probabilities and categories are printed above in reverse order. The top five classes are all cats, with the most likely class a tabby cat with only 49% probability, followed by 20% for both tiger cat and Egyptian cat. I’m not sure what kind of cat it actually is, it looks like a tabby or tiger cat, so these predictions are very good.

Predicting on another cat

Let’s try on another cat, from here.

knitr::include_graphics('/post/2018-08-04-how-to-load-a-pretrained-model-in-tensorflow/catb299.jpg')

catb = plt.imread('C://Users//cbe117//Documents//GitHub//website-hugo//static//post//2018-08-04-how-to-load-a-pretrained-model-in-tensorflow//catb299.jpg')
catblogits2= sess.run(logits, feed_dict={X:preprocess_input(catb.reshape(1,299,299,3))})
print(np.sort(catblogits2[0,:]))
## [-1.5238764 -1.2445782 -1.2359526 ...  4.8217397  5.4154243  9.657993 ]
print(sess.run(end_points, feed_dict={X:preprocess_input(catb.reshape(1,299,299,3))})['Predictions'][0,np.argsort(catblogits2[0,:])[-5:]])
## [5.4889981e-04 8.1695215e-04 7.2656688e-03 1.3155567e-02 9.1544843e-01]
print([classes[i] for i in np.argsort(catblogits2[0,:])[-5:]])
## ['puck, hockey puck,', 'lynx, catamount,', 'tiger cat,', 'tabby, tabby cat,', 'Egyptian cat,']

I don’t know cats well, but some searching leads me to believe this is an Egyptian Mau cat, so it got it right. The probabilities show that it predicted it with 91.5% probability, so it was very confident. It may seem strange that hockey puck was the fifth most likely category, but it, like all below it, was assigned a miniscule probability.

Predicting a turtle

Let me try another very different type of animal. I’m going to use a loggerhead turtle image from here. Again I already resized it to 299 by 299.

knitr::include_graphics('/post/2018-08-04-how-to-load-a-pretrained-model-in-tensorflow/loggerhead299.jpeg')

turtle = plt.imread('C://Users//cbe117//Documents//GitHub//website-hugo//static//post//2018-08-04-how-to-load-a-pretrained-model-in-tensorflow//loggerhead299.jpeg')
turtlelogits2= sess.run(logits, feed_dict={X:preprocess_input(turtle.reshape(1,299,299,3))})
print(np.sort(turtlelogits2[0,:]))
## [-1.1727543 -1.1719702 -1.1606838 ...  4.608715   7.4410925  8.898018 ]
print(sess.run(end_points, feed_dict={X:preprocess_input(turtle.reshape(1,299,299,3))})['Predictions'][0,np.argsort(turtlelogits2[0,:])[-5:]])
## [0.00210884 0.00490913 0.00979606 0.1663939  0.7142859 ]
print([classes[i] for i in np.argsort(turtlelogits2[0,:])[-5:]])
## ['coral reef,', 'scuba diver,', 'terrapin,', 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea,', 'loggerhead, loggerhead turtle, Caretta caretta,']

Again it gets it right. And all the top five classes make sense, with most of the probability placed on loggerhead turtle, 17% on leatherback turtle, and less than 1% for any other class.

Conclusion

Here I have shown how to load a pretrained network in TensorFlow. Specifically I used the Inception-Resnet-v2, see this blog post from Google for more details. This net can classify images into 1001 categories. Using a pretrained network can be especially helpful when you want to train a net for your own specified categories since they will give you a good warm start.

While nearing writing the end of this post, I found another StackOverflow answer that goes over this same network with full code and does a better job than what I’ve put together here.