有很多关于如何创建和使用 TensorFlow 数据集的示例,例如 我的问题是如何以 numpy 形式从 TF 数据集中取回数据/标签? 换句话说,想要的是上面那一行的反向操作,即我有一个 TF 数据集,想从中取回图像和标签。 ...
提示:本站收集StackOverFlow近2千万问答,支持中英文搜索,鼠标放在语句上弹窗显示对应的参考中文或英文, 本站还提供 中文繁体 英文版本 中英对照 版本,有任何建议请联系yoyou2525@163.com。
我有一个序列化的TensorFlow示例协议缓冲区的TFRecord文件数据集,每个注释带有一个示例原型,可从https://magenta.tensorflow.org/datasets/nsynth下载。 我正在使用大约1 Gb的测试仪,以防有人要下载它,请检查下面的代码。 每个示例都包含许多功能:音高,乐器...
读取此数据的代码是:
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
# Reading input data
dataset = tf.data.TFRecordDataset('../data/nsynth-test.tfrecord')
# Convert features into tensors
features = {
"pitch": tf.FixedLenFeature([1], dtype=tf.int64),
"audio": tf.FixedLenFeature([64000], dtype=tf.float32),
"instrument_family": tf.FixedLenFeature([1], dtype=tf.int64)}
parse_function = lambda example_proto: tf.parse_single_example(example_proto,features)
dataset = dataset.map(parse_function)
# Consuming TFRecord data.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size=3)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess.run(batch)
现在,音高的范围是21到108。但是,我只考虑给定音高的数据,例如,pitch =51。如何从整个数据集中提取“ pitch = 51”子集? 或者,如何使我的迭代器仅遍历此子集?
您所拥有的看起来不错,所缺少的只是一个过滤器功能。
例如,如果您只想提取pitch = 51,则应在地图函数之后添加
dataset = dataset.filter(lambda example: tf.equal(example["pitch"][0], 51))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.