簡體   English   中英

如何將 Float 數組/列表轉換為 TFRecord?

[英]How to convert Float array/list to TFRecord?

這是用於將數據轉換為 TFRecord 的代碼

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

 def _bytes_feature(value):
   return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _floats_feature(value):
   return tf.train.Feature(float_list=tf.train.FloatList(value=value))

with tf.python_io.TFRecordWriter("train.tfrecords") as writer:
    for row in train_data:
        prices, label, pip = row[0],row[1],row[2]
        prices = np.asarray(prices).astype(np.float32)
        example = tf.train.Example(features=tf.train.Features(feature={
                                           'prices': _floats_feature(prices),
                                           'label': _int64_feature(label[0]),
                                           'pip': _floats_feature(pip)
    }))
        writer.write(example.SerializeToString())

特征價格是一個形狀數組(1,288)。 轉換成功了! 但是當使用解析函數和數據集 API 解碼數據時。

def parse_func(serialized_data):
    keys_to_features = {'prices': tf.FixedLenFeature([], tf.float32),
                    'label': tf.FixedLenFeature([], tf.int64)}

    parsed_features = tf.parse_single_example(serialized_data, keys_to_features)
    return parsed_features['prices'],tf.one_hot(parsed_features['label'],2)

它給了我錯誤

C:\\tf_jenkins\\workspace\\rel-win\\M\\windows-gpu\\PY\\36\\tensorflow\\core\\framework\\op_kernel.cc:1202] OP_REQUIRES 在 example_parsing_ops.cc:240 失敗:參數無效:鍵:價格。 無法解析序列化示例。 2018-03-31 15:37:11.443073: WC:\\tf_jenkins\\workspace\\rel-win\\M\\windows-gpu\\PY\\36\\tensorflow\\core\\framework\\op_kernel.cc:1202] OP_REQUIRES 在 example_parsing_op 失敗:240 : 無效參數:鍵:價格。 無法解析序列化示例。 2018-03-31 15:37:11.443313: WC:\\tf_jenkins\\workspace\\rel-win\\M\\windows-gpu\\ raise type(e)(node_def, op, message) PY\\36\\tensortensorflow.python.framework。 errors_impl.InvalidArgumentError: Key: 價格。 無法解析序列化示例。 [[節點:ParseSingleExample/ParseSingleExample = ParseSingleExample[Tdense=[DT_INT64, DT_FLOAT],dense_keys=["label", "prices"],dense_shapes=[[], []], num_sparse=0, sparse_keys=[types], =[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1)]] [[節點:IteratorGetNext_1 = IteratorGetNextoutput_shapes=[[?], [?,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job :localhost/replica:0/task:0/device:CPU:0"]]fflow\\core\\framework\\op_kernel.cc:1202] OP_REQUIRES 在 example_parsing_ops.cc:240 失敗:無效參數:關鍵:價格。 無法解析序列化示例。

我發現了問題。 不使用tf.io.FixedLenFeature來解析數組,而是使用tf.io.FixedLenSequenceFeature
(對於 TensorFlow 1,使用tf.而不是tf.io.

如果您的功能是固定的一維數組,則使用 tf.FixedLenSequenceFeature 根本不正確。 正如文檔中提到的, tf.FixedLenSequenceFeature 用於維度 2 及更高維度的輸入數據。 在此示例中,您需要將價格數組展平為 (288,),然后對於解碼部分,您需要提及數組維度。

編碼:

example = tf.train.Example(features=tf.train.Features(feature={
                                       'prices': _floats_feature(prices.tolist()),
                                       'label': _int64_feature(label[0]),
                                       'pip': _floats_feature(pip)

解碼:

keys_to_features = {'prices': tf.FixedLenFeature([288], tf.float32),
                'label': tf.FixedLenFeature([], tf.int64)}

您不能將 n 維數組存儲為浮點特征,因為浮點特征是簡單的列表。 您必須通過執行prices.tolist()prices壓平到列表中。 如果您需要從展平的浮點特征中恢復 n 維數組,那么您可以執行prices = np.reshape(float_feature, original_shape)

我在不小心修改一些腳本時遇到了同樣的問題,這是由數據形狀略有不同引起的。 我不得不改變形狀以匹配預期的形狀,例如(A, B)(1, A, B) 我使用np.ravel()進行展平。

TFrecord文件中讀取float32數據列表也發生了同樣的事情。

我得到無法解析執行時連載例sess.run([time_tensor, frequency_tensor, frequency_weight_tensor])tf.FixedLenFeature ,雖然tf.FixedLenSequenceFeature似乎是工作的罰款。

我讀取文件的特征格式(工作的)如下: feature_format = { 'time': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True), 'frequencies': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True), 'frequency_weights': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True) }

編碼部分是:

feature = { 'time': tf.train.Feature(float_list=tf.train.FloatList(value=[*some single value*]) ), 'frequencies': tf.train.Feature(float_list=tf.train.FloatList(value=*some_list*) ), 'frequency_weights': tf.train.Feature(float_list=tf.train.FloatList(value=*some_list*) ) }

這發生在 Debian 機器上的 TensorFlow 1.12 上,沒有 GPU 卸載(即只有 CPU 與 TensorFlow 一起使用)

我這邊有沒有誤用? 還是代碼或文檔中的錯誤? 如果對任何人都有好處,我可以考慮貢獻/上傳任何修復程序......

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM