I am trying to execute the following code
def f(x):
(_, H, W, C) = tf.shape(x)
x_reshaped = tf.reshape(x, (-1,C))
res = x_reshaped/(H*W*C)
return res
But, the problem here obviously is that I don't know H, W in advanced so they are ?,?. So the reshape and multiplication doesn't work. Now my question is, How to correctly do the above computation so that res
is a correct tensorflow Node that can be computed later in a Session?
The following should work:
X = tf.placeholder(tf.float32, shape=[None, None, None, 40])
def f(x):
s = tf.shape(x)
x_reshaped = tf.reshape(x, [-1,s[3]])
res = tf.div(x_reshaped, tf.cast((s[0]*s[1]*s[2]), tf.float32))
return res
out = f(X)
sess = tf.Session()
sess.run(out, {X:np.random.normal(size=(10,20,30,40))})
I'm assuming you want x
to be of shape (batch_size, H*W*C)
, which means that every item in x
is a "flattened" image data. In which case the correct code would be:
x_reshaped = tf.reshape(x, (-1, H*W*C))
But without seeing more of your code I can't be sure. For example if your neural network is designed as a convolution, it's wrong to reshape at all
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.