Save & load Tensorflow models, variables and ops

The key part of saving and restoring tenor objects come from classtf.train.Saver provides methods for saving and restoring models. The constructor tf.train.Saver adds and save restore ops to the graph for all, or a specified list, of the variables in the graph. The objectSaver provides methods to run these ops, specifying paths for the checkpoint files to write to or read from. Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using aSaver. When you save your model you will see:

model.ckpt.data-00000-of-00001             

model.ckpt.index

model.ckpt.meta

Data file stores the variables from the Tensorflow model in the same order they were store. Metafile stores the graph, all the ops from the Tensorflow model in the same order they were a store.

1. Saving Model Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tensorflow as tf
save_path = "./"
save_file = "model"

# Create some variables.
# NOTE:It is important to name variables and ops you would like to execute 
a = tf.Variable([[1,1], [1,1]], dtype=tf.float32, name="a")
b = tf.Variable([[2,2], [2,2]], dtype=tf.float32, name="b")
c = tf.matmul(a,b, name = "op_matmul")

# initialize all variables.
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    output = sess.run(c)
    print("output: {}" .format(output))
    print("a: {}" .format(a.eval()))

    #saver to save and restore all the variables.
    saver = tf.train.Saver()
    saver.save(sess, save_path + save_file)
    print("Model saved in file: %s" % save_path + save_file)
output: [[ 4.  4.][ 4.  4.]]
a: [[ 1.  1.][ 1.  1.]]
Model saved in file: ./model

2. Restore Variables

Note: Tesnorflow has to know the name of the variables to restore them

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#RESER GRAPH
tf.reset_default_graph()

#Note: becase the name of the varialbe is not specified 
# tensorflow will not retore the store values to these variables 
a = tf.Variable([[0,0], [0,0]], dtype=tf.float32, name="a")
b = tf.Variable([[0,0], [0,0]], dtype=tf.float32, name="b")

# initialize all variables.
init_op = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init_op)
# saver to save and restore all the variables.
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(save_path))

print("a: {}" .format(a.eval(sess)))
print("b: {}" .format(b.eval(sess)))
INFO:tensorflow:Restoring parameters from ./model
a: [[ 1.  1.][ 1.  1.]]
b: [[ 2.  2.][ 2.  2.]]

3. Restore and assign a specific variable

1
2
3
4
new_saver = tf.train.Saver({'a': b, 'b': a})
new_saver.restore(sess, tf.train.latest_checkpoint(save_path))
print("a: {}" .format(a.eval(sess)))
print("b: {}" .format(b.eval(sess)))
INFO:tensorflow:Restoring parameters from ./model
a: [[ 2.  2.][ 2.  2.]]
b: [[ 1.  1.][ 1.  1.]]

If you try to run an op without restoring the graph you would get an error like this one

1
2
output = sess.run(c)
print("output: {}" .format(output))
ValueError: Fetch argument <tf.Tensor 'op_matmul:0' shape=(2, 2)
dtype=float32> cannot be interpreted as a Tensor.
(Tensor Tensor("op_matmul:0", shape=(2, 2), dtype=float32)
is not an element of this graph.)

4. Restore graph and execute restored ops

In this example we do not need to declare variables; they are already initialized within the restored graph(import_meta_graph)

1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf
#RESER GRAPH
tf.reset_default_graph()
save_path = "./"
save_file = "model"

sess = tf.Session()
saver = tf.train.import_meta_graph(save_path + save_file + '.meta')
saver.restore(sess, tf.train.latest_checkpoint(save_path))
#NOTE: It's important to name ops you would like to execute 
output = sess.run('op_matmul:0')

print("output: {}" .format(output))
INFO:tensorflow:Restoring parameters from ./model
output: [[ 4.  4.][ 4.  4.]]

5. Restore placeholder and operations

tf.reset_default_graph()

# Initilize sesion
sess = tf.Session()

# Restore graph and all the variables.
new_saver = tf.train.import_meta_graph(model_dir + 'model.ckpt.meta')
new_saver.restore(sess, tf.train.latest_checkpoint(model_dir))

# initialize graph
graph = tf.get_default_graph()

# load placeholders 
x = graph.get_tensor_by_name("x:0")
y = graph.get_tensor_by_name("y:0")

# load and run tensor opp
op_operation = graph.get_tensor_by_name("op_accuracy_operation:0")
batch_x,  batch_y = ...
accuracy = sess.run(op_operation, feed_dict={x: batch_x, y: batch_y})
Manuel Cuevas

Manuel Cuevas

Hello, I'm Manuel Cuevas a Software Engineer with background in machine learning and artificial intelligence.

Recent post