TFRecord
1. What’s TFRecord?
A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord.
The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example. This is a record format based on Google’s protobufs but designed for TensorFlow. It’s more or less like a dict with some type of annotations.
Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data.
我们需要先创建 tf.train.Example,然后将其序列化,并写入 .tfrecords 文件。
好处:
- 我们可以将数据切分为多个文件,然后我们可以并行化计算。
建议:
TensorFlow 官方建议如果有 N 个主机,则切分 10 * N 个文件,而每个文件的大小最好在 10 MB 之上。
tf.train.Example 是一种字典,其形式是:
{"string": tf.train.Feature}
而 tf.train.Feature 信息(message)可以有三种形式:
tf.train.BytesList– string, byte.tf.train.FloatList– float (float32), double (float64).tf.train.Int64List– bool, enum, int32, uint32, int64, uint64.
我们可以通过如下的方式将各种数据转换为 tf.train.Feature:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
以上针对的是标量的输入,要处理非标量的特征,最简单的形式是:tf.io.serialize_tensor,可以将张量转换为二进制的字符串。我们可以使用 tf.io.parse_tensor 将二进制字符串转换回张量。
2. How to create a TFRecord?
To create a TFRecord dataset, the whole process might look something like:
- Build a dataset with
tf.data.Dataset. You could use thefrom_generatororfrom_tensor_slicesmethods. - Serialize the dataset by iterating over the dataset with
make_example. - Write the dataset to TFRecords with
io.TFRecordWriterordata.TFRecordWriter.
features = Features(feature={
'image': image_feature,
'label': label_feature,
'class_name': class_name_feature,
})
# Wrap with Example
example = Example(features=features)
example_bytes = example.SerializeToString()
Example – Create imagenet to TFRecord.
3. How to load a TFRecord?
FILENAMES = tf.io.gfile.glob("./validation_tf_records/validation-*")
split_ind = int(0.9 * len(FILENAMES))
TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:]
BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
AUTOTUNE = tf.data.experimental.AUTOTUNE
def decode_image(image):
image = tf.image.decode_image(image, channels=3)
image = tf.cast(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, 224, 224)
# image = image / 255.
return image
def read_tfrecord(example, labeled):
tfrecord_format = (
{
"image/encoded": tf.io.FixedLenFeature([], tf.string),
"image/class/label": tf.io.FixedLenFeature([], tf.int64),
}
if labeled
else {"image/encoded": tf.io.FixedLenFeature([], tf.string),}
)
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example["image/encoded"])
if labeled:
label = tf.cast(example["image/class/label"], tf.int32)
return image, label - 1
return image
def load_dataset(filenames, labeled=True):
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False # disable order, increase speed
dataset = tf.data.TFRecordDataset(
filenames
) # automatically interleaves reads from multiple files
dataset = dataset.with_options(
ignore_order
) # uses data as soon as it streams in, rather than in its original order
dataset = dataset.map(
partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE
)
# returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
return dataset
def get_dataset(filenames, labeled=True):
dataset = load_dataset(filenames, labeled=labeled)
dataset = dataset.shuffle(2048)
dataset = dataset.prefetch(buffer_size=AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
return dataset
train_dataset = get_dataset(TRAINING_FILENAMES)
valid_dataset = get_dataset(VALID_FILENAMES)
上述的代码是读取 ImageNet TFRecord 的例子。
4. Inspect TFRecord
import tensorflow as tf
raw_dataset = tf.data.TFRecordDataset("path-to-file")
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)