簡體   English   中英

Tensorflow 2:序列化和解碼時形狀不匹配

[英]Tensorflow 2: shape mismatch when serialize and decode it back

我有一個形狀為 (300,256,256) 的張量 A。 我想序列化 A 以保存為 tfrecord 格式。 但我無法將其轉換回具有相同形狀的張量。

A = tf.convert_to_tensor( *a numpy array with float32 type* )
B = tf.io.serialize_tensor(A)
C = tf.reshape(tf.io.decode_raw(B, out_type=tf.float32),[300,256,256])

如果我運行上面的代碼,我會得到一個形狀錯誤:

tensorflow.python.framework.errors_impl.InvalidArgumentError:reshape 的輸入是一個具有 19660806 值的張量,但請求的形狀有 19660800 [Op:Reshape]

似乎當我序列化或解碼時,添加了 6 個浮點數。 (很奇怪)

嘗試使用: tf.io.parse_tensor() ,而不是tf.io.decode_raw()

https://www.tensorflow.org/api_docs/python/tf/io/parse_tensor

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM