簡體   English   中英

如何從 Tensorflow 張量中刪除元素列表

[英]How to remove list of elements from a Tensorflow tensor

對於以下張量:

 <tf.Tensor: shape=(2, 10, 6), dtype=int64, numpy=
 array([[[  3,  16,  43,  10,   7, 431],
         [  3,   2,   6,   5,   7,   2],
         [  3,  37,   5,   7,   2,  12],
         [  3,   2,  11,   5,   7,   2],
         [  3,   2,   6,  18,  14, 195],
         [  3,   2,   6,   5,   7, 195],
         [  3,   2,   6,   5,   7,   9],
         [  3,   2,  11,   7,   2,  12],
         [  3,  16,  52,  92, 177, 923],
         [  3,   9,  43,  10,   7,   9]],
 
        [[  3,   2,  22, 495, 230,   4],
         [  3,   2,  22,   5, 102, 122],
         [  3,   2,  22,   5, 102, 230],
         [  3,   2,  22,   5,  70, 908],
         [  3,   2,  22,   5,  70, 450],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 230],
         [  3,   2,  22,  70,  34, 470],
         [  3,   2,  22, 855, 450,   4]]], dtype=int64)>)

我想刪除張量中的最后一個列表[ 3, 2, 22, 855, 450, 4] 我試過tf.unstack但它沒有用。

您也可以簡單地使用tf.ragged.boolean_mask來排除您不想要的行:

import tensorflow as tf

x = tf.constant([[[  3,  16,  43,  10,   7, 431],
         [  3,   2,   6,   5,   7,   2],
         [  3,  37,   5,   7,   2,  12],
         [  3,   2,  11,   5,   7,   2],
         [  3,   2,   6,  18,  14, 195],
         [  3,   2,   6,   5,   7, 195],
         [  3,   2,   6,   5,   7,   9],
         [  3,   2,  11,   7,   2,  12],
         [  3,  16,  52,  92, 177, 923],
         [  3,   9,  43,  10,   7,   9]],
 
        [[  3,   2,  22, 495, 230,   4],
         [  3,   2,  22,   5, 102, 122],
         [  3,   2,  22,   5, 102, 230],
         [  3,   2,  22,   5,  70, 908],
         [  3,   2,  22,   5,  70, 450],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 230],
         [  3,   2,  22,  70,  34, 470],
         [  3,   2,  22, 855, 450,   4]]])
x_shape = tf.shape(x)
remove = tf.constant([3, 2, 22, 855, 450, 4])

mask = tf.reduce_all(tf.equal(x, remove), axis=-1)
x = tf.ragged.boolean_mask(x, ~mask)
print(x)
<tf.RaggedTensor [[[3, 16, 43, 10, 7, 431],
  [3, 2, 6, 5, 7, 2],
  [3, 37, 5, 7, 2, 12],
  [3, 2, 11, 5, 7, 2],
  [3, 2, 6, 18, 14, 195],
  [3, 2, 6, 5, 7, 195],
  [3, 2, 6, 5, 7, 9],
  [3, 2, 11, 7, 2, 12],
  [3, 16, 52, 92, 177, 923],
  [3, 9, 43, 10, 7, 9]]     , [[3, 2, 22, 495, 230, 4],
                               [3, 2, 22, 5, 102, 122],
                               [3, 2, 22, 5, 102, 230],
                               [3, 2, 22, 5, 70, 908],
                               [3, 2, 22, 5, 70, 450],
                               [3, 2, 22, 5, 70, 122],
                               [3, 2, 22, 5, 70, 122],
                               [3, 2, 22, 5, 70, 230],
                               [3, 2, 22, 70, 34, 470]]]>

您可以在下面嘗試從張量中刪除最后一個列表:

sliced_tensor = tf.slice(tensor, [0, 0, 0], [2, 9, 6]) 

試試這個

new_tensor = tf.slice(tensor, [0,0,0], [2,9,6], [1,1,1])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM