繁体   English   中英

使用 torch.max() 时,为批次中的每个条目屏蔽某些索引

[英]Mask certain indices for every entry in a batch, when using torch.max()

我正在对batch大小为torch.Size([n, 8])的样本进行增量采样。

我还有一个长度为 n 的valid_indices列表,其中包含对批次中的每个条目都有效的索引元组。

例如batch valid_indices[0]可能看起来像这样: (0,1,3,4,5,7) ,这表明索引 2 和 6 应该从 dim 1 中的第一个条目中排除。

特别是当我使用torch.max(batch, dim=1, keepdim=True)时,我需要排除这些值。

要排除的指数(如果有)可能因批次中的条目而异。

有任何想法吗? 提前致谢。

我假设你正在变老

IndexError: too many indices for tensor of dimension 1

直接在张量上使用元组索引时出错。 至少这是我在执行以下行时能够重现的错误

 t[0][valid_idx0]

其中 t 是大小为 (10,8) 的随机张量,valid_idx0 是包含 4 个元素的元组。

但是,当您将元组转换为如下列表时,同一行工作得很好

 t[0][list(valid_idx0)]

 >>> tensor([0.1847, 0.1028, 0.7130, 0.5093])

但是当涉及到将这些索引应用于二维张量时,情况会有所不同,因为我们需要保留张量的结构以进行批处理。

因此,将我们的索引转换为掩码 arrays是合理的。

假设我们手头有一个元组valid_indices列表。 首先是将其转换为列表列表。

valid_idx_list = [list(tup) for tup in valid_indices]

第二件事是将它们转换为掩码 arrays。

masks = np.zeros((t.size()))
for i, indices in enumerate(valid_idx_list):
  masks[i][indices] = 1

完毕。 现在我们可以应用掩码并在掩码张量上使用 torch.max。

torch.max(t*masks)

请查看我用来重现该问题的 colab notebook。

https://colab.research.google.com/drive/1BhKKgxk3gRwUjM8ilmiqgFvo0sfXMGiK?usp=sharing

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM