简体   繁体   English

使用条件对张量进行裁剪或阈值化,并在 PyTorch 中填充结果

[英]Clip or threshold a tensor using condition and zero pad the result in PyTorch

let's say I have a tensor like this假设我有一个这样的张量

w = [[0.1, 0.7, 0.7, 0.8, 0.3],
    [0.3, 0.2, 0.9, 0.1, 0.5],
    [0.1, 0.4, 0.8, 0.3, 0.4]]

Now I want to eliminate certain values base on some condition (for example greater than 0.5 or not)现在我想根据某些条件消除某些值(例如是否大于 0.5)

w = [[0.1, 0.3],
     [0.3, 0.2, 0.1],
     [0.1, 0.4, 0.3, 0.4]]

Then pad it to equal length:然后将其填充到相等的长度:

w = [[0.1, 0.3, 0, 0],
     [0.3, 0.2, 0.1, 0],
     [0.1, 0.4, 0.3, 0.4]]

and this is how I implemented it in pytorch:这就是我在 pytorch 中实现它的方式:

w = torch.rand(3, 5)
condition = w <= 0.5
w = [w[i][condition[i]] for i in range(3)]
w = torch.nn.utils.rnn.pad_sequence(w)

But apparently this is going to be extremely slow, mainly because of the list comprehension.但显然这将非常缓慢,主要是因为列表理解。 is there any better way to do it?有没有更好的方法呢?

Here's one straightforward way using boolean masking , tensor splitting , and then eventually padding the splitted tensors using torch.nn.utils.rnn.pad_sequence(...) .这是使用boolean 掩码张量拆分,然后最终使用torch.nn.utils.rnn.pad_sequence(...)填充拆分的张量的一种直接方法。

# input tensor to work with
In [213]: w 
Out[213]: 
tensor([[0.1000, 0.7000, 0.7000, 0.8000, 0.3000],
        [0.3000, 0.2000, 0.9000, 0.1000, 0.5000],
        [0.1000, 0.4000, 0.8000, 0.3000, 0.4000]])

# values above this should be clipped from the input tensor
In [214]: clip_value = 0.5 

# generate a boolean mask that satisfies the condition
In [215]: boolean_mask = (w <= clip_value) 

# we need to sum the mask along axis 1 (needed for splitting)
In [216]: summed_mask = boolean_mask.sum(dim=1) 

# a sequence of splitted tensors
In [217]: splitted_tensors = torch.split(w[boolean_mask], summed_mask.tolist())  

# finally pad them along dimension 1 (or axis 1)
In [219]: torch.nn.utils.rnn.pad_sequence(splitted_tensors, 1) 
Out[219]: 
tensor([[0.1000, 0.3000, 0.0000, 0.0000],
        [0.3000, 0.2000, 0.1000, 0.5000],
        [0.1000, 0.4000, 0.3000, 0.4000]])

A short note on efficiency : Using torch.split() is super efficient since it returns the splitted tensors as a view of the original tensor (ie no copy is made).关于效率的简短说明:使用torch.split()非常高效,因为它将拆分的张量作为原始张量的视图返回(即不复制)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM