1. What’s TFRecord?

A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord.

The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example. This is a record format based on Google’s protobufs but designed for TensorFlow. It’s more or less like a dict with some type of annotations.

Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data.

我们需要先创建 tf.train.Example,然后将其序列化,并写入 .tfrecords 文件。


  1. 我们可以将数据切分为多个文件,然后我们可以并行化计算。


TensorFlow 官方建议如果有 N 个主机,则切分 10 * N 个文件,而每个文件的大小最好在 10 MB 之上。

tf.train.Example 是一种字典,其形式是:

{"string": tf.train.Feature}

tf.train.Feature 信息(message)可以有三种形式:

  1. tf.train.BytesList – string, byte.
  2. tf.train.FloatList – float (float32), double (float64).
  3. tf.train.Int64List – bool, enum, int32, uint32, int64, uint64.

我们可以通过如下的方式将各种数据转换为 tf.train.Feature:

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

以上针对的是标量的输入,要处理非标量的特征,最简单的形式是,可以将张量转换为二进制的字符串。我们可以使用 将二进制字符串转换回张量。

2. How to create a TFRecord?

To create a TFRecord dataset, the whole process might look something like:

  1. Build a dataset with You could use the from_generator or from_tensor_slices methods.
  2. Serialize the dataset by iterating over the dataset with make_example.
  3. Write the dataset to TFRecords with io.TFRecordWriter or data.TFRecordWriter.
features = Features(feature={
    'image': image_feature,
    'label': label_feature,
    'class_name': class_name_feature,

# Wrap with Example
example = Example(features=features)

example_bytes = example.SerializeToString()

Example – Create imagenet to TFRecord.

3. How to load a TFRecord?

FILENAMES ="./validation_tf_records/validation-*")
split_ind = int(0.9 * len(FILENAMES))

IMAGE_SIZE = (224, 224)

def decode_image(image):
    image = tf.image.decode_image(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.image.resize_with_crop_or_pad(image, 224, 224)
#     image = image / 255.
    return image

def read_tfrecord(example, labeled):
    tfrecord_format = (
            "image/encoded":[], tf.string),
            "image/class/label":[], tf.int64),
        if labeled
        else {"image/encoded":[], tf.string),}
    example =, tfrecord_format)
    image = decode_image(example["image/encoded"])
    if labeled:
        label = tf.cast(example["image/class/label"], tf.int32)
        return image, label - 1
    return image

def load_dataset(filenames, labeled=True):
    ignore_order =
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset =
    )  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
    )  # uses data as soon as it streams in, rather than in its original order
    dataset =
        partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE
    # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames, labeled=labeled)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

train_dataset = get_dataset(TRAINING_FILENAMES)
valid_dataset = get_dataset(VALID_FILENAMES)

上述的代码是读取 ImageNet TFRecord 的例子。

4. Inspect TFRecord

import tensorflow as tf 
raw_dataset ="path-to-file")

for raw_record in raw_dataset.take(1):
    example = tf.train.Example()


