[英]Remove a set of tensors from a tensor in Tensorflow
我正在尋找一種簡單的方法來從Tensorflow中的當前張量中刪除一組張量,並且我有一個困難的合理解決方案。
例如,假設我具有以下當前張量:
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
我想從該張量中刪除兩個項目(2.0和5.0)。
創建張量后將其轉換為[1.0、3.0、4.0、6.0]的最佳方法是什么?
提前謝謝了。
您可以調用tf.unstack
以獲得子張量列表。 然后,您可以修改列表並調用tf.stack
從列表構造張量。 例如,以下代碼從以下位置刪除[2.0,5.0]列:
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
a_vecs = tf.unstack(a, axis=1)
del a_vecs[1]
a_new = tf.stack(a_vecs, 1)
另一種方法是使用split或slice函數。 這在張量很大時特別有用。
方法1:使用拆分功能。
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
split1, split2, split3 = tf.split(a, [1, 1, 1], 1)
a_new = tf.concat([split1, split3], 1)
方法2:使用切片功能。
slice1 = tf.slice(a, [0, 0], [2, 1])
slice2 = tf.slice(a, [0, 2], [2, 1])
a_new = tf.concat([slice1, slice2], 1)
在這兩種情況下,a_new將具有
[[ 1. 3.]
[ 4. 6.]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.