繁体   English   中英

如何将 4 维 PyTorch 张量乘以 1 维张量?

[英]How to 4 dimension PyTorch tensor multiply by 1 dimension tensor?

我正在尝试编写 function 用于混合训练。 在这个网站上我找到了一些代码并适应了我以前的代码。 但在原始代码中,批次(64)只生成一个随机变量。 但我想要批量每张图片的随机值。 批处理一个变量的代码:

def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    mixed_y = lam * y + (1 - lam) * y[index,:]

    return mixed_x, mixed_y

输入的 x 和 y 来自 pytorch DataLoader。 x 输入尺寸: torch.Size([64, 3, 256, 256]) y 输入尺寸: torch.Size([64, 3474])

这段代码效果很好。 然后我把它改成这样:

def mixup_data(x, y):
    batch_size = x.size()[0]
    lam = torch.rand(batch_size)
    index = torch.randperm(batch_size)

    mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
    mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]

    return mixed_x, mixed_y

但它给出了一个错误: RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

我如何理解代码的工作原理是它批量获取第一张图像并乘以lam张量中的第一个值(长 64 个值)。 我该怎么做?

您需要替换以下行:

lam = torch.rand(batch_size)

经过

lam = torch.rand(batch_size, 1, 1, 1)

使用您当前的代码, lam[index] * x乘法是不可能的,因为lam[index]的大小是torch.Size([64])x的大小是torch.Size([64, 3, 256, 256]) . 因此,您需要将lam[index]的大小设置为torch.Size([64, 1, 1, 1])以便它可以广播。

为了应对以下陈述:

mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

我们可以在声明之前重塑lam张量。

lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

问题是相乘的两个张量的大小不匹配。 我们以lam[index] * x为例。 尺寸如下:

  • x : torch.Size([64, 3, 256, 256])
  • lam[index]torch.Size([64])

为了将它们相乘,它们应该具有相同的大小,其中lam[index]对每个批次的[3, 256, 256]使用相同的值,因为您想将该批次中的每个元素乘以相同的值,但是每个批次都不一样。

lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])

.expand_as(x)重复单个维度,使其具有与 x 相同的大小,有关详细信息,请参阅.expand()文档

您不需要扩展张量,因为如果存在奇异维度,PyTorch 会自动为您执行此操作。 这就是所谓的广播: PyTorch - 广播语义 因此torch.Size([64, 1, 1, 1])将其与x相乘就足够了。

lam[index].view(batch_size, 1, 1, 1) * x

这同样适用于y但大小为torch.Size([64, 1]) ,因为y的大小torch.Size([64, 3474])

mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]

只是一个小旁注, lam[index]仅重新排列lam的元素,但是由于您是随机创建的,因此无论您是否重新排列它都没有任何区别。 唯一重要的是重新排列xy ,就像在原始代码中一样。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM