繁体   English   中英

如何在 pytorch 中获得具有特定值条件的子二维张量?

[英]how to get sub 2d tensor with specific value condition in pytorch?

我想将二维火炬张量复制到仅包含值的目标张量,直到第一次出现 202 值和 rest 的零值,如下所示:

source_t=tensor[[101,2001,2034,1045,202,3454,3453,1234,202]
         ,[101,1999,2808,202,17658,3454,202,0,0]
         ,[101,2012,3832,4027,3454,202,3454,9987,202]]

destination_t=tensor[[101,2001,2034,1045,202,0,0,0,0]
                    ,[101,1999,2808,202,0,0,0,0,0]
                    ,[101,2012,3832,4027,3454,202,0,0,0]]

我该怎么做?

我制定了有效且非常有效的解决方案。

我制作了一些更复杂的源张量,在不同的地方添加了 202 行:

import copy
import torch

source_t = torch.tensor([[101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
                         [101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
                         [101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
                         [101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
                         [101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]
                         ])

一开始,我们应该找到第一个 202 的出现。我们可以找到所有出现,然后选择第一个:

index_202 = (source_t == 202).nonzero(as_tuple=False).numpy()
rows_for_replace = list()
columns_to_replace = list()

elements = source_t.shape[1]
current_ind = 0

while current_ind < len(index_202)-1:
    current = index_202[current_ind]
    element_ind = current[1] + 1
    rows_for_replace.extend([current[0]]*(elements-element_ind))
    while element_ind < elements:
        columns_to_replace.append(element_ind)
        element_ind += 1
    if current[0] == index_202[current_ind+1][0]:
        current_ind += 1
    current_ind += 1

在此操作之后,我们拥有了应该用零替换的所有索引。 第一行 4 个元素,第二行 5 个元素,第四行 3 个元素,第三和第五行没有。 rows_for_replace, columns_to_replace

([0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3], [5, 5, 5, 5, 4, 4, 4, 4, 4, 6, 6, 6])

然后我们只需复制源张量并设置新值:

destination_t = copy.deepcopy(source_t)
destination_t[rows_for_replace, columns_to_replace] = 0

摘要:source_t

tensor([[  101,  2001,  2034,  1045,   202,  3454,  3453,  1234,   202],
        [  101,  1999,  2808,   202, 17658,  3454,   202,     0,     0],
        [  101,  2012,  3832,  4027,  3454,  2020,  3454,  9987,  2020],
        [  101,  2012,  3832,  4027,  3454,   202,  3454,  9987,   202],
        [  101,  2012,  3832,  4027,  3454,  2020,  3454,  9987,   202]])

目的地_t

tensor([[ 101, 2001, 2034, 1045,  202,    0,    0,    0,    0],
        [ 101, 1999, 2808,  202,    0,    0,    0,    0,    0],
        [ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
        [ 101, 2012, 3832, 4027, 3454,  202,    0,    0,    0],
        [ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987,  202]])

我认为有一个更好的解决方案,它要求每一行都有一个“202”,否则你必须删除不是这种情况的行。

import torch

t = torch.tensor([[101,2001,2034,1045,202,3454,3453,1234,202],
                  [101,1999,2808,202,17658,3454,202,0,0],
                  [101,2012,3832,4027,3454,202,3454,9987,202]])
out = t.clone() # make copy

检查张量等于 202 的位置,将 boolean 转换为 int 并为每一行获取 argmax,这意味着我们有第一个 1 出现的列,它对应于第一个 202。

然后遍历每一行

cols = t.eq(202).int().argmax(1)
k = t.shape[1] # number of columns

for idx, c in enumerate(cols):
    if c + 1 < k:
        out[idx, c+1:] = 0 # make all values right of c equal to zero

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM