![](/img/trans.png)
[英]Pytorch compute the mean of a 2D tensor at specific rows given a condition
[英]how to get sub 2d tensor with specific value condition in pytorch?
我想将二维火炬张量复制到仅包含值的目标张量,直到第一次出现 202 值和 rest 的零值,如下所示:
source_t=tensor[[101,2001,2034,1045,202,3454,3453,1234,202]
,[101,1999,2808,202,17658,3454,202,0,0]
,[101,2012,3832,4027,3454,202,3454,9987,202]]
destination_t=tensor[[101,2001,2034,1045,202,0,0,0,0]
,[101,1999,2808,202,0,0,0,0,0]
,[101,2012,3832,4027,3454,202,0,0,0]]
我该怎么做?
我制定了有效且非常有效的解决方案。
我制作了一些更复杂的源张量,在不同的地方添加了 202 行:
import copy
import torch
source_t = torch.tensor([[101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]
])
一开始,我们应该找到第一个 202 的出现。我们可以找到所有出现,然后选择第一个:
index_202 = (source_t == 202).nonzero(as_tuple=False).numpy()
rows_for_replace = list()
columns_to_replace = list()
elements = source_t.shape[1]
current_ind = 0
while current_ind < len(index_202)-1:
current = index_202[current_ind]
element_ind = current[1] + 1
rows_for_replace.extend([current[0]]*(elements-element_ind))
while element_ind < elements:
columns_to_replace.append(element_ind)
element_ind += 1
if current[0] == index_202[current_ind+1][0]:
current_ind += 1
current_ind += 1
在此操作之后,我们拥有了应该用零替换的所有索引。 第一行 4 个元素,第二行 5 个元素,第四行 3 个元素,第三和第五行没有。 rows_for_replace, columns_to_replace
([0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3], [5, 5, 5, 5, 4, 4, 4, 4, 4, 6, 6, 6])
然后我们只需复制源张量并设置新值:
destination_t = copy.deepcopy(source_t)
destination_t[rows_for_replace, columns_to_replace] = 0
摘要:source_t
tensor([[ 101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[ 101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
目的地_t
tensor([[ 101, 2001, 2034, 1045, 202, 0, 0, 0, 0],
[ 101, 1999, 2808, 202, 0, 0, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
我认为有一个更好的解决方案,它要求每一行都有一个“202”,否则你必须删除不是这种情况的行。
import torch
t = torch.tensor([[101,2001,2034,1045,202,3454,3453,1234,202],
[101,1999,2808,202,17658,3454,202,0,0],
[101,2012,3832,4027,3454,202,3454,9987,202]])
out = t.clone() # make copy
检查张量等于 202 的位置,将 boolean 转换为 int 并为每一行获取 argmax,这意味着我们有第一个 1 出现的列,它对应于第一个 202。
然后遍历每一行
cols = t.eq(202).int().argmax(1)
k = t.shape[1] # number of columns
for idx, c in enumerate(cols):
if c + 1 < k:
out[idx, c+1:] = 0 # make all values right of c equal to zero
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.