[英]An invalid combination of arguments in pytorch
I'm having some trouble to identify the error. 我在识别错误时遇到了一些麻烦。 This is the error message that I'm getting: 这是我收到的错误消息:
return torch.exp(-((x-2)**2)) + 0.8*torch.exp(-(x+2)**2) 返回torch.exp(-((x-2)** 2))+ 0.8 * torch.exp(-(x + 2)** 2)
TypeError: torch.exp received an invalid combination of arguments - got (!float!), but expected (torch.FloatTensor source) TypeError:torch.exp接收到无效的参数组合-得到了(!float!),但是是预期的(torch.FloatTensor源代码)
import torch
import time
dtype = torch.FloatTensor
def functionFit(x):
return torch.exp(-((x-2)**2)) + 0.8*torch.exp(-(x+2)**2)
def createSample(mu,sig,N):
return mu + sig*torch.randn(N,1).type(dtype)
def updateMU(alpha,rho,x,N,I,mu):
return alpha*torch.mean(x[I[int((1-rho)*N):N,0]]) + (1-alpha)*mu
def updateSIG(alpha,rho,x,N,I,sig):
return alpha*torch.std(x[I[int((1-rho)*N):N,0]]) + (1-alpha)*sig
def CE(N,rho,alpha,epsilon,mu,sig): # initial std dev.
start = time.time()
k = 0
while (sig > epsilon):
x = createSample(mu,sig,N)
S = functionFit(x)
sorted_v , I = torch.sort(S,0)
mu = updateMU(alpha,rho,x,N,I,mu)
sig = updateSIG(alpha,rho,x,N,I,sig)
k = k + 1
end = time.time()
xm = torch.mean(x)
ym = functionFit(xm)
print('x =',xm)
print('y =',ym)
print('time =',end - start,'s')
print('iter =',k)
if __name__ == '__main__':
N = 50
rho = 0.5
alpha = 0.9
epsilon = 0.001
mu = 20*torch.rand(1,1).type(dtype)-10 # init mu
sig = 5
CE(N,rho,alpha,epsilon,mu,sig)
In the CE
function, the following two lines are causing the error. 在CE
功能中,以下两行引起该错误。
xm = torch.mean(x)
ym = functionFit(xm)
Here, x
is a 50 x 1
FloatTensor but when you call torch.mean()
, it returns a float value which is causing error when you call functionFit(xm)
. 在这里, x
是一个50 x 1
FloatTensor,但是当您调用torch.mean()
,它将返回一个float值,当您调用functionFit(xm)
时,它将导致错误。
By the way, just for your information, torch.mean() returns a float value and torch.exp() expects a tensor as input. 顺便说一句,仅作为参考, torch.mean()返回一个浮点值, torch.exp()需要一个张量作为输入。 You can simply check the type of the parameter in functionFit()
and compute exponent using numpy if the parameter is a float value instead of a tensor. 您可以简单地在functionFit()
检查参数的类型,如果参数是浮点值而不是张量,则可以使用numpy计算指数。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.