[英]Is there a Numpy or pyTorch function for this code?
基本上有一个 Numpy 或 PyTorch function 这样做:
vp_sa_s=mdp_data['sa_s'].detach().clone()
dims = vp_sa_s.size()
for i in range(dims[0]):
for j in range(dims[1]):
for k in range(dims[2]):
# to mimic matlab functionality: vp(mdp_data.sa_s)
try:
vp_sa_s[i,j,k] = vp[mdp_data['sa_s'][i,j,k]]
except:
pass
假设vp_sa_s
的大小为(10,5,5)
,并且每个值都是有效的索引 vp,即在 0-9 范围内。 vp 是大小(10,1)
,带有一堆随机值。
Matlab 使用vp(mdp_data.sa_s)
优雅而快速地完成它,这将形成一个新的(10,5,5)
矩阵。 如果mdp_data.sa_s
中的所有值都是 1,则结果将是(10,5,5)
张量,每个值都是vp
中的第一个值。
是否存在 function 或存在的方法可以在少于 O(N^3) 的时间内实现这一点,因为上述代码效率极低。
谢谢!
出什么问题了
result = vp[vp_sa_s, 0]
请注意,由于您的vp
的形状为(10, 1)
(它有一个尾随 singleton 维度),您需要在分配中添加, 0]
索引以摆脱这个额外的维度。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.