簡體   English   中英

從Tensorflow中的張量中刪除一組張量

[英]Remove a set of tensors from a tensor in Tensorflow

我正在尋找一種簡單的方法來從Tensorflow中的當前張量中刪除一組張量,並且我有一個困難的合理解決方案。

例如,假設我具有以下當前張量:

a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')

我想從該張量中刪除兩個項目(2.0和5.0)。

創建張量后將其轉換為[1.0、3.0、4.0、6.0]的最佳方法是什么?

提前謝謝了。

您可以調用tf.unstack以獲得子張量列表。 然后,您可以修改列表並調用tf.stack從列表構造張量。 例如,以下代碼從以下位置刪除[2.0,5.0]列:

a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
a_vecs = tf.unstack(a, axis=1)
del a_vecs[1]
a_new = tf.stack(a_vecs, 1)

另一種方法是使用split或slice函數。 這在張量很大時特別有用。

方法1:使用拆分功能。

a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
split1, split2, split3 = tf.split(a, [1, 1, 1], 1)
a_new = tf.concat([split1, split3], 1)

方法2:使用切片功能。

slice1 = tf.slice(a, [0, 0], [2, 1])
slice2 = tf.slice(a, [0, 2], [2, 1])
a_new = tf.concat([slice1, slice2], 1)

在這兩種情況下,a_new將具有

[[ 1.  3.]
 [ 4.  6.]]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM