TFRecord

1. What’s TFRecord?

A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord.

The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example. This is a record format based on Google’s protobufs but designed for TensorFlow. It’s more or less like a dict with some type of annotations.

Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data.

我们需要先创建 tf.train.Example,然后将其序列化,并写入 .tfrecords 文件。

好处:

  1. 我们可以将数据切分为多个文件,然后我们可以并行化计算。

建议:

TensorFlow 官方建议如果有 N 个主机,则切分 10 * N 个文件,而每个文件的大小最好在 10 MB 之上。

tf.train.Example 是一种字典,其形式是:

{"string": tf.train.Feature}

tf.train.Feature 信息(message)可以有三种形式:

  1. tf.train.BytesList – string, byte.
  2. tf.train.FloatList – float (float32), double (float64).
  3. tf.train.Int64List – bool, enum, int32, uint32, int64, uint64.

我们可以通过如下的方式将各种数据转换为 tf.train.Feature:

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

以上针对的是标量的输入,要处理非标量的特征,最简单的形式是:tf.io.serialize_tensor,可以将张量转换为二进制的字符串。我们可以使用 tf.io.parse_tensor 将二进制字符串转换回张量。

2. How to create a TFRecord?

To create a TFRecord dataset, the whole process might look something like:

  1. Build a dataset with tf.data.Dataset. You could use the from_generator or from_tensor_slices methods.
  2. Serialize the dataset by iterating over the dataset with make_example.
  3. Write the dataset to TFRecords with io.TFRecordWriter or data.TFRecordWriter.
features = Features(feature={
    'image': image_feature,
    'label': label_feature,
    'class_name': class_name_feature,
})

# Wrap with Example
example = Example(features=features)

example_bytes = example.SerializeToString()

Example – Create imagenet to TFRecord.

3. How to load a TFRecord?

FILENAMES = tf.io.gfile.glob("./validation_tf_records/validation-*")
split_ind = int(0.9 * len(FILENAMES))
TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:]

BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
AUTOTUNE = tf.data.experimental.AUTOTUNE

def decode_image(image):
    image = tf.image.decode_image(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.image.resize_with_crop_or_pad(image, 224, 224)
#     image = image / 255.
    return image

def read_tfrecord(example, labeled):
    tfrecord_format = (
        {
            "image/encoded": tf.io.FixedLenFeature([], tf.string),
            "image/class/label": tf.io.FixedLenFeature([], tf.int64),
        }
        if labeled
        else {"image/encoded": tf.io.FixedLenFeature([], tf.string),}
    )
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example["image/encoded"])
    if labeled:
        label = tf.cast(example["image/class/label"], tf.int32)
        return image, label - 1
    return image

def load_dataset(filenames, labeled=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(
        filenames
    )  # automatically interleaves reads from multiple files
    
    dataset = dataset.with_options(
        ignore_order
    )  # uses data as soon as it streams in, rather than in its original order
    
    dataset = dataset.map(
        partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE
    )
    # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames, labeled=labeled)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

train_dataset = get_dataset(TRAINING_FILENAMES)
valid_dataset = get_dataset(VALID_FILENAMES)

上述的代码是读取 ImageNet TFRecord 的例子。

4. Inspect TFRecord

import tensorflow as tf 
raw_dataset = tf.data.TFRecordDataset("path-to-file")

for raw_record in raw_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    print(example)

Reference

  1. TFRecords Basics
  2. TFRecord and tf.train.Example
  3. ImageNet to TFRecords
  4. How to train a Keras model on TFRecord files