[英]How to remove list of elements from a Tensorflow tensor
對於以下張量:
<tf.Tensor: shape=(2, 10, 6), dtype=int64, numpy=
array([[[ 3, 16, 43, 10, 7, 431],
[ 3, 2, 6, 5, 7, 2],
[ 3, 37, 5, 7, 2, 12],
[ 3, 2, 11, 5, 7, 2],
[ 3, 2, 6, 18, 14, 195],
[ 3, 2, 6, 5, 7, 195],
[ 3, 2, 6, 5, 7, 9],
[ 3, 2, 11, 7, 2, 12],
[ 3, 16, 52, 92, 177, 923],
[ 3, 9, 43, 10, 7, 9]],
[[ 3, 2, 22, 495, 230, 4],
[ 3, 2, 22, 5, 102, 122],
[ 3, 2, 22, 5, 102, 230],
[ 3, 2, 22, 5, 70, 908],
[ 3, 2, 22, 5, 70, 450],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 230],
[ 3, 2, 22, 70, 34, 470],
[ 3, 2, 22, 855, 450, 4]]], dtype=int64)>)
我想刪除張量中的最后一個列表[ 3, 2, 22, 855, 450, 4]
。 我試過tf.unstack
但它沒有用。
您也可以簡單地使用tf.ragged.boolean_mask
來排除您不想要的行:
import tensorflow as tf
x = tf.constant([[[ 3, 16, 43, 10, 7, 431],
[ 3, 2, 6, 5, 7, 2],
[ 3, 37, 5, 7, 2, 12],
[ 3, 2, 11, 5, 7, 2],
[ 3, 2, 6, 18, 14, 195],
[ 3, 2, 6, 5, 7, 195],
[ 3, 2, 6, 5, 7, 9],
[ 3, 2, 11, 7, 2, 12],
[ 3, 16, 52, 92, 177, 923],
[ 3, 9, 43, 10, 7, 9]],
[[ 3, 2, 22, 495, 230, 4],
[ 3, 2, 22, 5, 102, 122],
[ 3, 2, 22, 5, 102, 230],
[ 3, 2, 22, 5, 70, 908],
[ 3, 2, 22, 5, 70, 450],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 230],
[ 3, 2, 22, 70, 34, 470],
[ 3, 2, 22, 855, 450, 4]]])
x_shape = tf.shape(x)
remove = tf.constant([3, 2, 22, 855, 450, 4])
mask = tf.reduce_all(tf.equal(x, remove), axis=-1)
x = tf.ragged.boolean_mask(x, ~mask)
print(x)
<tf.RaggedTensor [[[3, 16, 43, 10, 7, 431],
[3, 2, 6, 5, 7, 2],
[3, 37, 5, 7, 2, 12],
[3, 2, 11, 5, 7, 2],
[3, 2, 6, 18, 14, 195],
[3, 2, 6, 5, 7, 195],
[3, 2, 6, 5, 7, 9],
[3, 2, 11, 7, 2, 12],
[3, 16, 52, 92, 177, 923],
[3, 9, 43, 10, 7, 9]] , [[3, 2, 22, 495, 230, 4],
[3, 2, 22, 5, 102, 122],
[3, 2, 22, 5, 102, 230],
[3, 2, 22, 5, 70, 908],
[3, 2, 22, 5, 70, 450],
[3, 2, 22, 5, 70, 122],
[3, 2, 22, 5, 70, 122],
[3, 2, 22, 5, 70, 230],
[3, 2, 22, 70, 34, 470]]]>
您可以在下面嘗試從張量中刪除最后一個列表:
sliced_tensor = tf.slice(tensor, [0, 0, 0], [2, 9, 6])
試試這個
new_tensor = tf.slice(tensor, [0,0,0], [2,9,6], [1,1,1])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.