TensorFlow – Freezing Models
Find Output Node(s) From Checkpoint
- Read .meta file in the checkpoint.
- Use TensorBoard to read the graph.
- Summarize graph tool.
Freeze models
python /work/tensorflow/tensorflow/python/tools/freeze_graph.py --input_checkpoint="model.ckpt-0" --clear_devices=true --input_meta_graph=./model.ckpt-0.meta --output_node_names="save/restore_all" --output_graph='/home/jaszha02/mnist_fp32.pb' --input_binary=true
Transform (Quantize) models
bazel build tensorflow/tools/graph_transforms:transform_graph bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --in_graph=tensorflow_inception_graph.pb \ --out_graph=optimized_inception_graph.pb \ --inputs='Mul' \ --outputs='softmax' \ --transforms=' add_default_attributes strip_unused_nodes(type=float, shape="1,299,299,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes strip_unused_nodes sort_by_execution_order'
Structure of checkpoint
- meta file: describes the saved graph structure, includes GraphDef, SaverDef, and so on; then apply
tf.train.import_meta_graph('/tmp/model.ckpt.meta')
, will restoreSaver
andGraph
. - index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the “data” files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
- data file: it is TensorBundle collection, save the values of all variables.
Freeze graph from code
from tensorflow.python.framework import graph_util input_graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names=['output']) tf.train.write_graph(output_graph_def, './', 'dc_tts.pb', False)
Load frozen model (.pb):
def load_graph(filename): """Unpersists graph from file as default graph.""" with tf.gfile.FastGFile(filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='')
Print All Nodes in the Graph
print([n.name for n in tf.get_default_graph().as_graph_def().node])
Replace Input Node
import tensorflow as tf from tensorflow.python.framework import graph_util model_path = "DS_CNN_L.pb" imported_graph = tf.Graph() with tf.gfile.GFile(model_path, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) input_node = tf.placeholder(tf.float32, shape=(None,49,10,1), name="input_node") # Replace reshape node and the nodes above it with newly defined placeholder. imported_graph = tf.import_graph_def(graph_def, input_map={"Reshape_1:0": input_node}, name="") with tf.Session(graph=imported_graph) as sess: input_graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names=['labels_softmax']) tf.train.write_graph(output_graph_def, './', 'ds_cnn_l_removed_mfcc.pb', False)