TFRecord
1. What’s TFRecord?
A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord.
The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example
. This is a record format based on Google’s protobufs but designed for TensorFlow. It’s more or less like a dict
with some type of annotations.
Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data.
我们需要先创建 tf.train.Example
,然后将其序列化,并写入 .tfrecords
文件。
好处:
- 我们可以将数据切分为多个文件,然后我们可以并行化计算。
建议:
TensorFlow 官方建议如果有 N 个主机,则切分 10 * N 个文件,而每个文件的大小最好在 10 MB 之上。
tf.train.Example
是一种字典,其形式是:
{"string": tf.train.Feature}
而 tf.train.Feature
信息(message)可以有三种形式:
tf.train.BytesList
– string, byte.tf.train.FloatList
– float (float32), double (float64).tf.train.Int64List
– bool, enum, int32, uint32, int64, uint64.
我们可以通过如下的方式将各种数据转换为 tf.train.Feature:
# The following functions can be used to convert a value to a type compatible # with tf.train.Example. def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
以上针对的是标量的输入,要处理非标量的特征,最简单的形式是:tf.io.serialize_tensor
,可以将张量转换为二进制的字符串。我们可以使用 tf.io.parse_tensor
将二进制字符串转换回张量。
2. How to create a TFRecord?
To create a TFRecord dataset, the whole process might look something like:
- Build a dataset with
tf.data.Dataset
. You could use thefrom_generator
orfrom_tensor_slices
methods. - Serialize the dataset by iterating over the dataset with
make_example
. - Write the dataset to TFRecords with
io.TFRecordWriter
ordata.TFRecordWriter
.
features = Features(feature={ 'image': image_feature, 'label': label_feature, 'class_name': class_name_feature, }) # Wrap with Example example = Example(features=features) example_bytes = example.SerializeToString()
Example – Create imagenet to TFRecord.
3. How to load a TFRecord?
FILENAMES = tf.io.gfile.glob("./validation_tf_records/validation-*") split_ind = int(0.9 * len(FILENAMES)) TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:] BATCH_SIZE = 32 IMAGE_SIZE = (224, 224) AUTOTUNE = tf.data.experimental.AUTOTUNE def decode_image(image): image = tf.image.decode_image(image, channels=3) image = tf.cast(image, tf.float32) image = tf.image.resize_with_crop_or_pad(image, 224, 224) # image = image / 255. return image def read_tfrecord(example, labeled): tfrecord_format = ( { "image/encoded": tf.io.FixedLenFeature([], tf.string), "image/class/label": tf.io.FixedLenFeature([], tf.int64), } if labeled else {"image/encoded": tf.io.FixedLenFeature([], tf.string),} ) example = tf.io.parse_single_example(example, tfrecord_format) image = decode_image(example["image/encoded"]) if labeled: label = tf.cast(example["image/class/label"], tf.int32) return image, label - 1 return image def load_dataset(filenames, labeled=True): ignore_order = tf.data.Options() ignore_order.experimental_deterministic = False # disable order, increase speed dataset = tf.data.TFRecordDataset( filenames ) # automatically interleaves reads from multiple files dataset = dataset.with_options( ignore_order ) # uses data as soon as it streams in, rather than in its original order dataset = dataset.map( partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE ) # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False return dataset def get_dataset(filenames, labeled=True): dataset = load_dataset(filenames, labeled=labeled) dataset = dataset.shuffle(2048) dataset = dataset.prefetch(buffer_size=AUTOTUNE) dataset = dataset.batch(BATCH_SIZE) return dataset train_dataset = get_dataset(TRAINING_FILENAMES) valid_dataset = get_dataset(VALID_FILENAMES)
上述的代码是读取 ImageNet TFRecord 的例子。
4. Inspect TFRecord
import tensorflow as tf raw_dataset = tf.data.TFRecordDataset("path-to-file") for raw_record in raw_dataset.take(1): example = tf.train.Example() example.ParseFromString(raw_record.numpy()) print(example)