[英]How to 4 dimension PyTorch tensor multiply by 1 dimension tensor?
我正在尝试编写 function 用于混合训练。 在这个网站上我找到了一些代码并适应了我以前的代码。 但在原始代码中,批次(64)只生成一个随机变量。 但我想要批量每张图片的随机值。 批处理一个变量的代码:
def mixup_data(x, y, alpha=1.0):
lam = np.random.beta(alpha, alpha)
batch_size = x.size()[0]
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index,:]
mixed_y = lam * y + (1 - lam) * y[index,:]
return mixed_x, mixed_y
输入的 x 和 y 来自 pytorch DataLoader。 x 输入尺寸: torch.Size([64, 3, 256, 256])
y 输入尺寸: torch.Size([64, 3474])
这段代码效果很好。 然后我把它改成这样:
def mixup_data(x, y):
batch_size = x.size()[0]
lam = torch.rand(batch_size)
index = torch.randperm(batch_size)
mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]
return mixed_x, mixed_y
但它给出了一个错误: RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3
我如何理解代码的工作原理是它批量获取第一张图像并乘以lam
张量中的第一个值(长 64 个值)。 我该怎么做?
您需要替换以下行:
lam = torch.rand(batch_size)
经过
lam = torch.rand(batch_size, 1, 1, 1)
使用您当前的代码, lam[index] * x
乘法是不可能的,因为lam[index]
的大小是torch.Size([64])
而x
的大小是torch.Size([64, 3, 256, 256])
. 因此,您需要将lam[index]
的大小设置为torch.Size([64, 1, 1, 1])
以便它可以广播。
为了应对以下陈述:
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
我们可以在声明之前重塑lam
张量。
lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
问题是相乘的两个张量的大小不匹配。 我们以lam[index] * x
为例。 尺寸如下:
x
: torch.Size([64, 3, 256, 256])
lam[index]
: torch.Size([64])
为了将它们相乘,它们应该具有相同的大小,其中lam[index]
对每个批次的[3, 256, 256]
使用相同的值,因为您想将该批次中的每个元素乘以相同的值,但是每个批次都不一样。
lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])
.expand_as(x)
重复单个维度,使其具有与 x 相同的大小,有关详细信息,请参阅.expand()
文档。
您不需要扩展张量,因为如果存在奇异维度,PyTorch 会自动为您执行此操作。 这就是所谓的广播: PyTorch - 广播语义。 因此torch.Size([64, 1, 1, 1])
将其与x
相乘就足够了。
lam[index].view(batch_size, 1, 1, 1) * x
这同样适用于y
但大小为torch.Size([64, 1])
,因为y
的大小torch.Size([64, 3474])
。
mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]
只是一个小旁注, lam[index]
仅重新排列lam
的元素,但是由于您是随机创建的,因此无论您是否重新排列它都没有任何区别。 唯一重要的是重新排列x
和y
,就像在原始代码中一样。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.