TensorFlow – Freezing Models

Find Output Node(s) From Checkpoint

  1. Read .meta file in the checkpoint.
  2. Use TensorBoard to read the graph.
  3. Summarize graph tool.

Freeze models

python /work/tensorflow/tensorflow/python/tools/freeze_graph.py --input_checkpoint="model.ckpt-0" --clear_devices=true --input_meta_graph=./model.ckpt-0.meta --output_node_names="save/restore_all" --output_graph='/home/jaszha02/mnist_fp32.pb' --input_binary=true

Transform (Quantize) models

bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul' \
--outputs='softmax' \
--transforms='
  add_default_attributes
  strip_unused_nodes(type=float, shape="1,299,299,3")
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  quantize_weights
  quantize_nodes
  strip_unused_nodes
  sort_by_execution_order'

Structure of checkpoint

  • meta file: describes the saved graph structure, includes GraphDef, SaverDef, and so on; then apply tf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.
  • index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the “data” files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
  • data file: it is TensorBundle collection, save the values of all variables.

Freeze graph from code

from tensorflow.python.framework import graph_util

input_graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def,  output_node_names=['output'])
tf.train.write_graph(output_graph_def, './', 'dc_tts.pb', False)

Load frozen model (.pb):

def load_graph(filename): 
"""Unpersists graph from file as default graph.""" 
    with tf.gfile.FastGFile(filename, 'rb') as f: 
        graph_def = tf.GraphDef() 
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

Print All Nodes in the Graph

print([n.name for n in tf.get_default_graph().as_graph_def().node])

Replace Input Node

import tensorflow as tf
from tensorflow.python.framework import graph_util
model_path = "DS_CNN_L.pb"
 
imported_graph = tf.Graph()
with tf.gfile.GFile(model_path, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    input_node = tf.placeholder(tf.float32, shape=(None,49,10,1), name="input_node")
    # Replace reshape node and the nodes above it with newly defined placeholder.
    imported_graph = tf.import_graph_def(graph_def, input_map={"Reshape_1:0": input_node}, name="")
 
with tf.Session(graph=imported_graph) as sess:
    input_graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names=['labels_softmax'])
    tf.train.write_graph(output_graph_def, './', 'ds_cnn_l_removed_mfcc.pb', False)