简体   繁体   中英

How to properly get the shape in Tensorflow so that I can reshape again?

I am trying to execute the following code

def f(x):
    (_, H, W, C) = tf.shape(x)
    x_reshaped = tf.reshape(x, (-1,C))
    res =  x_reshaped/(H*W*C)
    return res

But, the problem here obviously is that I don't know H, W in advanced so they are ?,?. So the reshape and multiplication doesn't work. Now my question is, How to correctly do the above computation so that res is a correct tensorflow Node that can be computed later in a Session?

The following should work:

X = tf.placeholder(tf.float32, shape=[None, None, None, 40])

def f(x):
   s = tf.shape(x)
   x_reshaped = tf.reshape(x, [-1,s[3]])
   res =  tf.div(x_reshaped, tf.cast((s[0]*s[1]*s[2]), tf.float32))
   return res

out = f(X)

sess = tf.Session()
sess.run(out, {X:np.random.normal(size=(10,20,30,40))})

I'm assuming you want x to be of shape (batch_size, H*W*C) , which means that every item in x is a "flattened" image data. In which case the correct code would be:

x_reshaped = tf.reshape(x, (-1, H*W*C))

But without seeing more of your code I can't be sure. For example if your neural network is designed as a convolution, it's wrong to reshape at all

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM