返回

TFRecord 文件格式:深入浅出

人工智能

TensorFlow Record (TFRecord) 文件格式是 TensorFlow 数据集 API 推荐的二进制文件格式,用于存储和管理机器学习数据。它提供了高效和可扩展的方式来处理大型和复杂的训练数据集。

理解 TFRecord 文件

TFRecord 文件本质上是一个包含多个“示例”的序列化的协议缓冲区(Protobuf)文件。每个示例都是一个包含一组键值对的二进制记录。键是字符串,而值可以是各种类型,例如整数、浮点数或字符串。

TFRecord 文件的结构如下:

<header>
<example_length>
<example_data>
<example_length>
<example_data>
...

其中:

  • :包含有关 TFRecord 文件本身的信息,例如版本和文件模式。
  • <example_length> :一个 4 字节整数,指定随后 <example_data> 块的大小。
  • <example_data> :示例的序列化 Protobuf 数据。

TFRecord 文件的优势

使用 TFRecord 文件格式存储数据集具有以下优势:

  • 高效: TFRecord 文件是二进制格式,因此比其他格式(如 CSV 或 JSON)更紧凑和更高效。
  • 可扩展: TFRecord 文件可以处理大型数据集,而不会遇到性能问题。
  • 灵活: TFRecord 文件支持各种数据类型,允许存储复杂和多样化的数据集。
  • 可压缩: TFRecord 文件可以压缩,以进一步减少存储空间需求。
  • 标准化: TFRecord 是 TensorFlow 数据集 API 推荐的格式,确保了与 TensorFlow 生态系统的广泛兼容性。

创建和读取 TFRecord 文件

TensorFlow 提供了用于创建和读取 TFRecord 文件的实用程序。

创建 TFRecord 文件:

import tensorflow as tf

# 创建一个 TFRecordWriter 对象
writer = tf.io.TFRecordWriter("my_data.tfrecord")

# 创建一个示例
example = tf.train.Example(features=tf.train.Features(feature={
    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))

# 将示例序列化并写入 TFRecord 文件
writer.write(example.SerializeToString())

# 关闭 TFRecordWriter 对象
writer.close()

读取 TFRecord 文件:

import tensorflow as tf

# 创建一个 TFRecordReader 对象
reader = tf.io.TFRecordReader()

# 从文件读取一个示例
example = reader.read("my_data.tfrecord")

# 解析示例
features = tf.io.parse_single_example(example, features={
    "image": tf.io.FixedLenFeature([], tf.string),
    "label": tf.io.FixedLenFeature([], tf.int64)
})

# 提取数据
image_data = features["image"].numpy()
label = features["label"].numpy()