简体   繁体   English

从 PyTorch N 维张量中过滤出 NaN 值

[英]Filter out NaN values from a PyTorch N-Dimensional tensor

This question is very similar to filtering np.nan values from pytorch in a -Dimensional tensor .这个问题与np.nan从 pytorch 过滤np.nan非常相似。 The difference is that I want to apply the same concept to tensors of 2 or higher dimensions.不同之处在于我想将相同的概念应用于 2 维或更高维的张量。

I have a tensor that looks like this:我有一个看起来像这样的张量:

import torch

tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
 [float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
 [2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]

I would like to find the most pythonic / PyTorch way of to filter out (remove) the rows of the tensor which are nan .我想找到最 Pythonic / PyTorch 的方法来过滤(删除)张量的nan By filtering this tensor along the first ( 0 th axis) I want to obtain a filtered_tensor which looks like this:通过沿着第一个(第0个轴)过滤这个tensor ,我想获得一个如下所示的filtered_tensor

>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
 [2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]

Use PyTorch's isnan() together with any() to slice tensor 's rows using the obtained boolean mask as follows:使用 PyTorch 的isnan()any()使用获得的布尔掩码对tensor的行进行切片,如下所示:

filtered_tensor = tensor[~torch.any(tensor.isnan(),dim=1)]

Note that this will drop any row that has a nan value in it.请注意,这将删除其中包含nan值的任何行。 If you want to drop only rows where all values are nan replace torch.any with torch.all .如果您只想删除所有值为nan行,请将torch.any替换为torch.all

For an N-dimensional tensor you could just flatten all the dims apart from the first dim and apply the same procedure as above:对于 N 维张量,您可以将除第一个暗淡之外的所有暗淡变平并应用与上述相同的过程:

#Flatten:
shape = tensor.shape
tensor_reshaped = tensor.reshape(shape[0],-1)
#Drop all rows containing any nan:
tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)]
#Reshape back:
tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM