繁体   English   中英

将数组每一行中的所有元素乘以一维数组中的数字

[英]Multiply all elements in each row of an array by numbers in a 1D array

我有一个形状为 [16,3,32,32]、16 张图像、3 个颜色通道 32x32 的火炬张量 (x)。 我正在进行扩散,需要将以下公式应用于图像

return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * error

误差与 x 具有相同的维度。 当 sqrt_alpha_hat 和 sqrt_one_minus_alpha_hat 是整数时,这很好用,张量都乘以数字然后相加。 我想将每个图像乘以不同的值。 所以我的 sqrt_alpha_hat 和 sqrt_one_minus_alpha_hat 是 1D arrays,大小为 32,每个图像一个数字。 请记住,此数组位于 CUDA 中,因此某些 np 函数将不起作用。

我尝试使用 np.fill 创建一个具有以下格式的大型数组:

[[[1...1],...[1...1](32 列)...(32 行)[1...1],...[1...1]] ,

...(3 个颜色通道)

[[1...1],...[1...1]...[1...1],...[1...1]]]

...(16 张图片)

[[[16...16],...[16...16]...[16...16],...[16...16]],

...

[[16...16],...[16...16]...[16...16],...[16...16]]]

但这没有用。 肯定有一种更简单的方法可以做到这一点。

用过的

sqrt_alpha_hat_table = torch.stack([torch.full(x.shape[1:], sqrt_alpha_hat[i]) for i in range(x.shape[0])]).to(device)

执行此操作的“正确”方法(矢量化而不是基于循环,并且没有为重复行向量分配大量 memory)是使用expand() 我假设您的意思是sqrt_alpha_hat的大小为 [16],或者有 32 张图像,或者您在描述的其他地方犯了语义错误。

# transform from size [n_images] to size [n_images,1,1,1]
sqrt_alpha_hat = sqrt_alpha_hat.unsqueeze(1).unsqueeze(1).unsqueeze(1)

# broadcast (view tensor view, rather than copying values) across new 
dimensions to size [n_images,3,32,32] 
sqrt_alpha_hat = sqrt_alpha_hat.expand(n_images,3,32,32)

# same for sqrt_one_minus_alpha_hat
... 

# now you can multiply and add easily because the dimensions of all arrays match
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * error

作为一般规则,在使用 pytorch 张量时最好坚持使用 torch 函数而不是 np 函数,以避免 CUDA 不兼容等问题。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM