[英]Tensorflow - use a tensor as an index
我想使用向后累積和功能:
def _backwards_cumsum(x, length, batch_size):
upper_triangular_ones = np.float32(np.triu(np.ones((length, length))))
repeated_tri = np.float32(np.kron(np.eye(batch_size), upper_triangular_ones))
return tf.matmul(repeated_tri,
tf.reshape(x, [length, 1]))
但是長度是一個占位符:
length = tf.placeholder("int32" ,name = 'xx')
因此,每次獲得新值后,就開始計算_backwards_cumsum。
嘗試運行該函數后,出現錯誤:
TypeError: 'Tensor' object cannot be interpreted as an index
完整的回溯:
{
TypeError Traceback (most recent call last)
<ipython-input-561-970ae9e96aa1> in <module>()
----> 1 rewards = _backwards_cumsum(tf.reshape(tf.reshape(decays,[-1,1]) * tf.sigmoid(disc_pred_gen_ph), [-1]), _maxx, batch_size)
<ipython-input-546-5c6928fac357> in _backwards_cumsum(x, length, batch_size)
1 def _backwards_cumsum(x, length, batch_size):
2
----> 3 upper_triangular_ones = np.float32(np.triu(np.ones((length, length))))
4 repeated_tri = np.float32(np.kron(np.eye(batch_size), upper_triangular_ones))
5 return tf.matmul(repeated_tri,
/Users/onivron/anaconda/envs/tensorflow/lib/python2.7/site-packages/numpy/core/numeric.pyc in ones(shape, dtype, order)
190
191 """
--> 192 a = empty(shape, dtype, order)
193 multiarray.copyto(a, 1, casting='unsafe')
194 return a
其中_maxx與上面的length占位符相同。
有什么解決方法嗎?
該錯誤與您在不知不覺中用於numpy array: length
張量對象有關。 在tensorflow中使用numpy功能的最好方法是使用tf.py_func
。
# Define a new function that only depends on numpy/any non tensorflow graph object
def get_repeated_tri(length, batch_size):
upper_triangular_ones = np.float32(np.triu(np.ones((length, length))))
repeated_tri = np.float32(np.kron(np.eye(batch_size), upper_triangular_ones))
return repeated_tri
# Here length and batch size must be non tensor object
repeated_tri = tf.py_func(get_repeated_tri, [length, batch_size], tf.int32)
# there're some size mismacthes also in your code `tf.matmul`
def _backwards_cumsum(repeated_tri, x, length_, batch_size):
return tf.matmul(repeated_tri, tf.reshape(x, [length_*batch_size, -1]))
length_ = tf.placeholder(tf.int32, name='length')
# also define length, batch_size as nump constants
# x as tensorflow tensor
some_tensor_out= _backwards_cumsum(repeated_tri, x, length_, batch_size)
some_tensor_out_ = sess.run(some_tensor_out, {length_:length})
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.