[英]How to reshape and padding zeros in tf.Tensor
The following code is trying to convert the tensor into (x,y) dimension arrays in Tensorflow.以下代码试图将张量转换为 Tensorflow 中的 (x,y) 维数组。
The "a" can be convert to "b" by using this code, but the "c" can't.使用此代码可以将“a”转换为“b”,但“c”不能。
Here is the test code:这是测试代码:
def reshape_array(old_array, x, y):
new_array = tf.reshape(old_array, [-1])
current_size = tf.size(new_array)
reshape_size = tf.math.multiply(x, y)
diff = tf.math.subtract(reshape_size, current_size)
if tf.greater_equal(diff, tf.constant([0])):
new_array = tf.pad(new_array, [[0,0],[0, diff]], mode='CONSTANT', constant_values=0)
new_array = tf.reshape(new_array, (x, y))
else:
new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
new_array = tf.reshape(new_array, (x, y))
return tf.cast(new_array, old_array.dtype)
a = tf.zeros(256*192*1)
print("a.shape: {}".format(a.shape))
b = reshape_array(a, 28, 28)
print("b.shape: {}".format(b.shape))
c = tf.constant([1, 2, 3, 4, 5, 6])
print("c.shape: {}".format(c.shape))
d = reshape_array(c, 28, 28)
print("d.shape: {}".format(d.shape))
Here is the output:这是输出:
a.shape: (49152,)
b.shape: (28, 28)
c.shape: (6,)
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/tmp/ipykernel_7071/4036910860.py in <cell line: 26>()
24 c = tf.constant([1, 2, 3, 4, 5, 6])
25 print("c.shape: {}".format(c.shape))
---> 26 d = reshape_array(c, 28, 28)
27 print("d.shape: {}".format(d.shape))
/tmp/ipykernel_7071/4036910860.py in reshape_array(old_array, x, y)
9 diff = tf.math.subtract(reshape_size, current_size)
10 if tf.greater_equal(diff, tf.constant([0])):
---> 11 new_array = tf.pad(new_array, [[0,0],[0, diff]], mode='CONSTANT', constant_values=0)
12 new_array = tf.reshape(new_array, (x, y))
13 else:
/usr/local/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
52 try:
53 ctx.ensure_initialized()
---> 54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
InvalidArgumentError: The first dimension of paddings must be the rank of inputs[2,2] [6] [Op:Pad]
What's wrong in my code and how to fix?我的代码有什么问题以及如何解决?
You a working with a 1D tensor in your second example, so try:您在第二个示例中使用一维张量,因此请尝试:
import tensorflow as tf
def reshape_array(old_array, x, y):
new_array = tf.reshape(old_array, [-1])
current_size = tf.size(new_array)
reshape_size = tf.math.multiply(x, y)
diff = tf.math.subtract(reshape_size, current_size)
if tf.greater_equal(diff, tf.constant([0])):
print(diff)
new_array = tf.pad(new_array, [[0, diff]], mode='CONSTANT', constant_values=0)
new_array = tf.reshape(new_array, (x, y))
else:
new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
new_array = tf.reshape(new_array, (x, y))
return tf.cast(new_array, old_array.dtype)
a = tf.zeros(256*192*1)
print("a.shape: {}".format(a.shape))
b = reshape_array(a, 28, 28)
print("b.shape: {}".format(b.shape))
c = tf.constant([1, 2, 3, 4, 5, 6])
print("c.shape: {}".format(c.shape))
d = reshape_array(c, 28, 28)
print("d.shape: {}".format(d.shape))
In your case, I would generally prefer using tf.concat
for padding:在您的情况下,我通常更喜欢使用tf.concat
进行填充:
def reshape_array(old_array, x, y):
new_array = tf.reshape(old_array, [-1])
current_size = tf.size(new_array)
reshape_size = tf.math.multiply(x, y)
diff = tf.math.subtract(reshape_size, current_size)
if tf.greater_equal(diff, tf.constant([0])):
new_array = tf.concat([new_array, tf.repeat([0], repeats=diff)], axis=0)
new_array = tf.reshape(new_array, (x, y))
else:
new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
new_array = tf.reshape(new_array, (x, y))
return tf.cast(new_array, old_array.dtype)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.