[英]How to convert Float array/list to TFRecord?
这是用于将数据转换为 TFRecord 的代码
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
with tf.python_io.TFRecordWriter("train.tfrecords") as writer:
for row in train_data:
prices, label, pip = row[0],row[1],row[2]
prices = np.asarray(prices).astype(np.float32)
example = tf.train.Example(features=tf.train.Features(feature={
'prices': _floats_feature(prices),
'label': _int64_feature(label[0]),
'pip': _floats_feature(pip)
}))
writer.write(example.SerializeToString())
特征价格是一个形状数组(1,288)。 转换成功了! 但是当使用解析函数和数据集 API 解码数据时。
def parse_func(serialized_data):
keys_to_features = {'prices': tf.FixedLenFeature([], tf.float32),
'label': tf.FixedLenFeature([], tf.int64)}
parsed_features = tf.parse_single_example(serialized_data, keys_to_features)
return parsed_features['prices'],tf.one_hot(parsed_features['label'],2)
它给了我错误
C:\\tf_jenkins\\workspace\\rel-win\\M\\windows-gpu\\PY\\36\\tensorflow\\core\\framework\\op_kernel.cc:1202] OP_REQUIRES 在 example_parsing_ops.cc:240 失败:参数无效:键:价格。 无法解析序列化示例。 2018-03-31 15:37:11.443073: WC:\\tf_jenkins\\workspace\\rel-win\\M\\windows-gpu\\PY\\36\\tensorflow\\core\\framework\\op_kernel.cc:1202] OP_REQUIRES 在 example_parsing_op 失败:240 : 无效参数:键:价格。 无法解析序列化示例。 2018-03-31 15:37:11.443313: WC:\\tf_jenkins\\workspace\\rel-win\\M\\windows-gpu\\ raise type(e)(node_def, op, message) PY\\36\\tensortensorflow.python.framework。 errors_impl.InvalidArgumentError: Key: 价格。 无法解析序列化示例。 [[节点:ParseSingleExample/ParseSingleExample = ParseSingleExample[Tdense=[DT_INT64, DT_FLOAT],dense_keys=["label", "prices"],dense_shapes=[[], []], num_sparse=0, sparse_keys=[types], =[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1)]] [[节点:IteratorGetNext_1 = IteratorGetNextoutput_shapes=[[?], [?,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job :localhost/replica:0/task:0/device:CPU:0"]]fflow\\core\\framework\\op_kernel.cc:1202] OP_REQUIRES 在 example_parsing_ops.cc:240 失败:无效参数:关键:价格。 无法解析序列化示例。
我发现了问题。 不使用tf.io.FixedLenFeature
来解析数组,而是使用tf.io.FixedLenSequenceFeature
(对于 TensorFlow 1,使用tf.
而不是tf.io.
)
如果您的功能是固定的一维数组,则使用 tf.FixedLenSequenceFeature 根本不正确。 正如文档中提到的, tf.FixedLenSequenceFeature 用于维度 2 及更高维度的输入数据。 在此示例中,您需要将价格数组展平为 (288,),然后对于解码部分,您需要提及数组维度。
编码:
example = tf.train.Example(features=tf.train.Features(feature={
'prices': _floats_feature(prices.tolist()),
'label': _int64_feature(label[0]),
'pip': _floats_feature(pip)
解码:
keys_to_features = {'prices': tf.FixedLenFeature([288], tf.float32),
'label': tf.FixedLenFeature([], tf.int64)}
您不能将 n 维数组存储为浮点特征,因为浮点特征是简单的列表。 您必须通过执行prices.tolist()
将prices
压平到列表中。 如果您需要从展平的浮点特征中恢复 n 维数组,那么您可以执行prices = np.reshape(float_feature, original_shape)
。
我在不小心修改一些脚本时遇到了同样的问题,这是由数据形状略有不同引起的。 我不得不改变形状以匹配预期的形状,例如(A, B)
到(1, A, B)
。 我使用np.ravel()
进行展平。
从TFrecord
文件中读取float32
数据列表也发生了同样的事情。
我得到无法解析执行时连载例sess.run([time_tensor, frequency_tensor, frequency_weight_tensor])
与tf.FixedLenFeature
,虽然tf.FixedLenSequenceFeature
似乎是工作的罚款。
我读取文件的特征格式(工作的)如下: feature_format = { 'time': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True), 'frequencies': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True), 'frequency_weights': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True) }
编码部分是:
feature = { 'time': tf.train.Feature(float_list=tf.train.FloatList(value=[*some single value*]) ), 'frequencies': tf.train.Feature(float_list=tf.train.FloatList(value=*some_list*) ), 'frequency_weights': tf.train.Feature(float_list=tf.train.FloatList(value=*some_list*) ) }
这发生在 Debian 机器上的 TensorFlow 1.12 上,没有 GPU 卸载(即只有 CPU 与 TensorFlow 一起使用)
我这边有没有误用? 还是代码或文档中的错误? 如果对任何人都有好处,我可以考虑贡献/上传任何修复程序......
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.