[英]Tensorflow 2: shape mismatch when serialize and decode it back
我有一個形狀為 (300,256,256) 的張量 A。 我想序列化 A 以保存為 tfrecord 格式。 但我無法將其轉換回具有相同形狀的張量。
A = tf.convert_to_tensor( *a numpy array with float32 type* )
B = tf.io.serialize_tensor(A)
C = tf.reshape(tf.io.decode_raw(B, out_type=tf.float32),[300,256,256])
如果我運行上面的代碼,我會得到一個形狀錯誤:
tensorflow.python.framework.errors_impl.InvalidArgumentError:reshape 的輸入是一個具有 19660806 值的張量,但請求的形狀有 19660800 [Op:Reshape]
似乎當我序列化或解碼時,添加了 6 個浮點數。 (很奇怪)
嘗試使用: tf.io.parse_tensor()
,而不是tf.io.decode_raw()
。
https://www.tensorflow.org/api_docs/python/tf/io/parse_tensor
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.