繁体   English   中英

一批二维张量中前 n 个分位数的二元掩码,但每个张量都有单独的 n

[英]Binary mask of top n-th quantile in a batch of 2D tensors, but with individual n for each tensor

我有一个形状为(100, 16, 16)张量 A 和形状为(100)张量 B ,其中 100 是批量大小。 我想创建一个具有形状的二进制掩码(100, 16, 16)在各元件,其中(元件具有的形状(1, 16, 16)所述掩模的,该值是1 ,如果该元素是大于计算出的分位数值,否则为0 张量 B 中的每个元素依次表示 A 中每个单独元素的百分位值。 如果 B 只是一个标量,我可以使用:

flat_A = torch.reshape(A, (100, -1))
quants = torch.quantile(flat_A, B, dim=1)
quants = torch.reshape(quants, (100, 1, 1))
mask = torch.where(A >= quants, 1, 0)
# quants will have shape (100, 1, 1)

问题是:如果 B 是我上面所说的形状为(100)一维张量,我如何计算 A 中每个单独元素的百分位值? 我尝试了以下操作,但结果并不像我预期的那样:

>>> torch.quantile(flat_A, B, dim=1).shape
torch.Size([100, 100])
>>> torch.quantile(flat_A, B, dim=0).shape
torch.Size([100, 256])

我认为结果的形状应该是(100) ,所以我可以使用mask = torch.where(A >= quants, 1, 0) ,或者我误解了它?

对于更多上下文,这个问题也是我之前在这里提出的标量 B 值问题的扩展。

这是使用torch.quantile()函数的一种方式。 请注意,为简单起见,这里我使用形状为(5, 2, 2)而不是(100, 16, 16) 的张量。

import torch
# Generate some data of shape (5, 2, 2)
A = torch.arange(5 * 2 * 2).reshape(5, 2, 2) + 1.0
B = torch.linspace(0, 1, 5) # 5 quantile values for each element in A

Af = A.reshape(A.shape[0], -1) # flattens A to a 2D tensor
quantiles = torch.quantile(Af, B, dim = 1, keepdim = True)
quants = quantiles[torch.arange(A.shape[0]), torch.arange(A.shape[0]), 0]

mask = (A >= quants[:, None, None]).type(torch.uint8)

这里的张量quantiles是形状的torch.Size([5, 5, 1])因为它存储用于在每个位数值阈值B对中的每个元素A (或Af )。 由于我们有 5 个分位数值,因此我们为A每个元素获得了 5 个阈值。

例如, quantiles[i, j, 0]具有A[j]Af[j] B[i] th 分位数的阈值,并且您基本上需要值quantiles[k, k, 0] for k in range批量大小或 5 在这里。

现在要满足您需要B相应分位数和A元素的阈值的要求,只需从quantiles索引出对角线元素并填充形状为torch.Size([5])quants

最后得到mask ,将A与每个元素的相应阈值进行比较。 请注意,这使用了与阈值的广播元素比较。 mask具有所需的torch.Size([5, 2, 2])形状。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM